diff options
author | Matthew Barber <quitesimplymatt@gmail.com> | 2022-05-06 09:27:27 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-06 11:27:27 +0200 |
commit | befef7b26773eddd2b656a3ab87f504e6cc173db (patch) | |
tree | 95c9ea72750052d759f87f35abe43019f0478b71 /numpy/array_api/tests/test_array_object.py | |
parent | ba54f569cecf17812695f17812d238af2bb91000 (diff) | |
download | numpy-befef7b26773eddd2b656a3ab87f504e6cc173db.tar.gz |
API: Allow newaxis indexing for `array_api` arrays (#21377)
* TST: Add test checking if newaxis indexing works for `array_api`
Also removes previous check against newaxis indexing, which is now outdated
* TST, BUG: Allow `None` in `array_api` indexing
Introduces test for validating flat indexing when `None` is present
* MAINT,DOC,TST: Rework of `_validate_index()` in `numpy.array_api`
_validate_index() is now called as self._validate_index(shape), and does not
return a key. This rework removes the recursive pattern used. Tests are
introduced to cover some edge cases. Additionally, its internal docstring
reflects new behaviour, and extends the flat indexing note.
* MAINT: `advance` -> `advanced` (integer indexing)
Co-authored-by: Aaron Meurer <asmeurer@gmail.com>
* BUG: array_api arrays use internal arrays from array_api array keys
When an array_api array is passed as the key for get/setitem, we access the
key's internal np.ndarray array to be used as the key for the internal
get/setitem operation. This behaviour was initially removed when
`_validate_index()` was reworked.
* MAINT: Better flat indexing error message for `array_api` arrays
Also better semantics for its prior ellipsis count condition
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
* MAINT: `array_api` arrays don't special case multi-ellipsis errors
This gets handled by NumPy-proper.
Co-authored-by: Aaron Meurer <asmeurer@gmail.com>
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
Diffstat (limited to 'numpy/array_api/tests/test_array_object.py')
-rw-r--r-- | numpy/array_api/tests/test_array_object.py | 63 |
1 files changed, 57 insertions, 6 deletions
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 1fe1dfddf..ba9223532 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -2,8 +2,9 @@ import operator from numpy.testing import assert_raises import numpy as np +import pytest -from .. import ones, asarray, result_type, all, equal +from .. import ones, asarray, reshape, result_type, all, equal from .._array_object import Array from .._dtypes import ( _all_dtypes, @@ -17,6 +18,7 @@ from .._dtypes import ( int32, int64, uint64, + bool as bool_, ) @@ -70,11 +72,6 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[np.array([[0, 1]])]) - # np.newaxis is not allowed - assert_raises(IndexError, lambda: a[None]) - assert_raises(IndexError, lambda: a[None, ...]) - assert_raises(IndexError, lambda: a[..., None]) - # Multiaxis indices must contain exactly as many indices as dimensions assert_raises(IndexError, lambda: a[()]) assert_raises(IndexError, lambda: a[0,]) @@ -322,3 +319,57 @@ def test___array__(): b = np.asarray(a, dtype=np.float64) assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) assert b.dtype == np.float64 + +def test_allow_newaxis(): + a = ones(5) + indexed_a = a[None, :] + assert indexed_a.shape == (1, 5) + +def test_disallow_flat_indexing_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, 0, 0] + +def test_disallow_mask_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, asarray(True)] + +@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) +@pytest.mark.parametrize("index", ["string", False, True]) +def test_error_on_invalid_index(shape, index): + a = ones(shape) + with pytest.raises(IndexError): + a[index] + +def test_mask_0d_array_without_errors(): + a = ones(()) + a[asarray(True)] + +@pytest.mark.parametrize( + "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])] +) +def test_error_on_invalid_index_with_ellipsis(i): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[..., i] + with pytest.raises(IndexError): + a[i, ...] + +def test_array_keys_use_private_array(): + """ + Indexing operations convert array keys before indexing the internal array + + Fails when array_api array keys are not converted into NumPy-proper arrays + in __getitem__(). This is achieved by passing array_api arrays with 0-sized + dimensions, which NumPy-proper treats erroneously - not sure why! + + TODO: Find and use appropiate __setitem__() case. + """ + a = ones((0, 0), dtype=bool_) + assert a[a].shape == (0,) + + a = ones((0,), dtype=bool_) + key = ones((0, 0), dtype=bool_) + with pytest.raises(IndexError): + a[key] |