summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/__init__.py11
-rw-r--r--numpy/array_api/_array_object.py38
-rw-r--r--numpy/array_api/_creation_functions.py34
-rw-r--r--numpy/array_api/_data_type_functions.py4
-rw-r--r--numpy/array_api/_linear_algebra_functions.py13
-rw-r--r--numpy/array_api/_manipulation_functions.py11
-rw-r--r--numpy/array_api/_statistical_functions.py35
-rw-r--r--numpy/array_api/_typing.py52
-rw-r--r--numpy/array_api/tests/test_array_object.py30
-rw-r--r--numpy/array_api/tests/test_creation_functions.py15
10 files changed, 192 insertions, 51 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index 790157504..d8b29057e 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -143,6 +143,8 @@ from ._creation_functions import (
meshgrid,
ones,
ones_like,
+ tril,
+ triu,
zeros,
zeros_like,
)
@@ -160,6 +162,8 @@ __all__ += [
"meshgrid",
"ones",
"ones_like",
+ "tril",
+ "triu",
"zeros",
"zeros_like",
]
@@ -333,21 +337,22 @@ __all__ += [
# from ._linear_algebra_functions import einsum
# __all__ += ['einsum']
-from ._linear_algebra_functions import matmul, tensordot, transpose, vecdot
+from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
-__all__ += ["matmul", "tensordot", "transpose", "vecdot"]
+__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
from ._manipulation_functions import (
concat,
expand_dims,
flip,
+ permute_dims,
reshape,
roll,
squeeze,
stack,
)
-__all__ += ["concat", "expand_dims", "flip", "reshape", "roll", "squeeze", "stack"]
+__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
from ._searching_functions import argmax, argmin, nonzero, where
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 2d746e78b..ef66c5efd 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -29,7 +29,7 @@ from ._dtypes import (
_dtype_categories,
)
-from typing import TYPE_CHECKING, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
from ._typing import PyCapsule, Device, Dtype
@@ -99,7 +99,10 @@ class Array:
"""
Performs the operation __repr__.
"""
- return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})"
+ prefix = "Array("
+ suffix = f", dtype={self.dtype.name})"
+ mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
+ return prefix + mid + suffix
# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
@@ -379,7 +382,7 @@ class Array:
def __array_namespace__(
self: Array, /, *, api_version: Optional[str] = None
- ) -> object:
+ ) -> Any:
if api_version is not None and not api_version.startswith("2021."):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
return array_api
@@ -391,6 +394,8 @@ class Array:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("bool is only allowed on arrays with 0 dimensions")
+ if self.dtype not in _boolean_dtypes:
+ raise ValueError("bool is only allowed on boolean arrays")
res = self._array.__bool__()
return res
@@ -429,6 +434,8 @@ class Array:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("float is only allowed on arrays with 0 dimensions")
+ if self.dtype not in _floating_dtypes:
+ raise ValueError("float is only allowed on floating-point arrays")
res = self._array.__float__()
return res
@@ -488,9 +495,18 @@ class Array:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("int is only allowed on arrays with 0 dimensions")
+ if self.dtype not in _integer_dtypes:
+ raise ValueError("int is only allowed on integer arrays")
res = self._array.__int__()
return res
+ def __index__(self: Array, /) -> int:
+ """
+ Performs the operation __index__.
+ """
+ res = self._array.__index__()
+ return res
+
def __invert__(self: Array, /) -> Array:
"""
Performs the operation __invert__.
@@ -979,6 +995,11 @@ class Array:
res = self._array.__rxor__(other._array)
return self.__class__._new(res)
+ def to_device(self: Array, device: Device, /) -> Array:
+ if device == 'cpu':
+ return self
+ raise ValueError(f"Unsupported device {device!r}")
+
@property
def dtype(self) -> Dtype:
"""
@@ -992,6 +1013,12 @@ class Array:
def device(self) -> Device:
return "cpu"
+ # Note: mT is new in array API spec (see matrix_transpose)
+ @property
+ def mT(self) -> Array:
+ from ._linear_algebra_functions import matrix_transpose
+ return matrix_transpose(self)
+
@property
def ndim(self) -> int:
"""
@@ -1026,4 +1053,9 @@ class Array:
See its docstring for more information.
"""
+ # Note: T only works on 2-dimensional arrays. See the corresponding
+ # note in the specification:
+ # https://data-apis.org/array-api/latest/API_specification/array_object.html#t
+ if self.ndim != 2:
+ raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.")
return self._array.T
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index 2d6cf4414..d760bf2fc 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -22,7 +22,7 @@ def _check_valid_dtype(dtype):
# Note: Only spelling dtypes as the dtype objects is supported.
# We use this instead of "dtype in _all_dtypes" because the dtype objects
- # define equality with the sorts of things we want to disallw.
+ # define equality with the sorts of things we want to disallow.
for d in (None,) + _all_dtypes:
if dtype is d:
return
@@ -134,7 +134,7 @@ def eye(
n_cols: Optional[int] = None,
/,
*,
- k: Optional[int] = 0,
+ k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> Array:
@@ -232,7 +232,7 @@ def linspace(
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
-def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]:
+def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
@@ -281,6 +281,34 @@ def ones_like(
return Array._new(np.ones_like(x._array, dtype=dtype))
+def tril(x: Array, /, *, k: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ if x.ndim < 2:
+ # Note: Unlike np.tril, x must be at least 2-D
+ raise ValueError("x must be at least 2-dimensional for tril")
+ return Array._new(np.tril(x._array, k=k))
+
+
+def triu(x: Array, /, *, k: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ if x.ndim < 2:
+ # Note: Unlike np.triu, x must be at least 2-D
+ raise ValueError("x must be at least 2-dimensional for triu")
+ return Array._new(np.triu(x._array, k=k))
+
+
def zeros(
shape: Union[int, Tuple[int, ...]],
*,
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index fd92aa250..7ccbe9469 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
import numpy as np
-def broadcast_arrays(*arrays: Sequence[Array]) -> List[Array]:
+def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
@@ -98,7 +98,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
return iinfo_object(ii.bits, ii.max, ii.min)
-def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype:
+def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py
index 089081725..7a6c9846c 100644
--- a/numpy/array_api/_linear_algebra_functions.py
+++ b/numpy/array_api/_linear_algebra_functions.py
@@ -52,13 +52,12 @@ def tensordot(
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
-def transpose(x: Array, /, *, axes: Optional[Tuple[int, ...]] = None) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
-
- See its docstring for more information.
- """
- return Array._new(np.transpose(x._array, axes=axes))
+# Note: this function is new in the array API spec. Unlike transpose, it only
+# transposes the last two axes.
+def matrix_transpose(x: Array, /) -> Array:
+ if x.ndim < 2:
+ raise ValueError("x must be at least 2-dimensional for matrix_transpose")
+ return Array._new(np.swapaxes(x._array, -1, -2))
# Note: vecdot is not in NumPy
diff --git a/numpy/array_api/_manipulation_functions.py b/numpy/array_api/_manipulation_functions.py
index c11866261..4f2114ff5 100644
--- a/numpy/array_api/_manipulation_functions.py
+++ b/numpy/array_api/_manipulation_functions.py
@@ -41,6 +41,17 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
return Array._new(np.flip(x._array, axis=axis))
+# Note: The function name is different here (see also matrix_transpose).
+# Unlike transpose(), the axes argument is required.
+def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.transpose(x._array, axes))
+
+
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
"""
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py
index 63790b447..c5abf9468 100644
--- a/numpy/array_api/_statistical_functions.py
+++ b/numpy/array_api/_statistical_functions.py
@@ -1,8 +1,17 @@
from __future__ import annotations
+from ._dtypes import (
+ _floating_dtypes,
+ _numeric_dtypes,
+)
from ._array_object import Array
+from ._creation_functions import asarray
+from ._dtypes import float32, float64
-from typing import Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple, Union
+
+if TYPE_CHECKING:
+ from ._typing import Dtype
import numpy as np
@@ -14,6 +23,8 @@ def max(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in max")
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
@@ -24,6 +35,8 @@ def mean(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in mean")
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
@@ -34,6 +47,8 @@ def min(
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in min")
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
@@ -42,8 +57,15 @@ def prod(
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
keepdims: bool = False,
) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in prod")
+ # Note: sum() and prod() always upcast float32 to float64 for dtype=None
+ # We need to do so here before computing the product to avoid overflow
+ if dtype is None and x.dtype == float32:
+ x = asarray(x, dtype=float64)
return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims))
@@ -56,6 +78,8 @@ def std(
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in std")
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
@@ -64,8 +88,15 @@ def sum(
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
keepdims: bool = False,
) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in sum")
+ # Note: sum() and prod() always upcast float32 to float64 for dtype=None
+ # We need to do so here before summing to avoid overflow
+ if dtype is None and x.dtype == float32:
+ x = asarray(x, dtype=float64)
return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims))
@@ -78,4 +109,6 @@ def var(
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in var")
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index d530a91ae..dfa87b358 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only
valid for inputs that match the given type annotations.
"""
+from __future__ import annotations
+
__all__ = [
"Array",
"Device",
@@ -15,10 +17,21 @@ __all__ = [
"PyCapsule",
]
-from typing import Any, Sequence, Type, Union
+import sys
+from typing import (
+ Any,
+ Literal,
+ Sequence,
+ Type,
+ Union,
+ TYPE_CHECKING,
+ TypeVar,
+ Protocol,
+)
-from . import (
- Array,
+from ._array_object import Array
+from numpy import (
+ dtype,
int8,
int16,
int32,
@@ -31,14 +44,31 @@ from . import (
float64,
)
-# This should really be recursive, but that isn't supported yet. See the
-# similar comment in numpy/typing/_array_like.py
-NestedSequence = Sequence[Sequence[Any]]
+_T_co = TypeVar("_T_co", covariant=True)
+
+class NestedSequence(Protocol[_T_co]):
+ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+ def __len__(self, /) -> int: ...
+
+Device = Literal["cpu"]
+if TYPE_CHECKING or sys.version_info >= (3, 9):
+ Dtype = dtype[Union[
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ ]]
+else:
+ Dtype = dtype
-Device = Any
-Dtype = Type[
- Union[[int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64]]
-]
-SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any
+
+class SupportsDLPack(Protocol):
+ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index 088e09b9f..7959f92b4 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -1,3 +1,5 @@
+import operator
+
from numpy.testing import assert_raises
import numpy as np
@@ -255,15 +257,31 @@ def test_operators():
def test_python_scalar_construtors():
- a = asarray(False)
- b = asarray(0)
- c = asarray(0.0)
+ b = asarray(False)
+ i = asarray(0)
+ f = asarray(0.0)
- assert bool(a) == bool(b) == bool(c) == False
- assert int(a) == int(b) == int(c) == 0
- assert float(a) == float(b) == float(c) == 0.0
+ assert bool(b) == False
+ assert int(i) == 0
+ assert float(f) == 0.0
+ assert operator.index(i) == 0
# bool/int/float should only be allowed on 0-D arrays.
assert_raises(TypeError, lambda: bool(asarray([False])))
assert_raises(TypeError, lambda: int(asarray([0])))
assert_raises(TypeError, lambda: float(asarray([0.0])))
+ assert_raises(TypeError, lambda: operator.index(asarray([0])))
+
+ # bool/int/float should only be allowed on arrays of the corresponding
+ # dtype
+ assert_raises(ValueError, lambda: bool(i))
+ assert_raises(ValueError, lambda: bool(f))
+
+ assert_raises(ValueError, lambda: int(b))
+ assert_raises(ValueError, lambda: int(f))
+
+ assert_raises(ValueError, lambda: float(b))
+ assert_raises(ValueError, lambda: float(i))
+
+ assert_raises(TypeError, lambda: operator.index(b))
+ assert_raises(TypeError, lambda: operator.index(f))
diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py
index 2ee23a47b..c13bc4262 100644
--- a/numpy/array_api/tests/test_creation_functions.py
+++ b/numpy/array_api/tests/test_creation_functions.py
@@ -8,30 +8,15 @@ from .._creation_functions import (
empty,
empty_like,
eye,
- from_dlpack,
full,
full_like,
linspace,
- meshgrid,
ones,
ones_like,
zeros,
zeros_like,
)
from .._array_object import Array
-from .._dtypes import (
- _all_dtypes,
- _boolean_dtypes,
- _floating_dtypes,
- _integer_dtypes,
- _integer_or_boolean_dtypes,
- _numeric_dtypes,
- int8,
- int16,
- int32,
- int64,
- uint64,
-)
def test_asarray_errors():