summaryrefslogtreecommitdiff
path: root/numpy/core/tests
diff options
context:
space:
mode:
authorRoss Barnowski <rossbar@berkeley.edu>2020-03-03 15:02:49 -0800
committerRoss Barnowski <rossbar@berkeley.edu>2020-03-09 10:39:08 -0700
commitc898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9 (patch)
tree6dc8b71d58e0f64a85134ff8c93b7988b2b942c2 /numpy/core/tests
parent6894bbc6d396b87464cbc21516d239d5f94f13b7 (diff)
downloadnumpy-c898ff3f0338f2b24b8f4e7d1b4fffc36d3e51c9.tar.gz
ENH: Adds a fast path to var for complex input
var currently has a conditional that results in conjugate being called for the variance calculation of complex inputs. This leg of the computation is slow. This PR avoids this computational leg for complex inputs via a type check. Closes #15684
Diffstat (limited to 'numpy/core/tests')
-rw-r--r--numpy/core/tests/test_multiarray.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 13244f3ba..829679dab 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -5507,6 +5507,39 @@ class TestStats:
res = _var(mat, axis=axis)
assert_almost_equal(res, tgt)
+ @pytest.mark.parametrize(('complex_dtype', 'ndec'), (
+ ('complex64', 6),
+ ('complex128', 7),
+ ('clongdouble', 7),
+ ))
+ def test_var_complex_values(self, complex_dtype, ndec):
+ # Test fast-paths for every builtin complex type
+ for axis in [0, 1, None]:
+ mat = self.cmat.copy().astype(complex_dtype)
+ msqr = _mean(mat * mat.conj(), axis=axis)
+ mean = _mean(mat, axis=axis)
+ tgt = msqr - mean * mean.conjugate()
+ res = _var(mat, axis=axis)
+ assert_almost_equal(res, tgt, decimal=ndec)
+
+ def test_var_dimensions(self):
+ # _var paths for complex number introduce additions on views that
+ # increase dimensions. Ensure this generalizes to higher dims
+ mat = np.stack([self.cmat]*3)
+ for axis in [0, 1, 2, -1, None]:
+ msqr = _mean(mat * mat.conj(), axis=axis)
+ mean = _mean(mat, axis=axis)
+ tgt = msqr - mean * mean.conjugate()
+ res = _var(mat, axis=axis)
+ assert_almost_equal(res, tgt)
+
+ def test_var_complex_byteorder(self):
+ # Test that var fast-path does not cause failures for complex arrays
+ # with non-native byteorder
+ cmat = self.cmat.copy().astype('complex128')
+ cmat_swapped = cmat.astype(cmat.dtype.newbyteorder())
+ assert_almost_equal(cmat.var(), cmat_swapped.var())
+
def test_std_values(self):
for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1, None]: