diff options
author | Matthew Barber <quitesimplymatt@gmail.com> | 2022-05-06 09:27:27 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-06 11:27:27 +0200 |
commit | befef7b26773eddd2b656a3ab87f504e6cc173db (patch) | |
tree | 95c9ea72750052d759f87f35abe43019f0478b71 /numpy/array_api/_array_object.py | |
parent | ba54f569cecf17812695f17812d238af2bb91000 (diff) | |
download | numpy-befef7b26773eddd2b656a3ab87f504e6cc173db.tar.gz |
API: Allow newaxis indexing for `array_api` arrays (#21377)
* TST: Add test checking if newaxis indexing works for `array_api`
Also removes previous check against newaxis indexing, which is now outdated
* TST, BUG: Allow `None` in `array_api` indexing
Introduces test for validating flat indexing when `None` is present
* MAINT,DOC,TST: Rework of `_validate_index()` in `numpy.array_api`
_validate_index() is now called as self._validate_index(shape), and does not
return a key. This rework removes the recursive pattern used. Tests are
introduced to cover some edge cases. Additionally, its internal docstring
reflects new behaviour, and extends the flat indexing note.
* MAINT: `advance` -> `advanced` (integer indexing)
Co-authored-by: Aaron Meurer <asmeurer@gmail.com>
* BUG: array_api arrays use internal arrays from array_api array keys
When an array_api array is passed as the key for get/setitem, we access the
key's internal np.ndarray array to be used as the key for the internal
get/setitem operation. This behaviour was initially removed when
`_validate_index()` was reworked.
* MAINT: Better flat indexing error message for `array_api` arrays
Also better semantics for its prior ellipsis count condition
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
* MAINT: `array_api` arrays don't special case multi-ellipsis errors
This gets handled by NumPy-proper.
Co-authored-by: Aaron Meurer <asmeurer@gmail.com>
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r-- | numpy/array_api/_array_object.py | 212 |
1 files changed, 119 insertions, 93 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 6cf9ec6f3..c4746fad9 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -29,7 +29,7 @@ from ._dtypes import ( _dtype_categories, ) -from typing import TYPE_CHECKING, Optional, Tuple, Union, Any +from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types if TYPE_CHECKING: @@ -243,8 +243,7 @@ class Array: # Note: A large fraction of allowed indices are disallowed here (see the # docstring below) - @staticmethod - def _validate_index(key, shape): + def _validate_index(self, key): """ Validate an index according to the array API. @@ -257,8 +256,7 @@ class Array: https://data-apis.org/array-api/latest/API_specification/indexing.html for the full list of required indexing behavior - This function either raises IndexError if the index ``key`` is - invalid, or a new key to be used in place of ``key`` in indexing. It + This function raises IndexError if the index ``key`` is invalid. It only raises ``IndexError`` on indices that are not already rejected by NumPy, as NumPy will already raise the appropriate error on such indices. ``shape`` may be None, in which case, only cases that are @@ -269,7 +267,7 @@ class Array: - Indices to not include an implicit ellipsis at the end. That is, every axis of an array must be explicitly indexed or an ellipsis - included. + included. This behaviour is sometimes referred to as flat indexing. - The start and stop of a slice may not be out of bounds. In particular, for a slice ``i:j:k`` on an axis of size ``n``, only the @@ -292,100 +290,122 @@ class Array: ``Array._new`` constructor, not this function. """ - if isinstance(key, slice): - if shape is None: - return key - if shape == (): - return key - if len(shape) > 1: + _key = key if isinstance(key, tuple) else (key,) + for i in _key: + if isinstance(i, bool) or not ( + isinstance(i, SupportsIndex) # i.e. ints + or isinstance(i, slice) + or i == Ellipsis + or i is None + or isinstance(i, Array) + or isinstance(i, np.ndarray) + ): raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"Single-axes index {i} has {type(i)=}, but only " + "integers, slices (:), ellipsis (...), newaxis (None), " + "zero-dimensional integer arrays and boolean arrays " + "are specified in the Array API." ) - size = shape[0] - # Ensure invalid slice entries are passed through. - if key.start is not None: - try: - operator.index(key.start) - except TypeError: - return key - if not (-size <= key.start <= size): - raise IndexError( - "Slices with out-of-bounds start are not allowed in the array API namespace" - ) - if key.stop is not None: - try: - operator.index(key.stop) - except TypeError: - return key - step = 1 if key.step is None else key.step - if (step > 0 and not (-size <= key.stop <= size) - or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))): - raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace") - return key - - elif isinstance(key, tuple): - key = tuple(Array._validate_index(idx, None) for idx in key) - - for idx in key: - if ( - isinstance(idx, np.ndarray) - and idx.dtype in _boolean_dtypes - or isinstance(idx, (bool, np.bool_)) - ): - if len(key) == 1: - return key - raise IndexError( - "Boolean array indices combined with other indices are not allowed in the array API namespace" - ) - if isinstance(idx, tuple): - raise IndexError( - "Nested tuple indices are not allowed in the array API namespace" - ) - - if shape is None: - return key - n_ellipsis = key.count(...) - if n_ellipsis > 1: - return key - ellipsis_i = key.index(...) if n_ellipsis else len(key) - for idx, size in list(zip(key[:ellipsis_i], shape)) + list( - zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1]) - ): - Array._validate_index(idx, (size,)) - if n_ellipsis == 0 and len(key) < len(shape): + nonexpanding_key = [] + single_axes = [] + n_ellipsis = 0 + key_has_mask = False + for i in _key: + if i is not None: + nonexpanding_key.append(i) + if isinstance(i, Array) or isinstance(i, np.ndarray): + if i.dtype in _boolean_dtypes: + key_has_mask = True + single_axes.append(i) + else: + # i must not be an array here, to avoid elementwise equals + if i == Ellipsis: + n_ellipsis += 1 + else: + single_axes.append(i) + + n_single_axes = len(single_axes) + if n_ellipsis > 1: + return # handled by ndarray + elif n_ellipsis == 0: + # Note boolean masks must be the sole index, which we check for + # later on. + if not key_has_mask and n_single_axes < self.ndim: raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"{self.ndim=}, but the multi-axes index only specifies " + f"{n_single_axes} dimensions. If this was intentional, " + "add a trailing ellipsis (...) which expands into as many " + "slices (:) as necessary - this is what np.ndarray arrays " + "implicitly do, but such flat indexing behaviour is not " + "specified in the Array API." ) - return key - elif isinstance(key, bool): - return key - elif isinstance(key, Array): - if key.dtype in _integer_dtypes: - if key.ndim != 0: + + if n_ellipsis == 0: + indexed_shape = self.shape + else: + ellipsis_start = None + for pos, i in enumerate(nonexpanding_key): + if not (isinstance(i, Array) or isinstance(i, np.ndarray)): + if i == Ellipsis: + ellipsis_start = pos + break + assert ellipsis_start is not None # sanity check + ellipsis_end = self.ndim - (n_single_axes - ellipsis_start) + indexed_shape = ( + self.shape[:ellipsis_start] + self.shape[ellipsis_end:] + ) + for i, side in zip(single_axes, indexed_shape): + if isinstance(i, slice): + if side == 0: + f_range = "0 (or None)" + else: + f_range = f"between -{side} and {side - 1} (or None)" + if i.start is not None: + try: + start = operator.index(i.start) + except TypeError: + pass # handled by ndarray + else: + if not (-side <= start <= side): + raise IndexError( + f"Slice {i} contains {start=}, but should be " + f"{f_range} for an axis of size {side} " + "(out-of-bounds starts are not specified in " + "the Array API)" + ) + if i.stop is not None: + try: + stop = operator.index(i.stop) + except TypeError: + pass # handled by ndarray + else: + if not (-side <= stop <= side): + raise IndexError( + f"Slice {i} contains {stop=}, but should be " + f"{f_range} for an axis of size {side} " + "(out-of-bounds stops are not specified in " + "the Array API)" + ) + elif isinstance(i, Array): + if i.dtype in _boolean_dtypes and len(_key) != 1: + assert isinstance(key, tuple) # sanity check raise IndexError( - "Non-zero dimensional integer array indices are not allowed in the array API namespace" + f"Single-axes index {i} is a boolean array and " + f"{len(key)=}, but masking is only specified in the " + "Array API when the array is the sole index." ) - return key._array - elif key is Ellipsis: - return key - elif key is None: - raise IndexError( - "newaxis indices are not allowed in the array API namespace" - ) - try: - key = operator.index(key) - if shape is not None and len(shape) > 1: + elif i.dtype in _integer_dtypes and i.ndim != 0: + raise IndexError( + f"Single-axes index {i} is a non-zero-dimensional " + "integer array, but advanced integer indexing is not " + "specified in the Array API." + ) + elif isinstance(i, tuple): raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"Single-axes index {i} is a tuple, but nested tuple " + "indices are not specified in the Array API." ) - return key - except TypeError: - # Note: This also omits boolean arrays that are not already in - # Array() form, like a list of booleans. - raise IndexError( - "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace" - ) # Everything below this line is required by the spec. @@ -511,7 +531,10 @@ class Array: """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - key = self._validate_index(key, self.shape) + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array res = self._array.__getitem__(key) return self._new(res) @@ -698,7 +721,10 @@ class Array: """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - key = self._validate_index(key, self.shape) + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array self._array.__setitem__(key, asarray(value)._array) def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: |