summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
authorThomas Green <tomgreen66@hotmail.com>2021-12-08 11:57:10 +0000
committerGitHub <noreply@github.com>2021-12-08 11:57:10 +0000
commitdc766fc1abb546ab883f76ef4e405e99e9287ab6 (patch)
tree9e7c7748ba8bfbb2ba5224633b0725909712d2fa /numpy/array_api/_array_object.py
parent1cfdac82ac793061d8ca4b07c046fc6b21ee7e54 (diff)
parentab7a1927353ab9dd52e3f2f7a1a889ae790667b9 (diff)
downloadnumpy-dc766fc1abb546ab883f76ef4e405e99e9287ab6.tar.gz
Merge branch 'numpy:main' into armcompiler
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 8794c5ea5..75baf34b0 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -33,6 +33,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
from ._typing import Any, PyCapsule, Device, Dtype
+ import numpy.typing as npt
import numpy as np
@@ -108,6 +109,17 @@ class Array:
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix
+ # This function is not required by the spec, but we implement it here for
+ # convenience so that np.asarray(np.array_api.Array) will work.
+ def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
+ """
+ Warning: this method is NOT part of the array API spec. Implementers
+ of other libraries need not include it, and users should not assume it
+ will be present in other implementations.
+
+ """
+ return np.asarray(self._array, dtype=dtype)
+
# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
# NumPy behavior
@@ -1072,4 +1084,4 @@ class Array:
# https://data-apis.org/array-api/latest/API_specification/array_object.html#t
if self.ndim != 2:
raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.")
- return self._array.T
+ return self.__class__._new(self._array.T)