summaryrefslogtreecommitdiff
path: root/numpy/ma/extras.py
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2016-12-06 00:37:53 +0100
committerJulian Taylor <jtaylor.debian@googlemail.com>2016-12-12 23:19:06 +0100
commitff4758ff8432077c06cb580b97c6080f9b312c5a (patch)
treea40b77c06a1446f053f405a31bf0bcff91d3d8eb /numpy/ma/extras.py
parente00b9587052248486a9bf66c2ce95638c0d9817f (diff)
downloadnumpy-ff4758ff8432077c06cb580b97c6080f9b312c5a.tar.gz
BUG: handle unmasked NaN in ma.median like normal median
This requires to base masked median on sort(endwith=False) as we need to distinguish Inf and NaN. Using Inf as filler element of the sort does not work as then the mask is not guaranteed to be at the end. Closes gh-8340 Also fixed 1d ma.median not handling np.inf correctly, the nd variant was ok.
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r--numpy/ma/extras.py52
1 files changed, 42 insertions, 10 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index e4ff8ef2d..dadf032e0 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -699,15 +699,21 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
return r
def _median(a, axis=None, out=None, overwrite_input=False):
+ # when an unmasked NaN is present return it, so we need to sort the NaN
+ # values behind the mask
+ if np.issubdtype(a.dtype, np.inexact):
+ fill_value = np.inf
+ else:
+ fill_value = None
if overwrite_input:
if axis is None:
asorted = a.ravel()
- asorted.sort()
+ asorted.sort(fill_value=fill_value)
else:
- a.sort(axis=axis)
+ a.sort(axis=axis, fill_value=fill_value)
asorted = a
else:
- asorted = sort(a, axis=axis)
+ asorted = sort(a, axis=axis, fill_value=fill_value)
if axis is None:
axis = 0
@@ -715,8 +721,23 @@ def _median(a, axis=None, out=None, overwrite_input=False):
axis += asorted.ndim
if asorted.ndim == 1:
+ counts = count(asorted)
idx, odd = divmod(count(asorted), 2)
- return asorted[idx + odd - 1 : idx + 1].mean(out=out)
+ mid = asorted[idx + odd - 1 : idx + 1]
+ if np.issubdtype(asorted.dtype, np.inexact) and asorted.size > 0:
+ # avoid inf / x = masked
+ s = mid.sum(out=out)
+ np.true_divide(s, 2., casting='unsafe')
+ s = np.lib.utils._median_nancheck(asorted, s, axis, out)
+ else:
+ s = mid.mean(out=out)
+
+ # if result is masked either the input contained enough
+ # minimum_fill_value so that it would be the median or all values
+ # masked
+ if np.ma.is_masked(s) and not np.all(asorted.mask):
+ return np.ma.minimum_fill_value(asorted)
+ return s
counts = count(asorted, axis=axis)
h = counts // 2
@@ -727,24 +748,35 @@ def _median(a, axis=None, out=None, overwrite_input=False):
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
# insert indices of low and high median
- ind.insert(axis, np.maximum(0, h - 1))
+ ind.insert(axis, h - 1)
low = asorted[tuple(ind)]
- ind[axis] = h
+ ind[axis] = np.minimum(h, asorted.shape[axis] - 1)
high = asorted[tuple(ind)]
# duplicate high if odd number of elements so mean does nothing
odd = counts % 2 == 1
- if asorted.ndim > 1:
- np.copyto(low, high, where=odd)
- elif odd:
- low = high
+ np.copyto(low, high, where=odd)
+ # not necessary for scalar True/False masks
+ try:
+ np.copyto(low.mask, high.mask, where=odd)
+ except:
+ pass
if np.issubdtype(asorted.dtype, np.inexact):
# avoid inf / x = masked
s = np.ma.sum([low, high], axis=0, out=out)
np.true_divide(s.data, 2., casting='unsafe', out=s.data)
+
+ s = np.lib.utils._median_nancheck(asorted, s, axis, out)
else:
s = np.ma.mean([low, high], axis=0, out=out)
+
+ # if result is masked either the input contained enough minimum_fill_value
+ # so that it would be the median or all values masked
+ if np.ma.is_masked(s):
+ rep = (~np.all(asorted.mask, axis=axis)) & s.mask
+ s.data[rep] = np.ma.minimum_fill_value(asorted)
+ s.mask[rep] = False
return s