From 18fe695dbc66df660039aca9f76a949de8b9348e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 Dec 2021 17:02:43 -0700 Subject: Add tests for T and mT in array_api --- numpy/array_api/tests/test_array_object.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'numpy/array_api/tests/test_array_object.py') diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 12479d765..deab50693 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -4,6 +4,7 @@ from numpy.testing import assert_raises import numpy as np from .. import ones, asarray, result_type, all, equal +from .._array_object import Array from .._dtypes import ( _all_dtypes, _boolean_dtypes, @@ -301,3 +302,16 @@ def test_device_property(): assert all(equal(asarray(a, device='cpu'), a)) assert_raises(ValueError, lambda: asarray(a, device='gpu')) + +def test_array_properties(): + a = ones((1, 2, 3)) + b = ones((2, 3)) + assert_raises(ValueError, lambda: a.T) + + assert isinstance(b.T, Array) + assert b.T.shape == (3, 2) + + assert isinstance(a.mT, Array) + assert a.mT.shape == (1, 3, 2) + assert isinstance(b.mT, Array) + assert b.mT.shape == (3, 2) -- cgit v1.2.1 From 74a3ee7a8b75bf6dc271c9a1a4b55d2ad9758420 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 Dec 2021 13:59:08 -0700 Subject: ENH: Add __array__ to the array_api Array object This is *NOT* part of the array API spec (so it should not be relied on for portable code). However, without this, np.asarray(np.array_api.Array) produces an object array instead of doing the conversion to a NumPy array as expected. This would work once np.asarray() implements dlpack support, but until then, it seems reasonable to make the conversion work. Note that the reverse, calling np.array_api.asarray(np.array), already works because np.array_api.asarray() is just a wrapper for np.asarray(). --- numpy/array_api/tests/test_array_object.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'numpy/array_api/tests/test_array_object.py') diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index deab50693..b980bacca 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -315,3 +315,10 @@ def test_array_properties(): assert a.mT.shape == (1, 3, 2) assert isinstance(b.mT, Array) assert b.mT.shape == (3, 2) + +def test___array__(): + a = ones((2, 3), dtype=int16) + assert np.asarray(a) is a._array + b = np.asarray(a, dtype=np.float64) + assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) + assert b.dtype == np.float64 -- cgit v1.2.1