summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_array_object.py27
-rw-r--r--numpy/array_api/_creation_functions.py2
-rw-r--r--numpy/array_api/_data_type_functions.py20
-rw-r--r--numpy/array_api/_elementwise_functions.py4
-rw-r--r--numpy/array_api/_set_functions.py20
-rw-r--r--numpy/array_api/_sorting_functions.py20
-rw-r--r--numpy/array_api/linalg.py10
-rw-r--r--numpy/array_api/tests/test_array_object.py2
-rw-r--r--numpy/array_api/tests/test_data_type_functions.py19
-rw-r--r--numpy/array_api/tests/test_elementwise_functions.py2
-rw-r--r--numpy/array_api/tests/test_set_functions.py19
-rw-r--r--numpy/array_api/tests/test_sorting_functions.py23
-rw-r--r--numpy/array_api/tests/test_validation.py27
13 files changed, 163 insertions, 32 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 75baf34b0..6cf9ec6f3 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -30,6 +30,7 @@ from ._dtypes import (
)
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
+import types
if TYPE_CHECKING:
from ._typing import Any, PyCapsule, Device, Dtype
@@ -55,6 +56,7 @@ class Array:
functions, such as asarray().
"""
+ _array: np.ndarray
# Use a custom constructor instead of __init__, as manually initializing
# this class is not supported API.
@@ -124,7 +126,7 @@ class Array:
# spec in places where it either deviates from or is more strict than
# NumPy behavior
- def _check_allowed_dtypes(self, other, dtype_category, op):
+ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
"""
Helper function for operators to only allow specific input dtypes
@@ -175,6 +177,8 @@ class Array:
integer that is too large to fit in a NumPy integer dtype, or
TypeError when the scalar type is incompatible with the dtype of self.
"""
+ # Note: Only Python scalar types that match the array dtype are
+ # allowed.
if isinstance(scalar, bool):
if self.dtype not in _boolean_dtypes:
raise TypeError(
@@ -193,6 +197,9 @@ class Array:
else:
raise TypeError("'scalar' must be a Python scalar")
+ # Note: scalars are unconditionally cast to the same dtype as the
+ # array.
+
# Note: the spec only specifies integer-dtype/int promotion
# behavior for integers within the bounds of the integer dtype.
# Outside of those bounds we use the default NumPy behavior (either
@@ -200,7 +207,7 @@ class Array:
return Array._new(np.array(scalar, self.dtype))
@staticmethod
- def _normalize_two_args(x1, x2):
+ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
"""
Normalize inputs to two arg functions to fix type promotion rules
@@ -415,7 +422,7 @@ class Array:
def __array_namespace__(
self: Array, /, *, api_version: Optional[str] = None
- ) -> Any:
+ ) -> types.ModuleType:
if api_version is not None and not api_version.startswith("2021."):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
return array_api
@@ -654,15 +661,13 @@ class Array:
res = self._array.__pos__()
return self.__class__._new(res)
- # PEP 484 requires int to be a subtype of float, but __pow__ should not
- # accept int.
- def __pow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __pow__.
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, "floating-point", "__pow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
@@ -912,23 +917,23 @@ class Array:
res = self._array.__ror__(other._array)
return self.__class__._new(res)
- def __ipow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ipow__.
"""
- other = self._check_allowed_dtypes(other, "floating-point", "__ipow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
self._array.__ipow__(other._array)
return self
- def __rpow__(self: Array, other: Union[float, Array], /) -> Array:
+ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rpow__.
"""
from ._elementwise_functions import pow
- other = self._check_allowed_dtypes(other, "floating-point", "__rpow__")
+ other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
# Note: NumPy's __pow__ does not follow the spec type promotion rules
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index 741498ff6..3b014d37b 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -154,7 +154,7 @@ def eye(
def from_dlpack(x: object, /) -> Array:
from ._array_object import Array
- return Array._new(np._from_dlpack(x))
+ return Array._new(np.from_dlpack(x))
def full(
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index e4d6db61b..7026bd489 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -50,11 +50,23 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
See its docstring for more information.
"""
- from ._array_object import Array
-
if isinstance(from_, Array):
- from_ = from_._array
- return np.can_cast(from_, to)
+ from_ = from_.dtype
+ elif from_ not in _all_dtypes:
+ raise TypeError(f"{from_=}, but should be an array_api array or dtype")
+ if to not in _all_dtypes:
+ raise TypeError(f"{to=}, but should be a dtype")
+ # Note: We avoid np.can_cast() as it has discrepancies with the array API,
+ # since NumPy allows cross-kind casting (e.g., NumPy allows bool -> int8).
+ # See https://github.com/numpy/numpy/issues/20870
+ try:
+ # We promote `from_` and `to` together. We then check if the promoted
+ # dtype is `to`, which indicates if `from_` can (up)cast to `to`.
+ dtype = _result_type(from_, to)
+ return to == dtype
+ except TypeError:
+ # _result_type() raises if the dtypes don't promote together
+ return False
# These are internal objects for the return types of finfo and iinfo, since
diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py
index 4408fe833..c758a0944 100644
--- a/numpy/array_api/_elementwise_functions.py
+++ b/numpy/array_api/_elementwise_functions.py
@@ -591,8 +591,8 @@ def pow(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
- if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
- raise TypeError("Only floating-point dtypes are allowed in pow")
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in pow")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
x1, x2 = Array._normalize_two_args(x1, x2)
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py
index 05ee7e555..db9370f84 100644
--- a/numpy/array_api/_set_functions.py
+++ b/numpy/array_api/_set_functions.py
@@ -41,14 +41,21 @@ def unique_all(x: Array, /) -> UniqueAllResult:
See its docstring for more information.
"""
- res = np.unique(
+ values, indices, inverse_indices, counts = np.unique(
x._array,
return_counts=True,
return_index=True,
return_inverse=True,
)
-
- return UniqueAllResult(*[Array._new(i) for i in res])
+ # np.unique() flattens inverse indices, but they need to share x's shape
+ # See https://github.com/numpy/numpy/issues/20638
+ inverse_indices = inverse_indices.reshape(x.shape)
+ return UniqueAllResult(
+ Array._new(values),
+ Array._new(indices),
+ Array._new(inverse_indices),
+ Array._new(counts),
+ )
def unique_counts(x: Array, /) -> UniqueCountsResult:
@@ -68,13 +75,16 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
See its docstring for more information.
"""
- res = np.unique(
+ values, inverse_indices = np.unique(
x._array,
return_counts=False,
return_index=False,
return_inverse=True,
)
- return UniqueInverseResult(*[Array._new(i) for i in res])
+ # np.unique() flattens inverse indices, but they need to share x's shape
+ # See https://github.com/numpy/numpy/issues/20638
+ inverse_indices = inverse_indices.reshape(x.shape)
+ return UniqueInverseResult(Array._new(values), Array._new(inverse_indices))
def unique_values(x: Array, /) -> Array:
diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py
index 9cd49786c..afbb412f7 100644
--- a/numpy/array_api/_sorting_functions.py
+++ b/numpy/array_api/_sorting_functions.py
@@ -5,6 +5,7 @@ from ._array_object import Array
import numpy as np
+# Note: the descending keyword argument is new in this function
def argsort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
@@ -15,12 +16,23 @@ def argsort(
"""
# Note: this keyword argument is different, and the default is different.
kind = "stable" if stable else "quicksort"
- res = np.argsort(x._array, axis=axis, kind=kind)
- if descending:
- res = np.flip(res, axis=axis)
+ if not descending:
+ res = np.argsort(x._array, axis=axis, kind=kind)
+ else:
+ # As NumPy has no native descending sort, we imitate it here. Note that
+ # simply flipping the results of np.argsort(x._array, ...) would not
+ # respect the relative order like it would in native descending sorts.
+ res = np.flip(
+ np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind),
+ axis=axis,
+ )
+ # Rely on flip()/argsort() to validate axis
+ normalised_axis = axis if axis >= 0 else x.ndim + axis
+ max_i = x.shape[normalised_axis] - 1
+ res = max_i - res
return Array._new(res)
-
+# Note: the descending keyword argument is new in this function
def sort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py
index 8d7ba659e..f422e1c27 100644
--- a/numpy/array_api/linalg.py
+++ b/numpy/array_api/linalg.py
@@ -89,7 +89,6 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array:
return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
-# Note: the keyword argument name upper is different from np.linalg.eigh
def eigh(x: Array, /) -> EighResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
@@ -106,7 +105,6 @@ def eigh(x: Array, /) -> EighResult:
return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
-# Note: the keyword argument name upper is different from np.linalg.eigvalsh
def eigvalsh(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
@@ -346,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# np.linalg.svd(compute_uv=False).
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svdvals')
return Array._new(np.linalg.svd(x._array, compute_uv=False))
# Note: tensordot is the numpy top-level namespace but not in np.linalg
@@ -366,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
See its docstring for more information.
"""
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in trace')
# Note: trace always operates on the last two axes, whereas np.trace
# operates on the first two axes by default
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
# Note: vecdot is not in NumPy
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in vecdot')
return tensordot(x1, x2, axes=((axis,), (axis,)))
@@ -380,7 +384,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
# The type for ord should be Optional[Union[int, float, Literal[np.inf,
# -np.inf]]] but Literal does not support floating-point literals.
-def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
+def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index b980bacca..1fe1dfddf 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -98,7 +98,7 @@ def test_operators():
"__mul__": "numeric",
"__ne__": "all",
"__or__": "integer_or_boolean",
- "__pow__": "floating",
+ "__pow__": "numeric",
"__rshift__": "integer",
"__sub__": "numeric",
"__truediv__": "floating",
diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py
new file mode 100644
index 000000000..efe3d0abd
--- /dev/null
+++ b/numpy/array_api/tests/test_data_type_functions.py
@@ -0,0 +1,19 @@
+import pytest
+
+from numpy import array_api as xp
+
+
+@pytest.mark.parametrize(
+ "from_, to, expected",
+ [
+ (xp.int8, xp.int16, True),
+ (xp.int16, xp.int8, False),
+ (xp.bool, xp.int8, False),
+ (xp.asarray(0, dtype=xp.uint8), xp.int8, False),
+ ],
+)
+def test_can_cast(from_, to, expected):
+ """
+ can_cast() returns correct result
+ """
+ assert xp.can_cast(from_, to) == expected
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py
index a9274aec9..b2fb44e76 100644
--- a/numpy/array_api/tests/test_elementwise_functions.py
+++ b/numpy/array_api/tests/test_elementwise_functions.py
@@ -66,7 +66,7 @@ def test_function_types():
"negative": "numeric",
"not_equal": "all",
"positive": "numeric",
- "pow": "floating-point",
+ "pow": "numeric",
"remainder": "numeric",
"round": "numeric",
"sign": "numeric",
diff --git a/numpy/array_api/tests/test_set_functions.py b/numpy/array_api/tests/test_set_functions.py
new file mode 100644
index 000000000..b8eb65d43
--- /dev/null
+++ b/numpy/array_api/tests/test_set_functions.py
@@ -0,0 +1,19 @@
+import pytest
+from hypothesis import given
+from hypothesis.extra.array_api import make_strategies_namespace
+
+from numpy import array_api as xp
+
+xps = make_strategies_namespace(xp)
+
+
+@pytest.mark.parametrize("func", [xp.unique_all, xp.unique_inverse])
+@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes()))
+def test_inverse_indices_shape(func, x):
+ """
+ Inverse indices share shape of input array
+
+ See https://github.com/numpy/numpy/issues/20638
+ """
+ out = func(x)
+ assert out.inverse_indices.shape == x.shape
diff --git a/numpy/array_api/tests/test_sorting_functions.py b/numpy/array_api/tests/test_sorting_functions.py
new file mode 100644
index 000000000..9848bbfeb
--- /dev/null
+++ b/numpy/array_api/tests/test_sorting_functions.py
@@ -0,0 +1,23 @@
+import pytest
+
+from numpy import array_api as xp
+
+
+@pytest.mark.parametrize(
+ "obj, axis, expected",
+ [
+ ([0, 0], -1, [0, 1]),
+ ([0, 1, 0], -1, [1, 0, 2]),
+ ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
+ ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
+ ],
+)
+def test_stable_desc_argsort(obj, axis, expected):
+ """
+ Indices respect relative order of a descending stable-sort
+
+ See https://github.com/numpy/numpy/issues/20778
+ """
+ x = xp.asarray(obj)
+ out = xp.argsort(x, axis=axis, stable=True, descending=True)
+ assert xp.all(out == xp.asarray(expected))
diff --git a/numpy/array_api/tests/test_validation.py b/numpy/array_api/tests/test_validation.py
new file mode 100644
index 000000000..0dd100d15
--- /dev/null
+++ b/numpy/array_api/tests/test_validation.py
@@ -0,0 +1,27 @@
+from typing import Callable
+
+import pytest
+
+from numpy import array_api as xp
+
+
+def p(func: Callable, *args, **kwargs):
+ f_sig = ", ".join(
+ [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()]
+ )
+ id_ = f"{func.__name__}({f_sig})"
+ return pytest.param(func, args, kwargs, id=id_)
+
+
+@pytest.mark.parametrize(
+ "func, args, kwargs",
+ [
+ p(xp.can_cast, 42, xp.int8),
+ p(xp.can_cast, xp.int8, 42),
+ p(xp.result_type, 42),
+ ],
+)
+def test_raises_on_invalid_types(func, args, kwargs):
+ """Function raises TypeError when passed invalidly-typed inputs"""
+ with pytest.raises(TypeError):
+ func(*args, **kwargs)