diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2014-06-02 16:32:02 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2014-06-02 16:32:02 -0600 |
commit | 14cc7172d78fb7cabbddcf7805c634ec0555abca (patch) | |
tree | 17d220ae25a06f1591f1aa609ae4d87264c5c209 /numpy/lib/nanfunctions.py | |
parent | e6f43660b156438b0ad4f10b4c8503ba478c0cdd (diff) | |
parent | 99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1 (diff) | |
download | numpy-14cc7172d78fb7cabbddcf7805c634ec0555abca.tar.gz |
Merge pull request #4760 from juliantaylor/masked-median
ENH: rewrite ma.median to improve poor performance for multiple dimensions
Diffstat (limited to 'numpy/lib/nanfunctions.py')
-rw-r--r-- | numpy/lib/nanfunctions.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index 478e7cf7e..7120760b5 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): See nanmedian for parameter usage """ - if axis is None: + if axis is None or a.ndim == 1: part = a.ravel() if out is None: return _nanmedian1d(part, overwrite_input) @@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): out[...] = _nanmedian1d(part, overwrite_input) return out else: + # for small medians use sort + indexing which is still faster than + # apply_along_axis + if a.shape[axis] < 400: + return _nanmedian_small(a, axis, out, overwrite_input) result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input) if out is not None: out[...] = result return result +def _nanmedian_small(a, axis=None, out=None, overwrite_input=False): + """ + sort + indexing median, faster for small medians along multiple dimensions + due to the high overhead of apply_along_axis + see nanmedian for parameter usage + """ + a = np.ma.masked_array(a, np.isnan(a)) + m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input) + for i in range(np.count_nonzero(m.mask.ravel())): + warnings.warn("All-NaN slice encountered", RuntimeWarning) + if out is not None: + out[...] = m.filled(np.nan) + return out + return m.filled(np.nan) def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ |