summaryrefslogtreecommitdiff
path: root/numpy/lib/utils.py
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2016-12-06 00:37:53 +0100
committerJulian Taylor <jtaylor.debian@googlemail.com>2016-12-12 23:19:06 +0100
commitff4758ff8432077c06cb580b97c6080f9b312c5a (patch)
treea40b77c06a1446f053f405a31bf0bcff91d3d8eb /numpy/lib/utils.py
parente00b9587052248486a9bf66c2ce95638c0d9817f (diff)
downloadnumpy-ff4758ff8432077c06cb580b97c6080f9b312c5a.tar.gz
BUG: handle unmasked NaN in ma.median like normal median
This requires to base masked median on sort(endwith=False) as we need to distinguish Inf and NaN. Using Inf as filler element of the sort does not work as then the mask is not guaranteed to be at the end. Closes gh-8340 Also fixed 1d ma.median not handling np.inf correctly, the nd variant was ok.
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r--numpy/lib/utils.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 5c364268c..61aa5e33b 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -8,6 +8,7 @@ import warnings
from numpy.core.numerictypes import issubclass_, issubsctype, issubdtype
from numpy.core import ndarray, ufunc, asarray
+import numpy as np
# getargspec and formatargspec were removed in Python 3.6
from numpy.compat import getargspec, formatargspec
@@ -1113,4 +1114,49 @@ def safe_eval(source):
import ast
return ast.literal_eval(source)
+
+
+def _median_nancheck(data, result, axis, out):
+ """
+ Utility function to check median result from data for NaN values at the end
+ and return NaN in that case. Input result can also be a MaskedArray.
+
+ Parameters
+ ----------
+ data : array
+ Input data to median function
+ result : Array or MaskedArray
+ Result of median function
+ axis : {int, sequence of int, None}, optional
+ Axis or axes along which the median was computed.
+ out : ndarray, optional
+ Output array in which to place the result.
+ Returns
+ -------
+ median : scalar or ndarray
+ Median or NaN in axes which contained NaN in the input.
+ """
+ if data.size == 0:
+ return result
+ data = np.rollaxis(data, axis, data.ndim)
+ n = np.isnan(data[..., -1])
+ # masked NaN values are ok
+ if np.ma.isMaskedArray(n):
+ n = n.filled(False)
+ if result.ndim == 0:
+ if n == True:
+ warnings.warn("Invalid value encountered in median",
+ RuntimeWarning, stacklevel=3)
+ if out is not None:
+ out[...] = data.dtype.type(np.nan)
+ result = out
+ else:
+ result = data.dtype.type(np.nan)
+ elif np.count_nonzero(n.ravel()) > 0:
+ warnings.warn("Invalid value encountered in median for" +
+ " %d results" % np.count_nonzero(n.ravel()),
+ RuntimeWarning, stacklevel=3)
+ result[n] = np.nan
+ return result
+
#-----------------------------------------------------------------------------