summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-05-30 19:06:24 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-06-02 22:55:09 +0200
commit0ae36289aa5104fce4e40c63ba46e19365f33b5d (patch)
tree12b29bf957f161e70ed8ff316c2519124bf89825 /numpy/ma/extras.py
parent39fbc1b1e7aee9220c62d8aeaa47eb885a4fde96 (diff)
downloadnumpy-0ae36289aa5104fce4e40c63ba46e19365f33b5d.tar.gz
ENH: rewrite ma.median to improve poor performance for multiple dimensions
many masked median along a small dimension is extremely slow due to the usage of apply_along_axis which iterates fully in python. The unmasked median is about 1000x faster. Work around this issue by using indexing to select the median element instead of apply_along_axis. Further improvements are possible, e.g. using the current np.nanmedian approach for masked medians along large dimensions so partition is used instead of sort or to extend partition to allow broadcasting over multiple elements. Closes gh-4683.
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py41
1 files changed, 25 insertions, 16 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index f53d9c7e5..82a61a67c 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -668,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False):
fill_value = 1e+20)
"""
- def _median1D(data):
- counts = filled(count(data), 0)
- (idx, rmd) = divmod(counts, 2)
- if rmd:
- choice = slice(idx, idx + 1)
- else:
- choice = slice(idx - 1, idx + 1)
- return data[choice].mean(0)
- #
+ if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0:
+ return masked_array(np.median(a, axis=axis, out=out,
+ overwrite_input=overwrite_input), copy=False)
if overwrite_input:
if axis is None:
asorted = a.ravel()
@@ -687,14 +681,29 @@ def median(a, axis=None, out=None, overwrite_input=False):
else:
asorted = sort(a, axis=axis)
if axis is None:
- result = _median1D(asorted)
+ axis = 0
+ elif axis < 0:
+ axis += a.ndim
+
+ counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
+ h = counts // 2
+ # create indexing mesh grid for all but reduced axis
+ axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
+ if i != axis]
+ ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
+ # insert indices of low and high median
+ ind.insert(axis, h - 1)
+ low = asorted[ind]
+ ind[axis] = h
+ high = asorted[ind]
+ # duplicate high if odd number of elements so mean does nothing
+ odd = counts % 2 == 1
+ if asorted.ndim == 1:
+ if odd:
+ low = high
else:
- result = apply_along_axis(_median1D, axis, asorted)
- if out is not None:
- out = result
- return result
-
-
+ low[odd] = high[odd]
+ return np.ma.mean([low, high], axis=0, out=out)
#..............................................................................