summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_array_object.py4
-rw-r--r--numpy/_array_api/tests/__init__.py7
-rw-r--r--numpy/_array_api/tests/test_array_object.py59
3 files changed, 69 insertions, 1 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index 4e3c7b344..54280ef37 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -158,7 +158,9 @@ class Array:
allowed by NumPy but not required by the array API specification. We
always raise ``IndexError`` on such indices (the spec does not require
any specific behavior on them, but this makes the NumPy array API
- namespace a minimal implementation of the spec).
+ namespace a minimal implementation of the spec). See
+ https://data-apis.org/array-api/latest/API_specification/indexing.html
+ for the full list of required indexing behavior
This function either raises IndexError if the index ``key`` is
invalid, or a new key to be used in place of ``key`` in indexing. It
diff --git a/numpy/_array_api/tests/__init__.py b/numpy/_array_api/tests/__init__.py
new file mode 100644
index 000000000..536062e38
--- /dev/null
+++ b/numpy/_array_api/tests/__init__.py
@@ -0,0 +1,7 @@
+"""
+Tests for the array API namespace.
+
+Note, full compliance with the array API can be tested with the official array API test
+suite https://github.com/data-apis/array-api-tests. This test suite primarily
+focuses on those things that are not tested by the official test suite.
+"""
diff --git a/numpy/_array_api/tests/test_array_object.py b/numpy/_array_api/tests/test_array_object.py
new file mode 100644
index 000000000..49ec3b37b
--- /dev/null
+++ b/numpy/_array_api/tests/test_array_object.py
@@ -0,0 +1,59 @@
+from numpy.testing import assert_raises
+import numpy as np
+
+from .. import ones, asarray
+
+def test_validate_index():
+ # The indexing tests in the official array API test suite test that the
+ # array object correctly handles the subset of indices that are required
+ # by the spec. But the NumPy array API implementation specifically
+ # disallows any index not required by the spec, via Array._validate_index.
+ # This test focuses on testing that non-valid indices are correctly
+ # rejected. See
+ # https://data-apis.org/array-api/latest/API_specification/indexing.html
+ # and the docstring of Array._validate_index for the exact indexing
+ # behavior that should be allowed. This does not test indices that are
+ # already invalid in NumPy itself because Array will generally just pass
+ # such indices directly to the underlying np.ndarray.
+
+ a = ones((3, 4))
+
+ # Out of bounds slices are not allowed
+ assert_raises(IndexError, lambda: a[:4])
+ assert_raises(IndexError, lambda: a[:-4])
+ assert_raises(IndexError, lambda: a[:3:-1])
+ assert_raises(IndexError, lambda: a[:-5:-1])
+ assert_raises(IndexError, lambda: a[3:])
+ assert_raises(IndexError, lambda: a[-4:])
+ assert_raises(IndexError, lambda: a[3::-1])
+ assert_raises(IndexError, lambda: a[-4::-1])
+
+ assert_raises(IndexError, lambda: a[...,:5])
+ assert_raises(IndexError, lambda: a[...,:-5])
+ assert_raises(IndexError, lambda: a[...,:4:-1])
+ assert_raises(IndexError, lambda: a[...,:-6:-1])
+ assert_raises(IndexError, lambda: a[...,4:])
+ assert_raises(IndexError, lambda: a[...,-5:])
+ assert_raises(IndexError, lambda: a[...,4::-1])
+ assert_raises(IndexError, lambda: a[...,-5::-1])
+
+ # Boolean indices cannot be part of a larger tuple index
+ assert_raises(IndexError, lambda: a[a[:,0]==1,0])
+ assert_raises(IndexError, lambda: a[a[:,0]==1,...])
+ assert_raises(IndexError, lambda: a[..., a[0]==1])
+ assert_raises(IndexError, lambda: a[[True, True, True]])
+ assert_raises(IndexError, lambda: a[(True, True, True),])
+
+ # Integer array indices are not allowed (except for 0-D)
+ idx = asarray([[0, 1]])
+ assert_raises(IndexError, lambda: a[idx])
+ assert_raises(IndexError, lambda: a[idx,])
+ assert_raises(IndexError, lambda: a[[0, 1]])
+ assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
+ assert_raises(IndexError, lambda: a[[0, 1]])
+ assert_raises(IndexError, lambda: a[np.array([[0, 1]])])
+
+ # np.newaxis is not allowed
+ assert_raises(IndexError, lambda: a[None])
+ assert_raises(IndexError, lambda: a[None, ...])
+ assert_raises(IndexError, lambda: a[..., None])