diff options
author | Ross Barnowski <rossbar@berkeley.edu> | 2020-03-03 15:02:49 -0800 |
---|---|---|
committer | Ross Barnowski <rossbar@berkeley.edu> | 2020-03-09 10:39:08 -0700 |
commit | c898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9 (patch) | |
tree | 6dc8b71d58e0f64a85134ff8c93b7988b2b942c2 /numpy/core/tests | |
parent | 6894bbc6d396b87464cbc21516d239d5f94f13b7 (diff) | |
download | numpy-c898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9.tar.gz |
ENH: Adds a fast path to var for complex input
var currently has a conditional that results in conjugate being
called for the variance calculation of complex inputs.
This leg of the computation is slow. This PR avoids this
computational leg for complex inputs via a type check.
Closes #15684
Diffstat (limited to 'numpy/core/tests')
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 13244f3ba..829679dab 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -5507,6 +5507,39 @@ class TestStats: res = _var(mat, axis=axis) assert_almost_equal(res, tgt) + @pytest.mark.parametrize(('complex_dtype', 'ndec'), ( + ('complex64', 6), + ('complex128', 7), + ('clongdouble', 7), + )) + def test_var_complex_values(self, complex_dtype, ndec): + # Test fast-paths for every builtin complex type + for axis in [0, 1, None]: + mat = self.cmat.copy().astype(complex_dtype) + msqr = _mean(mat * mat.conj(), axis=axis) + mean = _mean(mat, axis=axis) + tgt = msqr - mean * mean.conjugate() + res = _var(mat, axis=axis) + assert_almost_equal(res, tgt, decimal=ndec) + + def test_var_dimensions(self): + # _var paths for complex number introduce additions on views that + # increase dimensions. Ensure this generalizes to higher dims + mat = np.stack([self.cmat]*3) + for axis in [0, 1, 2, -1, None]: + msqr = _mean(mat * mat.conj(), axis=axis) + mean = _mean(mat, axis=axis) + tgt = msqr - mean * mean.conjugate() + res = _var(mat, axis=axis) + assert_almost_equal(res, tgt) + + def test_var_complex_byteorder(self): + # Test that var fast-path does not cause failures for complex arrays + # with non-native byteorder + cmat = self.cmat.copy().astype('complex128') + cmat_swapped = cmat.astype(cmat.dtype.newbyteorder()) + assert_almost_equal(cmat.var(), cmat_swapped.var()) + def test_std_values(self): for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: |