summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/__init__.py10
-rw-r--r--numpy/array_api/_array_object.py48
-rw-r--r--numpy/array_api/_creation_functions.py19
-rw-r--r--numpy/array_api/_data_type_functions.py7
-rw-r--r--numpy/array_api/_searching_functions.py1
-rw-r--r--numpy/array_api/_set_functions.py89
-rw-r--r--numpy/array_api/tests/test_array_object.py21
-rw-r--r--numpy/array_api/tests/test_creation_functions.py10
8 files changed, 156 insertions, 49 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index 89f5e9cba..36e3f3ed5 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -136,7 +136,7 @@ from ._creation_functions import (
empty,
empty_like,
eye,
- _from_dlpack,
+ from_dlpack,
full,
full_like,
linspace,
@@ -155,7 +155,7 @@ __all__ += [
"empty",
"empty_like",
"eye",
- "_from_dlpack",
+ "from_dlpack",
"full",
"full_like",
"linspace",
@@ -169,6 +169,7 @@ __all__ += [
]
from ._data_type_functions import (
+ astype,
broadcast_arrays,
broadcast_to,
can_cast,
@@ -178,6 +179,7 @@ from ._data_type_functions import (
)
__all__ += [
+ "astype",
"broadcast_arrays",
"broadcast_to",
"can_cast",
@@ -358,9 +360,9 @@ from ._searching_functions import argmax, argmin, nonzero, where
__all__ += ["argmax", "argmin", "nonzero", "where"]
-from ._set_functions import unique
+from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
-__all__ += ["unique"]
+__all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"]
from ._sorting_functions import argsort, sort
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index ef66c5efd..dc74bb8c5 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -32,7 +32,7 @@ from ._dtypes import (
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
- from ._typing import PyCapsule, Device, Dtype
+ from ._typing import Any, PyCapsule, Device, Dtype
import numpy as np
@@ -99,9 +99,13 @@ class Array:
"""
Performs the operation __repr__.
"""
- prefix = "Array("
suffix = f", dtype={self.dtype.name})"
- mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
+ if 0 in self.shape:
+ prefix = "empty("
+ mid = str(self.shape)
+ else:
+ prefix = "Array("
+ mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix
# These are various helper functions to make the array behavior match the
@@ -244,6 +248,10 @@ class Array:
The following cases are allowed by NumPy, but not specified by the array
API specification:
+ - Indices to not include an implicit ellipsis at the end. That is,
+ every axis of an array must be explicitly indexed or an ellipsis
+ included.
+
- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
following are allowed:
@@ -270,6 +278,10 @@ class Array:
return key
if shape == ():
return key
+ if len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
size = shape[0]
# Ensure invalid slice entries are passed through.
if key.start is not None:
@@ -277,7 +289,7 @@ class Array:
operator.index(key.start)
except TypeError:
return key
- if not (-size <= key.start <= max(0, size - 1)):
+ if not (-size <= key.start <= size):
raise IndexError(
"Slices with out-of-bounds start are not allowed in the array API namespace"
)
@@ -322,6 +334,10 @@ class Array:
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
):
Array._validate_index(idx, (size,))
+ if n_ellipsis == 0 and len(key) < len(shape):
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
return key
elif isinstance(key, bool):
return key
@@ -339,7 +355,12 @@ class Array:
"newaxis indices are not allowed in the array API namespace"
)
try:
- return operator.index(key)
+ key = operator.index(key)
+ if shape is not None and len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
+ return key
except TypeError:
# Note: This also omits boolean arrays that are not already in
# Array() form, like a list of booleans.
@@ -403,16 +424,14 @@ class Array:
"""
Performs the operation __dlpack__.
"""
- res = self._array.__dlpack__(stream=stream)
- return self.__class__._new(res)
+ return self._array.__dlpack__(stream=stream)
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
"""
Performs the operation __dlpack_device__.
"""
# Note: device support is required for this
- res = self._array.__dlpack_device__()
- return self.__class__._new(res)
+ return self._array.__dlpack_device__()
def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
@@ -527,13 +546,6 @@ class Array:
res = self._array.__le__(other._array)
return self.__class__._new(res)
- # Note: __len__ may end up being removed from the array API spec.
- def __len__(self, /) -> int:
- """
- Performs the operation __len__.
- """
- return self._array.__len__()
-
def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
@@ -995,7 +1007,9 @@ class Array:
res = self._array.__rxor__(other._array)
return self.__class__._new(res)
- def to_device(self: Array, device: Device, /) -> Array:
+ def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
+ if stream is not None:
+ raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return self
raise ValueError(f"Unsupported device {device!r}")
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index c3644ac2c..23beec444 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
Device,
Dtype,
NestedSequence,
- SupportsDLPack,
SupportsBufferProtocol,
)
from collections.abc import Sequence
@@ -36,7 +35,6 @@ def asarray(
int,
float,
NestedSequence[bool | int | float],
- SupportsDLPack,
SupportsBufferProtocol,
],
/,
@@ -60,7 +58,9 @@ def asarray(
if copy is False:
# Note: copy=False is not yet implemented in np.asarray
raise NotImplementedError("copy=False is not yet implemented")
- if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
+ if isinstance(obj, Array):
+ if dtype is not None and obj.dtype != dtype:
+ copy = True
if copy is True:
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
@@ -151,9 +151,10 @@ def eye(
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
-def _from_dlpack(x: object, /) -> Array:
- # Note: dlpack support is not yet implemented on Array
- raise NotImplementedError("DLPack support is not yet implemented")
+def from_dlpack(x: object, /) -> Array:
+ from ._array_object import Array
+
+ return Array._new(np._from_dlpack(x))
def full(
@@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
from ._array_object import Array
+ # Note: unlike np.meshgrid, only inputs with all the same dtype are
+ # allowed
+
+ if len({a.dtype for a in arrays}) > 1:
+ raise ValueError("meshgrid inputs must all have the same dtype")
+
return [
Array._new(array)
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py
index 7ccbe9469..e4d6db61b 100644
--- a/numpy/array_api/_data_type_functions.py
+++ b/numpy/array_api/_data_type_functions.py
@@ -13,6 +13,13 @@ if TYPE_CHECKING:
import numpy as np
+# Note: astype is a function, not an array method as in NumPy.
+def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
+ if not copy and dtype == x.dtype:
+ return x
+ return Array._new(x._array.astype(dtype=dtype, copy=copy))
+
+
def broadcast_arrays(*arrays: Array) -> List[Array]:
"""
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py
index 3dcef61c3..40f5a4d2e 100644
--- a/numpy/array_api/_searching_functions.py
+++ b/numpy/array_api/_searching_functions.py
@@ -43,4 +43,5 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.where(condition._array, x1._array, x2._array))
diff --git a/numpy/array_api/_set_functions.py b/numpy/array_api/_set_functions.py
index 357f238f5..05ee7e555 100644
--- a/numpy/array_api/_set_functions.py
+++ b/numpy/array_api/_set_functions.py
@@ -2,19 +2,82 @@ from __future__ import annotations
from ._array_object import Array
-from typing import Tuple, Union
+from typing import NamedTuple
import numpy as np
+# Note: np.unique() is split into four functions in the array API:
+# unique_all, unique_counts, unique_inverse, and unique_values (this is done
+# to remove polymorphic return types).
-def unique(
- x: Array,
- /,
- *,
- return_counts: bool = False,
- return_index: bool = False,
- return_inverse: bool = False,
-) -> Union[Array, Tuple[Array, ...]]:
+# Note: The various unique() functions are supposed to return multiple NaNs.
+# This does not match the NumPy behavior, however, this is currently left as a
+# TODO in this implementation as this behavior may be reverted in np.unique().
+# See https://github.com/numpy/numpy/issues/20326.
+
+# Note: The functions here return a namedtuple (np.unique() returns a normal
+# tuple).
+
+class UniqueAllResult(NamedTuple):
+ values: Array
+ indices: Array
+ inverse_indices: Array
+ counts: Array
+
+
+class UniqueCountsResult(NamedTuple):
+ values: Array
+ counts: Array
+
+
+class UniqueInverseResult(NamedTuple):
+ values: Array
+ inverse_indices: Array
+
+
+def unique_all(x: Array, /) -> UniqueAllResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=True,
+ return_inverse=True,
+ )
+
+ return UniqueAllResult(*[Array._new(i) for i in res])
+
+
+def unique_counts(x: Array, /) -> UniqueCountsResult:
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=False,
+ return_inverse=False,
+ )
+
+ return UniqueCountsResult(*[Array._new(i) for i in res])
+
+
+def unique_inverse(x: Array, /) -> UniqueInverseResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=False,
+ return_index=False,
+ return_inverse=True,
+ )
+ return UniqueInverseResult(*[Array._new(i) for i in res])
+
+
+def unique_values(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
@@ -22,10 +85,8 @@ def unique(
"""
res = np.unique(
x._array,
- return_counts=return_counts,
- return_index=return_index,
- return_inverse=return_inverse,
+ return_counts=False,
+ return_index=False,
+ return_inverse=False,
)
- if isinstance(res, tuple):
- return tuple(Array._new(i) for i in res)
return Array._new(res)
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index fb42cf621..12479d765 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -3,7 +3,7 @@ import operator
from numpy.testing import assert_raises
import numpy as np
-from .. import ones, asarray, result_type
+from .. import ones, asarray, result_type, all, equal
from .._dtypes import (
_all_dtypes,
_boolean_dtypes,
@@ -39,18 +39,18 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[:-4])
assert_raises(IndexError, lambda: a[:3:-1])
assert_raises(IndexError, lambda: a[:-5:-1])
- assert_raises(IndexError, lambda: a[3:])
+ assert_raises(IndexError, lambda: a[4:])
assert_raises(IndexError, lambda: a[-4:])
- assert_raises(IndexError, lambda: a[3::-1])
+ assert_raises(IndexError, lambda: a[4::-1])
assert_raises(IndexError, lambda: a[-4::-1])
assert_raises(IndexError, lambda: a[...,:5])
assert_raises(IndexError, lambda: a[...,:-5])
- assert_raises(IndexError, lambda: a[...,:4:-1])
+ assert_raises(IndexError, lambda: a[...,:5:-1])
assert_raises(IndexError, lambda: a[...,:-6:-1])
- assert_raises(IndexError, lambda: a[...,4:])
+ assert_raises(IndexError, lambda: a[...,5:])
assert_raises(IndexError, lambda: a[...,-5:])
- assert_raises(IndexError, lambda: a[...,4::-1])
+ assert_raises(IndexError, lambda: a[...,5::-1])
assert_raises(IndexError, lambda: a[...,-5::-1])
# Boolean indices cannot be part of a larger tuple index
@@ -74,6 +74,11 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[None, ...])
assert_raises(IndexError, lambda: a[..., None])
+ # Multiaxis indices must contain exactly as many indices as dimensions
+ assert_raises(IndexError, lambda: a[()])
+ assert_raises(IndexError, lambda: a[0,])
+ assert_raises(IndexError, lambda: a[0])
+ assert_raises(IndexError, lambda: a[:])
def test_operators():
# For every operator, we test that it works for the required type
@@ -291,8 +296,8 @@ def test_device_property():
a = ones((3, 4))
assert a.device == 'cpu'
- assert np.array_equal(a.to_device('cpu'), a)
+ assert all(equal(a.to_device('cpu'), a))
assert_raises(ValueError, lambda: a.to_device('gpu'))
- assert np.array_equal(asarray(a, device='cpu'), a)
+ assert all(equal(asarray(a, device='cpu'), a))
assert_raises(ValueError, lambda: asarray(a, device='gpu'))
diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py
index 7b633eaf1..ebbb6aab3 100644
--- a/numpy/array_api/tests/test_creation_functions.py
+++ b/numpy/array_api/tests/test_creation_functions.py
@@ -11,11 +11,13 @@ from .._creation_functions import (
full,
full_like,
linspace,
+ meshgrid,
ones,
ones_like,
zeros,
zeros_like,
)
+from .._dtypes import float32, float64
from .._array_object import Array
@@ -124,3 +126,11 @@ def test_zeros_like_errors():
assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))
+
+def test_meshgrid_dtype_errors():
+ # Doesn't raise
+ meshgrid()
+ meshgrid(asarray([1.], dtype=float32))
+ meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32))
+
+ assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64)))