diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-03-10 16:03:48 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-10 16:03:48 -0700 |
commit | 2e9169601aff252a661b845399ec61c3e575407f (patch) | |
tree | 9fdd9d6cd1678d2f7acb3d47a9e1831df06739fa | |
parent | c9bfd4eb68e61c67aa27ed0cb2788f60d11cf354 (diff) | |
parent | ffe1f46121cd11b2b876d20ba1758a09cb4e5be7 (diff) | |
download | numpy-2e9169601aff252a661b845399ec61c3e575407f.tar.gz |
Merge pull request #15696 from rossbar/enh/var_complex_fastpath
MAINT: Add a fast path to var for complex input
-rw-r--r-- | benchmarks/benchmarks/bench_core.py | 11 | ||||
-rw-r--r-- | numpy/core/_methods.py | 23 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 33 |
3 files changed, 67 insertions, 0 deletions
diff --git a/benchmarks/benchmarks/bench_core.py b/benchmarks/benchmarks/bench_core.py index 94d3ad503..060d0f7db 100644 --- a/benchmarks/benchmarks/bench_core.py +++ b/benchmarks/benchmarks/bench_core.py @@ -180,3 +180,14 @@ class UnpackBits(Benchmark): class Indices(Benchmark): def time_indices(self): np.indices((1000, 500)) + +class VarComplex(Benchmark): + params = [10**n for n in range(1, 9)] + def setup(self, n): + self.arr = np.random.randn(n) + 1j * np.random.randn(n) + + def teardown(self, n): + del self.arr + + def time_var(self, n): + self.arr.var() diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 694523b20..8a90731e9 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -21,6 +21,21 @@ umr_prod = um.multiply.reduce umr_any = um.logical_or.reduce umr_all = um.logical_and.reduce +# Complex types to -> (2,)float view for fast-path computation in _var() +_complex_to_float = { + nt.dtype(nt.csingle) : nt.dtype(nt.single), + nt.dtype(nt.cdouble) : nt.dtype(nt.double), +} +# Special case for windows: ensure double takes precedence +if nt.dtype(nt.longdouble) != nt.dtype(nt.double): + _complex_to_float.update({ + nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble), + }) +# Add reverse-endian types +_complex_to_float.update({ + k.newbyteorder() : v.newbyteorder() for k, v in _complex_to_float.items() +}) + # avoid keyword arguments to speed up parsing, saves about 15%-20% for very # small reductions def _amax(a, axis=None, out=None, keepdims=False, @@ -189,8 +204,16 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # Note that x may not be inexact and that we need it to be an array, # not a scalar. x = asanyarray(arr - arrmean) + if issubclass(arr.dtype.type, (nt.floating, nt.integer)): x = um.multiply(x, x, out=x) + # Fast-paths for built-in complex types + elif x.dtype in _complex_to_float: + xv = x.view(dtype=(_complex_to_float[x.dtype], (2,))) + um.multiply(xv, xv, out=xv) + x = um.add(xv[..., 0], xv[..., 1], out=x.real).real + # Most general case; includes handling object arrays containing imaginary + # numbers and complex types with non-native byteorder else: x = um.multiply(x, um.conjugate(x), out=x).real diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 13244f3ba..829679dab 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -5507,6 +5507,39 @@ class TestStats: res = _var(mat, axis=axis) assert_almost_equal(res, tgt) + @pytest.mark.parametrize(('complex_dtype', 'ndec'), ( + ('complex64', 6), + ('complex128', 7), + ('clongdouble', 7), + )) + def test_var_complex_values(self, complex_dtype, ndec): + # Test fast-paths for every builtin complex type + for axis in [0, 1, None]: + mat = self.cmat.copy().astype(complex_dtype) + msqr = _mean(mat * mat.conj(), axis=axis) + mean = _mean(mat, axis=axis) + tgt = msqr - mean * mean.conjugate() + res = _var(mat, axis=axis) + assert_almost_equal(res, tgt, decimal=ndec) + + def test_var_dimensions(self): + # _var paths for complex number introduce additions on views that + # increase dimensions. Ensure this generalizes to higher dims + mat = np.stack([self.cmat]*3) + for axis in [0, 1, 2, -1, None]: + msqr = _mean(mat * mat.conj(), axis=axis) + mean = _mean(mat, axis=axis) + tgt = msqr - mean * mean.conjugate() + res = _var(mat, axis=axis) + assert_almost_equal(res, tgt) + + def test_var_complex_byteorder(self): + # Test that var fast-path does not cause failures for complex arrays + # with non-native byteorder + cmat = self.cmat.copy().astype('complex128') + cmat_swapped = cmat.astype(cmat.dtype.newbyteorder()) + assert_almost_equal(cmat.var(), cmat_swapped.var()) + def test_std_values(self): for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: |