summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/lib/function_base.py15
-rw-r--r--numpy/lib/utils.py12
-rw-r--r--numpy/ma/extras.py4
3 files changed, 15 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index d875a00ae..e33d55056 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3714,16 +3714,15 @@ def _median(a, axis=None, out=None, overwrite_input=False):
indexer[axis] = slice(index-1, index+1)
indexer = tuple(indexer)
+ # Use mean in both odd and even case to coerce data type,
+ # using out array if needed.
+ rout = mean(part[indexer], axis=axis, out=out)
# Check if the array contains any nan's
if np.issubdtype(a.dtype, np.inexact) and sz > 0:
- # warn and return nans like mean would
- rout = mean(part[indexer], axis=axis, out=out)
- return np.lib.utils._median_nancheck(part, rout, axis, out)
- else:
- # if there are no nans
- # Use mean in odd and even case to coerce data type
- # and check, use out array.
- return mean(part[indexer], axis=axis, out=out)
+ # If nans are possible, warn and replace by nans like mean would.
+ rout = np.lib.utils._median_nancheck(part, rout, axis)
+
+ return rout
def _percentile_dispatcher(a, q, axis=None, out=None, overwrite_input=None,
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 931669fc1..cd5bb51e7 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -1002,7 +1002,7 @@ def safe_eval(source):
return ast.literal_eval(source)
-def _median_nancheck(data, result, axis, out):
+def _median_nancheck(data, result, axis):
"""
Utility function to check median result from data for NaN values at the end
and return NaN in that case. Input result can also be a MaskedArray.
@@ -1012,16 +1012,16 @@ def _median_nancheck(data, result, axis, out):
data : array
Input data to median function
result : Array or MaskedArray
- Result of median function
+ Result of median function.
axis : int
Axis along which the median was computed.
- out : ndarray, optional
- Output array in which to place the result.
Returns
-------
- median : scalar or ndarray
- Median or NaN in axes which contained NaN in the input.
+ result : scalar or ndarray
+ Median or NaN in axes which contained NaN in the input. If the input
+ was an array, NaN will be inserted in-place. If a scalar, either the
+ input itself or a scalar NaN.
"""
if data.size == 0:
return result
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 73abfc296..38bf1f0e8 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -750,7 +750,7 @@ def _median(a, axis=None, out=None, overwrite_input=False):
s = mid.sum(out=out)
if not odd:
s = np.true_divide(s, 2., casting='safe', out=out)
- s = np.lib.utils._median_nancheck(asorted, s, axis, out)
+ s = np.lib.utils._median_nancheck(asorted, s, axis)
else:
s = mid.mean(out=out)
@@ -790,7 +790,7 @@ def _median(a, axis=None, out=None, overwrite_input=False):
s = np.ma.sum(low_high, axis=axis, out=out)
np.true_divide(s.data, 2., casting='unsafe', out=s.data)
- s = np.lib.utils._median_nancheck(asorted, s, axis, out)
+ s = np.lib.utils._median_nancheck(asorted, s, axis)
else:
s = np.ma.mean(low_high, axis=axis, out=out)