From bc20d334b575f897157b1cf3eecda77f3e40e049 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 4 Aug 2021 20:01:11 -0600 Subject: Move the array API dtype categories into the top level They are not an official part of the spec but are useful for various parts of the implementation. --- numpy/array_api/_array_object.py | 17 ++++------------- numpy/array_api/_dtypes.py | 10 ++++++++++ numpy/array_api/tests/test_elementwise_functions.py | 16 +++------------- 3 files changed, 17 insertions(+), 26 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index af70058e6..50906642d 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -98,23 +98,14 @@ class Array: if other is NotImplemented: return other """ - from ._dtypes import _result_type - - _dtypes = { - 'all': _all_dtypes, - 'numeric': _numeric_dtypes, - 'integer': _integer_dtypes, - 'integer or boolean': _integer_or_boolean_dtypes, - 'boolean': _boolean_dtypes, - 'floating-point': _floating_dtypes, - } - - if self.dtype not in _dtypes[dtype_category]: + from ._dtypes import _result_type, _dtype_categories + + if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') if isinstance(other, (int, float, bool)): other = self._promote_scalar(other) elif isinstance(other, Array): - if other.dtype not in _dtypes[dtype_category]: + if other.dtype not in _dtype_categories[dtype_category]: raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}') else: return NotImplemented diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index fcdb562da..07be267da 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -23,6 +23,16 @@ _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) _integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64) _numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_dtype_categories = { + 'all': _all_dtypes, + 'numeric': _numeric_dtypes, + 'integer': _integer_dtypes, + 'integer or boolean': _integer_or_boolean_dtypes, + 'boolean': _boolean_dtypes, + 'floating-point': _floating_dtypes, +} + + # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + # integer are not allowed, even for functions that accept both kinds. diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index 994cb0bf0..2a5ddbc87 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -4,9 +4,8 @@ from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift -from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes, - _integer_dtypes, _integer_or_boolean_dtypes, - _numeric_dtypes) +from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes, + _integer_dtypes) def nargs(func): return len(getfullargspec(func).args) @@ -75,15 +74,6 @@ def test_function_types(): 'trunc': 'numeric', } - _dtypes = { - 'all': _all_dtypes, - 'numeric': _numeric_dtypes, - 'integer': _integer_dtypes, - 'integer_or_boolean': _integer_or_boolean_dtypes, - 'boolean': _boolean_dtypes, - 'floating': _floating_dtypes, - } - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -94,7 +84,7 @@ def test_function_types(): for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): - dtypes = _dtypes[types] + dtypes = _dtype_categories[types] func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): -- cgit v1.2.1