summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-03-10 16:03:48 -0700
committerGitHub <noreply@github.com>2020-03-10 16:03:48 -0700
commit2e9169601aff252a661b845399ec61c3e575407f (patch)
tree9fdd9d6cd1678d2f7acb3d47a9e1831df06739fa
parentc9bfd4eb68e61c67aa27ed0cb2788f60d11cf354 (diff)
parentffe1f46121cd11b2b876d20ba1758a09cb4e5be7 (diff)
downloadnumpy-2e9169601aff252a661b845399ec61c3e575407f.tar.gz
Merge pull request #15696 from rossbar/enh/var_complex_fastpath
MAINT: Add a fast path to var for complex input
-rw-r--r--benchmarks/benchmarks/bench_core.py11
-rw-r--r--numpy/core/_methods.py23
-rw-r--r--numpy/core/tests/test_multiarray.py33
3 files changed, 67 insertions, 0 deletions
diff --git a/benchmarks/benchmarks/bench_core.py b/benchmarks/benchmarks/bench_core.py
index 94d3ad503..060d0f7db 100644
--- a/benchmarks/benchmarks/bench_core.py
+++ b/benchmarks/benchmarks/bench_core.py
@@ -180,3 +180,14 @@ class UnpackBits(Benchmark):
class Indices(Benchmark):
def time_indices(self):
np.indices((1000, 500))
+
+class VarComplex(Benchmark):
+ params = [10**n for n in range(1, 9)]
+ def setup(self, n):
+ self.arr = np.random.randn(n) + 1j * np.random.randn(n)
+
+ def teardown(self, n):
+ del self.arr
+
+ def time_var(self, n):
+ self.arr.var()
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 694523b20..8a90731e9 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -21,6 +21,21 @@ umr_prod = um.multiply.reduce
umr_any = um.logical_or.reduce
umr_all = um.logical_and.reduce
+# Complex types to -> (2,)float view for fast-path computation in _var()
+_complex_to_float = {
+ nt.dtype(nt.csingle) : nt.dtype(nt.single),
+ nt.dtype(nt.cdouble) : nt.dtype(nt.double),
+}
+# Special case for windows: ensure double takes precedence
+if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
+ _complex_to_float.update({
+ nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
+ })
+# Add reverse-endian types
+_complex_to_float.update({
+ k.newbyteorder() : v.newbyteorder() for k, v in _complex_to_float.items()
+})
+
# avoid keyword arguments to speed up parsing, saves about 15%-20% for very
# small reductions
def _amax(a, axis=None, out=None, keepdims=False,
@@ -189,8 +204,16 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
# Note that x may not be inexact and that we need it to be an array,
# not a scalar.
x = asanyarray(arr - arrmean)
+
if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
x = um.multiply(x, x, out=x)
+ # Fast-paths for built-in complex types
+ elif x.dtype in _complex_to_float:
+ xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
+ um.multiply(xv, xv, out=xv)
+ x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
+ # Most general case; includes handling object arrays containing imaginary
+ # numbers and complex types with non-native byteorder
else:
x = um.multiply(x, um.conjugate(x), out=x).real
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 13244f3ba..829679dab 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -5507,6 +5507,39 @@ class TestStats:
res = _var(mat, axis=axis)
assert_almost_equal(res, tgt)
+ @pytest.mark.parametrize(('complex_dtype', 'ndec'), (
+ ('complex64', 6),
+ ('complex128', 7),
+ ('clongdouble', 7),
+ ))
+ def test_var_complex_values(self, complex_dtype, ndec):
+ # Test fast-paths for every builtin complex type
+ for axis in [0, 1, None]:
+ mat = self.cmat.copy().astype(complex_dtype)
+ msqr = _mean(mat * mat.conj(), axis=axis)
+ mean = _mean(mat, axis=axis)
+ tgt = msqr - mean * mean.conjugate()
+ res = _var(mat, axis=axis)
+ assert_almost_equal(res, tgt, decimal=ndec)
+
+ def test_var_dimensions(self):
+ # _var paths for complex number introduce additions on views that
+ # increase dimensions. Ensure this generalizes to higher dims
+ mat = np.stack([self.cmat]*3)
+ for axis in [0, 1, 2, -1, None]:
+ msqr = _mean(mat * mat.conj(), axis=axis)
+ mean = _mean(mat, axis=axis)
+ tgt = msqr - mean * mean.conjugate()
+ res = _var(mat, axis=axis)
+ assert_almost_equal(res, tgt)
+
+ def test_var_complex_byteorder(self):
+ # Test that var fast-path does not cause failures for complex arrays
+ # with non-native byteorder
+ cmat = self.cmat.copy().astype('complex128')
+ cmat_swapped = cmat.astype(cmat.dtype.newbyteorder())
+ assert_almost_equal(cmat.var(), cmat_swapped.var())
+
def test_std_values(self):
for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1, None]: