diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-08-23 21:32:20 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-08-23 21:32:20 -0600 |
commit | 098f874144161b6a49efa5108846a408ca8f39b8 (patch) | |
tree | d618c32a54705d38e5458669774c88d9f6225212 /numpy/array_api/tests/test_elementwise_functions.py | |
parent | a3ac75c6f92ed158777492f343dc59adeacb745c (diff) | |
parent | 7091e4c48ce7af8a5263b6808a6d7976d4af4c6f (diff) | |
download | numpy-098f874144161b6a49efa5108846a408ca8f39b8.tar.gz |
Merge pull request #18585 from data-apis/array-api
ENH: Implementation of the NEP 47 (adopting the array API standard)
Diffstat (limited to 'numpy/array_api/tests/test_elementwise_functions.py')
-rw-r--r-- | numpy/array_api/tests/test_elementwise_functions.py | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py new file mode 100644 index 000000000..a9274aec9 --- /dev/null +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -0,0 +1,111 @@ +from inspect import getfullargspec + +from numpy.testing import assert_raises + +from .. import asarray, _elementwise_functions +from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift +from .._dtypes import ( + _dtype_categories, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, +) + + +def nargs(func): + return len(getfullargspec(func).args) + + +def test_function_types(): + # Test that every function accepts only the required input types. We only + # test the negative cases here (error). The positive cases are tested in + # the array API test suite. + + elementwise_function_input_types = { + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "floating-point", + "atanh": "floating-point", + "bitwise_and": "integer or boolean", + "bitwise_invert": "integer or boolean", + "bitwise_left_shift": "integer", + "bitwise_or": "integer or boolean", + "bitwise_right_shift": "integer", + "bitwise_xor": "integer or boolean", + "ceil": "numeric", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "numeric", + "floor_divide": "numeric", + "greater": "numeric", + "greater_equal": "numeric", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "numeric", + "less_equal": "numeric", + "log": "floating-point", + "logaddexp": "floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + "logical_and": "boolean", + "logical_not": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "floating-point", + "remainder": "numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "numeric", + } + + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1.0, dtype=d) + + for x in _array_vals(): + for func_name, types in elementwise_function_input_types.items(): + dtypes = _dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + if nargs(func) == 2: + for y in _array_vals(): + if x.dtype not in dtypes or y.dtype not in dtypes: + assert_raises(TypeError, lambda: func(x, y)) + else: + if x.dtype not in dtypes: + assert_raises(TypeError, lambda: func(x)) + + +def test_bitwise_shift_error(): + # bitwise shift functions should raise when the second argument is negative + assert_raises( + ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) + ) + assert_raises( + ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) + ) |