summaryrefslogtreecommitdiff
path: root/numpy/lib/nanfunctions.py
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-05-31 13:20:43 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-06-02 23:47:24 +0200
commit99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1 (patch)
tree6bdb286e11684864ba3ff86a0d913a95863235b1 /numpy/lib/nanfunctions.py
parent0ae36289aa5104fce4e40c63ba46e19365f33b5d (diff)
downloadnumpy-99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1.tar.gz
ENH: use masked median for small multidimensional nanmedians
Diffstat (limited to 'numpy/lib/nanfunctions.py')
-rw-r--r--numpy/lib/nanfunctions.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 478e7cf7e..7120760b5 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
See nanmedian for parameter usage
"""
- if axis is None:
+ if axis is None or a.ndim == 1:
part = a.ravel()
if out is None:
return _nanmedian1d(part, overwrite_input)
@@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
out[...] = _nanmedian1d(part, overwrite_input)
return out
else:
+ # for small medians use sort + indexing which is still faster than
+ # apply_along_axis
+ if a.shape[axis] < 400:
+ return _nanmedian_small(a, axis, out, overwrite_input)
result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input)
if out is not None:
out[...] = result
return result
+def _nanmedian_small(a, axis=None, out=None, overwrite_input=False):
+ """
+ sort + indexing median, faster for small medians along multiple dimensions
+ due to the high overhead of apply_along_axis
+ see nanmedian for parameter usage
+ """
+ a = np.ma.masked_array(a, np.isnan(a))
+ m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input)
+ for i in range(np.count_nonzero(m.mask.ravel())):
+ warnings.warn("All-NaN slice encountered", RuntimeWarning)
+ if out is not None:
+ out[...] = m.filled(np.nan)
+ return out
+ return m.filled(np.nan)
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""