diff options
author | ahaldane <ealloc@gmail.com> | 2016-05-21 20:46:21 -0400 |
---|---|---|
committer | ahaldane <ealloc@gmail.com> | 2016-05-21 20:46:21 -0400 |
commit | 6ce33a18a62fbc42af2c7a325130f705740c2a1e (patch) | |
tree | 8ffed1b54140e1e6cadeace9d63ad0e9a475ae03 | |
parent | 2423048255f74aeefc22f707fcccfb8d21723ce3 (diff) | |
parent | a4cc361003d3b2b241b826372d9691187b47f86f (diff) | |
download | numpy-6ce33a18a62fbc42af2c7a325130f705740c2a1e.tar.gz |
Merge pull request #7635 from AmitAronovitch/ma_median_fix
BUG: ma.median alternate fix for #7592
-rw-r--r-- | numpy/ma/extras.py | 10 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 13 |
2 files changed, 20 insertions, 3 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index ea8f9e49a..781b25449 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -27,7 +27,7 @@ import warnings from . import core as ma from .core import ( - MaskedArray, MAError, add, array, asarray, concatenate, filled, + MaskedArray, MAError, add, array, asarray, concatenate, filled, count, getmask, getmaskarray, make_mask_descr, masked, masked_array, mask_or, nomask, ones, sort, zeros, getdata, get_masked_subclass, dot, mask_rowcols @@ -653,6 +653,10 @@ def _median(a, axis=None, out=None, overwrite_input=False): elif axis < 0: axis += a.ndim + if asorted.ndim == 1: + idx, odd = divmod(count(asorted), 2) + return asorted[idx - (not odd) : idx + 1].mean() + counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis) h = counts // 2 # create indexing mesh grid for all but reduced axis @@ -661,10 +665,10 @@ def _median(a, axis=None, out=None, overwrite_input=False): ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij') # insert indices of low and high median ind.insert(axis, h - 1) - low = asorted[ind] + low = asorted[tuple(ind)] low._sharedmask = False ind[axis] = h - high = asorted[ind] + high = asorted[tuple(ind)] # duplicate high if odd number of elements so mean does nothing odd = counts % 2 == 1 if asorted.ndim == 1: diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index 52669cb90..b2c053fbd 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -662,6 +662,19 @@ class TestMedian(TestCase): assert_equal(np.ma.median(np.arange(9)), 4.) assert_equal(np.ma.median(range(9)), 4) + def test_masked_1d(self): + "test the examples given in the docstring of ma.median" + x = array(np.arange(8), mask=[0]*4 + [1]*4) + assert_equal(np.ma.median(x), 1.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + x = array(np.arange(10).reshape(2, 5), mask=[0]*6 + [1]*4) + assert_equal(np.ma.median(x), 2.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + + def test_1d_shape_consistency(self): + assert_equal(np.ma.median(array([1,2,3],mask=[0,0,0])).shape, + np.ma.median(array([1,2,3],mask=[0,1,0])).shape ) + def test_2d(self): # Tests median w/ 2D (n, p) = (101, 30) |