summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
authorThomas Green <tomgreen66@hotmail.com>2021-12-08 11:57:10 +0000
committerGitHub <noreply@github.com>2021-12-08 11:57:10 +0000
commitdc766fc1abb546ab883f76ef4e405e99e9287ab6 (patch)
tree9e7c7748ba8bfbb2ba5224633b0725909712d2fa /numpy/array_api
parent1cfdac82ac793061d8ca4b07c046fc6b21ee7e54 (diff)
parentab7a1927353ab9dd52e3f2f7a1a889ae790667b9 (diff)
downloadnumpy-dc766fc1abb546ab883f76ef4e405e99e9287ab6.tar.gz
Merge branch 'numpy:main' into armcompiler
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_array_object.py14
-rw-r--r--numpy/array_api/_statistical_functions.py4
-rw-r--r--numpy/array_api/tests/test_array_object.py21
3 files changed, 36 insertions, 3 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 8794c5ea5..75baf34b0 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -33,6 +33,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Any
if TYPE_CHECKING:
from ._typing import Any, PyCapsule, Device, Dtype
+ import numpy.typing as npt
import numpy as np
@@ -108,6 +109,17 @@ class Array:
mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
return prefix + mid + suffix
+ # This function is not required by the spec, but we implement it here for
+ # convenience so that np.asarray(np.array_api.Array) will work.
+ def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
+ """
+ Warning: this method is NOT part of the array API spec. Implementers
+ of other libraries need not include it, and users should not assume it
+ will be present in other implementations.
+
+ """
+ return np.asarray(self._array, dtype=dtype)
+
# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
# NumPy behavior
@@ -1072,4 +1084,4 @@ class Array:
# https://data-apis.org/array-api/latest/API_specification/array_object.html#t
if self.ndim != 2:
raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.")
- return self._array.T
+ return self.__class__._new(self._array.T)
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py
index 7bee3f4db..5bc831ac2 100644
--- a/numpy/array_api/_statistical_functions.py
+++ b/numpy/array_api/_statistical_functions.py
@@ -65,8 +65,8 @@ def prod(
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
# We need to do so here before computing the product to avoid overflow
if dtype is None and x.dtype == float32:
- x = asarray(x, dtype=float64)
- return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims))
+ dtype = float64
+ return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
def std(
diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py
index 12479d765..b980bacca 100644
--- a/numpy/array_api/tests/test_array_object.py
+++ b/numpy/array_api/tests/test_array_object.py
@@ -4,6 +4,7 @@ from numpy.testing import assert_raises
import numpy as np
from .. import ones, asarray, result_type, all, equal
+from .._array_object import Array
from .._dtypes import (
_all_dtypes,
_boolean_dtypes,
@@ -301,3 +302,23 @@ def test_device_property():
assert all(equal(asarray(a, device='cpu'), a))
assert_raises(ValueError, lambda: asarray(a, device='gpu'))
+
+def test_array_properties():
+ a = ones((1, 2, 3))
+ b = ones((2, 3))
+ assert_raises(ValueError, lambda: a.T)
+
+ assert isinstance(b.T, Array)
+ assert b.T.shape == (3, 2)
+
+ assert isinstance(a.mT, Array)
+ assert a.mT.shape == (1, 3, 2)
+ assert isinstance(b.mT, Array)
+ assert b.mT.shape == (3, 2)
+
+def test___array__():
+ a = ones((2, 3), dtype=int16)
+ assert np.asarray(a) is a._array
+ b = np.asarray(a, dtype=np.float64)
+ assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
+ assert b.dtype == np.float64