summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-12-07 08:56:26 -0700
committerGitHub <noreply@github.com>2021-12-07 08:56:26 -0700
commitb5331ea9ece515d45e5a5adcb8e06117c9d33569 (patch)
treeb4664d6e82ece902843a7b3862be12b127a2e906 /numpy/array_api/_array_object.py
parent7ca1d1ad0dc86f0d20414c946eb2b6e8dc19c367 (diff)
parent5f21063cc317d92a866c7259a9509f5e5d6189c2 (diff)
downloadnumpy-b5331ea9ece515d45e5a5adcb8e06117c9d33569.tar.gz
Merge pull request #20527 from asmeurer/array_api-__array__
ENH: Add __array__ to the array_api Array object
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index ead061882..75baf34b0 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -33,6 +33,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
from ._typing import Any, PyCapsule, Device, Dtype
+ import numpy.typing as npt
import numpy as np
@@ -108,6 +109,17 @@ class Array:
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix
+ # This function is not required by the spec, but we implement it here for
+ # convenience so that np.asarray(np.array_api.Array) will work.
+ def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
+ """
+ Warning: this method is NOT part of the array API spec. Implementers
+ of other libraries need not include it, and users should not assume it
+ will be present in other implementations.
+
+ """
+ return np.asarray(self._array, dtype=dtype)
+
# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
# NumPy behavior