summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-16 15:02:10 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-16 15:02:10 -0600
commitd9b958259bda8db50949922770101217b2d0b50a (patch)
tree8964fc0ac699dbe87b6b1566f0aaf263e53c9486 /numpy/_array_api
parent185d06d45cbc21ad06d7719ff265b1132d1c41ff (diff)
downloadnumpy-d9b958259bda8db50949922770101217b2d0b50a.tar.gz
Move the dtype check to the array API Array._new constructor
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_array_object.py4
-rw-r--r--numpy/_array_api/_creation_functions.py2
2 files changed, 3 insertions, 3 deletions
diff --git a/numpy/_array_api/_array_object.py b/numpy/_array_api/_array_object.py
index 0659b7b05..267bd698f 100644
--- a/numpy/_array_api/_array_object.py
+++ b/numpy/_array_api/_array_object.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import operator
from enum import IntEnum
from ._creation_functions import asarray
-from ._dtypes import _boolean_dtypes, _integer_dtypes, _floating_dtypes
+from ._dtypes import _all_dtypes, _boolean_dtypes, _integer_dtypes, _floating_dtypes
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
if TYPE_CHECKING:
@@ -61,6 +61,8 @@ class Array:
xa = np.empty((), x.dtype)
xa[()] = x
x = xa
+ if x.dtype not in _all_dtypes:
+ raise TypeError(f"The array_api namespace does not support the dtype '{x.dtype}'")
obj._array = x
return obj
diff --git a/numpy/_array_api/_creation_functions.py b/numpy/_array_api/_creation_functions.py
index 2be0aea09..24d28a3fa 100644
--- a/numpy/_array_api/_creation_functions.py
+++ b/numpy/_array_api/_creation_functions.py
@@ -32,8 +32,6 @@ def asarray(obj: Union[Array, float, NestedSequence[bool|int|float], SupportsDLP
# to an object array.
raise OverflowError("Integer out of bounds for array dtypes")
res = np.asarray(obj, dtype=dtype)
- if res.dtype not in _all_dtypes:
- raise TypeError(f"The array_api namespace does not support the dtype '{res.dtype}'")
return Array._new(res)
def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: