diff options
Diffstat (limited to 'numpy/array_api/tests/test_elementwise_functions.py')
-rw-r--r-- | numpy/array_api/tests/test_elementwise_functions.py | 133 |
1 files changed, 72 insertions, 61 deletions
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index ec76cb7a7..a9274aec9 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -4,74 +4,80 @@ from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift -from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes, - _integer_dtypes) +from .._dtypes import ( + _dtype_categories, + _boolean_dtypes, + _floating_dtypes, + _integer_dtypes, +) + def nargs(func): return len(getfullargspec(func).args) + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in # the array API test suite. elementwise_function_input_types = { - 'abs': 'numeric', - 'acos': 'floating-point', - 'acosh': 'floating-point', - 'add': 'numeric', - 'asin': 'floating-point', - 'asinh': 'floating-point', - 'atan': 'floating-point', - 'atan2': 'floating-point', - 'atanh': 'floating-point', - 'bitwise_and': 'integer or boolean', - 'bitwise_invert': 'integer or boolean', - 'bitwise_left_shift': 'integer', - 'bitwise_or': 'integer or boolean', - 'bitwise_right_shift': 'integer', - 'bitwise_xor': 'integer or boolean', - 'ceil': 'numeric', - 'cos': 'floating-point', - 'cosh': 'floating-point', - 'divide': 'floating-point', - 'equal': 'all', - 'exp': 'floating-point', - 'expm1': 'floating-point', - 'floor': 'numeric', - 'floor_divide': 'numeric', - 'greater': 'numeric', - 'greater_equal': 'numeric', - 'isfinite': 'numeric', - 'isinf': 'numeric', - 'isnan': 'numeric', - 'less': 'numeric', - 'less_equal': 'numeric', - 'log': 'floating-point', - 'logaddexp': 'floating-point', - 'log10': 'floating-point', - 'log1p': 'floating-point', - 'log2': 'floating-point', - 'logical_and': 'boolean', - 'logical_not': 'boolean', - 'logical_or': 'boolean', - 'logical_xor': 'boolean', - 'multiply': 'numeric', - 'negative': 'numeric', - 'not_equal': 'all', - 'positive': 'numeric', - 'pow': 'floating-point', - 'remainder': 'numeric', - 'round': 'numeric', - 'sign': 'numeric', - 'sin': 'floating-point', - 'sinh': 'floating-point', - 'sqrt': 'floating-point', - 'square': 'numeric', - 'subtract': 'numeric', - 'tan': 'floating-point', - 'tanh': 'floating-point', - 'trunc': 'numeric', + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "floating-point", + "atanh": "floating-point", + "bitwise_and": "integer or boolean", + "bitwise_invert": "integer or boolean", + "bitwise_left_shift": "integer", + "bitwise_or": "integer or boolean", + "bitwise_right_shift": "integer", + "bitwise_xor": "integer or boolean", + "ceil": "numeric", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "numeric", + "floor_divide": "numeric", + "greater": "numeric", + "greater_equal": "numeric", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "numeric", + "less_equal": "numeric", + "log": "floating-point", + "logaddexp": "floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + "logical_and": "boolean", + "logical_not": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "floating-point", + "remainder": "numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "numeric", } def _array_vals(): @@ -80,7 +86,7 @@ def test_function_types(): for d in _boolean_dtypes: yield asarray(False, dtype=d) for d in _floating_dtypes: - yield asarray(1., dtype=d) + yield asarray(1.0, dtype=d) for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): @@ -94,7 +100,12 @@ def test_function_types(): if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x)) + def test_bitwise_shift_error(): # bitwise shift functions should raise when the second argument is negative - assert_raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) - assert_raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) + assert_raises( + ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) + ) + assert_raises( + ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) + ) |