diff options
Diffstat (limited to 'numpy/array_api')
-rw-r--r-- | numpy/array_api/_array_object.py | 212 | ||||
-rw-r--r-- | numpy/array_api/tests/test_array_object.py | 63 |
2 files changed, 176 insertions, 99 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 6cf9ec6f3..c4746fad9 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -29,7 +29,7 @@ from ._dtypes import ( _dtype_categories, ) -from typing import TYPE_CHECKING, Optional, Tuple, Union, Any +from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types if TYPE_CHECKING: @@ -243,8 +243,7 @@ class Array: # Note: A large fraction of allowed indices are disallowed here (see the # docstring below) - @staticmethod - def _validate_index(key, shape): + def _validate_index(self, key): """ Validate an index according to the array API. @@ -257,8 +256,7 @@ class Array: https://data-apis.org/array-api/latest/API_specification/indexing.html for the full list of required indexing behavior - This function either raises IndexError if the index ``key`` is - invalid, or a new key to be used in place of ``key`` in indexing. It + This function raises IndexError if the index ``key`` is invalid. It only raises ``IndexError`` on indices that are not already rejected by NumPy, as NumPy will already raise the appropriate error on such indices. ``shape`` may be None, in which case, only cases that are @@ -269,7 +267,7 @@ class Array: - Indices to not include an implicit ellipsis at the end. That is, every axis of an array must be explicitly indexed or an ellipsis - included. + included. This behaviour is sometimes referred to as flat indexing. - The start and stop of a slice may not be out of bounds. In particular, for a slice ``i:j:k`` on an axis of size ``n``, only the @@ -292,100 +290,122 @@ class Array: ``Array._new`` constructor, not this function. """ - if isinstance(key, slice): - if shape is None: - return key - if shape == (): - return key - if len(shape) > 1: + _key = key if isinstance(key, tuple) else (key,) + for i in _key: + if isinstance(i, bool) or not ( + isinstance(i, SupportsIndex) # i.e. ints + or isinstance(i, slice) + or i == Ellipsis + or i is None + or isinstance(i, Array) + or isinstance(i, np.ndarray) + ): raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"Single-axes index {i} has {type(i)=}, but only " + "integers, slices (:), ellipsis (...), newaxis (None), " + "zero-dimensional integer arrays and boolean arrays " + "are specified in the Array API." ) - size = shape[0] - # Ensure invalid slice entries are passed through. - if key.start is not None: - try: - operator.index(key.start) - except TypeError: - return key - if not (-size <= key.start <= size): - raise IndexError( - "Slices with out-of-bounds start are not allowed in the array API namespace" - ) - if key.stop is not None: - try: - operator.index(key.stop) - except TypeError: - return key - step = 1 if key.step is None else key.step - if (step > 0 and not (-size <= key.stop <= size) - or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))): - raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace") - return key - - elif isinstance(key, tuple): - key = tuple(Array._validate_index(idx, None) for idx in key) - - for idx in key: - if ( - isinstance(idx, np.ndarray) - and idx.dtype in _boolean_dtypes - or isinstance(idx, (bool, np.bool_)) - ): - if len(key) == 1: - return key - raise IndexError( - "Boolean array indices combined with other indices are not allowed in the array API namespace" - ) - if isinstance(idx, tuple): - raise IndexError( - "Nested tuple indices are not allowed in the array API namespace" - ) - - if shape is None: - return key - n_ellipsis = key.count(...) - if n_ellipsis > 1: - return key - ellipsis_i = key.index(...) if n_ellipsis else len(key) - for idx, size in list(zip(key[:ellipsis_i], shape)) + list( - zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1]) - ): - Array._validate_index(idx, (size,)) - if n_ellipsis == 0 and len(key) < len(shape): + nonexpanding_key = [] + single_axes = [] + n_ellipsis = 0 + key_has_mask = False + for i in _key: + if i is not None: + nonexpanding_key.append(i) + if isinstance(i, Array) or isinstance(i, np.ndarray): + if i.dtype in _boolean_dtypes: + key_has_mask = True + single_axes.append(i) + else: + # i must not be an array here, to avoid elementwise equals + if i == Ellipsis: + n_ellipsis += 1 + else: + single_axes.append(i) + + n_single_axes = len(single_axes) + if n_ellipsis > 1: + return # handled by ndarray + elif n_ellipsis == 0: + # Note boolean masks must be the sole index, which we check for + # later on. + if not key_has_mask and n_single_axes < self.ndim: raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"{self.ndim=}, but the multi-axes index only specifies " + f"{n_single_axes} dimensions. If this was intentional, " + "add a trailing ellipsis (...) which expands into as many " + "slices (:) as necessary - this is what np.ndarray arrays " + "implicitly do, but such flat indexing behaviour is not " + "specified in the Array API." ) - return key - elif isinstance(key, bool): - return key - elif isinstance(key, Array): - if key.dtype in _integer_dtypes: - if key.ndim != 0: + + if n_ellipsis == 0: + indexed_shape = self.shape + else: + ellipsis_start = None + for pos, i in enumerate(nonexpanding_key): + if not (isinstance(i, Array) or isinstance(i, np.ndarray)): + if i == Ellipsis: + ellipsis_start = pos + break + assert ellipsis_start is not None # sanity check + ellipsis_end = self.ndim - (n_single_axes - ellipsis_start) + indexed_shape = ( + self.shape[:ellipsis_start] + self.shape[ellipsis_end:] + ) + for i, side in zip(single_axes, indexed_shape): + if isinstance(i, slice): + if side == 0: + f_range = "0 (or None)" + else: + f_range = f"between -{side} and {side - 1} (or None)" + if i.start is not None: + try: + start = operator.index(i.start) + except TypeError: + pass # handled by ndarray + else: + if not (-side <= start <= side): + raise IndexError( + f"Slice {i} contains {start=}, but should be " + f"{f_range} for an axis of size {side} " + "(out-of-bounds starts are not specified in " + "the Array API)" + ) + if i.stop is not None: + try: + stop = operator.index(i.stop) + except TypeError: + pass # handled by ndarray + else: + if not (-side <= stop <= side): + raise IndexError( + f"Slice {i} contains {stop=}, but should be " + f"{f_range} for an axis of size {side} " + "(out-of-bounds stops are not specified in " + "the Array API)" + ) + elif isinstance(i, Array): + if i.dtype in _boolean_dtypes and len(_key) != 1: + assert isinstance(key, tuple) # sanity check raise IndexError( - "Non-zero dimensional integer array indices are not allowed in the array API namespace" + f"Single-axes index {i} is a boolean array and " + f"{len(key)=}, but masking is only specified in the " + "Array API when the array is the sole index." ) - return key._array - elif key is Ellipsis: - return key - elif key is None: - raise IndexError( - "newaxis indices are not allowed in the array API namespace" - ) - try: - key = operator.index(key) - if shape is not None and len(shape) > 1: + elif i.dtype in _integer_dtypes and i.ndim != 0: + raise IndexError( + f"Single-axes index {i} is a non-zero-dimensional " + "integer array, but advanced integer indexing is not " + "specified in the Array API." + ) + elif isinstance(i, tuple): raise IndexError( - "Multidimensional arrays must include an index for every axis or use an ellipsis" + f"Single-axes index {i} is a tuple, but nested tuple " + "indices are not specified in the Array API." ) - return key - except TypeError: - # Note: This also omits boolean arrays that are not already in - # Array() form, like a list of booleans. - raise IndexError( - "Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace" - ) # Everything below this line is required by the spec. @@ -511,7 +531,10 @@ class Array: """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - key = self._validate_index(key, self.shape) + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array res = self._array.__getitem__(key) return self._new(res) @@ -698,7 +721,10 @@ class Array: """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - key = self._validate_index(key, self.shape) + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array self._array.__setitem__(key, asarray(value)._array) def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 1fe1dfddf..ba9223532 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -2,8 +2,9 @@ import operator from numpy.testing import assert_raises import numpy as np +import pytest -from .. import ones, asarray, result_type, all, equal +from .. import ones, asarray, reshape, result_type, all, equal from .._array_object import Array from .._dtypes import ( _all_dtypes, @@ -17,6 +18,7 @@ from .._dtypes import ( int32, int64, uint64, + bool as bool_, ) @@ -70,11 +72,6 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[np.array([[0, 1]])]) - # np.newaxis is not allowed - assert_raises(IndexError, lambda: a[None]) - assert_raises(IndexError, lambda: a[None, ...]) - assert_raises(IndexError, lambda: a[..., None]) - # Multiaxis indices must contain exactly as many indices as dimensions assert_raises(IndexError, lambda: a[()]) assert_raises(IndexError, lambda: a[0,]) @@ -322,3 +319,57 @@ def test___array__(): b = np.asarray(a, dtype=np.float64) assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) assert b.dtype == np.float64 + +def test_allow_newaxis(): + a = ones(5) + indexed_a = a[None, :] + assert indexed_a.shape == (1, 5) + +def test_disallow_flat_indexing_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, 0, 0] + +def test_disallow_mask_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, asarray(True)] + +@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) +@pytest.mark.parametrize("index", ["string", False, True]) +def test_error_on_invalid_index(shape, index): + a = ones(shape) + with pytest.raises(IndexError): + a[index] + +def test_mask_0d_array_without_errors(): + a = ones(()) + a[asarray(True)] + +@pytest.mark.parametrize( + "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])] +) +def test_error_on_invalid_index_with_ellipsis(i): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[..., i] + with pytest.raises(IndexError): + a[i, ...] + +def test_array_keys_use_private_array(): + """ + Indexing operations convert array keys before indexing the internal array + + Fails when array_api array keys are not converted into NumPy-proper arrays + in __getitem__(). This is achieved by passing array_api arrays with 0-sized + dimensions, which NumPy-proper treats erroneously - not sure why! + + TODO: Find and use appropiate __setitem__() case. + """ + a = ones((0, 0), dtype=bool_) + assert a[a].shape == (0,) + + a = ones((0,), dtype=bool_) + key = ones((0, 0), dtype=bool_) + with pytest.raises(IndexError): + a[key] |