summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index a993fd05d..1849df72b 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -671,8 +671,8 @@ def median(a, axis=None, out=None, overwrite_input=False):
"""
if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0:
- return masked_array(np.median(a, axis=axis, out=out,
- overwrite_input=overwrite_input), copy=False)
+ return masked_array(np.median(getattr(a, 'data', a), axis=axis,
+ out=out, overwrite_input=overwrite_input), copy=False)
if overwrite_input:
if axis is None:
asorted = a.ravel()
@@ -705,7 +705,14 @@ def median(a, axis=None, out=None, overwrite_input=False):
low = high
else:
low[odd] = high[odd]
- return np.ma.mean([low, high], axis=0, out=out)
+
+ if np.issubdtype(asorted.dtype, np.inexact):
+ # avoid inf / x = masked
+ s = np.ma.sum([low, high], axis=0, out=out)
+ np.true_divide(s.data, 2., casting='unsafe', out=s.data)
+ else:
+ s = np.ma.mean([low, high], axis=0, out=out)
+ return s
#..............................................................................