diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-05-30 19:06:24 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-06-02 22:55:09 +0200 |
commit | 0ae36289aa5104fce4e40c63ba46e19365f33b5d (patch) | |
tree | 12b29bf957f161e70ed8ff316c2519124bf89825 /numpy/ma/extras.py | |
parent | 39fbc1b1e7aee9220c62d8aeaa47eb885a4fde96 (diff) | |
download | numpy-0ae36289aa5104fce4e40c63ba46e19365f33b5d.tar.gz |
ENH: rewrite ma.median to improve poor performance for multiple dimensions
many masked median along a small dimension is extremely slow due to the
usage of apply_along_axis which iterates fully in python. The unmasked
median is about 1000x faster.
Work around this issue by using indexing to select the median element
instead of apply_along_axis.
Further improvements are possible, e.g. using the current np.nanmedian
approach for masked medians along large dimensions so partition is used
instead of sort or to extend partition to allow broadcasting over
multiple elements.
Closes gh-4683.
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r-- | numpy/ma/extras.py | 41 |
1 files changed, 25 insertions, 16 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index f53d9c7e5..82a61a67c 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -668,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False): fill_value = 1e+20) """ - def _median1D(data): - counts = filled(count(data), 0) - (idx, rmd) = divmod(counts, 2) - if rmd: - choice = slice(idx, idx + 1) - else: - choice = slice(idx - 1, idx + 1) - return data[choice].mean(0) - # + if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0: + return masked_array(np.median(a, axis=axis, out=out, + overwrite_input=overwrite_input), copy=False) if overwrite_input: if axis is None: asorted = a.ravel() @@ -687,14 +681,29 @@ def median(a, axis=None, out=None, overwrite_input=False): else: asorted = sort(a, axis=axis) if axis is None: - result = _median1D(asorted) + axis = 0 + elif axis < 0: + axis += a.ndim + + counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis) + h = counts // 2 + # create indexing mesh grid for all but reduced axis + axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape) + if i != axis] + ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij') + # insert indices of low and high median + ind.insert(axis, h - 1) + low = asorted[ind] + ind[axis] = h + high = asorted[ind] + # duplicate high if odd number of elements so mean does nothing + odd = counts % 2 == 1 + if asorted.ndim == 1: + if odd: + low = high else: - result = apply_along_axis(_median1D, axis, asorted) - if out is not None: - out = result - return result - - + low[odd] = high[odd] + return np.ma.mean([low, high], axis=0, out=out) #.............................................................................. |