summaryrefslogtreecommitdiff
path: root/numpy/core/tests
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-03-10 16:03:48 -0700
committerGitHub <noreply@github.com>2020-03-10 16:03:48 -0700
commit2e9169601aff252a661b845399ec61c3e575407f (patch)
tree9fdd9d6cd1678d2f7acb3d47a9e1831df06739fa /numpy/core/tests
parentc9bfd4eb68e61c67aa27ed0cb2788f60d11cf354 (diff)
parentffe1f46121cd11b2b876d20ba1758a09cb4e5be7 (diff)
downloadnumpy-2e9169601aff252a661b845399ec61c3e575407f.tar.gz
Merge pull request #15696 from rossbar/enh/var_complex_fastpath
MAINT: Add a fast path to var for complex input
Diffstat (limited to 'numpy/core/tests')
-rw-r--r--numpy/core/tests/test_multiarray.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 13244f3ba..829679dab 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -5507,6 +5507,39 @@ class TestStats:
res = _var(mat, axis=axis)
assert_almost_equal(res, tgt)
+ @pytest.mark.parametrize(('complex_dtype', 'ndec'), (
+ ('complex64', 6),
+ ('complex128', 7),
+ ('clongdouble', 7),
+ ))
+ def test_var_complex_values(self, complex_dtype, ndec):
+ # Test fast-paths for every builtin complex type
+ for axis in [0, 1, None]:
+ mat = self.cmat.copy().astype(complex_dtype)
+ msqr = _mean(mat * mat.conj(), axis=axis)
+ mean = _mean(mat, axis=axis)
+ tgt = msqr - mean * mean.conjugate()
+ res = _var(mat, axis=axis)
+ assert_almost_equal(res, tgt, decimal=ndec)
+
+ def test_var_dimensions(self):
+ # _var paths for complex number introduce additions on views that
+ # increase dimensions. Ensure this generalizes to higher dims
+ mat = np.stack([self.cmat]*3)
+ for axis in [0, 1, 2, -1, None]:
+ msqr = _mean(mat * mat.conj(), axis=axis)
+ mean = _mean(mat, axis=axis)
+ tgt = msqr - mean * mean.conjugate()
+ res = _var(mat, axis=axis)
+ assert_almost_equal(res, tgt)
+
+ def test_var_complex_byteorder(self):
+ # Test that var fast-path does not cause failures for complex arrays
+ # with non-native byteorder
+ cmat = self.cmat.copy().astype('complex128')
+ cmat_swapped = cmat.astype(cmat.dtype.newbyteorder())
+ assert_almost_equal(cmat.var(), cmat_swapped.var())
+
def test_std_values(self):
for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1, None]: