diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-03-10 16:03:48 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-10 16:03:48 -0700 |
commit | 2e9169601aff252a661b845399ec61c3e575407f (patch) | |
tree | 9fdd9d6cd1678d2f7acb3d47a9e1831df06739fa /numpy/core/tests | |
parent | c9bfd4eb68e61c67aa27ed0cb2788f60d11cf354 (diff) | |
parent | ffe1f46121cd11b2b876d20ba1758a09cb4e5be7 (diff) | |
download | numpy-2e9169601aff252a661b845399ec61c3e575407f.tar.gz |
Merge pull request #15696 from rossbar/enh/var_complex_fastpath
MAINT: Add a fast path to var for complex input
Diffstat (limited to 'numpy/core/tests')
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 13244f3ba..829679dab 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -5507,6 +5507,39 @@ class TestStats: res = _var(mat, axis=axis) assert_almost_equal(res, tgt) + @pytest.mark.parametrize(('complex_dtype', 'ndec'), ( + ('complex64', 6), + ('complex128', 7), + ('clongdouble', 7), + )) + def test_var_complex_values(self, complex_dtype, ndec): + # Test fast-paths for every builtin complex type + for axis in [0, 1, None]: + mat = self.cmat.copy().astype(complex_dtype) + msqr = _mean(mat * mat.conj(), axis=axis) + mean = _mean(mat, axis=axis) + tgt = msqr - mean * mean.conjugate() + res = _var(mat, axis=axis) + assert_almost_equal(res, tgt, decimal=ndec) + + def test_var_dimensions(self): + # _var paths for complex number introduce additions on views that + # increase dimensions. Ensure this generalizes to higher dims + mat = np.stack([self.cmat]*3) + for axis in [0, 1, 2, -1, None]: + msqr = _mean(mat * mat.conj(), axis=axis) + mean = _mean(mat, axis=axis) + tgt = msqr - mean * mean.conjugate() + res = _var(mat, axis=axis) + assert_almost_equal(res, tgt) + + def test_var_complex_byteorder(self): + # Test that var fast-path does not cause failures for complex arrays + # with non-native byteorder + cmat = self.cmat.copy().astype('complex128') + cmat_swapped = cmat.astype(cmat.dtype.newbyteorder()) + assert_almost_equal(cmat.var(), cmat_swapped.var()) + def test_std_values(self): for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: |