summaryrefslogtreecommitdiff
path: root/numpy/array_api/tests/test_elementwise_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-08-04 16:47:05 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-08-04 16:50:30 -0600
commit6e57d829cb6628610e163524f203245b247a2839 (patch)
treef15f4900f995835bbd8526d7a4918a4d776d63e2 /numpy/array_api/tests/test_elementwise_functions.py
parent1596415c32f6008fcacc14a3a5394787aeb44265 (diff)
downloadnumpy-6e57d829cb6628610e163524f203245b247a2839.tar.gz
Rename numpy._array_api to numpy.array_api
Instead of the leading underscore, the experimentalness of the module will be indicated by omitting a warning on import. That we, we do not have to change the API from underscore to no underscore when the module is no longer experimental.
Diffstat (limited to 'numpy/array_api/tests/test_elementwise_functions.py')
-rw-r--r--numpy/array_api/tests/test_elementwise_functions.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py
new file mode 100644
index 000000000..994cb0bf0
--- /dev/null
+++ b/numpy/array_api/tests/test_elementwise_functions.py
@@ -0,0 +1,110 @@
+from inspect import getfullargspec
+
+from numpy.testing import assert_raises
+
+from .. import asarray, _elementwise_functions
+from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
+from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
+ _integer_dtypes, _integer_or_boolean_dtypes,
+ _numeric_dtypes)
+
+def nargs(func):
+ return len(getfullargspec(func).args)
+
+def test_function_types():
+ # Test that every function accepts only the required input types. We only
+ # test the negative cases here (error). The positive cases are tested in
+ # the array API test suite.
+
+ elementwise_function_input_types = {
+ 'abs': 'numeric',
+ 'acos': 'floating',
+ 'acosh': 'floating',
+ 'add': 'numeric',
+ 'asin': 'floating',
+ 'asinh': 'floating',
+ 'atan': 'floating',
+ 'atan2': 'floating',
+ 'atanh': 'floating',
+ 'bitwise_and': 'integer_or_boolean',
+ 'bitwise_invert': 'integer_or_boolean',
+ 'bitwise_left_shift': 'integer',
+ 'bitwise_or': 'integer_or_boolean',
+ 'bitwise_right_shift': 'integer',
+ 'bitwise_xor': 'integer_or_boolean',
+ 'ceil': 'numeric',
+ 'cos': 'floating',
+ 'cosh': 'floating',
+ 'divide': 'floating',
+ 'equal': 'all',
+ 'exp': 'floating',
+ 'expm1': 'floating',
+ 'floor': 'numeric',
+ 'floor_divide': 'numeric',
+ 'greater': 'numeric',
+ 'greater_equal': 'numeric',
+ 'isfinite': 'numeric',
+ 'isinf': 'numeric',
+ 'isnan': 'numeric',
+ 'less': 'numeric',
+ 'less_equal': 'numeric',
+ 'log': 'floating',
+ 'logaddexp': 'floating',
+ 'log10': 'floating',
+ 'log1p': 'floating',
+ 'log2': 'floating',
+ 'logical_and': 'boolean',
+ 'logical_not': 'boolean',
+ 'logical_or': 'boolean',
+ 'logical_xor': 'boolean',
+ 'multiply': 'numeric',
+ 'negative': 'numeric',
+ 'not_equal': 'all',
+ 'positive': 'numeric',
+ 'pow': 'floating',
+ 'remainder': 'numeric',
+ 'round': 'numeric',
+ 'sign': 'numeric',
+ 'sin': 'floating',
+ 'sinh': 'floating',
+ 'sqrt': 'floating',
+ 'square': 'numeric',
+ 'subtract': 'numeric',
+ 'tan': 'floating',
+ 'tanh': 'floating',
+ 'trunc': 'numeric',
+ }
+
+ _dtypes = {
+ 'all': _all_dtypes,
+ 'numeric': _numeric_dtypes,
+ 'integer': _integer_dtypes,
+ 'integer_or_boolean': _integer_or_boolean_dtypes,
+ 'boolean': _boolean_dtypes,
+ 'floating': _floating_dtypes,
+ }
+
+ def _array_vals():
+ for d in _integer_dtypes:
+ yield asarray(1, dtype=d)
+ for d in _boolean_dtypes:
+ yield asarray(False, dtype=d)
+ for d in _floating_dtypes:
+ yield asarray(1., dtype=d)
+
+ for x in _array_vals():
+ for func_name, types in elementwise_function_input_types.items():
+ dtypes = _dtypes[types]
+ func = getattr(_elementwise_functions, func_name)
+ if nargs(func) == 2:
+ for y in _array_vals():
+ if x.dtype not in dtypes or y.dtype not in dtypes:
+ assert_raises(TypeError, lambda: func(x, y))
+ else:
+ if x.dtype not in dtypes:
+ assert_raises(TypeError, lambda: func(x))
+
+def test_bitwise_shift_error():
+ # bitwise shift functions should raise when the second argument is negative
+ assert_raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])))
+ assert_raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])))