diff options
Diffstat (limited to 'numpy/array_api')
-rw-r--r-- | numpy/array_api/_array_object.py | 27 | ||||
-rw-r--r-- | numpy/array_api/_creation_functions.py | 2 | ||||
-rw-r--r-- | numpy/array_api/_data_type_functions.py | 20 | ||||
-rw-r--r-- | numpy/array_api/_elementwise_functions.py | 4 | ||||
-rw-r--r-- | numpy/array_api/_set_functions.py | 20 | ||||
-rw-r--r-- | numpy/array_api/_sorting_functions.py | 20 | ||||
-rw-r--r-- | numpy/array_api/linalg.py | 10 | ||||
-rw-r--r-- | numpy/array_api/tests/test_array_object.py | 2 | ||||
-rw-r--r-- | numpy/array_api/tests/test_data_type_functions.py | 19 | ||||
-rw-r--r-- | numpy/array_api/tests/test_elementwise_functions.py | 2 | ||||
-rw-r--r-- | numpy/array_api/tests/test_set_functions.py | 19 | ||||
-rw-r--r-- | numpy/array_api/tests/test_sorting_functions.py | 23 | ||||
-rw-r--r-- | numpy/array_api/tests/test_validation.py | 27 |
13 files changed, 163 insertions, 32 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 75baf34b0..6cf9ec6f3 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -30,6 +30,7 @@ from ._dtypes import ( ) from typing import TYPE_CHECKING, Optional, Tuple, Union, Any +import types if TYPE_CHECKING: from ._typing import Any, PyCapsule, Device, Dtype @@ -55,6 +56,7 @@ class Array: functions, such as asarray(). """ + _array: np.ndarray # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @@ -124,7 +126,7 @@ class Array: # spec in places where it either deviates from or is more strict than # NumPy behavior - def _check_allowed_dtypes(self, other, dtype_category, op): + def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -175,6 +177,8 @@ class Array: integer that is too large to fit in a NumPy integer dtype, or TypeError when the scalar type is incompatible with the dtype of self. """ + # Note: Only Python scalar types that match the array dtype are + # allowed. if isinstance(scalar, bool): if self.dtype not in _boolean_dtypes: raise TypeError( @@ -193,6 +197,9 @@ class Array: else: raise TypeError("'scalar' must be a Python scalar") + # Note: scalars are unconditionally cast to the same dtype as the + # array. + # Note: the spec only specifies integer-dtype/int promotion # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either @@ -200,7 +207,7 @@ class Array: return Array._new(np.array(scalar, self.dtype)) @staticmethod - def _normalize_two_args(x1, x2): + def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: """ Normalize inputs to two arg functions to fix type promotion rules @@ -415,7 +422,7 @@ class Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None - ) -> Any: + ) -> types.ModuleType: if api_version is not None and not api_version.startswith("2021."): raise ValueError(f"Unrecognized array API version: {api_version!r}") return array_api @@ -654,15 +661,13 @@ class Array: res = self._array.__pos__() return self.__class__._new(res) - # PEP 484 requires int to be a subtype of float, but __pow__ should not - # accept int. - def __pow__(self: Array, other: Union[float, Array], /) -> Array: + def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __pow__. """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, "floating-point", "__pow__") + other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow type promotion rules for 0-d @@ -912,23 +917,23 @@ class Array: res = self._array.__ror__(other._array) return self.__class__._new(res) - def __ipow__(self: Array, other: Union[float, Array], /) -> Array: + def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ipow__. """ - other = self._check_allowed_dtypes(other, "floating-point", "__ipow__") + other = self._check_allowed_dtypes(other, "numeric", "__ipow__") if other is NotImplemented: return other self._array.__ipow__(other._array) return self - def __rpow__(self: Array, other: Union[float, Array], /) -> Array: + def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rpow__. """ from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, "floating-point", "__rpow__") + other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: return other # Note: NumPy's __pow__ does not follow the spec type promotion rules diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index 741498ff6..3b014d37b 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -154,7 +154,7 @@ def eye( def from_dlpack(x: object, /) -> Array: from ._array_object import Array - return Array._new(np._from_dlpack(x)) + return Array._new(np.from_dlpack(x)) def full( diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index e4d6db61b..7026bd489 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -50,11 +50,23 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: See its docstring for more information. """ - from ._array_object import Array - if isinstance(from_, Array): - from_ = from_._array - return np.can_cast(from_, to) + from_ = from_.dtype + elif from_ not in _all_dtypes: + raise TypeError(f"{from_=}, but should be an array_api array or dtype") + if to not in _all_dtypes: + raise TypeError(f"{to=}, but should be a dtype") + # Note: We avoid np.can_cast() as it has discrepancies with the array API, + # since NumPy allows cross-kind casting (e.g., NumPy allows bool -> int8). + # See https://github.com/numpy/numpy/issues/20870 + try: + # We promote `from_` and `to` together. We then check if the promoted + # dtype is `to`, which indicates if `from_` can (up)cast to `to`. + dtype = _result_type(from_, to) + return to == dtype + except TypeError: + # _result_type() raises if the dtypes don't promote together + return False # These are internal objects for the return types of finfo and iinfo, since diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py index 4408fe833..c758a0944 100644 --- a/numpy/array_api/_elementwise_functions.py +++ b/numpy/array_api/_elementwise_functions.py @@ -591,8 +591,8 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in pow") + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py index 05ee7e555..db9370f84 100644 --- a/numpy/array_api/_set_functions.py +++ b/numpy/array_api/_set_functions.py @@ -41,14 +41,21 @@ def unique_all(x: Array, /) -> UniqueAllResult: See its docstring for more information. """ - res = np.unique( + values, indices, inverse_indices, counts = np.unique( x._array, return_counts=True, return_index=True, return_inverse=True, ) - - return UniqueAllResult(*[Array._new(i) for i in res]) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueAllResult( + Array._new(values), + Array._new(indices), + Array._new(inverse_indices), + Array._new(counts), + ) def unique_counts(x: Array, /) -> UniqueCountsResult: @@ -68,13 +75,16 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: See its docstring for more information. """ - res = np.unique( + values, inverse_indices = np.unique( x._array, return_counts=False, return_index=False, return_inverse=True, ) - return UniqueInverseResult(*[Array._new(i) for i in res]) + # np.unique() flattens inverse indices, but they need to share x's shape + # See https://github.com/numpy/numpy/issues/20638 + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) def unique_values(x: Array, /) -> Array: diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index 9cd49786c..afbb412f7 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -5,6 +5,7 @@ from ._array_object import Array import numpy as np +# Note: the descending keyword argument is new in this function def argsort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: @@ -15,12 +16,23 @@ def argsort( """ # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" - res = np.argsort(x._array, axis=axis, kind=kind) - if descending: - res = np.flip(res, axis=axis) + if not descending: + res = np.argsort(x._array, axis=axis, kind=kind) + else: + # As NumPy has no native descending sort, we imitate it here. Note that + # simply flipping the results of np.argsort(x._array, ...) would not + # respect the relative order like it would in native descending sorts. + res = np.flip( + np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind), + axis=axis, + ) + # Rely on flip()/argsort() to validate axis + normalised_axis = axis if axis >= 0 else x.ndim + axis + max_i = x.shape[normalised_axis] - 1 + res = max_i - res return Array._new(res) - +# Note: the descending keyword argument is new in this function def sort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True ) -> Array: diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index 8d7ba659e..f422e1c27 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -89,7 +89,6 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) -# Note: the keyword argument name upper is different from np.linalg.eigh def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`. @@ -106,7 +105,6 @@ def eigh(x: Array, /) -> EighResult: return EighResult(*map(Array._new, np.linalg.eigh(x._array))) -# Note: the keyword argument name upper is different from np.linalg.eigvalsh def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`. @@ -346,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: + if x.dtype not in _floating_dtypes: + raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -366,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: See its docstring for more information. """ + if x.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in trace') # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') return tensordot(x1, x2, axes=((axis,), (axis,))) @@ -380,7 +384,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. -def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: +def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index b980bacca..1fe1dfddf 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -98,7 +98,7 @@ def test_operators(): "__mul__": "numeric", "__ne__": "all", "__or__": "integer_or_boolean", - "__pow__": "floating", + "__pow__": "numeric", "__rshift__": "integer", "__sub__": "numeric", "__truediv__": "floating", diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py new file mode 100644 index 000000000..efe3d0abd --- /dev/null +++ b/numpy/array_api/tests/test_data_type_functions.py @@ -0,0 +1,19 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "from_, to, expected", + [ + (xp.int8, xp.int16, True), + (xp.int16, xp.int8, False), + (xp.bool, xp.int8, False), + (xp.asarray(0, dtype=xp.uint8), xp.int8, False), + ], +) +def test_can_cast(from_, to, expected): + """ + can_cast() returns correct result + """ + assert xp.can_cast(from_, to) == expected diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index a9274aec9..b2fb44e76 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -66,7 +66,7 @@ def test_function_types(): "negative": "numeric", "not_equal": "all", "positive": "numeric", - "pow": "floating-point", + "pow": "numeric", "remainder": "numeric", "round": "numeric", "sign": "numeric", diff --git a/numpy/array_api/tests/test_set_functions.py b/numpy/array_api/tests/test_set_functions.py new file mode 100644 index 000000000..b8eb65d43 --- /dev/null +++ b/numpy/array_api/tests/test_set_functions.py @@ -0,0 +1,19 @@ +import pytest +from hypothesis import given +from hypothesis.extra.array_api import make_strategies_namespace + +from numpy import array_api as xp + +xps = make_strategies_namespace(xp) + + +@pytest.mark.parametrize("func", [xp.unique_all, xp.unique_inverse]) +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes())) +def test_inverse_indices_shape(func, x): + """ + Inverse indices share shape of input array + + See https://github.com/numpy/numpy/issues/20638 + """ + out = func(x) + assert out.inverse_indices.shape == x.shape diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py new file mode 100644 index 000000000..9848bbfeb --- /dev/null +++ b/numpy/array_api/tests/test_sorting_functions.py @@ -0,0 +1,23 @@ +import pytest + +from numpy import array_api as xp + + +@pytest.mark.parametrize( + "obj, axis, expected", + [ + ([0, 0], -1, [0, 1]), + ([0, 1, 0], -1, [1, 0, 2]), + ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]), + ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]), + ], +) +def test_stable_desc_argsort(obj, axis, expected): + """ + Indices respect relative order of a descending stable-sort + + See https://github.com/numpy/numpy/issues/20778 + """ + x = xp.asarray(obj) + out = xp.argsort(x, axis=axis, stable=True, descending=True) + assert xp.all(out == xp.asarray(expected)) diff --git a/numpy/array_api/tests/test_validation.py b/numpy/array_api/tests/test_validation.py new file mode 100644 index 000000000..0dd100d15 --- /dev/null +++ b/numpy/array_api/tests/test_validation.py @@ -0,0 +1,27 @@ +from typing import Callable + +import pytest + +from numpy import array_api as xp + + +def p(func: Callable, *args, **kwargs): + f_sig = ", ".join( + [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()] + ) + id_ = f"{func.__name__}({f_sig})" + return pytest.param(func, args, kwargs, id=id_) + + +@pytest.mark.parametrize( + "func, args, kwargs", + [ + p(xp.can_cast, 42, xp.int8), + p(xp.can_cast, xp.int8, 42), + p(xp.result_type, 42), + ], +) +def test_raises_on_invalid_types(func, args, kwargs): + """Function raises TypeError when passed invalidly-typed inputs""" + with pytest.raises(TypeError): + func(*args, **kwargs) |