summaryrefslogtreecommitdiff
path: root/numpy/array_api/tests
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
committerRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
commiteccb8dfbd9b07183e16a1144e8d5d76936671bfc (patch)
tree647a9477b4f3b8b7205f2f7f2feb99eaa482e806 /numpy/array_api/tests
parentd0d75f39f28ac26d4cc1aa3a4cbea63a6a027929 (diff)
parentff2e2a1e7eea29d925063b13922e096d14331222 (diff)
downloadnumpy-eccb8dfbd9b07183e16a1144e8d5d76936671bfc.tar.gz
Merge branch 'main' into never_copy
Diffstat (limited to 'numpy/array_api/tests')
-rw-r--r--numpy/array_api/tests/test_array_object.py28
-rw-r--r--numpy/array_api/tests/test_creation_functions.py10
2 files changed, 32 insertions, 6 deletions
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index 7959f92b4..12479d765 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -3,7 +3,7 @@ import operator
from numpy.testing import assert_raises
import numpy as np
-from .. import ones, asarray, result_type
+from .. import ones, asarray, result_type, all, equal
from .._dtypes import (
_all_dtypes,
_boolean_dtypes,
@@ -39,18 +39,18 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[:-4])
assert_raises(IndexError, lambda: a[:3:-1])
assert_raises(IndexError, lambda: a[:-5:-1])
- assert_raises(IndexError, lambda: a[3:])
+ assert_raises(IndexError, lambda: a[4:])
assert_raises(IndexError, lambda: a[-4:])
- assert_raises(IndexError, lambda: a[3::-1])
+ assert_raises(IndexError, lambda: a[4::-1])
assert_raises(IndexError, lambda: a[-4::-1])
assert_raises(IndexError, lambda: a[...,:5])
assert_raises(IndexError, lambda: a[...,:-5])
- assert_raises(IndexError, lambda: a[...,:4:-1])
+ assert_raises(IndexError, lambda: a[...,:5:-1])
assert_raises(IndexError, lambda: a[...,:-6:-1])
- assert_raises(IndexError, lambda: a[...,4:])
+ assert_raises(IndexError, lambda: a[...,5:])
assert_raises(IndexError, lambda: a[...,-5:])
- assert_raises(IndexError, lambda: a[...,4::-1])
+ assert_raises(IndexError, lambda: a[...,5::-1])
assert_raises(IndexError, lambda: a[...,-5::-1])
# Boolean indices cannot be part of a larger tuple index
@@ -74,6 +74,11 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[None, ...])
assert_raises(IndexError, lambda: a[..., None])
+ # Multiaxis indices must contain exactly as many indices as dimensions
+ assert_raises(IndexError, lambda: a[()])
+ assert_raises(IndexError, lambda: a[0,])
+ assert_raises(IndexError, lambda: a[0])
+ assert_raises(IndexError, lambda: a[:])
def test_operators():
# For every operator, we test that it works for the required type
@@ -285,3 +290,14 @@ def test_python_scalar_construtors():
assert_raises(TypeError, lambda: operator.index(b))
assert_raises(TypeError, lambda: operator.index(f))
+
+
+def test_device_property():
+ a = ones((3, 4))
+ assert a.device == 'cpu'
+
+ assert all(equal(a.to_device('cpu'), a))
+ assert_raises(ValueError, lambda: a.to_device('gpu'))
+
+ assert all(equal(asarray(a, device='cpu'), a))
+ assert_raises(ValueError, lambda: asarray(a, device='gpu'))
diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py
index c13bc4262..be9eaa383 100644
--- a/numpy/array_api/tests/test_creation_functions.py
+++ b/numpy/array_api/tests/test_creation_functions.py
@@ -11,11 +11,13 @@ from .._creation_functions import (
full,
full_like,
linspace,
+ meshgrid,
ones,
ones_like,
zeros,
zeros_like,
)
+from .._dtypes import float32, float64
from .._array_object import Array
@@ -130,3 +132,11 @@ def test_zeros_like_errors():
assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu"))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))
+
+def test_meshgrid_dtype_errors():
+ # Doesn't raise
+ meshgrid()
+ meshgrid(asarray([1.], dtype=float32))
+ meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32))
+
+ assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64)))