summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorahaldane <ealloc@gmail.com>2016-05-21 20:46:21 -0400
committerahaldane <ealloc@gmail.com>2016-05-21 20:46:21 -0400
commit6ce33a18a62fbc42af2c7a325130f705740c2a1e (patch)
tree8ffed1b54140e1e6cadeace9d63ad0e9a475ae03
parent2423048255f74aeefc22f707fcccfb8d21723ce3 (diff)
parenta4cc361003d3b2b241b826372d9691187b47f86f (diff)
downloadnumpy-6ce33a18a62fbc42af2c7a325130f705740c2a1e.tar.gz
Merge pull request #7635 from AmitAronovitch/ma_median_fix
BUG: ma.median alternate fix for #7592
-rw-r--r--numpy/ma/extras.py10
-rw-r--r--numpy/ma/tests/test_extras.py13
2 files changed, 20 insertions, 3 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index ea8f9e49a..781b25449 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -27,7 +27,7 @@ import warnings
from . import core as ma
from .core import (
- MaskedArray, MAError, add, array, asarray, concatenate, filled,
+ MaskedArray, MAError, add, array, asarray, concatenate, filled, count,
getmask, getmaskarray, make_mask_descr, masked, masked_array, mask_or,
nomask, ones, sort, zeros, getdata, get_masked_subclass, dot,
mask_rowcols
@@ -653,6 +653,10 @@ def _median(a, axis=None, out=None, overwrite_input=False):
elif axis < 0:
axis += a.ndim
+ if asorted.ndim == 1:
+ idx, odd = divmod(count(asorted), 2)
+ return asorted[idx - (not odd) : idx + 1].mean()
+
counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
h = counts // 2
# create indexing mesh grid for all but reduced axis
@@ -661,10 +665,10 @@ def _median(a, axis=None, out=None, overwrite_input=False):
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
# insert indices of low and high median
ind.insert(axis, h - 1)
- low = asorted[ind]
+ low = asorted[tuple(ind)]
low._sharedmask = False
ind[axis] = h
- high = asorted[ind]
+ high = asorted[tuple(ind)]
# duplicate high if odd number of elements so mean does nothing
odd = counts % 2 == 1
if asorted.ndim == 1:
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 52669cb90..b2c053fbd 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -662,6 +662,19 @@ class TestMedian(TestCase):
assert_equal(np.ma.median(np.arange(9)), 4.)
assert_equal(np.ma.median(range(9)), 4)
+ def test_masked_1d(self):
+ "test the examples given in the docstring of ma.median"
+ x = array(np.arange(8), mask=[0]*4 + [1]*4)
+ assert_equal(np.ma.median(x), 1.5)
+ assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+ x = array(np.arange(10).reshape(2, 5), mask=[0]*6 + [1]*4)
+ assert_equal(np.ma.median(x), 2.5)
+ assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+
+ def test_1d_shape_consistency(self):
+ assert_equal(np.ma.median(array([1,2,3],mask=[0,0,0])).shape,
+ np.ma.median(array([1,2,3],mask=[0,1,0])).shape )
+
def test_2d(self):
# Tests median w/ 2D
(n, p) = (101, 30)