summaryrefslogtreecommitdiff
path: root/numpy/lib/nanfunctions.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/nanfunctions.py')
-rw-r--r--numpy/lib/nanfunctions.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 478e7cf7e..7120760b5 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
See nanmedian for parameter usage
"""
- if axis is None:
+ if axis is None or a.ndim == 1:
part = a.ravel()
if out is None:
return _nanmedian1d(part, overwrite_input)
@@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
out[...] = _nanmedian1d(part, overwrite_input)
return out
else:
+ # for small medians use sort + indexing which is still faster than
+ # apply_along_axis
+ if a.shape[axis] < 400:
+ return _nanmedian_small(a, axis, out, overwrite_input)
result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input)
if out is not None:
out[...] = result
return result
+def _nanmedian_small(a, axis=None, out=None, overwrite_input=False):
+ """
+ sort + indexing median, faster for small medians along multiple dimensions
+ due to the high overhead of apply_along_axis
+ see nanmedian for parameter usage
+ """
+ a = np.ma.masked_array(a, np.isnan(a))
+ m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input)
+ for i in range(np.count_nonzero(m.mask.ravel())):
+ warnings.warn("All-NaN slice encountered", RuntimeWarning)
+ if out is not None:
+ out[...] = m.filled(np.nan)
+ return out
+ return m.filled(np.nan)
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""