summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-02-27 19:32:27 +0000
committerGitHub <noreply@github.com>2017-02-27 19:32:27 +0000
commit2dd9125bad6d5b78a2d10620f76b3906860a526d (patch)
treec0e72c2e61156ab8bf34f7961bdd85dc505ccd6d /numpy
parent533cef98154164928759c352a43a461ac9951c03 (diff)
parent05aa44d53f4f9528847a0c014fe4bda5caa5fd3d (diff)
downloadnumpy-2dd9125bad6d5b78a2d10620f76b3906860a526d.tar.gz
Merge pull request #8705 from juliantaylor/ma-median-empty
BUG: fix ma.median for empty ndarrays
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/extras.py7
-rw-r--r--numpy/ma/tests/test_extras.py6
2 files changed, 10 insertions, 3 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 7149b525b..697565251 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -717,6 +717,13 @@ def _median(a, axis=None, out=None, overwrite_input=False):
else:
axis = normalize_axis_index(axis, asorted.ndim)
+ if asorted.shape[axis] == 0:
+ # for empty axis integer indices fail so use slicing to get same result
+ # as median (which is mean of empty slice = nan)
+ indexer = [slice(None)] * asorted.ndim
+ indexer[axis] = slice(0, 0)
+ return np.ma.mean(asorted[indexer], axis=axis, out=out)
+
if asorted.ndim == 1:
counts = count(asorted)
idx, odd = divmod(count(asorted), 2)
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index fb68bd261..547a3fb81 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -1039,14 +1039,14 @@ class TestMedian(TestCase):
# axis 0 and 1
b = np.ma.masked_array(np.array([], dtype=float, ndmin=2))
- assert_equal(np.median(a, axis=0), b)
- assert_equal(np.median(a, axis=1), b)
+ assert_equal(np.ma.median(a, axis=0), b)
+ assert_equal(np.ma.median(a, axis=1), b)
# axis 2
b = np.ma.masked_array(np.array(np.nan, dtype=float, ndmin=2))
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
- assert_equal(np.median(a, axis=2), b)
+ assert_equal(np.ma.median(a, axis=2), b)
assert_(w[0].category is RuntimeWarning)
def test_object(self):