diff options
| -rw-r--r-- | numpy/lib/function_base.py | 15 | ||||
| -rw-r--r-- | numpy/lib/utils.py | 12 | ||||
| -rw-r--r-- | numpy/ma/extras.py | 4 |
3 files changed, 15 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index d875a00ae..e33d55056 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -3714,16 +3714,15 @@ def _median(a, axis=None, out=None, overwrite_input=False): indexer[axis] = slice(index-1, index+1) indexer = tuple(indexer) + # Use mean in both odd and even case to coerce data type, + # using out array if needed. + rout = mean(part[indexer], axis=axis, out=out) # Check if the array contains any nan's if np.issubdtype(a.dtype, np.inexact) and sz > 0: - # warn and return nans like mean would - rout = mean(part[indexer], axis=axis, out=out) - return np.lib.utils._median_nancheck(part, rout, axis, out) - else: - # if there are no nans - # Use mean in odd and even case to coerce data type - # and check, use out array. - return mean(part[indexer], axis=axis, out=out) + # If nans are possible, warn and replace by nans like mean would. + rout = np.lib.utils._median_nancheck(part, rout, axis) + + return rout def _percentile_dispatcher(a, q, axis=None, out=None, overwrite_input=None, diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 931669fc1..cd5bb51e7 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -1002,7 +1002,7 @@ def safe_eval(source): return ast.literal_eval(source) -def _median_nancheck(data, result, axis, out): +def _median_nancheck(data, result, axis): """ Utility function to check median result from data for NaN values at the end and return NaN in that case. Input result can also be a MaskedArray. @@ -1012,16 +1012,16 @@ def _median_nancheck(data, result, axis, out): data : array Input data to median function result : Array or MaskedArray - Result of median function + Result of median function. axis : int Axis along which the median was computed. - out : ndarray, optional - Output array in which to place the result. Returns ------- - median : scalar or ndarray - Median or NaN in axes which contained NaN in the input. + result : scalar or ndarray + Median or NaN in axes which contained NaN in the input. If the input + was an array, NaN will be inserted in-place. If a scalar, either the + input itself or a scalar NaN. """ if data.size == 0: return result diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 73abfc296..38bf1f0e8 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -750,7 +750,7 @@ def _median(a, axis=None, out=None, overwrite_input=False): s = mid.sum(out=out) if not odd: s = np.true_divide(s, 2., casting='safe', out=out) - s = np.lib.utils._median_nancheck(asorted, s, axis, out) + s = np.lib.utils._median_nancheck(asorted, s, axis) else: s = mid.mean(out=out) @@ -790,7 +790,7 @@ def _median(a, axis=None, out=None, overwrite_input=False): s = np.ma.sum(low_high, axis=axis, out=out) np.true_divide(s.data, 2., casting='unsafe', out=s.data) - s = np.lib.utils._median_nancheck(asorted, s, axis, out) + s = np.lib.utils._median_nancheck(asorted, s, axis) else: s = np.ma.mean(low_high, axis=axis, out=out) |
