summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py17
1 files changed, 4 insertions, 13 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index af70058e6..50906642d 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -98,23 +98,14 @@ class Array:
if other is NotImplemented:
return other
"""
- from ._dtypes import _result_type
-
- _dtypes = {
- 'all': _all_dtypes,
- 'numeric': _numeric_dtypes,
- 'integer': _integer_dtypes,
- 'integer or boolean': _integer_or_boolean_dtypes,
- 'boolean': _boolean_dtypes,
- 'floating-point': _floating_dtypes,
- }
-
- if self.dtype not in _dtypes[dtype_category]:
+ from ._dtypes import _result_type, _dtype_categories
+
+ if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
if isinstance(other, (int, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, Array):
- if other.dtype not in _dtypes[dtype_category]:
+ if other.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
else:
return NotImplemented