diff options
Diffstat (limited to 'numpy/array_api')
| -rw-r--r-- | numpy/array_api/_array_object.py | 2 | ||||
| -rw-r--r-- | numpy/array_api/tests/test_array_object.py | 14 |
2 files changed, 15 insertions, 1 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index 8794c5ea5..ead061882 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -1072,4 +1072,4 @@ class Array: # https://data-apis.org/array-api/latest/API_specification/array_object.html#t if self.ndim != 2: raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.") - return self._array.T + return self.__class__._new(self._array.T) diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index 12479d765..deab50693 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -4,6 +4,7 @@ from numpy.testing import assert_raises import numpy as np from .. import ones, asarray, result_type, all, equal +from .._array_object import Array from .._dtypes import ( _all_dtypes, _boolean_dtypes, @@ -301,3 +302,16 @@ def test_device_property(): assert all(equal(asarray(a, device='cpu'), a)) assert_raises(ValueError, lambda: asarray(a, device='gpu')) + +def test_array_properties(): + a = ones((1, 2, 3)) + b = ones((2, 3)) + assert_raises(ValueError, lambda: a.T) + + assert isinstance(b.T, Array) + assert b.T.shape == (3, 2) + + assert isinstance(a.mT, Array) + assert a.mT.shape == (1, 3, 2) + assert isinstance(b.mT, Array) + assert b.mT.shape == (3, 2) |
