diff options
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/__init__.py | 10 | ||||
| -rw-r--r-- | numpy/_array_api/_array_object.py | 32 | ||||
| -rw-r--r-- | numpy/_array_api/_creation_functions.py | 84 | ||||
| -rw-r--r-- | numpy/_array_api/_data_type_functions.py | 18 | ||||
| -rw-r--r-- | numpy/_array_api/_elementwise_functions.py | 164 | ||||
| -rw-r--r-- | numpy/_array_api/_linear_algebra_functions.py | 12 | ||||
| -rw-r--r-- | numpy/_array_api/_manipulation_functions.py | 18 | ||||
| -rw-r--r-- | numpy/_array_api/_searching_functions.py | 12 | ||||
| -rw-r--r-- | numpy/_array_api/_set_functions.py | 6 | ||||
| -rw-r--r-- | numpy/_array_api/_sorting_functions.py | 10 | ||||
| -rw-r--r-- | numpy/_array_api/_statistical_functions.py | 18 | ||||
| -rw-r--r-- | numpy/_array_api/_types.py | 2 | ||||
| -rw-r--r-- | numpy/_array_api/_utility_functions.py | 8 |
13 files changed, 192 insertions, 202 deletions
diff --git a/numpy/_array_api/__init__.py b/numpy/_array_api/__init__.py index e39a2c7d0..320c8df19 100644 --- a/numpy/_array_api/__init__.py +++ b/numpy/_array_api/__init__.py @@ -58,7 +58,7 @@ request. guaranteed to give a comprehensive coverage of the spec. Therefore, those reviewing this submodule should refer to the standard documents themselves. -- There is a custom array object, numpy._array_api.ndarray, which is returned +- There is a custom array object, numpy._array_api.Array, which is returned by all functions in this module. All functions in the array API namespace implicitly assume that they will only receive this object as input. The only way to create instances of this object is to use one of the array creation @@ -69,14 +69,14 @@ request. limit/change certain behavior that differs in the spec. In particular: - Indexing: Only a subset of indices supported by NumPy are required by the - spec. The ndarray object restricts indexing to only allow those types of + spec. The Array object restricts indexing to only allow those types of indices that are required by the spec. See the docstring of the - numpy._array_api.ndarray._validate_indices helper function for more + numpy._array_api.Array._validate_indices helper function for more information. - Type promotion: Some type promotion rules are different in the spec. In particular, the spec does not have any value-based casting. Note that the - code to correct the type promotion rules on numpy._array_api.ndarray is + code to correct the type promotion rules on numpy._array_api.Array is not yet implemented. - All functions include type annotations, corresponding to those given in the @@ -93,7 +93,7 @@ request. Still TODO in this module are: -- Implement the spec type promotion rules on the ndarray object. +- Implement the spec type promotion rules on the Array object. - Disable NumPy warnings in the API functions. diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 89ec3ba1a..9ea0eef18 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -22,13 +22,13 @@ from ._dtypes import _boolean_dtypes, _integer_dtypes, _floating_dtypes from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import Any, Optional, PyCapsule, Tuple, Union, Array + from ._types import Any, Optional, PyCapsule, Tuple, Union, Device, Dtype import numpy as np -class ndarray: +class Array: """ - ndarray object for the array API namespace. + n-d array object for the array API namespace. See the docstring of :py:obj:`np.ndarray <numpy.ndarray>` for more information. @@ -46,7 +46,7 @@ class ndarray: @classmethod def _new(cls, x, /): """ - This is a private method for initializing the array API ndarray + This is a private method for initializing the array API Array object. Functions outside of the array_api submodule should not use this @@ -64,9 +64,9 @@ class ndarray: obj._array = x return obj - # Prevent ndarray() from working + # Prevent Array() from working def __new__(cls, *args, **kwargs): - raise TypeError("The array_api ndarray object should not be instantiated directly. Use an array creation function, such as asarray(), instead.") + raise TypeError("The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead.") # These functions are not required by the spec, but are implemented for # the sake of usability. @@ -75,13 +75,13 @@ class ndarray: """ Performs the operation __str__. """ - return self._array.__str__().replace('array', 'ndarray') + return self._array.__str__().replace('array', 'Array') def __repr__(self: Array, /) -> str: """ Performs the operation __repr__. """ - return self._array.__repr__().replace('array', 'ndarray') + return self._array.__repr__().replace('array', 'Array') # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): @@ -109,7 +109,7 @@ class ndarray: # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return ndarray._new(np.array(scalar, self.dtype)) + return Array._new(np.array(scalar, self.dtype)) @staticmethod def _normalize_two_args(x1, x2): @@ -135,9 +135,9 @@ class ndarray: # performant. broadcast_to(x1._array, x2.shape) is much slower. We # could also manually type promote x2, but that is more complicated # and about the same performance as this. - x1 = ndarray._new(x1._array[None]) + x1 = Array._new(x1._array[None]) elif x2.shape == () and x1.shape != (): - x2 = ndarray._new(x2._array[None]) + x2 = Array._new(x2._array[None]) return (x1, x2) # Everything below this line is required by the spec. @@ -284,7 +284,7 @@ class ndarray: Additionally, it should be noted that indices that would return a scalar in NumPy will return a shape () array. Array scalars are not allowed in the specification, only shape () arrays. This is done in the - ``ndarray._new`` constructor, not this function. + ``Array._new`` constructor, not this function. """ if isinstance(key, slice): @@ -313,7 +313,7 @@ class ndarray: return key elif isinstance(key, tuple): - key = tuple(ndarray._validate_index(idx, None) for idx in key) + key = tuple(Array._validate_index(idx, None) for idx in key) for idx in key: if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): @@ -329,11 +329,11 @@ class ndarray: ellipsis_i = key.index(...) if n_ellipsis else len(key) for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): - ndarray._validate_index(idx, (size,)) + Array._validate_index(idx, (size,)) return key elif isinstance(key, bool): return key - elif isinstance(key, ndarray): + elif isinstance(key, Array): if key.dtype in _integer_dtypes: if key.shape != (): raise IndexError("Integer array indices with shape != () are not allowed in the array API namespace") @@ -346,7 +346,7 @@ class ndarray: return operator.index(key) except TypeError: # Note: This also omits boolean arrays that are not already in - # ndarray() form, like a list of booleans. + # Array() form, like a list of booleans. raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") def __getitem__(self: Array, key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], /) -> Array: diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py index 9845dd70f..8fb2a8b12 100644 --- a/numpy/_array_api/_creation_functions.py +++ b/numpy/_array_api/_creation_functions.py @@ -19,14 +19,14 @@ def asarray(obj: Union[float, NestedSequence[bool|int|float], SupportsDLPack, Su """ # _array_object imports in this file are inside the functions to avoid # circular imports - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") if copy is not None: # Note: copy is not yet implemented in np.asarray raise NotImplementedError("The copy keyword argument to asarray is not yet implemented") - if isinstance(obj, ndarray) and (dtype is None or obj.dtype == dtype): + if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): return obj if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -2**63): # Give a better error message in this case. NumPy would convert this @@ -35,7 +35,7 @@ def asarray(obj: Union[float, NestedSequence[bool|int|float], SupportsDLPack, Su res = np.asarray(obj, dtype=dtype) if res.dtype not in _all_dtypes: raise TypeError(f"The array_api namespace does not support the dtype '{res.dtype}'") - return ndarray._new(res) + return Array._new(res) def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -43,11 +43,11 @@ def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.arange(start, stop=stop, step=step, dtype=dtype)) + return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -55,11 +55,11 @@ def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.empty(shape, dtype=dtype)) + return Array._new(np.empty(shape, dtype=dtype)) def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -67,11 +67,11 @@ def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.empty_like(x._array, dtype=dtype)) + return Array._new(np.empty_like(x._array, dtype=dtype)) def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -79,14 +79,14 @@ def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, d See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) + return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) def from_dlpack(x: object, /) -> Array: - # Note: dlpack support is not yet implemented on ndarray + # Note: dlpack support is not yet implemented on Array raise NotImplementedError("DLPack support is not yet implemented") def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: @@ -95,18 +95,18 @@ def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, d See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - if isinstance(fill_value, ndarray) and fill_value.ndim == 0: - fill_value = fill_value._array[...] + if isinstance(fill_value, Array) and fill_value.ndim == 0: + fill_value = fill_value._array[...] res = np.full(shape, fill_value, dtype=dtype) if res.dtype not in _all_dtypes: # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full") - return ndarray._new(res) + return Array._new(res) def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -114,16 +114,16 @@ def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dty See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") res = np.full_like(x._array, fill_value, dtype=dtype) if res.dtype not in _all_dtypes: # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full_like") - return ndarray._new(res) + return Array._new(res) def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True) -> Array: """ @@ -131,11 +131,11 @@ def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) + return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) def meshgrid(*arrays: Sequence[Array], indexing: str = 'xy') -> List[Array, ...]: """ @@ -143,8 +143,8 @@ def meshgrid(*arrays: Sequence[Array], indexing: str = 'xy') -> List[Array, ...] See its docstring for more information. """ - from ._array_object import ndarray - return [ndarray._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)] + from ._array_object import Array + return [Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)] def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -152,11 +152,11 @@ def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, d See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.ones(shape, dtype=dtype)) + return Array._new(np.ones(shape, dtype=dtype)) def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -164,11 +164,11 @@ def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[De See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.ones_like(x._array, dtype=dtype)) + return Array._new(np.ones_like(x._array, dtype=dtype)) def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -176,11 +176,11 @@ def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.zeros(shape, dtype=dtype)) + return Array._new(np.zeros(shape, dtype=dtype)) def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: """ @@ -188,8 +188,8 @@ def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[D See its docstring for more information. """ - from ._array_object import ndarray + from ._array_object import Array if device is not None: - # Note: Device support is not yet implemented on ndarray + # Note: Device support is not yet implemented on Array raise NotImplementedError("Device support is not yet implemented") - return ndarray._new(np.zeros_like(x._array, dtype=dtype)) + return Array._new(np.zeros_like(x._array, dtype=dtype)) diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py index 5ab611fd3..2f304bf49 100644 --- a/numpy/_array_api/_data_type_functions.py +++ b/numpy/_array_api/_data_type_functions.py @@ -1,10 +1,10 @@ from __future__ import annotations -from ._array_object import ndarray +from ._array_object import Array from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import List, Tuple, Union, Array, Dtype + from ._types import List, Tuple, Union, Dtype from collections.abc import Sequence import numpy as np @@ -15,8 +15,8 @@ def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]: See its docstring for more information. """ - from ._array_object import ndarray - return [ndarray._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])] + from ._array_object import Array + return [Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])] def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: """ @@ -24,8 +24,8 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: See its docstring for more information. """ - from ._array_object import ndarray - return ndarray._new(np.broadcast_to(x._array, shape)) + from ._array_object import Array + return Array._new(np.broadcast_to(x._array, shape)) def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: """ @@ -33,8 +33,8 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: See its docstring for more information. """ - from ._array_object import ndarray - if isinstance(from_, ndarray): + from ._array_object import Array + if isinstance(from_, Array): from_ = from_._array return np.can_cast(from_, to) @@ -60,4 +60,4 @@ def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: See its docstring for more information. """ - return np.result_type(*(a._array if isinstance(a, ndarray) else a for a in arrays_and_dtypes)) + return np.result_type(*(a._array if isinstance(a, Array) else a for a in arrays_and_dtypes)) diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index ae265181a..8dedc77fb 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -3,11 +3,7 @@ from __future__ import annotations from ._dtypes import (_boolean_dtypes, _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes) -from ._array_object import ndarray - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ._types import Array +from ._array_object import Array import numpy as np @@ -19,7 +15,7 @@ def abs(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in abs') - return ndarray._new(np.abs(x._array)) + return Array._new(np.abs(x._array)) # Note: the function name is different here @np.errstate(all='ignore') @@ -31,7 +27,7 @@ def acos(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in acos') - return ndarray._new(np.arccos(x._array)) + return Array._new(np.arccos(x._array)) # Note: the function name is different here @np.errstate(all='ignore') @@ -43,7 +39,7 @@ def acosh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in acosh') - return ndarray._new(np.arccosh(x._array)) + return Array._new(np.arccosh(x._array)) @np.errstate(all='ignore') def add(x1: Array, x2: Array, /) -> Array: @@ -54,8 +50,8 @@ def add(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in add') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.add(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.add(x1._array, x2._array)) # Note: the function name is different here @np.errstate(all='ignore') @@ -67,7 +63,7 @@ def asin(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in asin') - return ndarray._new(np.arcsin(x._array)) + return Array._new(np.arcsin(x._array)) # Note: the function name is different here @np.errstate(all='ignore') @@ -79,7 +75,7 @@ def asinh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in asinh') - return ndarray._new(np.arcsinh(x._array)) + return Array._new(np.arcsinh(x._array)) # Note: the function name is different here def atan(x: Array, /) -> Array: @@ -90,7 +86,7 @@ def atan(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in atan') - return ndarray._new(np.arctan(x._array)) + return Array._new(np.arctan(x._array)) # Note: the function name is different here def atan2(x1: Array, x2: Array, /) -> Array: @@ -101,8 +97,8 @@ def atan2(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in atan2') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.arctan2(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.arctan2(x1._array, x2._array)) # Note: the function name is different here @np.errstate(all='ignore') @@ -114,7 +110,7 @@ def atanh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in atanh') - return ndarray._new(np.arctanh(x._array)) + return Array._new(np.arctanh(x._array)) def bitwise_and(x1: Array, x2: Array, /) -> Array: """ @@ -124,8 +120,8 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer_or_boolean dtypes are allowed in bitwise_and') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.bitwise_and(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.bitwise_and(x1._array, x2._array)) # Note: the function name is different here def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: @@ -136,14 +132,14 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') - x1, x2 = ndarray._normalize_two_args(x1, x2) + x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_left_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError('bitwise_left_shift(x1, x2) is only defined for x2 >= 0') # Note: The spec requires the return dtype of bitwise_left_shift to be the # same as the first argument. np.left_shift() returns a type that is the # type promotion of the two input types. - return ndarray._new(np.left_shift(x1._array, x2._array).astype(x1.dtype)) + return Array._new(np.left_shift(x1._array, x2._array).astype(x1.dtype)) # Note: the function name is different here def bitwise_invert(x: Array, /) -> Array: @@ -154,7 +150,7 @@ def bitwise_invert(x: Array, /) -> Array: """ if x.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_invert') - return ndarray._new(np.invert(x._array)) + return Array._new(np.invert(x._array)) def bitwise_or(x1: Array, x2: Array, /) -> Array: """ @@ -164,8 +160,8 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_or') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.bitwise_or(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.bitwise_or(x1._array, x2._array)) # Note: the function name is different here def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: @@ -176,14 +172,14 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') - x1, x2 = ndarray._normalize_two_args(x1, x2) + x1, x2 = Array._normalize_two_args(x1, x2) # Note: bitwise_right_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError('bitwise_right_shift(x1, x2) is only defined for x2 >= 0') # Note: The spec requires the return dtype of bitwise_left_shift to be the # same as the first argument. np.left_shift() returns a type that is the # type promotion of the two input types. - return ndarray._new(np.right_shift(x1._array, x2._array).astype(x1.dtype)) + return Array._new(np.right_shift(x1._array, x2._array).astype(x1.dtype)) def bitwise_xor(x1: Array, x2: Array, /) -> Array: """ @@ -193,8 +189,8 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes: raise TypeError('Only integer or boolean dtypes are allowed in bitwise_xor') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.bitwise_xor(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.bitwise_xor(x1._array, x2._array)) def ceil(x: Array, /) -> Array: """ @@ -207,7 +203,7 @@ def ceil(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x - return ndarray._new(np.ceil(x._array)) + return Array._new(np.ceil(x._array)) @np.errstate(all='ignore') def cos(x: Array, /) -> Array: @@ -218,7 +214,7 @@ def cos(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in cos') - return ndarray._new(np.cos(x._array)) + return Array._new(np.cos(x._array)) @np.errstate(all='ignore') def cosh(x: Array, /) -> Array: @@ -229,7 +225,7 @@ def cosh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in cosh') - return ndarray._new(np.cosh(x._array)) + return Array._new(np.cosh(x._array)) @np.errstate(all='ignore') def divide(x1: Array, x2: Array, /) -> Array: @@ -240,8 +236,8 @@ def divide(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in divide') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.divide(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.divide(x1._array, x2._array)) def equal(x1: Array, x2: Array, /) -> Array: """ @@ -249,8 +245,8 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.equal(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.equal(x1._array, x2._array)) @np.errstate(all='ignore') def exp(x: Array, /) -> Array: @@ -261,7 +257,7 @@ def exp(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in exp') - return ndarray._new(np.exp(x._array)) + return Array._new(np.exp(x._array)) @np.errstate(all='ignore') def expm1(x: Array, /) -> Array: @@ -272,7 +268,7 @@ def expm1(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in expm1') - return ndarray._new(np.expm1(x._array)) + return Array._new(np.expm1(x._array)) def floor(x: Array, /) -> Array: """ @@ -285,7 +281,7 @@ def floor(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x - return ndarray._new(np.floor(x._array)) + return Array._new(np.floor(x._array)) @np.errstate(all='ignore') def floor_divide(x1: Array, x2: Array, /) -> Array: @@ -296,8 +292,8 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in floor_divide') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.floor_divide(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.floor_divide(x1._array, x2._array)) def greater(x1: Array, x2: Array, /) -> Array: """ @@ -307,8 +303,8 @@ def greater(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in greater') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.greater(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.greater(x1._array, x2._array)) def greater_equal(x1: Array, x2: Array, /) -> Array: """ @@ -318,8 +314,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in greater_equal') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.greater_equal(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.greater_equal(x1._array, x2._array)) def isfinite(x: Array, /) -> Array: """ @@ -329,7 +325,7 @@ def isfinite(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in isfinite') - return ndarray._new(np.isfinite(x._array)) + return Array._new(np.isfinite(x._array)) def isinf(x: Array, /) -> Array: """ @@ -339,7 +335,7 @@ def isinf(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in isinf') - return ndarray._new(np.isinf(x._array)) + return Array._new(np.isinf(x._array)) def isnan(x: Array, /) -> Array: """ @@ -349,7 +345,7 @@ def isnan(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in isnan') - return ndarray._new(np.isnan(x._array)) + return Array._new(np.isnan(x._array)) def less(x1: Array, x2: Array, /) -> Array: """ @@ -359,8 +355,8 @@ def less(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in less') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.less(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.less(x1._array, x2._array)) def less_equal(x1: Array, x2: Array, /) -> Array: """ @@ -370,8 +366,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in less_equal') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.less_equal(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.less_equal(x1._array, x2._array)) @np.errstate(all='ignore') def log(x: Array, /) -> Array: @@ -382,7 +378,7 @@ def log(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in log') - return ndarray._new(np.log(x._array)) + return Array._new(np.log(x._array)) @np.errstate(all='ignore') def log1p(x: Array, /) -> Array: @@ -393,7 +389,7 @@ def log1p(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in log1p') - return ndarray._new(np.log1p(x._array)) + return Array._new(np.log1p(x._array)) @np.errstate(all='ignore') def log2(x: Array, /) -> Array: @@ -404,7 +400,7 @@ def log2(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in log2') - return ndarray._new(np.log2(x._array)) + return Array._new(np.log2(x._array)) @np.errstate(all='ignore') def log10(x: Array, /) -> Array: @@ -415,7 +411,7 @@ def log10(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in log10') - return ndarray._new(np.log10(x._array)) + return Array._new(np.log10(x._array)) def logaddexp(x1: Array, x2: Array) -> Array: """ @@ -425,8 +421,8 @@ def logaddexp(x1: Array, x2: Array) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in logaddexp') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.logaddexp(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logaddexp(x1._array, x2._array)) def logical_and(x1: Array, x2: Array, /) -> Array: """ @@ -436,8 +432,8 @@ def logical_and(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_and') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.logical_and(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logical_and(x1._array, x2._array)) def logical_not(x: Array, /) -> Array: """ @@ -447,7 +443,7 @@ def logical_not(x: Array, /) -> Array: """ if x.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_not') - return ndarray._new(np.logical_not(x._array)) + return Array._new(np.logical_not(x._array)) def logical_or(x1: Array, x2: Array, /) -> Array: """ @@ -457,8 +453,8 @@ def logical_or(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_or') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.logical_or(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logical_or(x1._array, x2._array)) def logical_xor(x1: Array, x2: Array, /) -> Array: """ @@ -468,8 +464,8 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError('Only boolean dtypes are allowed in logical_xor') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.logical_xor(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.logical_xor(x1._array, x2._array)) @np.errstate(all='ignore') def multiply(x1: Array, x2: Array, /) -> Array: @@ -480,8 +476,8 @@ def multiply(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in multiply') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.multiply(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.multiply(x1._array, x2._array)) def negative(x: Array, /) -> Array: """ @@ -491,7 +487,7 @@ def negative(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in negative') - return ndarray._new(np.negative(x._array)) + return Array._new(np.negative(x._array)) def not_equal(x1: Array, x2: Array, /) -> Array: """ @@ -499,8 +495,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.not_equal(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.not_equal(x1._array, x2._array)) def positive(x: Array, /) -> Array: """ @@ -510,7 +506,7 @@ def positive(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in positive') - return ndarray._new(np.positive(x._array)) + return Array._new(np.positive(x._array)) # Note: the function name is different here @np.errstate(all='ignore') @@ -522,8 +518,8 @@ def pow(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in pow') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.power(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.power(x1._array, x2._array)) @np.errstate(all='ignore') def remainder(x1: Array, x2: Array, /) -> Array: @@ -534,8 +530,8 @@ def remainder(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in remainder') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.remainder(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.remainder(x1._array, x2._array)) def round(x: Array, /) -> Array: """ @@ -545,7 +541,7 @@ def round(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in round') - return ndarray._new(np.round(x._array)) + return Array._new(np.round(x._array)) def sign(x: Array, /) -> Array: """ @@ -555,7 +551,7 @@ def sign(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in sign') - return ndarray._new(np.sign(x._array)) + return Array._new(np.sign(x._array)) @np.errstate(all='ignore') def sin(x: Array, /) -> Array: @@ -566,7 +562,7 @@ def sin(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in sin') - return ndarray._new(np.sin(x._array)) + return Array._new(np.sin(x._array)) @np.errstate(all='ignore') def sinh(x: Array, /) -> Array: @@ -577,7 +573,7 @@ def sinh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in sinh') - return ndarray._new(np.sinh(x._array)) + return Array._new(np.sinh(x._array)) @np.errstate(all='ignore') def square(x: Array, /) -> Array: @@ -588,7 +584,7 @@ def square(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in square') - return ndarray._new(np.square(x._array)) + return Array._new(np.square(x._array)) @np.errstate(all='ignore') def sqrt(x: Array, /) -> Array: @@ -599,7 +595,7 @@ def sqrt(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in sqrt') - return ndarray._new(np.sqrt(x._array)) + return Array._new(np.sqrt(x._array)) @np.errstate(all='ignore') def subtract(x1: Array, x2: Array, /) -> Array: @@ -610,8 +606,8 @@ def subtract(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in subtract') - x1, x2 = ndarray._normalize_two_args(x1, x2) - return ndarray._new(np.subtract(x1._array, x2._array)) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.subtract(x1._array, x2._array)) @np.errstate(all='ignore') def tan(x: Array, /) -> Array: @@ -622,7 +618,7 @@ def tan(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in tan') - return ndarray._new(np.tan(x._array)) + return Array._new(np.tan(x._array)) def tanh(x: Array, /) -> Array: """ @@ -632,7 +628,7 @@ def tanh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in tanh') - return ndarray._new(np.tanh(x._array)) + return Array._new(np.tanh(x._array)) def trunc(x: Array, /) -> Array: """ @@ -642,4 +638,4 @@ def trunc(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in trunc') - return ndarray._new(np.trunc(x._array)) + return Array._new(np.trunc(x._array)) diff --git a/numpy/_array_api/_linear_algebra_functions.py b/numpy/_array_api/_linear_algebra_functions.py index b6b0c6f6e..b4b2af134 100644 --- a/numpy/_array_api/_linear_algebra_functions.py +++ b/numpy/_array_api/_linear_algebra_functions.py @@ -1,11 +1,9 @@ from __future__ import annotations -from ._array_object import ndarray +from ._array_object import Array from ._dtypes import _numeric_dtypes -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ._types import Optional, Sequence, Tuple, Union, Array +from typing import Optional, Sequence, Tuple, Union import numpy as np @@ -30,7 +28,7 @@ def matmul(x1: Array, x2: Array, /) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in matmul') - return ndarray._new(np.matmul(x1._array, x2._array)) + return Array._new(np.matmul(x1._array, x2._array)) # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: @@ -39,7 +37,7 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in tensordot') - return ndarray._new(np.tensordot(x1._array, x2._array, axes=axes)) + return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: """ @@ -47,7 +45,7 @@ def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array: See its docstring for more information. """ - return ndarray._new(np.transpose(x._array, axes=axes)) + return Array._new(np.transpose(x._array, axes=axes)) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array: diff --git a/numpy/_array_api/_manipulation_functions.py b/numpy/_array_api/_manipulation_functions.py index da02155f9..c569d2834 100644 --- a/numpy/_array_api/_manipulation_functions.py +++ b/numpy/_array_api/_manipulation_functions.py @@ -1,10 +1,10 @@ from __future__ import annotations -from ._array_object import ndarray +from ._array_object import Array from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import Optional, Tuple, Union, Array + from ._types import Optional, Tuple, Union import numpy as np @@ -16,7 +16,7 @@ def concat(arrays: Tuple[Array, ...], /, *, axis: Optional[int] = 0) -> Array: See its docstring for more information. """ arrays = tuple(a._array for a in arrays) - return ndarray._new(np.concatenate(arrays, axis=axis)) + return Array._new(np.concatenate(arrays, axis=axis)) def expand_dims(x: Array, /, *, axis: int) -> Array: """ @@ -24,7 +24,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array: See its docstring for more information. """ - return ndarray._new(np.expand_dims(x._array, axis)) + return Array._new(np.expand_dims(x._array, axis)) def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """ @@ -32,7 +32,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> See its docstring for more information. """ - return ndarray._new(np.flip(x._array, axis=axis)) + return Array._new(np.flip(x._array, axis=axis)) def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: """ @@ -40,7 +40,7 @@ def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array: See its docstring for more information. """ - return ndarray._new(np.reshape(x._array, shape)) + return Array._new(np.reshape(x._array, shape)) def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """ @@ -48,7 +48,7 @@ def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Unio See its docstring for more information. """ - return ndarray._new(np.roll(x._array, shift, axis=axis)) + return Array._new(np.roll(x._array, shift, axis=axis)) def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """ @@ -56,7 +56,7 @@ def squeeze(x: Array, /, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> See its docstring for more information. """ - return ndarray._new(np.squeeze(x._array, axis=axis)) + return Array._new(np.squeeze(x._array, axis=axis)) def stack(arrays: Tuple[Array, ...], /, *, axis: int = 0) -> Array: """ @@ -65,4 +65,4 @@ def stack(arrays: Tuple[Array, ...], /, *, axis: int = 0) -> Array: See its docstring for more information. """ arrays = tuple(a._array for a in arrays) - return ndarray._new(np.stack(arrays, axis=axis)) + return Array._new(np.stack(arrays, axis=axis)) diff --git a/numpy/_array_api/_searching_functions.py b/numpy/_array_api/_searching_functions.py index 690256430..727f3013f 100644 --- a/numpy/_array_api/_searching_functions.py +++ b/numpy/_array_api/_searching_functions.py @@ -1,10 +1,10 @@ from __future__ import annotations -from ._array_object import ndarray +from ._array_object import Array from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import Tuple, Array + from ._types import Tuple import numpy as np @@ -15,7 +15,7 @@ def argmax(x: Array, /, *, axis: int = None, keepdims: bool = False) -> Array: See its docstring for more information. """ # Note: this currently fails as np.argmax does not implement keepdims - return ndarray._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) def argmin(x: Array, /, *, axis: int = None, keepdims: bool = False) -> Array: """ @@ -24,7 +24,7 @@ def argmin(x: Array, /, *, axis: int = None, keepdims: bool = False) -> Array: See its docstring for more information. """ # Note: this currently fails as np.argmin does not implement keepdims - return ndarray._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) def nonzero(x: Array, /) -> Tuple[Array, ...]: """ @@ -32,7 +32,7 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: See its docstring for more information. """ - return ndarray._new(np.nonzero(x._array)) + return Array._new(np.nonzero(x._array)) def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ @@ -40,4 +40,4 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - return ndarray._new(np.where(condition._array, x1._array, x2._array)) + return Array._new(np.where(condition._array, x1._array, x2._array)) diff --git a/numpy/_array_api/_set_functions.py b/numpy/_array_api/_set_functions.py index 719d54e5f..098145866 100644 --- a/numpy/_array_api/_set_functions.py +++ b/numpy/_array_api/_set_functions.py @@ -1,10 +1,10 @@ from __future__ import annotations -from ._array_object import ndarray +from ._array_object import Array from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import Tuple, Union, Array + from ._types import Tuple, Union import numpy as np @@ -14,4 +14,4 @@ def unique(x: Array, /, *, return_counts: bool = False, return_index: bool = Fal See its docstring for more information. """ - return ndarray._new(np.unique(x._array, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse)) + return Array._new(np.unique(x._array, return_counts=return_counts, return_index=return_index, return_inverse=return_inverse)) diff --git a/numpy/_array_api/_sorting_functions.py b/numpy/_array_api/_sorting_functions.py index 3dc0ec444..a125e0718 100644 --- a/numpy/_array_api/_sorting_functions.py +++ b/numpy/_array_api/_sorting_functions.py @@ -1,10 +1,6 @@ from __future__ import annotations -from ._array_object import ndarray - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ._types import Array +from ._array_object import Array import numpy as np @@ -19,7 +15,7 @@ def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bo res = np.argsort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return ndarray._new(res) + return Array._new(res) def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: """ @@ -32,4 +28,4 @@ def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool res = np.sort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return ndarray._new(res) + return Array._new(res) diff --git a/numpy/_array_api/_statistical_functions.py b/numpy/_array_api/_statistical_functions.py index e6a791fe6..4f6b1c034 100644 --- a/numpy/_array_api/_statistical_functions.py +++ b/numpy/_array_api/_statistical_functions.py @@ -1,32 +1,32 @@ from __future__ import annotations -from ._array_object import ndarray +from ._array_object import Array from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import Optional, Tuple, Union, Array + from ._types import Optional, Tuple, Union import numpy as np def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return ndarray._new(np.max(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return ndarray._new(np.asarray(np.mean(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.mean(x._array, axis=axis, keepdims=keepdims))) def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return ndarray._new(np.min(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return ndarray._new(np.asarray(np.prod(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.prod(x._array, axis=axis, keepdims=keepdims))) def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: # Note: the keyword argument correction is different here - return ndarray._new(np.asarray(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))) + return Array._new(np.asarray(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))) def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: - return ndarray._new(np.asarray(np.sum(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.sum(x._array, axis=axis, keepdims=keepdims))) def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> Array: # Note: the keyword argument correction is different here - return ndarray._new(np.asarray(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))) + return Array._new(np.asarray(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))) diff --git a/numpy/_array_api/_types.py b/numpy/_array_api/_types.py index 602c1df3e..d365e5e8c 100644 --- a/numpy/_array_api/_types.py +++ b/numpy/_array_api/_types.py @@ -12,7 +12,7 @@ __all__ = ['Any', 'List', 'Literal', 'Optional', 'Tuple', 'Union', 'Array', from typing import Any, List, Literal, Optional, Tuple, Union, TypeVar -from . import (ndarray, int8, int16, int32, int64, uint8, uint16, uint32, +from . import (Array, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64) Array = ndarray diff --git a/numpy/_array_api/_utility_functions.py b/numpy/_array_api/_utility_functions.py index a6a7721dd..ba77d3668 100644 --- a/numpy/_array_api/_utility_functions.py +++ b/numpy/_array_api/_utility_functions.py @@ -1,10 +1,10 @@ from __future__ import annotations -from ._array_object import ndarray +from ._array_object import Array from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._types import Optional, Tuple, Union, Array + from ._types import Optional, Tuple, Union import numpy as np @@ -14,7 +14,7 @@ def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep See its docstring for more information. """ - return ndarray._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: """ @@ -22,4 +22,4 @@ def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep See its docstring for more information. """ - return ndarray._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims))) |
