diff options
Diffstat (limited to 'numpy/array_api/tests/test_array_object.py')
-rw-r--r-- | numpy/array_api/tests/test_array_object.py | 63 |
1 files changed, 57 insertions, 6 deletions
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 1fe1dfddf..ba9223532 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -2,8 +2,9 @@ import operator from numpy.testing import assert_raises import numpy as np +import pytest -from .. import ones, asarray, result_type, all, equal +from .. import ones, asarray, reshape, result_type, all, equal from .._array_object import Array from .._dtypes import ( _all_dtypes, @@ -17,6 +18,7 @@ from .._dtypes import ( int32, int64, uint64, + bool as bool_, ) @@ -70,11 +72,6 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[np.array([[0, 1]])]) - # np.newaxis is not allowed - assert_raises(IndexError, lambda: a[None]) - assert_raises(IndexError, lambda: a[None, ...]) - assert_raises(IndexError, lambda: a[..., None]) - # Multiaxis indices must contain exactly as many indices as dimensions assert_raises(IndexError, lambda: a[()]) assert_raises(IndexError, lambda: a[0,]) @@ -322,3 +319,57 @@ def test___array__(): b = np.asarray(a, dtype=np.float64) assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) assert b.dtype == np.float64 + +def test_allow_newaxis(): + a = ones(5) + indexed_a = a[None, :] + assert indexed_a.shape == (1, 5) + +def test_disallow_flat_indexing_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, 0, 0] + +def test_disallow_mask_with_newaxis(): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[None, asarray(True)] + +@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) +@pytest.mark.parametrize("index", ["string", False, True]) +def test_error_on_invalid_index(shape, index): + a = ones(shape) + with pytest.raises(IndexError): + a[index] + +def test_mask_0d_array_without_errors(): + a = ones(()) + a[asarray(True)] + +@pytest.mark.parametrize( + "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])] +) +def test_error_on_invalid_index_with_ellipsis(i): + a = ones((3, 3, 3)) + with pytest.raises(IndexError): + a[..., i] + with pytest.raises(IndexError): + a[i, ...] + +def test_array_keys_use_private_array(): + """ + Indexing operations convert array keys before indexing the internal array + + Fails when array_api array keys are not converted into NumPy-proper arrays + in __getitem__(). This is achieved by passing array_api arrays with 0-sized + dimensions, which NumPy-proper treats erroneously - not sure why! + + TODO: Find and use appropiate __setitem__() case. + """ + a = ones((0, 0), dtype=bool_) + assert a[a].shape == (0,) + + a = ones((0,), dtype=bool_) + key = ones((0, 0), dtype=bool_) + with pytest.raises(IndexError): + a[key] |