from inspect import getfullargspec from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes, _integer_dtypes) def nargs(func): return len(getfullargspec(func).args) def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in # the array API test suite. elementwise_function_input_types = { 'abs': 'numeric', 'acos': 'floating-point', 'acosh': 'floating-point', 'add': 'numeric', 'asin': 'floating-point', 'asinh': 'floating-point', 'atan': 'floating-point', 'atan2': 'floating-point', 'atanh': 'floating-point', 'bitwise_and': 'integer or boolean', 'bitwise_invert': 'integer or boolean', 'bitwise_left_shift': 'integer', 'bitwise_or': 'integer or boolean', 'bitwise_right_shift': 'integer', 'bitwise_xor': 'integer or boolean', 'ceil': 'numeric', 'cos': 'floating-point', 'cosh': 'floating-point', 'divide': 'floating-point', 'equal': 'all', 'exp': 'floating-point', 'expm1': 'floating-point', 'floor': 'numeric', 'floor_divide': 'numeric', 'greater': 'numeric', 'greater_equal': 'numeric', 'isfinite': 'numeric', 'isinf': 'numeric', 'isnan': 'numeric', 'less': 'numeric', 'less_equal': 'numeric', 'log': 'floating-point', 'logaddexp': 'floating-point', 'log10': 'floating-point', 'log1p': 'floating-point', 'log2': 'floating-point', 'logical_and': 'boolean', 'logical_not': 'boolean', 'logical_or': 'boolean', 'logical_xor': 'boolean', 'multiply': 'numeric', 'negative': 'numeric', 'not_equal': 'all', 'positive': 'numeric', 'pow': 'floating-point', 'remainder': 'numeric', 'round': 'numeric', 'sign': 'numeric', 'sin': 'floating-point', 'sinh': 'floating-point', 'sqrt': 'floating-point', 'square': 'numeric', 'subtract': 'numeric', 'tan': 'floating-point', 'tanh': 'floating-point', 'trunc': 'numeric', } def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) for d in _boolean_dtypes: yield asarray(False, dtype=d) for d in _floating_dtypes: yield asarray(1., dtype=d) for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): if x.dtype not in dtypes or y.dtype not in dtypes: assert_raises(TypeError, lambda: func(x, y)) else: if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x)) def test_bitwise_shift_error(): # bitwise shift functions should raise when the second argument is negative assert_raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) assert_raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])))