summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/tests/test_multiarray.py31
-rw-r--r--numpy/core/tests/test_scalarmath.py54
2 files changed, 42 insertions, 43 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 048b1688f..b30fcb812 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -8576,23 +8576,22 @@ def test_equal_override():
assert_equal(array != my_always_equal, 'ne')
-def test_npymath_complex():
+@pytest.mark.parametrize(
+ ["fun", "npfun"],
+ [
+ (_multiarray_tests.npy_cabs, np.absolute),
+ (_multiarray_tests.npy_carg, np.angle)
+ ]
+)
+@pytest.mark.parametrize("x", [1, np.inf, -np.inf, np.nan])
+@pytest.mark.parametrize("y", [1, np.inf, -np.inf, np.nan])
+@pytest.mark.parametrize("test_dtype", np.complexfloating.__subclasses__())
+def test_npymath_complex(fun, npfun, x, y, test_dtype):
# Smoketest npymath functions
- from numpy.core._multiarray_tests import (
- npy_cabs, npy_carg)
-
- funcs = {npy_cabs: np.absolute,
- npy_carg: np.angle}
- vals = (1, np.inf, -np.inf, np.nan)
- types = (np.complex64, np.complex128, np.clongdouble)
-
- for fun, npfun in funcs.items():
- for x, y in itertools.product(vals, vals):
- for t in types:
- z = t(complex(x, y))
- got = fun(z)
- expected = npfun(z)
- assert_allclose(got, expected)
+ z = test_dtype(complex(x, y))
+ got = fun(z)
+ expected = npfun(z)
+ assert_allclose(got, expected)
def test_npymath_real():
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index d8529418e..0b615edfa 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -653,33 +653,33 @@ class TestSubtract:
class TestAbs:
- def _test_abs_func(self, absfunc):
- for tp in floating_types + complex_floating_types:
- x = tp(-1.5)
- assert_equal(absfunc(x), 1.5)
- x = tp(0.0)
- res = absfunc(x)
- # assert_equal() checks zero signedness
- assert_equal(res, 0.0)
- x = tp(-0.0)
- res = absfunc(x)
- assert_equal(res, 0.0)
-
- x = tp(np.finfo(tp).max)
- assert_equal(absfunc(x), x.real)
-
- x = tp(np.finfo(tp).tiny)
- assert_equal(absfunc(x), x.real)
-
- x = tp(np.finfo(tp).min)
- assert_equal(absfunc(x), -x.real)
-
- def test_builtin_abs(self):
- self._test_abs_func(abs)
-
- def test_numpy_abs(self):
- self._test_abs_func(np.abs)
-
+ def _test_abs_func(self, absfunc, test_dtype):
+ x = test_dtype(-1.5)
+ assert_equal(absfunc(x), 1.5)
+ x = test_dtype(0.0)
+ res = absfunc(x)
+ # assert_equal() checks zero signedness
+ assert_equal(res, 0.0)
+ x = test_dtype(-0.0)
+ res = absfunc(x)
+ assert_equal(res, 0.0)
+
+ x = test_dtype(np.finfo(test_dtype).max)
+ assert_equal(absfunc(x), x.real)
+
+ x = test_dtype(np.finfo(test_dtype).tiny)
+ assert_equal(absfunc(x), x.real)
+
+ x = test_dtype(np.finfo(test_dtype).min)
+ assert_equal(absfunc(x), -x.real)
+
+ @pytest.mark.parametrize("dtype", floating_types + complex_floating_types)
+ def test_builtin_abs(self, dtype):
+ self._test_abs_func(abs, dtype)
+
+ @pytest.mark.parametrize("dtype", floating_types + complex_floating_types)
+ def test_numpy_abs(self, dtype):
+ self._test_abs_func(np.abs, dtype)
class TestBitShifts: