summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
authorMatthew Barber <quitesimplymatt@gmail.com>2022-05-06 09:27:27 +0000
committerGitHub <noreply@github.com>2022-05-06 11:27:27 +0200
commitbefef7b26773eddd2b656a3ab87f504e6cc173db (patch)
tree95c9ea72750052d759f87f35abe43019f0478b71 /numpy/array_api/_array_object.py
parentba54f569cecf17812695f17812d238af2bb91000 (diff)
downloadnumpy-befef7b26773eddd2b656a3ab87f504e6cc173db.tar.gz
API: Allow newaxis indexing for `array_api` arrays (#21377)
* TST: Add test checking if newaxis indexing works for `array_api` Also removes previous check against newaxis indexing, which is now outdated * TST, BUG: Allow `None` in `array_api` indexing Introduces test for validating flat indexing when `None` is present * MAINT,DOC,TST: Rework of `_validate_index()` in `numpy.array_api` _validate_index() is now called as self._validate_index(shape), and does not return a key. This rework removes the recursive pattern used. Tests are introduced to cover some edge cases. Additionally, its internal docstring reflects new behaviour, and extends the flat indexing note. * MAINT: `advance` -> `advanced` (integer indexing) Co-authored-by: Aaron Meurer <asmeurer@gmail.com> * BUG: array_api arrays use internal arrays from array_api array keys When an array_api array is passed as the key for get/setitem, we access the key's internal np.ndarray array to be used as the key for the internal get/setitem operation. This behaviour was initially removed when `_validate_index()` was reworked. * MAINT: Better flat indexing error message for `array_api` arrays Also better semantics for its prior ellipsis count condition Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> * MAINT: `array_api` arrays don't special case multi-ellipsis errors This gets handled by NumPy-proper. Co-authored-by: Aaron Meurer <asmeurer@gmail.com> Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py212
1 files changed, 119 insertions, 93 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 6cf9ec6f3..c4746fad9 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -29,7 +29,7 @@ from ._dtypes import (
_dtype_categories,
)
-from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
+from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
import types
if TYPE_CHECKING:
@@ -243,8 +243,7 @@ class Array:
# Note: A large fraction of allowed indices are disallowed here (see the
# docstring below)
- @staticmethod
- def _validate_index(key, shape):
+ def _validate_index(self, key):
"""
Validate an index according to the array API.
@@ -257,8 +256,7 @@ class Array:
https://data-apis.org/array-api/latest/API_specification/indexing.html
for the full list of required indexing behavior
- This function either raises IndexError if the index ``key`` is
- invalid, or a new key to be used in place of ``key`` in indexing. It
+ This function raises IndexError if the index ``key`` is invalid. It
only raises ``IndexError`` on indices that are not already rejected by
NumPy, as NumPy will already raise the appropriate error on such
indices. ``shape`` may be None, in which case, only cases that are
@@ -269,7 +267,7 @@ class Array:
- Indices to not include an implicit ellipsis at the end. That is,
every axis of an array must be explicitly indexed or an ellipsis
- included.
+ included. This behaviour is sometimes referred to as flat indexing.
- The start and stop of a slice may not be out of bounds. In
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
@@ -292,100 +290,122 @@ class Array:
``Array._new`` constructor, not this function.
"""
- if isinstance(key, slice):
- if shape is None:
- return key
- if shape == ():
- return key
- if len(shape) > 1:
+ _key = key if isinstance(key, tuple) else (key,)
+ for i in _key:
+ if isinstance(i, bool) or not (
+ isinstance(i, SupportsIndex) # i.e. ints
+ or isinstance(i, slice)
+ or i == Ellipsis
+ or i is None
+ or isinstance(i, Array)
+ or isinstance(i, np.ndarray)
+ ):
raise IndexError(
- "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ f"Single-axes index {i} has {type(i)=}, but only "
+ "integers, slices (:), ellipsis (...), newaxis (None), "
+ "zero-dimensional integer arrays and boolean arrays "
+ "are specified in the Array API."
)
- size = shape[0]
- # Ensure invalid slice entries are passed through.
- if key.start is not None:
- try:
- operator.index(key.start)
- except TypeError:
- return key
- if not (-size <= key.start <= size):
- raise IndexError(
- "Slices with out-of-bounds start are not allowed in the array API namespace"
- )
- if key.stop is not None:
- try:
- operator.index(key.stop)
- except TypeError:
- return key
- step = 1 if key.step is None else key.step
- if (step > 0 and not (-size <= key.stop <= size)
- or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))):
- raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace")
- return key
-
- elif isinstance(key, tuple):
- key = tuple(Array._validate_index(idx, None) for idx in key)
-
- for idx in key:
- if (
- isinstance(idx, np.ndarray)
- and idx.dtype in _boolean_dtypes
- or isinstance(idx, (bool, np.bool_))
- ):
- if len(key) == 1:
- return key
- raise IndexError(
- "Boolean array indices combined with other indices are not allowed in the array API namespace"
- )
- if isinstance(idx, tuple):
- raise IndexError(
- "Nested tuple indices are not allowed in the array API namespace"
- )
-
- if shape is None:
- return key
- n_ellipsis = key.count(...)
- if n_ellipsis > 1:
- return key
- ellipsis_i = key.index(...) if n_ellipsis else len(key)
- for idx, size in list(zip(key[:ellipsis_i], shape)) + list(
- zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])
- ):
- Array._validate_index(idx, (size,))
- if n_ellipsis == 0 and len(key) < len(shape):
+ nonexpanding_key = []
+ single_axes = []
+ n_ellipsis = 0
+ key_has_mask = False
+ for i in _key:
+ if i is not None:
+ nonexpanding_key.append(i)
+ if isinstance(i, Array) or isinstance(i, np.ndarray):
+ if i.dtype in _boolean_dtypes:
+ key_has_mask = True
+ single_axes.append(i)
+ else:
+ # i must not be an array here, to avoid elementwise equals
+ if i == Ellipsis:
+ n_ellipsis += 1
+ else:
+ single_axes.append(i)
+
+ n_single_axes = len(single_axes)
+ if n_ellipsis > 1:
+ return # handled by ndarray
+ elif n_ellipsis == 0:
+ # Note boolean masks must be the sole index, which we check for
+ # later on.
+ if not key_has_mask and n_single_axes < self.ndim:
raise IndexError(
- "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ f"{self.ndim=}, but the multi-axes index only specifies "
+ f"{n_single_axes} dimensions. If this was intentional, "
+ "add a trailing ellipsis (...) which expands into as many "
+ "slices (:) as necessary - this is what np.ndarray arrays "
+ "implicitly do, but such flat indexing behaviour is not "
+ "specified in the Array API."
)
- return key
- elif isinstance(key, bool):
- return key
- elif isinstance(key, Array):
- if key.dtype in _integer_dtypes:
- if key.ndim != 0:
+
+ if n_ellipsis == 0:
+ indexed_shape = self.shape
+ else:
+ ellipsis_start = None
+ for pos, i in enumerate(nonexpanding_key):
+ if not (isinstance(i, Array) or isinstance(i, np.ndarray)):
+ if i == Ellipsis:
+ ellipsis_start = pos
+ break
+ assert ellipsis_start is not None # sanity check
+ ellipsis_end = self.ndim - (n_single_axes - ellipsis_start)
+ indexed_shape = (
+ self.shape[:ellipsis_start] + self.shape[ellipsis_end:]
+ )
+ for i, side in zip(single_axes, indexed_shape):
+ if isinstance(i, slice):
+ if side == 0:
+ f_range = "0 (or None)"
+ else:
+ f_range = f"between -{side} and {side - 1} (or None)"
+ if i.start is not None:
+ try:
+ start = operator.index(i.start)
+ except TypeError:
+ pass # handled by ndarray
+ else:
+ if not (-side <= start <= side):
+ raise IndexError(
+ f"Slice {i} contains {start=}, but should be "
+ f"{f_range} for an axis of size {side} "
+ "(out-of-bounds starts are not specified in "
+ "the Array API)"
+ )
+ if i.stop is not None:
+ try:
+ stop = operator.index(i.stop)
+ except TypeError:
+ pass # handled by ndarray
+ else:
+ if not (-side <= stop <= side):
+ raise IndexError(
+ f"Slice {i} contains {stop=}, but should be "
+ f"{f_range} for an axis of size {side} "
+ "(out-of-bounds stops are not specified in "
+ "the Array API)"
+ )
+ elif isinstance(i, Array):
+ if i.dtype in _boolean_dtypes and len(_key) != 1:
+ assert isinstance(key, tuple) # sanity check
raise IndexError(
- "Non-zero dimensional integer array indices are not allowed in the array API namespace"
+ f"Single-axes index {i} is a boolean array and "
+ f"{len(key)=}, but masking is only specified in the "
+ "Array API when the array is the sole index."
)
- return key._array
- elif key is Ellipsis:
- return key
- elif key is None:
- raise IndexError(
- "newaxis indices are not allowed in the array API namespace"
- )
- try:
- key = operator.index(key)
- if shape is not None and len(shape) > 1:
+ elif i.dtype in _integer_dtypes and i.ndim != 0:
+ raise IndexError(
+ f"Single-axes index {i} is a non-zero-dimensional "
+ "integer array, but advanced integer indexing is not "
+ "specified in the Array API."
+ )
+ elif isinstance(i, tuple):
raise IndexError(
- "Multidimensional arrays must include an index for every axis or use an ellipsis"
+ f"Single-axes index {i} is a tuple, but nested tuple "
+ "indices are not specified in the Array API."
)
- return key
- except TypeError:
- # Note: This also omits boolean arrays that are not already in
- # Array() form, like a list of booleans.
- raise IndexError(
- "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace"
- )
# Everything below this line is required by the spec.
@@ -511,7 +531,10 @@ class Array:
"""
# Note: Only indices required by the spec are allowed. See the
# docstring of _validate_index
- key = self._validate_index(key, self.shape)
+ self._validate_index(key)
+ if isinstance(key, Array):
+ # Indexing self._array with array_api arrays can be erroneous
+ key = key._array
res = self._array.__getitem__(key)
return self._new(res)
@@ -698,7 +721,10 @@ class Array:
"""
# Note: Only indices required by the spec are allowed. See the
# docstring of _validate_index
- key = self._validate_index(key, self.shape)
+ self._validate_index(key)
+ if isinstance(key, Array):
+ # Indexing self._array with array_api arrays can be erroneous
+ key = key._array
self._array.__setitem__(key, asarray(value)._array)
def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: