summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-03-05 14:35:35 -0700
committerAaron Meurer <asmeurer@gmail.com>2021-03-05 14:35:35 -0700
commitcdd6bbcdf260a4d6947901604dc8dd64c864c8d4 (patch)
tree947c48482ec001d0dd4f1fa5055bd7d6970bfc30 /numpy/_array_api
parent58c2a996afd13f729ec5d2aed77151c8e799548b (diff)
downloadnumpy-cdd6bbcdf260a4d6947901604dc8dd64c864c8d4.tar.gz
Support the ndarray object in the remaining array API functions
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_data_type_functions.py4
-rw-r--r--numpy/_array_api/_linear_algebra_functions.py17
-rw-r--r--numpy/_array_api/_searching_functions.py9
-rw-r--r--numpy/_array_api/_set_functions.py3
-rw-r--r--numpy/_array_api/_sorting_functions.py9
-rw-r--r--numpy/_array_api/_statistical_functions.py15
-rw-r--r--numpy/_array_api/_utility_functions.py5
7 files changed, 35 insertions, 27 deletions
diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py
index 18f741ebd..9e4dcba95 100644
--- a/numpy/_array_api/_data_type_functions.py
+++ b/numpy/_array_api/_data_type_functions.py
@@ -1,6 +1,8 @@
from __future__ import annotations
from ._types import Union, array, dtype
+from ._array_object import ndarray
+
from collections.abc import Sequence
import numpy as np
@@ -27,4 +29,4 @@ def result_type(*arrays_and_dtypes: Sequence[Union[array, dtype]]) -> dtype:
See its docstring for more information.
"""
- return np.result_type(*arrays_and_dtypes)
+ return np.result_type(*(a._array if isinstance(a, ndarray) else a for a in arrays_and_dtypes))
diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py
index ec67f9c0b..95ed00fd5 100644
--- a/numpy/_array_api/_linear_algebra_functions.py
+++ b/numpy/_array_api/_linear_algebra_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._types import Literal, Optional, Tuple, Union, array
+from ._array_object import ndarray
import numpy as np
@@ -18,7 +19,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
See its docstring for more information.
"""
- return np.cross(x1, x2, axis=axis)
+ return ndarray._new(np.cross(x1._array, x2._array, axis=axis))
def det(x: array, /) -> array:
"""
@@ -27,7 +28,7 @@ def det(x: array, /) -> array:
See its docstring for more information.
"""
# Note: this function is being imported from a nondefault namespace
- return np.linalg.det(x)
+ return ndarray._new(np.linalg.det(x._array))
def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
"""
@@ -35,7 +36,7 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) ->
See its docstring for more information.
"""
- return np.diagonal(x, axis1=axis1, axis2=axis2, offset=offset)
+ return ndarray._new(np.diagonal(x._array, axis1=axis1, axis2=axis2, offset=offset))
# def dot():
# """
@@ -76,7 +77,7 @@ def inv(x: array, /) -> array:
See its docstring for more information.
"""
# Note: this function is being imported from a nondefault namespace
- return np.linalg.inv(x)
+ return ndarray._new(np.linalg.inv(x._array))
# def lstsq():
# """
@@ -120,7 +121,7 @@ def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, kee
if axis == None and x.ndim > 2:
x = x.flatten()
# Note: this function is being imported from a nondefault namespace
- return np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord)
+ return ndarray._new(np.linalg.norm(x._array, axis=axis, keepdims=keepdims, ord=ord))
def outer(x1: array, x2: array, /) -> array:
"""
@@ -128,7 +129,7 @@ def outer(x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
- return np.outer(x1, x2)
+ return ndarray._new(np.outer(x1._array, x2._array))
# def pinv():
# """
@@ -176,7 +177,7 @@ def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> ar
See its docstring for more information.
"""
- return np.asarray(np.trace(x, axis1=axis1, axis2=axis2, offset=offset))
+ return ndarray._new(np.asarray(np.trace(x._array, axis1=axis1, axis2=axis2, offset=offset)))
def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
"""
@@ -184,4 +185,4 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
See its docstring for more information.
"""
- return np.transpose(x, axes=axes)
+ return ndarray._new(np.transpose(x._array, axes=axes))
diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py
index d5128cca9..44e5b2775 100644
--- a/numpy/_array_api/_searching_functions.py
+++ b/numpy/_array_api/_searching_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._types import Tuple, array
+from ._array_object import ndarray
import numpy as np
@@ -11,7 +12,7 @@ def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array:
See its docstring for more information.
"""
# Note: this currently fails as np.argmax does not implement keepdims
- return np.asarray(np.argmax(x, axis=axis, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)))
def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array:
"""
@@ -20,7 +21,7 @@ def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array:
See its docstring for more information.
"""
# Note: this currently fails as np.argmin does not implement keepdims
- return np.asarray(np.argmin(x, axis=axis, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)))
def nonzero(x: array, /) -> Tuple[array, ...]:
"""
@@ -28,7 +29,7 @@ def nonzero(x: array, /) -> Tuple[array, ...]:
See its docstring for more information.
"""
- return np.nonzero(x)
+ return ndarray._new(np.nonzero(x._array))
def where(condition: array, x1: array, x2: array, /) -> array:
"""
@@ -36,4 +37,4 @@ def where(condition: array, x1: array, x2: array, /) -> array:
See its docstring for more information.
"""
- return np.where(condition, x1, x2)
+ return ndarray._new(np.where(condition._array, x1._array, x2._array))
diff --git a/numpy/_array_api/_set_functions.py b/numpy/_array_api/_set_functions.py
index 91927a3a0..f5cd6d324 100644
--- a/numpy/_array_api/_set_functions.py
+++ b/numpy/_array_api/_set_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._types import Tuple, Union, array
+from ._array_object import ndarray
import numpy as np
@@ -10,4 +11,4 @@ def unique(x: array, /, *, return_counts: bool = False, return_index: bool = Fal
See its docstring for more information.
"""
- return np.unique(x, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse)
+ return ndarray._new(np.unique(x._array, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse))
diff --git a/numpy/_array_api/_sorting_functions.py b/numpy/_array_api/_sorting_functions.py
index cddfd1598..2e054b03a 100644
--- a/numpy/_array_api/_sorting_functions.py
+++ b/numpy/_array_api/_sorting_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._types import array
+from ._array_object import ndarray
import numpy as np
@@ -12,10 +13,10 @@ def argsort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bo
"""
# Note: this keyword argument is different, and the default is different.
kind = 'stable' if stable else 'quicksort'
- res = np.argsort(x, axis=axis, kind=kind)
+ res = np.argsort(x._array, axis=axis, kind=kind)
if descending:
res = np.flip(res, axis=axis)
- return res
+ return ndarray._new(res)
def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array:
"""
@@ -25,7 +26,7 @@ def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool
"""
# Note: this keyword argument is different, and the default is different.
kind = 'stable' if stable else 'quicksort'
- res = np.sort(x, axis=axis, kind=kind)
+ res = np.sort(x._array, axis=axis, kind=kind)
if descending:
res = np.flip(res, axis=axis)
- return res
+ return ndarray._new(res)
diff --git a/numpy/_array_api/_statistical_functions.py b/numpy/_array_api/_statistical_functions.py
index e62410d01..fa3551248 100644
--- a/numpy/_array_api/_statistical_functions.py
+++ b/numpy/_array_api/_statistical_functions.py
@@ -1,28 +1,29 @@
from __future__ import annotations
from ._types import Optional, Tuple, Union, array
+from ._array_object import ndarray
import numpy as np
def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- return np.max(x, axis=axis, keepdims=keepdims)
+ return ndarray._new(np.max(x._array, axis=axis, keepdims=keepdims))
def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- return np.asarray(np.mean(x, axis=axis, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.mean(x._array, axis=axis, keepdims=keepdims)))
def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- return np.min(x, axis=axis, keepdims=keepdims)
+ return ndarray._new(np.min(x._array, axis=axis, keepdims=keepdims))
def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- return np.asarray(np.prod(x, axis=axis, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.prod(x._array, axis=axis, keepdims=keepdims)))
def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:
# Note: the keyword argument correction is different here
- return np.asarray(np.std(x, axis=axis, ddof=correction, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)))
def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- return np.asarray(np.sum(x, axis=axis, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.sum(x._array, axis=axis, keepdims=keepdims)))
def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:
# Note: the keyword argument correction is different here
- return np.asarray(np.var(x, axis=axis, ddof=correction, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)))
diff --git a/numpy/_array_api/_utility_functions.py b/numpy/_array_api/_utility_functions.py
index 51a04dc8b..c4721ad7e 100644
--- a/numpy/_array_api/_utility_functions.py
+++ b/numpy/_array_api/_utility_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._types import Optional, Tuple, Union, array
+from ._array_object import ndarray
import numpy as np
@@ -10,7 +11,7 @@ def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
See its docstring for more information.
"""
- return np.asarray(np.all(x, axis=axis, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims)))
def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
"""
@@ -18,4 +19,4 @@ def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
See its docstring for more information.
"""
- return np.asarray(np.any(x, axis=axis, keepdims=keepdims))
+ return ndarray._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)))