summaryrefslogtreecommitdiff
path: root/numpy/core/tests
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-02-05 14:04:33 +0200
committerGitHub <noreply@github.com>2020-02-05 14:04:33 +0200
commitc4276e2183684c0763f61342022231481cffb329 (patch)
tree61cbbc3291a6c0753815284e2912ffe1df9ab916 /numpy/core/tests
parent32ce6b8cea5b464ea9b008e2b0cc3d86615a1bd5 (diff)
parentd5b4b721cce90adea3592c126087f1fbe489784e (diff)
downloadnumpy-c4276e2183684c0763f61342022231481cffb329.tar.gz
Merge pull request #15408 from r-devulap/cmplx-simd
ENH: Use AVX-512F for complex number arithmetic, absolute, square and conjugate
Diffstat (limited to 'numpy/core/tests')
-rw-r--r--numpy/core/tests/test_umath_complex.py39
1 files changed, 38 insertions, 1 deletions
diff --git a/numpy/core/tests/test_umath_complex.py b/numpy/core/tests/test_umath_complex.py
index 5e5ced85c..a21158420 100644
--- a/numpy/core/tests/test_umath_complex.py
+++ b/numpy/core/tests/test_umath_complex.py
@@ -6,7 +6,7 @@ import numpy as np
# import the c-extension module directly since _arg is not exported via umath
import numpy.core._multiarray_umath as ncu
from numpy.testing import (
- assert_raises, assert_equal, assert_array_equal, assert_almost_equal
+ assert_raises, assert_equal, assert_array_equal, assert_almost_equal, assert_array_max_ulp
)
# TODO: branch cuts (use Pauli code)
@@ -540,3 +540,40 @@ def check_complex_value(f, x1, y1, x2, y2, exact=True):
assert_equal(f(z1), z2)
else:
assert_almost_equal(f(z1), z2)
+
+class TestSpecialComplexAVX(object):
+ @pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4])
+ @pytest.mark.parametrize("astype", [np.complex64, np.complex128])
+ def test_array(self, stride, astype):
+ arr = np.array([np.complex(np.nan , np.nan),
+ np.complex(np.nan , np.inf),
+ np.complex(np.inf , np.nan),
+ np.complex(np.inf , np.inf),
+ np.complex(0. , np.inf),
+ np.complex(np.inf , 0.),
+ np.complex(0. , 0.),
+ np.complex(0. , np.nan),
+ np.complex(np.nan , 0.)], dtype=astype)
+ abs_true = np.array([np.nan, np.inf, np.inf, np.inf, np.inf, np.inf, 0., np.nan, np.nan], dtype=arr.real.dtype)
+ sq_true = np.array([np.complex(np.nan, np.nan),
+ np.complex(np.nan, np.nan),
+ np.complex(np.nan, np.nan),
+ np.complex(np.nan, np.inf),
+ np.complex(-np.inf, np.nan),
+ np.complex(np.inf, np.nan),
+ np.complex(0., 0.),
+ np.complex(np.nan, np.nan),
+ np.complex(np.nan, np.nan)], dtype=astype)
+ assert_equal(np.abs(arr[::stride]), abs_true[::stride])
+ with np.errstate(invalid='ignore'):
+ assert_equal(np.square(arr[::stride]), sq_true[::stride])
+
+class TestComplexAbsoluteAVX(object):
+ @pytest.mark.parametrize("arraysize", [1,2,3,4,5,6,7,8,9,10,11,13,15,17,18,19])
+ @pytest.mark.parametrize("stride", [-4,-3,-2,-1,1,2,3,4])
+ @pytest.mark.parametrize("astype", [np.complex64, np.complex128])
+ # test to ensure masking and strides work as intended in the AVX implementation
+ def test_array(self, arraysize, stride, astype):
+ arr = np.ones(arraysize, dtype=astype)
+ abs_true = np.ones(arraysize, dtype=arr.real.dtype)
+ assert_equal(np.abs(arr[::stride]), abs_true[::stride])