summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-11-12 08:07:23 -0700
committerGitHub <noreply@github.com>2021-11-12 16:07:23 +0100
commitff2e2a1e7eea29d925063b13922e096d14331222 (patch)
treebba7de819e8fbeb145d091a31f0881ea7549133f /numpy/array_api/_array_object.py
parent689d905f501b7abbddf0fdef241fa586a83e5cd6 (diff)
downloadnumpy-ff2e2a1e7eea29d925063b13922e096d14331222.tar.gz
MAINT: A few updates to the array_api (#20066)
* Allow casting in the array API asarray() * Restrict multidimensional indexing in the array API namespace The spec has recently been updated to only require multiaxis (i.e., tuple) indices in the case where every axis is indexed, meaning there are either as many indices as axes or the index has an ellipsis. * Fix type promotion for numpy.array_api.where where does value-based promotion for 0-dimensional arrays, so we use the same trick as in the Array operators to avoid this. * Print empty array_api arrays using empty() Printing behavior isn't required by the spec. This is just to make things easier to understand, especially with the array API test suite. * Fix an incorrect slice bounds guard in the array API * Disallow multiple different dtypes in the input to np.array_api.meshgrid * Remove DLPack support from numpy.array_api.asarray() from_dlpack() should be used to create arrays using DLPack. * Remove __len__ from the array API array object * Add astype() to numpy.array_api * Update the unique_* functions in numpy.array_api unique() in the array API was replaced with three separate functions, unique_all(), unique_inverse(), and unique_values(), in order to avoid polymorphic return types. Additionally, it should be noted that these functions to not currently conform to the spec with respect to NaN behavior. The spec requires multiple NaNs to be returned, but np.unique() returns a single NaN. Since this is currently an open issue in NumPy to possibly revert, I have not yet worked around this. See https://github.com/numpy/numpy/issues/20326. * Add the stream argument to the array API to_device method This does nothing in NumPy, and is just present so that the signature is valid according to the spec. * Use the NamedTuple classes for the type signatures * Add unique_counts to the array API namespace * Remove some unused imports * Update the array_api indexing restrictions The "multiaxis indexing must index every axis explicitly or use an ellipsis" was supposed to include any type of index, not just tuple indices. * Use a simpler type annotation for the array API to_device method * Fix a test failure in the array_api submodule The array_api cannot use the NumPy testing functions because array_api arrays do not mix with NumPy arrays, and also NumPy testing functions may use APIs that aren't supported in the array API. * Add dlpack support to the array_api submodule
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py48
1 files changed, 31 insertions, 17 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index ef66c5efd..dc74bb8c5 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -32,7 +32,7 @@ from ._dtypes import (
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
- from ._typing import PyCapsule, Device, Dtype
+ from ._typing import Any, PyCapsule, Device, Dtype
import numpy as np
@@ -99,9 +99,13 @@ class Array:
"""
Performs the operation __repr__.
"""
- prefix = "Array("
suffix = f", dtype={self.dtype.name})"
- mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
+ if 0 in self.shape:
+ prefix = "empty("
+ mid = str(self.shape)
+ else:
+ prefix = "Array("
+ mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix
# These are various helper functions to make the array behavior match the
@@ -244,6 +248,10 @@ class Array:
The following cases are allowed by NumPy, but not specified by the array
API specification:
+ - Indices to not include an implicit ellipsis at the end. That is,
+ every axis of an array must be explicitly indexed or an ellipsis
+ included.
+
- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
following are allowed:
@@ -270,6 +278,10 @@ class Array:
return key
if shape == ():
return key
+ if len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
size = shape[0]
# Ensure invalid slice entries are passed through.
if key.start is not None:
@@ -277,7 +289,7 @@ class Array:
operator.index(key.start)
except TypeError:
return key
- if not (-size <= key.start <= max(0, size - 1)):
+ if not (-size <= key.start <= size):
raise IndexError(
"Slices with out-of-bounds start are not allowed in the array API namespace"
)
@@ -322,6 +334,10 @@ class Array:
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
):
Array._validate_index(idx, (size,))
+ if n_ellipsis == 0 and len(key) < len(shape):
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
return key
elif isinstance(key, bool):
return key
@@ -339,7 +355,12 @@ class Array:
"newaxis indices are not allowed in the array API namespace"
)
try:
- return operator.index(key)
+ key = operator.index(key)
+ if shape is not None and len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
+ return key
except TypeError:
# Note: This also omits boolean arrays that are not already in
# Array() form, like a list of booleans.
@@ -403,16 +424,14 @@ class Array:
"""
Performs the operation __dlpack__.
"""
- res = self._array.__dlpack__(stream=stream)
- return self.__class__._new(res)
+ return self._array.__dlpack__(stream=stream)
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
"""
Performs the operation __dlpack_device__.
"""
# Note: device support is required for this
- res = self._array.__dlpack_device__()
- return self.__class__._new(res)
+ return self._array.__dlpack_device__()
def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
@@ -527,13 +546,6 @@ class Array:
res = self._array.__le__(other._array)
return self.__class__._new(res)
- # Note: __len__ may end up being removed from the array API spec.
- def __len__(self, /) -> int:
- """
- Performs the operation __len__.
- """
- return self._array.__len__()
-
def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
@@ -995,7 +1007,9 @@ class Array:
res = self._array.__rxor__(other._array)
return self.__class__._new(res)
- def to_device(self: Array, device: Device, /) -> Array:
+ def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
+ if stream is not None:
+ raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return self
raise ValueError(f"Unsupported device {device!r}")