summaryrefslogtreecommitdiff
path: root/numpy/array_api/tests/test_array_object.py
diff options
context:
space:
mode:
authorMatthew Barber <quitesimplymatt@gmail.com>2022-05-06 09:27:27 +0000
committerGitHub <noreply@github.com>2022-05-06 11:27:27 +0200
commitbefef7b26773eddd2b656a3ab87f504e6cc173db (patch)
tree95c9ea72750052d759f87f35abe43019f0478b71 /numpy/array_api/tests/test_array_object.py
parentba54f569cecf17812695f17812d238af2bb91000 (diff)
downloadnumpy-befef7b26773eddd2b656a3ab87f504e6cc173db.tar.gz
API: Allow newaxis indexing for `array_api` arrays (#21377)
* TST: Add test checking if newaxis indexing works for `array_api` Also removes previous check against newaxis indexing, which is now outdated * TST, BUG: Allow `None` in `array_api` indexing Introduces test for validating flat indexing when `None` is present * MAINT,DOC,TST: Rework of `_validate_index()` in `numpy.array_api` _validate_index() is now called as self._validate_index(shape), and does not return a key. This rework removes the recursive pattern used. Tests are introduced to cover some edge cases. Additionally, its internal docstring reflects new behaviour, and extends the flat indexing note. * MAINT: `advance` -> `advanced` (integer indexing) Co-authored-by: Aaron Meurer <asmeurer@gmail.com> * BUG: array_api arrays use internal arrays from array_api array keys When an array_api array is passed as the key for get/setitem, we access the key's internal np.ndarray array to be used as the key for the internal get/setitem operation. This behaviour was initially removed when `_validate_index()` was reworked. * MAINT: Better flat indexing error message for `array_api` arrays Also better semantics for its prior ellipsis count condition Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> * MAINT: `array_api` arrays don't special case multi-ellipsis errors This gets handled by NumPy-proper. Co-authored-by: Aaron Meurer <asmeurer@gmail.com> Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
Diffstat (limited to 'numpy/array_api/tests/test_array_object.py')
-rw-r--r--numpy/array_api/tests/test_array_object.py63
1 files changed, 57 insertions, 6 deletions
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index 1fe1dfddf..ba9223532 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -2,8 +2,9 @@ import operator
from numpy.testing import assert_raises
import numpy as np
+import pytest
-from .. import ones, asarray, result_type, all, equal
+from .. import ones, asarray, reshape, result_type, all, equal
from .._array_object import Array
from .._dtypes import (
_all_dtypes,
@@ -17,6 +18,7 @@ from .._dtypes import (
int32,
int64,
uint64,
+ bool as bool_,
)
@@ -70,11 +72,6 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[[0, 1]])
assert_raises(IndexError, lambda: a[np.array([[0, 1]])])
- # np.newaxis is not allowed
- assert_raises(IndexError, lambda: a[None])
- assert_raises(IndexError, lambda: a[None, ...])
- assert_raises(IndexError, lambda: a[..., None])
-
# Multiaxis indices must contain exactly as many indices as dimensions
assert_raises(IndexError, lambda: a[()])
assert_raises(IndexError, lambda: a[0,])
@@ -322,3 +319,57 @@ def test___array__():
b = np.asarray(a, dtype=np.float64)
assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
assert b.dtype == np.float64
+
+def test_allow_newaxis():
+ a = ones(5)
+ indexed_a = a[None, :]
+ assert indexed_a.shape == (1, 5)
+
+def test_disallow_flat_indexing_with_newaxis():
+ a = ones((3, 3, 3))
+ with pytest.raises(IndexError):
+ a[None, 0, 0]
+
+def test_disallow_mask_with_newaxis():
+ a = ones((3, 3, 3))
+ with pytest.raises(IndexError):
+ a[None, asarray(True)]
+
+@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
+@pytest.mark.parametrize("index", ["string", False, True])
+def test_error_on_invalid_index(shape, index):
+ a = ones(shape)
+ with pytest.raises(IndexError):
+ a[index]
+
+def test_mask_0d_array_without_errors():
+ a = ones(())
+ a[asarray(True)]
+
+@pytest.mark.parametrize(
+ "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])]
+)
+def test_error_on_invalid_index_with_ellipsis(i):
+ a = ones((3, 3, 3))
+ with pytest.raises(IndexError):
+ a[..., i]
+ with pytest.raises(IndexError):
+ a[i, ...]
+
+def test_array_keys_use_private_array():
+ """
+ Indexing operations convert array keys before indexing the internal array
+
+ Fails when array_api array keys are not converted into NumPy-proper arrays
+ in __getitem__(). This is achieved by passing array_api arrays with 0-sized
+ dimensions, which NumPy-proper treats erroneously - not sure why!
+
+ TODO: Find and use appropiate __setitem__() case.
+ """
+ a = ones((0, 0), dtype=bool_)
+ assert a[a].shape == (0,)
+
+ a = ones((0,), dtype=bool_)
+ key = ones((0, 0), dtype=bool_)
+ with pytest.raises(IndexError):
+ a[key]