diff options
Diffstat (limited to 'numpy/lib/utils.py')
-rw-r--r-- | numpy/lib/utils.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 5c364268c..61aa5e33b 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -8,6 +8,7 @@ import warnings from numpy.core.numerictypes import issubclass_, issubsctype, issubdtype from numpy.core import ndarray, ufunc, asarray +import numpy as np # getargspec and formatargspec were removed in Python 3.6 from numpy.compat import getargspec, formatargspec @@ -1113,4 +1114,49 @@ def safe_eval(source): import ast return ast.literal_eval(source) + + +def _median_nancheck(data, result, axis, out): + """ + Utility function to check median result from data for NaN values at the end + and return NaN in that case. Input result can also be a MaskedArray. + + Parameters + ---------- + data : array + Input data to median function + result : Array or MaskedArray + Result of median function + axis : {int, sequence of int, None}, optional + Axis or axes along which the median was computed. + out : ndarray, optional + Output array in which to place the result. + Returns + ------- + median : scalar or ndarray + Median or NaN in axes which contained NaN in the input. + """ + if data.size == 0: + return result + data = np.rollaxis(data, axis, data.ndim) + n = np.isnan(data[..., -1]) + # masked NaN values are ok + if np.ma.isMaskedArray(n): + n = n.filled(False) + if result.ndim == 0: + if n == True: + warnings.warn("Invalid value encountered in median", + RuntimeWarning, stacklevel=3) + if out is not None: + out[...] = data.dtype.type(np.nan) + result = out + else: + result = data.dtype.type(np.nan) + elif np.count_nonzero(n.ravel()) > 0: + warnings.warn("Invalid value encountered in median for" + + " %d results" % np.count_nonzero(n.ravel()), + RuntimeWarning, stacklevel=3) + result[n] = np.nan + return result + #----------------------------------------------------------------------------- |