diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-05-31 13:20:43 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-06-02 23:47:24 +0200 |
commit | 99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1 (patch) | |
tree | 6bdb286e11684864ba3ff86a0d913a95863235b1 /numpy/lib/nanfunctions.py | |
parent | 0ae36289aa5104fce4e40c63ba46e19365f33b5d (diff) | |
download | numpy-99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1.tar.gz |
ENH: use masked median for small multidimensional nanmedians
Diffstat (limited to 'numpy/lib/nanfunctions.py')
-rw-r--r-- | numpy/lib/nanfunctions.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index 478e7cf7e..7120760b5 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): See nanmedian for parameter usage """ - if axis is None: + if axis is None or a.ndim == 1: part = a.ravel() if out is None: return _nanmedian1d(part, overwrite_input) @@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): out[...] = _nanmedian1d(part, overwrite_input) return out else: + # for small medians use sort + indexing which is still faster than + # apply_along_axis + if a.shape[axis] < 400: + return _nanmedian_small(a, axis, out, overwrite_input) result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input) if out is not None: out[...] = result return result +def _nanmedian_small(a, axis=None, out=None, overwrite_input=False): + """ + sort + indexing median, faster for small medians along multiple dimensions + due to the high overhead of apply_along_axis + see nanmedian for parameter usage + """ + a = np.ma.masked_array(a, np.isnan(a)) + m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input) + for i in range(np.count_nonzero(m.mask.ravel())): + warnings.warn("All-NaN slice encountered", RuntimeWarning) + if out is not None: + out[...] = m.filled(np.nan) + return out + return m.filled(np.nan) def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ |