summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/ma/extras.py12
-rw-r--r--numpy/ma/tests/test_extras.py42
2 files changed, 46 insertions, 8 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 781b25449..cbf7b6cdb 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -626,10 +626,14 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
fill_value = 1e+20)
"""
- if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0:
- return masked_array(np.median(getdata(a, subok=True), axis=axis,
- out=out, overwrite_input=overwrite_input,
- keepdims=keepdims), copy=False)
+ if not hasattr(a, 'mask'):
+ m = np.median(getdata(a, subok=True), axis=axis,
+ out=out, overwrite_input=overwrite_input,
+ keepdims=keepdims)
+ if isinstance(m, np.ndarray) and 1 <= m.ndim:
+ return masked_array(m, copy=False)
+ else:
+ return m
r, k = _ureduce(a, func=_median, axis=axis, out=out,
overwrite_input=overwrite_input)
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index b2c053fbd..bb59aad96 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -659,17 +659,48 @@ class TestMedian(TestCase):
assert_equal(r, np.inf)
def test_non_masked(self):
- assert_equal(np.ma.median(np.arange(9)), 4.)
- assert_equal(np.ma.median(range(9)), 4)
-
- def test_masked_1d(self):
+ x = np.arange(9)
+ assert_equal(np.ma.median(x), 4.)
+ assert_(type(np.ma.median(x)) is not MaskedArray)
+ x = range(9)
+ assert_equal(np.ma.median(x), 4.)
+ assert_(type(np.ma.median(x)) is not MaskedArray)
+ x = 5
+ assert_equal(np.ma.median(x), 5.)
+ assert_(type(np.ma.median(x)) is not MaskedArray)
+
+ def test_docstring_examples(self):
"test the examples given in the docstring of ma.median"
x = array(np.arange(8), mask=[0]*4 + [1]*4)
assert_equal(np.ma.median(x), 1.5)
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+ assert_(type(np.ma.median(x)) is not MaskedArray)
x = array(np.arange(10).reshape(2, 5), mask=[0]*6 + [1]*4)
assert_equal(np.ma.median(x), 2.5)
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+ assert_(type(np.ma.median(x)) is not MaskedArray)
+ ma_x = np.ma.median(x, axis=-1, overwrite_input=True)
+ assert_equal(ma_x, [2., 5.])
+ assert_equal(ma_x.shape, (2,), "shape mismatch")
+ assert_(type(ma_x) is MaskedArray)
+
+ def test_masked_1d(self):
+ x = array(np.arange(5), mask=True)
+ assert_equal(np.ma.median(x), np.ma.masked)
+ assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+ assert_(type(np.ma.median(x)) is np.ma.core.MaskedConstant)
+ x = array(np.arange(5), mask=False)
+ assert_equal(np.ma.median(x), 2.)
+ assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+ assert_(type(np.ma.median(x)) is not MaskedArray)
+ x = array(np.arange(5), mask=[0,1,0,0,0])
+ assert_equal(np.ma.median(x), 2.5)
+ assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+ assert_(type(np.ma.median(x)) is not MaskedArray)
+ x = array(np.arange(5), mask=[0,1,1,1,1])
+ assert_equal(np.ma.median(x), 0.)
+ assert_equal(np.ma.median(x).shape, (), "shape mismatch")
+ assert_(type(np.ma.median(x)) is not MaskedArray)
def test_1d_shape_consistency(self):
assert_equal(np.ma.median(array([1,2,3],mask=[0,0,0])).shape,
@@ -696,8 +727,11 @@ class TestMedian(TestCase):
x = masked_array(np.arange(30).reshape(10, 3))
x[:3] = x[-3:] = masked
assert_equal(median(x), 14.5)
+ assert_(type(np.ma.median(x)) is not MaskedArray)
assert_equal(median(x, axis=0), [13.5, 14.5, 15.5])
+ assert_(type(np.ma.median(x, axis=0)) is MaskedArray)
assert_equal(median(x, axis=1), [0, 0, 0, 10, 13, 16, 19, 0, 0, 0])
+ assert_(type(np.ma.median(x, axis=1)) is MaskedArray)
assert_equal(median(x, axis=1).mask, [1, 1, 1, 0, 0, 0, 0, 1, 1, 1])
def test_3d(self):