diff options
| author | Aaron Meurer <asmeurer@gmail.com> | 2021-03-05 14:35:35 -0700 |
|---|---|---|
| committer | Aaron Meurer <asmeurer@gmail.com> | 2021-03-05 14:35:35 -0700 |
| commit | cdd6bbcdf260a4d6947901604dc8dd64c864c8d4 (patch) | |
| tree | 947c48482ec001d0dd4f1fa5055bd7d6970bfc30 /numpy/_array_api | |
| parent | 58c2a996afd13f729ec5d2aed77151c8e799548b (diff) | |
| download | numpy-cdd6bbcdf260a4d6947901604dc8dd64c864c8d4.tar.gz | |
Support the ndarray object in the remaining array API functions
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/_data_type_functions.py | 4 | ||||
| -rw-r--r-- | numpy/_array_api/_linear_algebra_functions.py | 17 | ||||
| -rw-r--r-- | numpy/_array_api/_searching_functions.py | 9 | ||||
| -rw-r--r-- | numpy/_array_api/_set_functions.py | 3 | ||||
| -rw-r--r-- | numpy/_array_api/_sorting_functions.py | 9 | ||||
| -rw-r--r-- | numpy/_array_api/_statistical_functions.py | 15 | ||||
| -rw-r--r-- | numpy/_array_api/_utility_functions.py | 5 |
7 files changed, 35 insertions, 27 deletions
diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py index 18f741ebd..9e4dcba95 100644 --- a/numpy/_array_api/_data_type_functions.py +++ b/numpy/_array_api/_data_type_functions.py @@ -1,6 +1,8 @@ from __future__ import annotations from ._types import Union, array, dtype +from ._array_object import ndarray + from collections.abc import Sequence import numpy as np @@ -27,4 +29,4 @@ def result_type(*arrays_and_dtypes: Sequence[Union[array, dtype]]) -> dtype: See its docstring for more information. """ - return np.result_type(*arrays_and_dtypes) + return np.result_type(*(a._array if isinstance(a, ndarray) else a for a in arrays_and_dtypes)) diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py index ec67f9c0b..95ed00fd5 100644 --- a/numpy/_array_api/_linear_algebra_functions.py +++ b/numpy/_array_api/_linear_algebra_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Literal, Optional, Tuple, Union, array +from ._array_object import ndarray import numpy as np @@ -18,7 +19,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: See its docstring for more information. """ - return np.cross(x1, x2, axis=axis) + return ndarray._new(np.cross(x1._array, x2._array, axis=axis)) def det(x: array, /) -> array: """ @@ -27,7 +28,7 @@ def det(x: array, /) -> array: See its docstring for more information. """ # Note: this function is being imported from a nondefault namespace - return np.linalg.det(x) + return ndarray._new(np.linalg.det(x._array)) def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array: """ @@ -35,7 +36,7 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> See its docstring for more information. """ - return np.diagonal(x, axis1=axis1, axis2=axis2, offset=offset) + return ndarray._new(np.diagonal(x._array, axis1=axis1, axis2=axis2, offset=offset)) # def dot(): # """ @@ -76,7 +77,7 @@ def inv(x: array, /) -> array: See its docstring for more information. """ # Note: this function is being imported from a nondefault namespace - return np.linalg.inv(x) + return ndarray._new(np.linalg.inv(x._array)) # def lstsq(): # """ @@ -120,7 +121,7 @@ def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, kee if axis == None and x.ndim > 2: x = x.flatten() # Note: this function is being imported from a nondefault namespace - return np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) + return ndarray._new(np.linalg.norm(x._array, axis=axis, keepdims=keepdims, ord=ord)) def outer(x1: array, x2: array, /) -> array: """ @@ -128,7 +129,7 @@ def outer(x1: array, x2: array, /) -> array: See its docstring for more information. """ - return np.outer(x1, x2) + return ndarray._new(np.outer(x1._array, x2._array)) # def pinv(): # """ @@ -176,7 +177,7 @@ def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> ar See its docstring for more information. """ - return np.asarray(np.trace(x, axis1=axis1, axis2=axis2, offset=offset)) + return ndarray._new(np.asarray(np.trace(x._array, axis1=axis1, axis2=axis2, offset=offset))) def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array: """ @@ -184,4 +185,4 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array: See its docstring for more information. """ - return np.transpose(x, axes=axes) + return ndarray._new(np.transpose(x._array, axes=axes)) diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py index d5128cca9..44e5b2775 100644 --- a/numpy/_array_api/_searching_functions.py +++ b/numpy/_array_api/_searching_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Tuple, array +from ._array_object import ndarray import numpy as np @@ -11,7 +12,7 @@ def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: See its docstring for more information. """ # Note: this currently fails as np.argmax does not implement keepdims - return np.asarray(np.argmax(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ @@ -20,7 +21,7 @@ def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: See its docstring for more information. """ # Note: this currently fails as np.argmin does not implement keepdims - return np.asarray(np.argmin(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) def nonzero(x: array, /) -> Tuple[array, ...]: """ @@ -28,7 +29,7 @@ def nonzero(x: array, /) -> Tuple[array, ...]: See its docstring for more information. """ - return np.nonzero(x) + return ndarray._new(np.nonzero(x._array)) def where(condition: array, x1: array, x2: array, /) -> array: """ @@ -36,4 +37,4 @@ def where(condition: array, x1: array, x2: array, /) -> array: See its docstring for more information. """ - return np.where(condition, x1, x2) + return ndarray._new(np.where(condition._array, x1._array, x2._array)) diff --git a/numpy/_array_api/_set_functions.py b/numpy/_array_api/_set_functions.py index 91927a3a0..f5cd6d324 100644 --- a/numpy/_array_api/_set_functions.py +++ b/numpy/_array_api/_set_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Tuple, Union, array +from ._array_object import ndarray import numpy as np @@ -10,4 +11,4 @@ def unique(x: array, /, *, return_counts: bool = False, return_index: bool = Fal See its docstring for more information. """ - return np.unique(x, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse) + return ndarray._new(np.unique(x._array, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse)) diff --git a/numpy/_array_api/_sorting_functions.py b/numpy/_array_api/_sorting_functions.py index cddfd1598..2e054b03a 100644 --- a/numpy/_array_api/_sorting_functions.py +++ b/numpy/_array_api/_sorting_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import array +from ._array_object import ndarray import numpy as np @@ -12,10 +13,10 @@ def argsort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bo """ # Note: this keyword argument is different, and the default is different. kind = 'stable' if stable else 'quicksort' - res = np.argsort(x, axis=axis, kind=kind) + res = np.argsort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return res + return ndarray._new(res) def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array: """ @@ -25,7 +26,7 @@ def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool """ # Note: this keyword argument is different, and the default is different. kind = 'stable' if stable else 'quicksort' - res = np.sort(x, axis=axis, kind=kind) + res = np.sort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return res + return ndarray._new(res) diff --git a/numpy/_array_api/_statistical_functions.py b/numpy/_array_api/_statistical_functions.py index e62410d01..fa3551248 100644 --- a/numpy/_array_api/_statistical_functions.py +++ b/numpy/_array_api/_statistical_functions.py @@ -1,28 +1,29 @@ from __future__ import annotations from ._types import Optional, Tuple, Union, array +from ._array_object import ndarray import numpy as np def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.max(x, axis=axis, keepdims=keepdims) + return ndarray._new(np.max(x._array, axis=axis, keepdims=keepdims)) def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.asarray(np.mean(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.mean(x._array, axis=axis, keepdims=keepdims))) def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.min(x, axis=axis, keepdims=keepdims) + return ndarray._new(np.min(x._array, axis=axis, keepdims=keepdims)) def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.asarray(np.prod(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.prod(x._array, axis=axis, keepdims=keepdims))) def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: # Note: the keyword argument correction is different here - return np.asarray(np.std(x, axis=axis, ddof=correction, keepdims=keepdims)) + return ndarray._new(np.asarray(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))) def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.asarray(np.sum(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.sum(x._array, axis=axis, keepdims=keepdims))) def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: # Note: the keyword argument correction is different here - return np.asarray(np.var(x, axis=axis, ddof=correction, keepdims=keepdims)) + return ndarray._new(np.asarray(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))) diff --git a/numpy/_array_api/_utility_functions.py b/numpy/_array_api/_utility_functions.py index 51a04dc8b..c4721ad7e 100644 --- a/numpy/_array_api/_utility_functions.py +++ b/numpy/_array_api/_utility_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Optional, Tuple, Union, array +from ._array_object import ndarray import numpy as np @@ -10,7 +11,7 @@ def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep See its docstring for more information. """ - return np.asarray(np.all(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: """ @@ -18,4 +19,4 @@ def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep See its docstring for more information. """ - return np.asarray(np.any(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims))) |
