From cdd6bbcdf260a4d6947901604dc8dd64c864c8d4 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 5 Mar 2021 14:35:35 -0700 Subject: Support the ndarray object in the remaining array API functions --- numpy/_array_api/_data_type_functions.py | 4 +++- numpy/_array_api/_linear_algebra_functions.py | 17 +++++++++-------- numpy/_array_api/_searching_functions.py | 9 +++++---- numpy/_array_api/_set_functions.py | 3 ++- numpy/_array_api/_sorting_functions.py | 9 +++++---- numpy/_array_api/_statistical_functions.py | 15 ++++++++------- numpy/_array_api/_utility_functions.py | 5 +++-- 7 files changed, 35 insertions(+), 27 deletions(-) (limited to 'numpy/_array_api') diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py index 18f741ebd..9e4dcba95 100644 --- a/numpy/_array_api/_data_type_functions.py +++ b/numpy/_array_api/_data_type_functions.py @@ -1,6 +1,8 @@ from __future__ import annotations from ._types import Union, array, dtype +from ._array_object import ndarray + from collections.abc import Sequence import numpy as np @@ -27,4 +29,4 @@ def result_type(*arrays_and_dtypes: Sequence[Union[array, dtype]]) -> dtype: See its docstring for more information. """ - return np.result_type(*arrays_and_dtypes) + return np.result_type(*(a._array if isinstance(a, ndarray) else a for a in arrays_and_dtypes)) diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py index ec67f9c0b..95ed00fd5 100644 --- a/numpy/_array_api/_linear_algebra_functions.py +++ b/numpy/_array_api/_linear_algebra_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Literal, Optional, Tuple, Union, array +from ._array_object import ndarray import numpy as np @@ -18,7 +19,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: See its docstring for more information. """ - return np.cross(x1, x2, axis=axis) + return ndarray._new(np.cross(x1._array, x2._array, axis=axis)) def det(x: array, /) -> array: """ @@ -27,7 +28,7 @@ def det(x: array, /) -> array: See its docstring for more information. """ # Note: this function is being imported from a nondefault namespace - return np.linalg.det(x) + return ndarray._new(np.linalg.det(x._array)) def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array: """ @@ -35,7 +36,7 @@ def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> See its docstring for more information. """ - return np.diagonal(x, axis1=axis1, axis2=axis2, offset=offset) + return ndarray._new(np.diagonal(x._array, axis1=axis1, axis2=axis2, offset=offset)) # def dot(): # """ @@ -76,7 +77,7 @@ def inv(x: array, /) -> array: See its docstring for more information. """ # Note: this function is being imported from a nondefault namespace - return np.linalg.inv(x) + return ndarray._new(np.linalg.inv(x._array)) # def lstsq(): # """ @@ -120,7 +121,7 @@ def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, kee if axis == None and x.ndim > 2: x = x.flatten() # Note: this function is being imported from a nondefault namespace - return np.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) + return ndarray._new(np.linalg.norm(x._array, axis=axis, keepdims=keepdims, ord=ord)) def outer(x1: array, x2: array, /) -> array: """ @@ -128,7 +129,7 @@ def outer(x1: array, x2: array, /) -> array: See its docstring for more information. """ - return np.outer(x1, x2) + return ndarray._new(np.outer(x1._array, x2._array)) # def pinv(): # """ @@ -176,7 +177,7 @@ def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> ar See its docstring for more information. """ - return np.asarray(np.trace(x, axis1=axis1, axis2=axis2, offset=offset)) + return ndarray._new(np.asarray(np.trace(x._array, axis1=axis1, axis2=axis2, offset=offset))) def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array: """ @@ -184,4 +185,4 @@ def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array: See its docstring for more information. """ - return np.transpose(x, axes=axes) + return ndarray._new(np.transpose(x._array, axes=axes)) diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py index d5128cca9..44e5b2775 100644 --- a/numpy/_array_api/_searching_functions.py +++ b/numpy/_array_api/_searching_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Tuple, array +from ._array_object import ndarray import numpy as np @@ -11,7 +12,7 @@ def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: See its docstring for more information. """ # Note: this currently fails as np.argmax does not implement keepdims - return np.asarray(np.argmax(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: """ @@ -20,7 +21,7 @@ def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array: See its docstring for more information. """ # Note: this currently fails as np.argmin does not implement keepdims - return np.asarray(np.argmin(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) def nonzero(x: array, /) -> Tuple[array, ...]: """ @@ -28,7 +29,7 @@ def nonzero(x: array, /) -> Tuple[array, ...]: See its docstring for more information. """ - return np.nonzero(x) + return ndarray._new(np.nonzero(x._array)) def where(condition: array, x1: array, x2: array, /) -> array: """ @@ -36,4 +37,4 @@ def where(condition: array, x1: array, x2: array, /) -> array: See its docstring for more information. """ - return np.where(condition, x1, x2) + return ndarray._new(np.where(condition._array, x1._array, x2._array)) diff --git a/numpy/_array_api/_set_functions.py b/numpy/_array_api/_set_functions.py index 91927a3a0..f5cd6d324 100644 --- a/numpy/_array_api/_set_functions.py +++ b/numpy/_array_api/_set_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Tuple, Union, array +from ._array_object import ndarray import numpy as np @@ -10,4 +11,4 @@ def unique(x: array, /, *, return_counts: bool = False, return_index: bool = Fal See its docstring for more information. """ - return np.unique(x, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse) + return ndarray._new(np.unique(x._array, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse)) diff --git a/numpy/_array_api/_sorting_functions.py b/numpy/_array_api/_sorting_functions.py index cddfd1598..2e054b03a 100644 --- a/numpy/_array_api/_sorting_functions.py +++ b/numpy/_array_api/_sorting_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import array +from ._array_object import ndarray import numpy as np @@ -12,10 +13,10 @@ def argsort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bo """ # Note: this keyword argument is different, and the default is different. kind = 'stable' if stable else 'quicksort' - res = np.argsort(x, axis=axis, kind=kind) + res = np.argsort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return res + return ndarray._new(res) def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array: """ @@ -25,7 +26,7 @@ def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool """ # Note: this keyword argument is different, and the default is different. kind = 'stable' if stable else 'quicksort' - res = np.sort(x, axis=axis, kind=kind) + res = np.sort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return res + return ndarray._new(res) diff --git a/numpy/_array_api/_statistical_functions.py b/numpy/_array_api/_statistical_functions.py index e62410d01..fa3551248 100644 --- a/numpy/_array_api/_statistical_functions.py +++ b/numpy/_array_api/_statistical_functions.py @@ -1,28 +1,29 @@ from __future__ import annotations from ._types import Optional, Tuple, Union, array +from ._array_object import ndarray import numpy as np def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.max(x, axis=axis, keepdims=keepdims) + return ndarray._new(np.max(x._array, axis=axis, keepdims=keepdims)) def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.asarray(np.mean(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.mean(x._array, axis=axis, keepdims=keepdims))) def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.min(x, axis=axis, keepdims=keepdims) + return ndarray._new(np.min(x._array, axis=axis, keepdims=keepdims)) def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.asarray(np.prod(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.prod(x._array, axis=axis, keepdims=keepdims))) def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: # Note: the keyword argument correction is different here - return np.asarray(np.std(x, axis=axis, ddof=correction, keepdims=keepdims)) + return ndarray._new(np.asarray(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))) def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: - return np.asarray(np.sum(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.sum(x._array, axis=axis, keepdims=keepdims))) def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array: # Note: the keyword argument correction is different here - return np.asarray(np.var(x, axis=axis, ddof=correction, keepdims=keepdims)) + return ndarray._new(np.asarray(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))) diff --git a/numpy/_array_api/_utility_functions.py b/numpy/_array_api/_utility_functions.py index 51a04dc8b..c4721ad7e 100644 --- a/numpy/_array_api/_utility_functions.py +++ b/numpy/_array_api/_utility_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._types import Optional, Tuple, Union, array +from ._array_object import ndarray import numpy as np @@ -10,7 +11,7 @@ def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep See its docstring for more information. """ - return np.asarray(np.all(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: """ @@ -18,4 +19,4 @@ def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep See its docstring for more information. """ - return np.asarray(np.any(x, axis=axis, keepdims=keepdims)) + return ndarray._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims))) -- cgit v1.2.1