summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-19 17:25:02 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-19 17:25:02 -0600
commitf20be6ad3239a2e7a611ad42c9b36df7863e9883 (patch)
tree84cc71b4f1ec0c6934d311ae0c32d1667345a5dc /numpy/_array_api
parenta566cd1c7110d36d0e7a1f2746ea61e45f49eb89 (diff)
downloadnumpy-f20be6ad3239a2e7a611ad42c9b36df7863e9883.tar.gz
Start adding tests for the array API submodule
The tests for the module will mostly focus on those things that aren't already tested by the official array API test suite (https://github.com/data-apis/array-api-tests). Currently, indexing tests are added, which test that the Array object correctly rejects otherwise valid indices that are not required by the array API spec.
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_array_object.py4
-rw-r--r--numpy/_array_api/tests/__init__.py7
-rw-r--r--numpy/_array_api/tests/test_array_object.py59
3 files changed, 69 insertions, 1 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index 4e3c7b344..54280ef37 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -158,7 +158,9 @@ class Array:
allowed by NumPy but not required by the array API specification. We
always raise ``IndexError`` on such indices (the spec does not require
any specific behavior on them, but this makes the NumPy array API
- namespace a minimal implementation of the spec).
+ namespace a minimal implementation of the spec). See
+ https://data-apis.org/array-api/latest/API_specification/indexing.html
+ for the full list of required indexing behavior
This function either raises IndexError if the index ``key`` is
invalid, or a new key to be used in place of ``key`` in indexing. It
diff --git a/numpy/_array_api/tests/__init__.py b/numpy/_array_api/tests/__init__.py
new file mode 100644
index 000000000..536062e38
--- /dev/null
+++ b/numpy/_array_api/tests/__init__.py
@@ -0,0 +1,7 @@
+"""
+Tests for the array API namespace.
+
+Note, full compliance with the array API can be tested with the official array API test
+suite https://github.com/data-apis/array-api-tests. This test suite primarily
+focuses on those things that are not tested by the official test suite.
+"""
diff --git a/numpy/_array_api/tests/test_array_object.py b/numpy/_array_api/tests/test_array_object.py
new file mode 100644
index 000000000..49ec3b37b
--- /dev/null
+++ b/numpy/_array_api/tests/test_array_object.py
@@ -0,0 +1,59 @@
+from numpy.testing import assert_raises
+import numpy as np
+
+from .. import ones, asarray
+
+def test_validate_index():
+ # The indexing tests in the official array API test suite test that the
+ # array object correctly handles the subset of indices that are required
+ # by the spec. But the NumPy array API implementation specifically
+ # disallows any index not required by the spec, via Array._validate_index.
+ # This test focuses on testing that non-valid indices are correctly
+ # rejected. See
+ # https://data-apis.org/array-api/latest/API_specification/indexing.html
+ # and the docstring of Array._validate_index for the exact indexing
+ # behavior that should be allowed. This does not test indices that are
+ # already invalid in NumPy itself because Array will generally just pass
+ # such indices directly to the underlying np.ndarray.
+
+ a = ones((3, 4))
+
+ # Out of bounds slices are not allowed
+ assert_raises(IndexError, lambda: a[:4])
+ assert_raises(IndexError, lambda: a[:-4])
+ assert_raises(IndexError, lambda: a[:3:-1])
+ assert_raises(IndexError, lambda: a[:-5:-1])
+ assert_raises(IndexError, lambda: a[3:])
+ assert_raises(IndexError, lambda: a[-4:])
+ assert_raises(IndexError, lambda: a[3::-1])
+ assert_raises(IndexError, lambda: a[-4::-1])
+
+ assert_raises(IndexError, lambda: a[...,:5])
+ assert_raises(IndexError, lambda: a[...,:-5])
+ assert_raises(IndexError, lambda: a[...,:4:-1])
+ assert_raises(IndexError, lambda: a[...,:-6:-1])
+ assert_raises(IndexError, lambda: a[...,4:])
+ assert_raises(IndexError, lambda: a[...,-5:])
+ assert_raises(IndexError, lambda: a[...,4::-1])
+ assert_raises(IndexError, lambda: a[...,-5::-1])
+
+ # Boolean indices cannot be part of a larger tuple index
+ assert_raises(IndexError, lambda: a[a[:,0]==1,0])
+ assert_raises(IndexError, lambda: a[a[:,0]==1,...])
+ assert_raises(IndexError, lambda: a[..., a[0]==1])
+ assert_raises(IndexError, lambda: a[[True, True, True]])
+ assert_raises(IndexError, lambda: a[(True, True, True),])
+
+ # Integer array indices are not allowed (except for 0-D)
+ idx = asarray([[0, 1]])
+ assert_raises(IndexError, lambda: a[idx])
+ assert_raises(IndexError, lambda: a[idx,])
+ assert_raises(IndexError, lambda: a[[0, 1]])
+ assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
+ assert_raises(IndexError, lambda: a[[0, 1]])
+ assert_raises(IndexError, lambda: a[np.array([[0, 1]])])
+
+ # np.newaxis is not allowed
+ assert_raises(IndexError, lambda: a[None])
+ assert_raises(IndexError, lambda: a[None, ...])
+ assert_raises(IndexError, lambda: a[..., None])