summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py48
1 files changed, 31 insertions, 17 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index ef66c5efd..dc74bb8c5 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -32,7 +32,7 @@ from ._dtypes import (
from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
- from ._typing import PyCapsule, Device, Dtype
+ from ._typing import Any, PyCapsule, Device, Dtype
import numpy as np
@@ -99,9 +99,13 @@ class Array:
"""
Performs the operation __repr__.
"""
- prefix = "Array("
suffix = f", dtype={self.dtype.name})"
- mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
+ if 0 in self.shape:
+ prefix = "empty("
+ mid = str(self.shape)
+ else:
+ prefix = "Array("
+ mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix
# These are various helper functions to make the array behavior match the
@@ -244,6 +248,10 @@ class Array:
The following cases are allowed by NumPy, but not specified by the array
API specification:
+ - Indices to not include an implicit ellipsis at the end. That is,
+ every axis of an array must be explicitly indexed or an ellipsis
+ included.
+
- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
following are allowed:
@@ -270,6 +278,10 @@ class Array:
return key
if shape == ():
return key
+ if len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
size = shape[0]
# Ensure invalid slice entries are passed through.
if key.start is not None:
@@ -277,7 +289,7 @@ class Array:
operator.index(key.start)
except TypeError:
return key
- if not (-size <= key.start <= max(0, size - 1)):
+ if not (-size <= key.start <= size):
raise IndexError(
"Slices with out-of-bounds start are not allowed in the array API namespace"
)
@@ -322,6 +334,10 @@ class Array:
zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
):
Array._validate_index(idx, (size,))
+ if n_ellipsis == 0 and len(key) < len(shape):
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
return key
elif isinstance(key, bool):
return key
@@ -339,7 +355,12 @@ class Array:
"newaxis indices are not allowed in the array API namespace"
)
try:
- return operator.index(key)
+ key = operator.index(key)
+ if shape is not None and len(shape) > 1:
+ raise IndexError(
+ "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ )
+ return key
except TypeError:
# Note: This also omits boolean arrays that are not already in
# Array() form, like a list of booleans.
@@ -403,16 +424,14 @@ class Array:
"""
Performs the operation __dlpack__.
"""
- res = self._array.__dlpack__(stream=stream)
- return self.__class__._new(res)
+ return self._array.__dlpack__(stream=stream)
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
"""
Performs the operation __dlpack_device__.
"""
# Note: device support is required for this
- res = self._array.__dlpack_device__()
- return self.__class__._new(res)
+ return self._array.__dlpack_device__()
def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
@@ -527,13 +546,6 @@ class Array:
res = self._array.__le__(other._array)
return self.__class__._new(res)
- # Note: __len__ may end up being removed from the array API spec.
- def __len__(self, /) -> int:
- """
- Performs the operation __len__.
- """
- return self._array.__len__()
-
def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
@@ -995,7 +1007,9 @@ class Array:
res = self._array.__rxor__(other._array)
return self.__class__._new(res)
- def to_device(self: Array, device: Device, /) -> Array:
+ def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
+ if stream is not None:
+ raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return self
raise ValueError(f"Unsupported device {device!r}")