diff options
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/_array_object.py | 4 | ||||
| -rw-r--r-- | numpy/_array_api/tests/__init__.py | 7 | ||||
| -rw-r--r-- | numpy/_array_api/tests/test_array_object.py | 59 |
3 files changed, 69 insertions, 1 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py index 4e3c7b344..54280ef37 100644 --- a/numpy/_array_api/_array_object.py +++ b/numpy/_array_api/_array_object.py @@ -158,7 +158,9 @@ class Array: allowed by NumPy but not required by the array API specification. We always raise ``IndexError`` on such indices (the spec does not require any specific behavior on them, but this makes the NumPy array API - namespace a minimal implementation of the spec). + namespace a minimal implementation of the spec). See + https://data-apis.org/array-api/latest/API_specification/indexing.html + for the full list of required indexing behavior This function either raises IndexError if the index ``key`` is invalid, or a new key to be used in place of ``key`` in indexing. It diff --git a/numpy/_array_api/tests/__init__.py b/numpy/_array_api/tests/__init__.py new file mode 100644 index 000000000..536062e38 --- /dev/null +++ b/numpy/_array_api/tests/__init__.py @@ -0,0 +1,7 @@ +""" +Tests for the array API namespace. + +Note, full compliance with the array API can be tested with the official array API test +suite https://github.com/data-apis/array-api-tests. This test suite primarily +focuses on those things that are not tested by the official test suite. +""" diff --git a/numpy/_array_api/tests/test_array_object.py b/numpy/_array_api/tests/test_array_object.py new file mode 100644 index 000000000..49ec3b37b --- /dev/null +++ b/numpy/_array_api/tests/test_array_object.py @@ -0,0 +1,59 @@ +from numpy.testing import assert_raises +import numpy as np + +from .. import ones, asarray + +def test_validate_index(): + # The indexing tests in the official array API test suite test that the + # array object correctly handles the subset of indices that are required + # by the spec. But the NumPy array API implementation specifically + # disallows any index not required by the spec, via Array._validate_index. + # This test focuses on testing that non-valid indices are correctly + # rejected. See + # https://data-apis.org/array-api/latest/API_specification/indexing.html + # and the docstring of Array._validate_index for the exact indexing + # behavior that should be allowed. This does not test indices that are + # already invalid in NumPy itself because Array will generally just pass + # such indices directly to the underlying np.ndarray. + + a = ones((3, 4)) + + # Out of bounds slices are not allowed + assert_raises(IndexError, lambda: a[:4]) + assert_raises(IndexError, lambda: a[:-4]) + assert_raises(IndexError, lambda: a[:3:-1]) + assert_raises(IndexError, lambda: a[:-5:-1]) + assert_raises(IndexError, lambda: a[3:]) + assert_raises(IndexError, lambda: a[-4:]) + assert_raises(IndexError, lambda: a[3::-1]) + assert_raises(IndexError, lambda: a[-4::-1]) + + assert_raises(IndexError, lambda: a[...,:5]) + assert_raises(IndexError, lambda: a[...,:-5]) + assert_raises(IndexError, lambda: a[...,:4:-1]) + assert_raises(IndexError, lambda: a[...,:-6:-1]) + assert_raises(IndexError, lambda: a[...,4:]) + assert_raises(IndexError, lambda: a[...,-5:]) + assert_raises(IndexError, lambda: a[...,4::-1]) + assert_raises(IndexError, lambda: a[...,-5::-1]) + + # Boolean indices cannot be part of a larger tuple index + assert_raises(IndexError, lambda: a[a[:,0]==1,0]) + assert_raises(IndexError, lambda: a[a[:,0]==1,...]) + assert_raises(IndexError, lambda: a[..., a[0]==1]) + assert_raises(IndexError, lambda: a[[True, True, True]]) + assert_raises(IndexError, lambda: a[(True, True, True),]) + + # Integer array indices are not allowed (except for 0-D) + idx = asarray([[0, 1]]) + assert_raises(IndexError, lambda: a[idx]) + assert_raises(IndexError, lambda: a[idx,]) + assert_raises(IndexError, lambda: a[[0, 1]]) + assert_raises(IndexError, lambda: a[(0, 1), (0, 1)]) + assert_raises(IndexError, lambda: a[[0, 1]]) + assert_raises(IndexError, lambda: a[np.array([[0, 1]])]) + + # np.newaxis is not allowed + assert_raises(IndexError, lambda: a[None]) + assert_raises(IndexError, lambda: a[None, ...]) + assert_raises(IndexError, lambda: a[..., None]) |
