diff options
author | Matti Picus <matti.picus@gmail.com> | 2020-02-05 14:04:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-05 14:04:33 +0200 |
commit | c4276e2183684c0763f61342022231481cffb329 (patch) | |
tree | 61cbbc3291a6c0753815284e2912ffe1df9ab916 /numpy/core/tests | |
parent | 32ce6b8cea5b464ea9b008e2b0cc3d86615a1bd5 (diff) | |
parent | d5b4b721cce90adea3592c126087f1fbe489784e (diff) | |
download | numpy-c4276e2183684c0763f61342022231481cffb329.tar.gz |
Merge pull request #15408 from r-devulap/cmplx-simd
ENH: Use AVX-512F for complex number arithmetic, absolute, square and conjugate
Diffstat (limited to 'numpy/core/tests')
-rw-r--r-- | numpy/core/tests/test_umath_complex.py | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/numpy/core/tests/test_umath_complex.py b/numpy/core/tests/test_umath_complex.py index 5e5ced85c..a21158420 100644 --- a/numpy/core/tests/test_umath_complex.py +++ b/numpy/core/tests/test_umath_complex.py @@ -6,7 +6,7 @@ import numpy as np # import the c-extension module directly since _arg is not exported via umath import numpy.core._multiarray_umath as ncu from numpy.testing import ( - assert_raises, assert_equal, assert_array_equal, assert_almost_equal + assert_raises, assert_equal, assert_array_equal, assert_almost_equal, assert_array_max_ulp ) # TODO: branch cuts (use Pauli code) @@ -540,3 +540,40 @@ def check_complex_value(f, x1, y1, x2, y2, exact=True): assert_equal(f(z1), z2) else: assert_almost_equal(f(z1), z2) + +class TestSpecialComplexAVX(object): + @pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4]) + @pytest.mark.parametrize("astype", [np.complex64, np.complex128]) + def test_array(self, stride, astype): + arr = np.array([np.complex(np.nan , np.nan), + np.complex(np.nan , np.inf), + np.complex(np.inf , np.nan), + np.complex(np.inf , np.inf), + np.complex(0. , np.inf), + np.complex(np.inf , 0.), + np.complex(0. , 0.), + np.complex(0. , np.nan), + np.complex(np.nan , 0.)], dtype=astype) + abs_true = np.array([np.nan, np.inf, np.inf, np.inf, np.inf, np.inf, 0., np.nan, np.nan], dtype=arr.real.dtype) + sq_true = np.array([np.complex(np.nan, np.nan), + np.complex(np.nan, np.nan), + np.complex(np.nan, np.nan), + np.complex(np.nan, np.inf), + np.complex(-np.inf, np.nan), + np.complex(np.inf, np.nan), + np.complex(0., 0.), + np.complex(np.nan, np.nan), + np.complex(np.nan, np.nan)], dtype=astype) + assert_equal(np.abs(arr[::stride]), abs_true[::stride]) + with np.errstate(invalid='ignore'): + assert_equal(np.square(arr[::stride]), sq_true[::stride]) + +class TestComplexAbsoluteAVX(object): + @pytest.mark.parametrize("arraysize", [1,2,3,4,5,6,7,8,9,10,11,13,15,17,18,19]) + @pytest.mark.parametrize("stride", [-4,-3,-2,-1,1,2,3,4]) + @pytest.mark.parametrize("astype", [np.complex64, np.complex128]) + # test to ensure masking and strides work as intended in the AVX implementation + def test_array(self, arraysize, stride, astype): + arr = np.ones(arraysize, dtype=astype) + abs_true = np.ones(arraysize, dtype=arr.real.dtype) + assert_equal(np.abs(arr[::stride]), abs_true[::stride]) |