summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/__init__.py18
-rw-r--r--numpy/array_api/_array_object.py54
-rw-r--r--numpy/array_api/_creation_functions.py23
-rw-r--r--numpy/array_api/_data_type_functions.py7
-rw-r--r--numpy/array_api/_linear_algebra_functions.py67
-rw-r--r--numpy/array_api/_searching_functions.py1
-rw-r--r--numpy/array_api/_set_functions.py89
-rw-r--r--numpy/array_api/_statistical_functions.py9
-rw-r--r--numpy/array_api/_typing.py26
-rw-r--r--numpy/array_api/linalg.py408
-rw-r--r--numpy/array_api/tests/test_array_object.py28
-rw-r--r--numpy/array_api/tests/test_creation_functions.py39
12 files changed, 616 insertions, 153 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index d8b29057e..bbe2fdce2 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -109,9 +109,6 @@ Still TODO in this module are:
- The spec is still in an RFC phase and may still have minor updates, which
will need to be reflected here.
-- The linear algebra extension in the spec will be added in a future pull
- request.
-
- Complex number support in array API spec is planned but not yet finalized,
as are the fft extension and certain linear algebra functions such as eig
that require complex dtypes.
@@ -169,6 +166,7 @@ __all__ += [
]
from ._data_type_functions import (
+ astype,
broadcast_arrays,
broadcast_to,
can_cast,
@@ -178,6 +176,7 @@ from ._data_type_functions import (
)
__all__ += [
+ "astype",
"broadcast_arrays",
"broadcast_to",
"can_cast",
@@ -332,12 +331,13 @@ __all__ += [
"trunc",
]
-# einsum is not yet implemented in the array API spec.
+# linalg is an extension in the array API spec, which is a sub-namespace. Only
+# a subset of functions in it are imported into the top-level namespace.
+from . import linalg
-# from ._linear_algebra_functions import einsum
-# __all__ += ['einsum']
+__all__ += ["linalg"]
-from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
+from .linalg import matmul, tensordot, matrix_transpose, vecdot
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
@@ -358,9 +358,9 @@ from ._searching_functions import argmax, argmin, nonzero, where
__all__ += ["argmax", "argmin", "nonzero", "where"]
-from ._set_functions import unique
+from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
-__all__ += ["unique"]
+__all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"]
from ._sorting_functions import argsort, sort
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 830319e8c..8794c5ea5 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -29,10 +29,10 @@ from ._dtypes import (
_dtype_categories,
)
-from typing import TYPE_CHECKING, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
- from ._typing import PyCapsule, Device, Dtype
+ from ._typing import Any, PyCapsule, Device, Dtype
import numpy as np
@@ -99,9 +99,13 @@ class Array:
"""
Performs the operation __repr__.
"""
- prefix = "Array("
suffix = f", dtype={self.dtype.name})"
- mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
+ if 0 in self.shape:
+ prefix = "empty("
+ mid = str(self.shape)
+ else:
+ prefix = "Array("
+ mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix
# These are various helper functions to make the array behavior match the
@@ -244,6 +248,10 @@ class Array:
The following cases are allowed by NumPy, but not specified by the array
API specification:
+ - Indices to not include an implicit ellipsis at the end. That is,
+ every axis of an array must be explicitly indexed or an ellipsis
+ included.
+
- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
following are allowed:
@@ -270,6 +278,10 @@ class Array:
return key
if shape == ():
return key
+ if len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
size = shape[0]
# Ensure invalid slice entries are passed through.
if key.start is not None:
@@ -277,7 +289,7 @@ class Array:
operator.index(key.start)
except TypeError:
return key
- if not (-size <= key.start <= max(0, size - 1)):
+ if not (-size <= key.start <= size):
raise IndexError(
"Slices with out-of-bounds start are not allowed in the array API namespace"
)
@@ -322,6 +334,10 @@ class Array:
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
):
Array._validate_index(idx, (size,))
+ if n_ellipsis == 0 and len(key) < len(shape):
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
return key
elif isinstance(key, bool):
return key
@@ -339,7 +355,12 @@ class Array:
"newaxis indices are not allowed in the array API namespace"
)
try:
- return operator.index(key)
+ key = operator.index(key)
+ if shape is not None and len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
+ return key
except TypeError:
# Note: This also omits boolean arrays that are not already in
# Array() form, like a list of booleans.
@@ -382,7 +403,7 @@ class Array:
def __array_namespace__(
self: Array, /, *, api_version: Optional[str] = None
- ) -> object:
+ ) -> Any:
if api_version is not None and not api_version.startswith("2021."):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
return array_api
@@ -403,16 +424,14 @@ class Array:
"""
Performs the operation __dlpack__.
"""
- res = self._array.__dlpack__(stream=stream)
- return self.__class__._new(res)
+ return self._array.__dlpack__(stream=stream)
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
"""
Performs the operation __dlpack_device__.
"""
# Note: device support is required for this
- res = self._array.__dlpack_device__()
- return self.__class__._new(res)
+ return self._array.__dlpack_device__()
def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
@@ -527,13 +546,6 @@ class Array:
res = self._array.__le__(other._array)
return self.__class__._new(res)
- # Note: __len__ may end up being removed from the array API spec.
- def __len__(self, /) -> int:
- """
- Performs the operation __len__.
- """
- return self._array.__len__()
-
def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
@@ -995,7 +1007,9 @@ class Array:
res = self._array.__rxor__(other._array)
return self.__class__._new(res)
- def to_device(self: Array, device: Device, /) -> Array:
+ def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
+ if stream is not None:
+ raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return self
raise ValueError(f"Unsupported device {device!r}")
@@ -1016,7 +1030,7 @@ class Array:
# Note: mT is new in array API spec (see matrix_transpose)
@property
def mT(self) -> Array:
- from ._linear_algebra_functions import matrix_transpose
+ from .linalg import matrix_transpose
return matrix_transpose(self)
@property
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index e36807468..741498ff6 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
Device,
Dtype,
NestedSequence,
- SupportsDLPack,
SupportsBufferProtocol,
)
from collections.abc import Sequence
@@ -36,14 +35,13 @@ def asarray(
int,
float,
NestedSequence[bool | int | float],
- SupportsDLPack,
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
- copy: Optional[bool] = None,
+ copy: Optional[Union[bool, np._CopyMode]] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
@@ -57,11 +55,13 @@ def asarray(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
- if copy is False:
+ if copy in (False, np._CopyMode.IF_NEEDED):
# Note: copy=False is not yet implemented in np.asarray
raise NotImplementedError("copy=False is not yet implemented")
- if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
- if copy is True:
+ if isinstance(obj, Array):
+ if dtype is not None and obj.dtype != dtype:
+ copy = True
+ if copy in (True, np._CopyMode.ALWAYS):
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
@@ -152,8 +152,9 @@ def eye(
def from_dlpack(x: object, /) -> Array:
- # Note: dlpack support is not yet implemented on Array
- raise NotImplementedError("DLPack support is not yet implemented")
+ from ._array_object import Array
+
+ return Array._new(np._from_dlpack(x))
def full(
@@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
from ._array_object import Array
+ # Note: unlike np.meshgrid, only inputs with all the same dtype are
+ # allowed
+
+ if len({a.dtype for a in arrays}) > 1:
+ raise ValueError("meshgrid inputs must all have the same dtype")
+
return [
Array._new(array)
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index 7ccbe9469..e4d6db61b 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -13,6 +13,13 @@ if TYPE_CHECKING:
import numpy as np
+# Note: astype is a function, not an array method as in NumPy.
+def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
+ if not copy and dtype == x.dtype:
+ return x
+ return Array._new(x._array.astype(dtype=dtype, copy=copy))
+
+
def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py
deleted file mode 100644
index 7a6c9846c..000000000
--- a/numpy/array_api/_linear_algebra_functions.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from __future__ import annotations
-
-from ._array_object import Array
-from ._dtypes import _numeric_dtypes, _result_type
-
-from typing import Optional, Sequence, Tuple, Union
-
-import numpy as np
-
-# einsum is not yet implemented in the array API spec.
-
-# def einsum():
-# """
-# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`.
-#
-# See its docstring for more information.
-# """
-# return np.einsum()
-
-
-def matmul(x1: Array, x2: Array, /) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
-
- See its docstring for more information.
- """
- # Note: the restriction to numeric dtypes only is different from
- # np.matmul.
- if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError("Only numeric dtypes are allowed in matmul")
- # Call result type here just to raise on disallowed type combinations
- _result_type(x1.dtype, x2.dtype)
-
- return Array._new(np.matmul(x1._array, x2._array))
-
-
-# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
-def tensordot(
- x1: Array,
- x2: Array,
- /,
- *,
- axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
-) -> Array:
- # Note: the restriction to numeric dtypes only is different from
- # np.tensordot.
- if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError("Only numeric dtypes are allowed in tensordot")
- # Call result type here just to raise on disallowed type combinations
- _result_type(x1.dtype, x2.dtype)
-
- return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
-
-
-# Note: this function is new in the array API spec. Unlike transpose, it only
-# transposes the last two axes.
-def matrix_transpose(x: Array, /) -> Array:
- if x.ndim < 2:
- raise ValueError("x must be at least 2-dimensional for matrix_transpose")
- return Array._new(np.swapaxes(x._array, -1, -2))
-
-
-# Note: vecdot is not in NumPy
-def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array:
- if axis is None:
- axis = -1
- return tensordot(x1, x2, axes=((axis,), (axis,)))
diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py
index 3dcef61c3..40f5a4d2e 100644
--- a/numpy/array_api/_searching_functions.py
+++ b/numpy/array_api/_searching_functions.py
@@ -43,4 +43,5 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.where(condition._array, x1._array, x2._array))
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py
index 357f238f5..05ee7e555 100644
--- a/numpy/array_api/_set_functions.py
+++ b/numpy/array_api/_set_functions.py
@@ -2,19 +2,82 @@ from __future__ import annotations
from ._array_object import Array
-from typing import Tuple, Union
+from typing import NamedTuple
import numpy as np
+# Note: np.unique() is split into four functions in the array API:
+# unique_all, unique_counts, unique_inverse, and unique_values (this is done
+# to remove polymorphic return types).
-def unique(
- x: Array,
- /,
- *,
- return_counts: bool = False,
- return_index: bool = False,
- return_inverse: bool = False,
-) -> Union[Array, Tuple[Array, ...]]:
+# Note: The various unique() functions are supposed to return multiple NaNs.
+# This does not match the NumPy behavior, however, this is currently left as a
+# TODO in this implementation as this behavior may be reverted in np.unique().
+# See https://github.com/numpy/numpy/issues/20326.
+
+# Note: The functions here return a namedtuple (np.unique() returns a normal
+# tuple).
+
+class UniqueAllResult(NamedTuple):
+ values: Array
+ indices: Array
+ inverse_indices: Array
+ counts: Array
+
+
+class UniqueCountsResult(NamedTuple):
+ values: Array
+ counts: Array
+
+
+class UniqueInverseResult(NamedTuple):
+ values: Array
+ inverse_indices: Array
+
+
+def unique_all(x: Array, /) -> UniqueAllResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=True,
+ return_inverse=True,
+ )
+
+ return UniqueAllResult(*[Array._new(i) for i in res])
+
+
+def unique_counts(x: Array, /) -> UniqueCountsResult:
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=False,
+ return_inverse=False,
+ )
+
+ return UniqueCountsResult(*[Array._new(i) for i in res])
+
+
+def unique_inverse(x: Array, /) -> UniqueInverseResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=False,
+ return_index=False,
+ return_inverse=True,
+ )
+ return UniqueInverseResult(*[Array._new(i) for i in res])
+
+
+def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
@@ -22,10 +85,8 @@ def unique(
"""
res = np.unique(
x._array,
- return_counts=return_counts,
- return_index=return_index,
- return_inverse=return_inverse,
+ return_counts=False,
+ return_index=False,
+ return_inverse=False,
)
- if isinstance(res, tuple):
- return tuple(Array._new(i) for i in res)
return Array._new(res)
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py
index c5abf9468..7bee3f4db 100644
--- a/numpy/array_api/_statistical_functions.py
+++ b/numpy/array_api/_statistical_functions.py
@@ -93,11 +93,12 @@ def sum(
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
- # Note: sum() and prod() always upcast float32 to float64 for dtype=None
- # We need to do so here before summing to avoid overflow
+ # Note: sum() and prod() always upcast integers to (u)int64 and float32 to
+ # float64 for dtype=None. `np.sum` does that too for integers, but not for
+ # float32, so we need to special-case it here
if dtype is None and x.dtype == float32:
- x = asarray(x, dtype=float64)
- return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims))
+ dtype = float64
+ return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
def var(
diff --git a/numpy/array_api/_typing.py b/numpy/array_api/_typing.py
index 519e8463c..dfa87b358 100644
--- a/numpy/array_api/_typing.py
+++ b/numpy/array_api/_typing.py
@@ -6,6 +6,8 @@ annotations in the function signatures. The functions in the module are only
valid for inputs that match the given type annotations.
"""
+from __future__ import annotations
+
__all__ = [
"Array",
"Device",
@@ -16,7 +18,16 @@ __all__ = [
]
import sys
-from typing import Any, Literal, Sequence, Type, Union, TYPE_CHECKING, TypeVar
+from typing import (
+ Any,
+ Literal,
+ Sequence,
+ Type,
+ Union,
+ TYPE_CHECKING,
+ TypeVar,
+ Protocol,
+)
from ._array_object import Array
from numpy import (
@@ -33,10 +44,11 @@ from numpy import (
float64,
)
-# This should really be recursive, but that isn't supported yet. See the
-# similar comment in numpy/typing/_array_like.py
-_T = TypeVar("_T")
-NestedSequence = Sequence[Sequence[_T]]
+_T_co = TypeVar("_T_co", covariant=True)
+
+class NestedSequence(Protocol[_T_co]):
+ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+ def __len__(self, /) -> int: ...
Device = Literal["cpu"]
if TYPE_CHECKING or sys.version_info >= (3, 9):
@@ -55,6 +67,8 @@ if TYPE_CHECKING or sys.version_info >= (3, 9):
else:
Dtype = dtype
-SupportsDLPack = Any
SupportsBufferProtocol = Any
PyCapsule = Any
+
+class SupportsDLPack(Protocol):
+ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py
new file mode 100644
index 000000000..8d7ba659e
--- /dev/null
+++ b/numpy/array_api/linalg.py
@@ -0,0 +1,408 @@
+from __future__ import annotations
+
+from ._dtypes import _floating_dtypes, _numeric_dtypes
+from ._array_object import Array
+
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from ._typing import Literal, Optional, Sequence, Tuple, Union
+
+from typing import NamedTuple
+
+import numpy.linalg
+import numpy as np
+
+class EighResult(NamedTuple):
+ eigenvalues: Array
+ eigenvectors: Array
+
+class QRResult(NamedTuple):
+ Q: Array
+ R: Array
+
+class SlogdetResult(NamedTuple):
+ sign: Array
+ logabsdet: Array
+
+class SVDResult(NamedTuple):
+ U: Array
+ S: Array
+ Vh: Array
+
+# Note: the inclusion of the upper keyword is different from
+# np.linalg.cholesky, which does not have it.
+def cholesky(x: Array, /, *, upper: bool = False) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.cholesky <numpy.linalg.cholesky>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.cholesky.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in cholesky')
+ L = np.linalg.cholesky(x._array)
+ if upper:
+ return Array._new(L).mT
+ return Array._new(L)
+
+# Note: cross is the numpy top-level namespace, not np.linalg
+def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in cross')
+ # Note: this is different from np.cross(), which broadcasts
+ if x1.shape != x2.shape:
+ raise ValueError('x1 and x2 must have the same shape')
+ if x1.ndim == 0:
+ raise ValueError('cross() requires arrays of dimension at least 1')
+ # Note: this is different from np.cross(), which allows dimension 2
+ if x1.shape[axis] != 3:
+ raise ValueError('cross() dimension must equal 3')
+ return Array._new(np.cross(x1._array, x2._array, axis=axis))
+
+def det(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.det.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in det')
+ return Array._new(np.linalg.det(x._array))
+
+# Note: diagonal is the numpy top-level namespace, not np.linalg
+def diagonal(x: Array, /, *, offset: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
+
+ See its docstring for more information.
+ """
+ # Note: diagonal always operates on the last two axes, whereas np.diagonal
+ # operates on the first two axes by default
+ return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
+
+
+# Note: the keyword argument name upper is different from np.linalg.eigh
+def eigh(x: Array, /) -> EighResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.eigh.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in eigh')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.eigh, which only returns a tuple.
+ return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
+
+
+# Note: the keyword argument name upper is different from np.linalg.eigvalsh
+def eigvalsh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.eigvalsh.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in eigvalsh')
+
+ return Array._new(np.linalg.eigvalsh(x._array))
+
+def inv(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.inv.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in inv')
+
+ return Array._new(np.linalg.inv(x._array))
+
+
+# Note: matmul is the numpy top-level namespace but not in np.linalg
+def matmul(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to numeric dtypes only is different from
+ # np.matmul.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in matmul')
+
+ return Array._new(np.matmul(x1._array, x2._array))
+
+
+# Note: the name here is different from norm(). The array API norm is split
+# into matrix_norm and vector_norm().
+
+# The type for ord should be Optional[Union[int, float, Literal[np.inf,
+# -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point
+# literals.
+def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.norm.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in matrix_norm')
+
+ return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord))
+
+
+def matrix_power(x: Array, n: int, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.matrix_power.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power')
+
+ # np.matrix_power already checks if n is an integer
+ return Array._new(np.linalg.matrix_power(x._array, n))
+
+# Note: the keyword argument name rtol is different from np.linalg.matrix_rank
+def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
+
+ See its docstring for more information.
+ """
+ # Note: this is different from np.linalg.matrix_rank, which supports 1
+ # dimensional arrays.
+ if x.ndim < 2:
+ raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
+ S = np.linalg.svd(x._array, compute_uv=False)
+ if rtol is None:
+ tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
+ else:
+ if isinstance(rtol, Array):
+ rtol = rtol._array
+ # Note: this is different from np.linalg.matrix_rank, which does not multiply
+ # the tolerance by the largest singular value.
+ tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
+ return Array._new(np.count_nonzero(S > tol, axis=-1))
+
+
+# Note: this function is new in the array API spec. Unlike transpose, it only
+# transposes the last two axes.
+def matrix_transpose(x: Array, /) -> Array:
+ if x.ndim < 2:
+ raise ValueError("x must be at least 2-dimensional for matrix_transpose")
+ return Array._new(np.swapaxes(x._array, -1, -2))
+
+# Note: outer is the numpy top-level namespace, not np.linalg
+def outer(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to numeric dtypes only is different from
+ # np.outer.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in outer')
+
+ # Note: the restriction to only 1-dim arrays is different from np.outer
+ if x1.ndim != 1 or x2.ndim != 1:
+ raise ValueError('The input arrays to outer must be 1-dimensional')
+
+ return Array._new(np.outer(x1._array, x2._array))
+
+# Note: the keyword argument name rtol is different from np.linalg.pinv
+def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.pinv.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in pinv')
+
+ # Note: this is different from np.linalg.pinv, which does not multiply the
+ # default tolerance by max(M, N).
+ if rtol is None:
+ rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps
+ return Array._new(np.linalg.pinv(x._array, rcond=rtol))
+
+def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.qr <numpy.linalg.qr>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.qr.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in qr')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.linalg.qr, which only returns a tuple.
+ return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode)))
+
+def slogdet(x: Array, /) -> SlogdetResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.slogdet <numpy.linalg.slogdet>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.slogdet.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in slogdet')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.linalg.slogdet, which only returns a tuple.
+ return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array)))
+
+# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
+# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
+# of matrices. The np.linalg.solve behavior of allowing stacks of both
+# matrices and vectors is ambiguous c.f.
+# https://github.com/numpy/numpy/issues/15349 and
+# https://github.com/data-apis/array-api/issues/285.
+
+# To workaround this, the below is the code from np.linalg.solve except
+# only calling solve1 in the exactly 1D case.
+def _solve(a, b):
+ from ..linalg.linalg import (_makearray, _assert_stacked_2d,
+ _assert_stacked_square, _commonType,
+ isComplexType, get_linalg_error_extobj,
+ _raise_linalgerror_singular)
+ from ..linalg import _umath_linalg
+
+ a, _ = _makearray(a)
+ _assert_stacked_2d(a)
+ _assert_stacked_square(a)
+ b, wrap = _makearray(b)
+ t, result_t = _commonType(a, b)
+
+ # This part is different from np.linalg.solve
+ if b.ndim == 1:
+ gufunc = _umath_linalg.solve1
+ else:
+ gufunc = _umath_linalg.solve
+
+ # This does nothing currently but is left in because it will be relevant
+ # when complex dtype support is added to the spec in 2022.
+ signature = 'DD->D' if isComplexType(t) else 'dd->d'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
+ r = gufunc(a, b, signature=signature, extobj=extobj)
+
+ return wrap(r.astype(result_t, copy=False))
+
+def solve(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.solve <numpy.linalg.solve>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.solve.
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in solve')
+
+ return Array._new(_solve(x1._array, x2._array))
+
+def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.svd <numpy.linalg.svd>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.svd.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svd')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.svd, which only returns a tuple.
+ return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices)))
+
+# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
+# np.linalg.svd(compute_uv=False).
+def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
+ return Array._new(np.linalg.svd(x._array, compute_uv=False))
+
+# Note: tensordot is the numpy top-level namespace but not in np.linalg
+
+# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
+def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
+ # Note: the restriction to numeric dtypes only is different from
+ # np.tensordot.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in tensordot')
+
+ return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
+
+# Note: trace is the numpy top-level namespace, not np.linalg
+def trace(x: Array, /, *, offset: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
+
+ See its docstring for more information.
+ """
+ # Note: trace always operates on the last two axes, whereas np.trace
+ # operates on the first two axes by default
+ return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
+
+# Note: vecdot is not in NumPy
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ return tensordot(x1, x2, axes=((axis,), (axis,)))
+
+
+# Note: the name here is different from norm(). The array API norm is split
+# into matrix_norm and vector_norm().
+
+# The type for ord should be Optional[Union[int, float, Literal[np.inf,
+# -np.inf]]] but Literal does not support floating-point literals.
+def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.norm.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in norm')
+
+ a = x._array
+ if axis is None:
+ a = a.flatten()
+ axis = 0
+ elif isinstance(axis, tuple):
+ # Note: The axis argument supports any number of axes, whereas norm()
+ # only supports a single axis for vector norm.
+ rest = tuple(i for i in range(a.ndim) if i not in axis)
+ newshape = axis + rest
+ a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest]))
+ axis = 0
+ return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord))
+
+
+__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index 7959f92b4..12479d765 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -3,7 +3,7 @@ import operator
from numpy.testing import assert_raises
import numpy as np
-from .. import ones, asarray, result_type
+from .. import ones, asarray, result_type, all, equal
from .._dtypes import (
_all_dtypes,
_boolean_dtypes,
@@ -39,18 +39,18 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[:-4])
assert_raises(IndexError, lambda: a[:3:-1])
assert_raises(IndexError, lambda: a[:-5:-1])
- assert_raises(IndexError, lambda: a[3:])
+ assert_raises(IndexError, lambda: a[4:])
assert_raises(IndexError, lambda: a[-4:])
- assert_raises(IndexError, lambda: a[3::-1])
+ assert_raises(IndexError, lambda: a[4::-1])
assert_raises(IndexError, lambda: a[-4::-1])
assert_raises(IndexError, lambda: a[...,:5])
assert_raises(IndexError, lambda: a[...,:-5])
- assert_raises(IndexError, lambda: a[...,:4:-1])
+ assert_raises(IndexError, lambda: a[...,:5:-1])
assert_raises(IndexError, lambda: a[...,:-6:-1])
- assert_raises(IndexError, lambda: a[...,4:])
+ assert_raises(IndexError, lambda: a[...,5:])
assert_raises(IndexError, lambda: a[...,-5:])
- assert_raises(IndexError, lambda: a[...,4::-1])
+ assert_raises(IndexError, lambda: a[...,5::-1])
assert_raises(IndexError, lambda: a[...,-5::-1])
# Boolean indices cannot be part of a larger tuple index
@@ -74,6 +74,11 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[None, ...])
assert_raises(IndexError, lambda: a[..., None])
+ # Multiaxis indices must contain exactly as many indices as dimensions
+ assert_raises(IndexError, lambda: a[()])
+ assert_raises(IndexError, lambda: a[0,])
+ assert_raises(IndexError, lambda: a[0])
+ assert_raises(IndexError, lambda: a[:])
def test_operators():
# For every operator, we test that it works for the required type
@@ -285,3 +290,14 @@ def test_python_scalar_construtors():
assert_raises(TypeError, lambda: operator.index(b))
assert_raises(TypeError, lambda: operator.index(f))
+
+
+def test_device_property():
+ a = ones((3, 4))
+ assert a.device == 'cpu'
+
+ assert all(equal(a.to_device('cpu'), a))
+ assert_raises(ValueError, lambda: a.to_device('gpu'))
+
+ assert all(equal(asarray(a, device='cpu'), a))
+ assert_raises(ValueError, lambda: asarray(a, device='gpu'))
diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py
index 3cb8865cd..be9eaa383 100644
--- a/numpy/array_api/tests/test_creation_functions.py
+++ b/numpy/array_api/tests/test_creation_functions.py
@@ -8,7 +8,6 @@ from .._creation_functions import (
empty,
empty_like,
eye,
- from_dlpack,
full,
full_like,
linspace,
@@ -18,20 +17,8 @@ from .._creation_functions import (
zeros,
zeros_like,
)
+from .._dtypes import float32, float64
from .._array_object import Array
-from .._dtypes import (
- _all_dtypes,
- _boolean_dtypes,
- _floating_dtypes,
- _integer_dtypes,
- _integer_or_boolean_dtypes,
- _numeric_dtypes,
- int8,
- int16,
- int32,
- int64,
- uint64,
-)
def test_asarray_errors():
@@ -56,12 +43,18 @@ def test_asarray_copy():
a[0] = 0
assert all(b[0] == 1)
assert all(a[0] == 0)
- # Once copy=False is implemented, replace this with
- # a = asarray([1])
- # b = asarray(a, copy=False)
- # a[0] = 0
- # assert all(b[0] == 0)
+ a = asarray([1])
+ b = asarray(a, copy=np._CopyMode.ALWAYS)
+ a[0] = 0
+ assert all(b[0] == 1)
+ assert all(a[0] == 0)
+ a = asarray([1])
+ b = asarray(a, copy=np._CopyMode.NEVER)
+ a[0] = 0
+ assert all(b[0] == 0)
assert_raises(NotImplementedError, lambda: asarray(a, copy=False))
+ assert_raises(NotImplementedError,
+ lambda: asarray(a, copy=np._CopyMode.IF_NEEDED))
def test_arange_errors():
@@ -139,3 +132,11 @@ def test_zeros_like_errors():
assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))
+
+def test_meshgrid_dtype_errors():
+ # Doesn't raise
+ meshgrid()
+ meshgrid(asarray([1.], dtype=float32))
+ meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32))
+
+ assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64)))