summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-05-31 13:20:43 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-06-02 23:47:24 +0200
commit99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1 (patch)
tree6bdb286e11684864ba3ff86a0d913a95863235b1
parent0ae36289aa5104fce4e40c63ba46e19365f33b5d (diff)
downloadnumpy-99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1.tar.gz
ENH: use masked median for small multidimensional nanmedians
-rw-r--r--numpy/lib/nanfunctions.py20
-rw-r--r--numpy/lib/tests/test_nanfunctions.py18
2 files changed, 36 insertions, 2 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 478e7cf7e..7120760b5 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
See nanmedian for parameter usage
"""
- if axis is None:
+ if axis is None or a.ndim == 1:
part = a.ravel()
if out is None:
return _nanmedian1d(part, overwrite_input)
@@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
out[...] = _nanmedian1d(part, overwrite_input)
return out
else:
+ # for small medians use sort + indexing which is still faster than
+ # apply_along_axis
+ if a.shape[axis] < 400:
+ return _nanmedian_small(a, axis, out, overwrite_input)
result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input)
if out is not None:
out[...] = result
return result
+def _nanmedian_small(a, axis=None, out=None, overwrite_input=False):
+ """
+ sort + indexing median, faster for small medians along multiple dimensions
+ due to the high overhead of apply_along_axis
+ see nanmedian for parameter usage
+ """
+ a = np.ma.masked_array(a, np.isnan(a))
+ m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input)
+ for i in range(np.count_nonzero(m.mask.ravel())):
+ warnings.warn("All-NaN slice encountered", RuntimeWarning)
+ if out is not None:
+ out[...] = m.filled(np.nan)
+ return out
+ return m.filled(np.nan)
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 3fcfca218..c5af61434 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -5,7 +5,7 @@ import warnings
import numpy as np
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_almost_equal,
- assert_raises
+ assert_raises, assert_array_equal
)
@@ -580,6 +580,22 @@ class TestNanFunctions_Median(TestCase):
assert_almost_equal(res, resout)
assert_almost_equal(res, tgt)
+ def test_small_large(self):
+ # test the small and large code paths, current cutoff 400 elements
+ for s in [5, 20, 51, 200, 1000]:
+ d = np.random.randn(4, s)
+ # Randomly set some elements to NaN:
+ w = np.random.randint(0, d.size, size=d.size // 5)
+ d.ravel()[w] = np.nan
+ d[:,0] = 1. # ensure at least one good value
+ # use normal median without nans to compare
+ tgt = []
+ for x in d:
+ nonan = np.compress(~np.isnan(x), x)
+ tgt.append(np.median(nonan, overwrite_input=True))
+
+ assert_array_equal(np.nanmedian(d, axis=-1), tgt)
+
def test_result_values(self):
tgt = [np.median(d) for d in _rdat]
res = np.nanmedian(_ndat, axis=1)