summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index a05ea476b..0d5c73e7e 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -708,34 +708,37 @@ def _median(a, axis=None, out=None, overwrite_input=False):
asorted = a
else:
asorted = sort(a, axis=axis)
+
if axis is None:
axis = 0
elif axis < 0:
- axis += a.ndim
+ axis += asorted.ndim
if asorted.ndim == 1:
idx, odd = divmod(count(asorted), 2)
- return asorted[idx - (not odd) : idx + 1].mean()
+ return asorted[idx + odd - 1 : idx + 1].mean(out=out)
- counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
+ counts = count(asorted, axis=axis)
h = counts // 2
+
# create indexing mesh grid for all but reduced axis
axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
if i != axis]
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
+
# insert indices of low and high median
ind.insert(axis, h - 1)
low = asorted[tuple(ind)]
low._sharedmask = False
ind[axis] = h
high = asorted[tuple(ind)]
+
# duplicate high if odd number of elements so mean does nothing
odd = counts % 2 == 1
- if asorted.ndim == 1:
- if odd:
- low = high
- else:
- low[odd] = high[odd]
+ if asorted.ndim > 1:
+ np.copyto(low, high, where=odd)
+ elif odd:
+ low = high
if np.issubdtype(asorted.dtype, np.inexact):
# avoid inf / x = masked