diff options
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 31 | ||||
-rw-r--r-- | numpy/core/tests/test_scalarmath.py | 54 |
2 files changed, 42 insertions, 43 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 048b1688f..b30fcb812 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -8576,23 +8576,22 @@ def test_equal_override(): assert_equal(array != my_always_equal, 'ne') -def test_npymath_complex(): +@pytest.mark.parametrize( + ["fun", "npfun"], + [ + (_multiarray_tests.npy_cabs, np.absolute), + (_multiarray_tests.npy_carg, np.angle) + ] +) +@pytest.mark.parametrize("x", [1, np.inf, -np.inf, np.nan]) +@pytest.mark.parametrize("y", [1, np.inf, -np.inf, np.nan]) +@pytest.mark.parametrize("test_dtype", np.complexfloating.__subclasses__()) +def test_npymath_complex(fun, npfun, x, y, test_dtype): # Smoketest npymath functions - from numpy.core._multiarray_tests import ( - npy_cabs, npy_carg) - - funcs = {npy_cabs: np.absolute, - npy_carg: np.angle} - vals = (1, np.inf, -np.inf, np.nan) - types = (np.complex64, np.complex128, np.clongdouble) - - for fun, npfun in funcs.items(): - for x, y in itertools.product(vals, vals): - for t in types: - z = t(complex(x, y)) - got = fun(z) - expected = npfun(z) - assert_allclose(got, expected) + z = test_dtype(complex(x, y)) + got = fun(z) + expected = npfun(z) + assert_allclose(got, expected) def test_npymath_real(): diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index d8529418e..0b615edfa 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -653,33 +653,33 @@ class TestSubtract: class TestAbs: - def _test_abs_func(self, absfunc): - for tp in floating_types + complex_floating_types: - x = tp(-1.5) - assert_equal(absfunc(x), 1.5) - x = tp(0.0) - res = absfunc(x) - # assert_equal() checks zero signedness - assert_equal(res, 0.0) - x = tp(-0.0) - res = absfunc(x) - assert_equal(res, 0.0) - - x = tp(np.finfo(tp).max) - assert_equal(absfunc(x), x.real) - - x = tp(np.finfo(tp).tiny) - assert_equal(absfunc(x), x.real) - - x = tp(np.finfo(tp).min) - assert_equal(absfunc(x), -x.real) - - def test_builtin_abs(self): - self._test_abs_func(abs) - - def test_numpy_abs(self): - self._test_abs_func(np.abs) - + def _test_abs_func(self, absfunc, test_dtype): + x = test_dtype(-1.5) + assert_equal(absfunc(x), 1.5) + x = test_dtype(0.0) + res = absfunc(x) + # assert_equal() checks zero signedness + assert_equal(res, 0.0) + x = test_dtype(-0.0) + res = absfunc(x) + assert_equal(res, 0.0) + + x = test_dtype(np.finfo(test_dtype).max) + assert_equal(absfunc(x), x.real) + + x = test_dtype(np.finfo(test_dtype).tiny) + assert_equal(absfunc(x), x.real) + + x = test_dtype(np.finfo(test_dtype).min) + assert_equal(absfunc(x), -x.real) + + @pytest.mark.parametrize("dtype", floating_types + complex_floating_types) + def test_builtin_abs(self, dtype): + self._test_abs_func(abs, dtype) + + @pytest.mark.parametrize("dtype", floating_types + complex_floating_types) + def test_numpy_abs(self, dtype): + self._test_abs_func(np.abs, dtype) class TestBitShifts: |