summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-06-02 16:32:02 -0600
committerCharles Harris <charlesr.harris@gmail.com>2014-06-02 16:32:02 -0600
commit14cc7172d78fb7cabbddcf7805c634ec0555abca (patch)
tree17d220ae25a06f1591f1aa609ae4d87264c5c209
parente6f43660b156438b0ad4f10b4c8503ba478c0cdd (diff)
parent99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1 (diff)
downloadnumpy-14cc7172d78fb7cabbddcf7805c634ec0555abca.tar.gz
Merge pull request #4760 from juliantaylor/masked-median
ENH: rewrite ma.median to improve poor performance for multiple dimensions
-rw-r--r--numpy/lib/nanfunctions.py20
-rw-r--r--numpy/lib/tests/test_nanfunctions.py18
-rw-r--r--numpy/ma/core.py11
-rw-r--r--numpy/ma/extras.py41
-rw-r--r--numpy/ma/tests/test_extras.py13
5 files changed, 80 insertions, 23 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index 478e7cf7e..7120760b5 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
See nanmedian for parameter usage
"""
- if axis is None:
+ if axis is None or a.ndim == 1:
part = a.ravel()
if out is None:
return _nanmedian1d(part, overwrite_input)
@@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
out[...] = _nanmedian1d(part, overwrite_input)
return out
else:
+ # for small medians use sort + indexing which is still faster than
+ # apply_along_axis
+ if a.shape[axis] < 400:
+ return _nanmedian_small(a, axis, out, overwrite_input)
result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input)
if out is not None:
out[...] = result
return result
+def _nanmedian_small(a, axis=None, out=None, overwrite_input=False):
+ """
+ sort + indexing median, faster for small medians along multiple dimensions
+ due to the high overhead of apply_along_axis
+ see nanmedian for parameter usage
+ """
+ a = np.ma.masked_array(a, np.isnan(a))
+ m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input)
+ for i in range(np.count_nonzero(m.mask.ravel())):
+ warnings.warn("All-NaN slice encountered", RuntimeWarning)
+ if out is not None:
+ out[...] = m.filled(np.nan)
+ return out
+ return m.filled(np.nan)
def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 3fcfca218..c5af61434 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -5,7 +5,7 @@ import warnings
import numpy as np
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_almost_equal,
- assert_raises
+ assert_raises, assert_array_equal
)
@@ -580,6 +580,22 @@ class TestNanFunctions_Median(TestCase):
assert_almost_equal(res, resout)
assert_almost_equal(res, tgt)
+ def test_small_large(self):
+ # test the small and large code paths, current cutoff 400 elements
+ for s in [5, 20, 51, 200, 1000]:
+ d = np.random.randn(4, s)
+ # Randomly set some elements to NaN:
+ w = np.random.randint(0, d.size, size=d.size // 5)
+ d.ravel()[w] = np.nan
+ d[:,0] = 1. # ensure at least one good value
+ # use normal median without nans to compare
+ tgt = []
+ for x in d:
+ nonan = np.compress(~np.isnan(x), x)
+ tgt.append(np.median(nonan, overwrite_input=True))
+
+ assert_array_equal(np.nanmedian(d, axis=-1), tgt)
+
def test_result_values(self):
tgt = [np.median(d) for d in _rdat]
res = np.nanmedian(_ndat, axis=1)
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index bb5c966ec..617f1921e 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -5082,12 +5082,12 @@ class MaskedArray(ndarray):
filler = maximum_fill_value(self)
else:
filler = fill_value
- idx = np.indices(self.shape)
+ idx = np.meshgrid(*[np.arange(x) for x in self.shape], sparse=True,
+ indexing='ij')
idx[axis] = self.filled(filler).argsort(axis=axis, kind=kind,
order=order)
- idx_l = idx.tolist()
- tmp_mask = self._mask[idx_l].flat
- tmp_data = self._data[idx_l].flat
+ tmp_mask = self._mask[idx].flat
+ tmp_data = self._data[idx].flat
self._data.flat = tmp_data
self._mask.flat = tmp_mask
return
@@ -6188,7 +6188,8 @@ def sort(a, axis= -1, kind='quicksort', order=None, endwith=True, fill_value=Non
else:
filler = fill_value
# return
- indx = np.indices(a.shape).tolist()
+ indx = np.meshgrid(*[np.arange(x) for x in a.shape], sparse=True,
+ indexing='ij')
indx[axis] = filled(a, filler).argsort(axis=axis, kind=kind, order=order)
return a[indx]
sort.__doc__ = MaskedArray.sort.__doc__
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index f53d9c7e5..82a61a67c 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -668,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False):
fill_value = 1e+20)
"""
- def _median1D(data):
- counts = filled(count(data), 0)
- (idx, rmd) = divmod(counts, 2)
- if rmd:
- choice = slice(idx, idx + 1)
- else:
- choice = slice(idx - 1, idx + 1)
- return data[choice].mean(0)
- #
+ if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0:
+ return masked_array(np.median(a, axis=axis, out=out,
+ overwrite_input=overwrite_input), copy=False)
if overwrite_input:
if axis is None:
asorted = a.ravel()
@@ -687,14 +681,29 @@ def median(a, axis=None, out=None, overwrite_input=False):
else:
asorted = sort(a, axis=axis)
if axis is None:
- result = _median1D(asorted)
+ axis = 0
+ elif axis < 0:
+ axis += a.ndim
+
+ counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
+ h = counts // 2
+ # create indexing mesh grid for all but reduced axis
+ axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
+ if i != axis]
+ ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
+ # insert indices of low and high median
+ ind.insert(axis, h - 1)
+ low = asorted[ind]
+ ind[axis] = h
+ high = asorted[ind]
+ # duplicate high if odd number of elements so mean does nothing
+ odd = counts % 2 == 1
+ if asorted.ndim == 1:
+ if odd:
+ low = high
else:
- result = apply_along_axis(_median1D, axis, asorted)
- if out is not None:
- out = result
- return result
-
-
+ low[odd] = high[odd]
+ return np.ma.mean([low, high], axis=0, out=out)
#..............................................................................
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index fa7503392..95f935c8b 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -531,6 +531,19 @@ class TestMedian(TestCase):
x[x % 5 == 0] = masked
assert_equal(median(x, 0), [[12, 10], [8, 9], [16, 17]])
+ def test_neg_axis(self):
+ x = masked_array(np.arange(30).reshape(10, 3))
+ x[:3] = x[-3:] = masked
+ assert_equal(median(x, axis=-1), median(x, axis=1))
+
+ def test_out(self):
+ x = masked_array(np.arange(30).reshape(10, 3))
+ x[:3] = x[-3:] = masked
+ out = masked_array(np.ones(10))
+ r = median(x, axis=1, out=out)
+ assert_equal(r, out)
+ assert_(type(r) == MaskedArray)
+
class TestCov(TestCase):