summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-05-22 11:53:39 -0600
committerGitHub <noreply@github.com>2021-05-22 11:53:39 -0600
commitd691f18e8ef8fd76d3ee257563d843f9ad5c7a71 (patch)
tree4917130d28579bb9f91b632d95c3d0d013f795cc
parentef6595517f451edc5c5710f048d1eaf6c82f4022 (diff)
parentae9314eff5d539122bf87800a1bc50a9f99762a8 (diff)
downloadnumpy-d691f18e8ef8fd76d3ee257563d843f9ad5c7a71.tar.gz
Merge pull request #19066 from BvB93/nanmedian
BUG: Fixed an issue wherein `nanmedian` could return an array with the wrong dtype
-rw-r--r--numpy/lib/nanfunctions.py20
-rw-r--r--numpy/lib/tests/test_nanfunctions.py50
2 files changed, 47 insertions, 23 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py
index a02ad779f..2c2c3435b 100644
--- a/numpy/lib/nanfunctions.py
+++ b/numpy/lib/nanfunctions.py
@@ -962,12 +962,16 @@ def _nanmedian1d(arr1d, overwrite_input=False):
Private function for rank 1 arrays. Compute the median ignoring NaNs.
See nanmedian for parameter usage
"""
- arr1d, overwrite_input = _remove_nan_1d(arr1d,
- overwrite_input=overwrite_input)
- if arr1d.size == 0:
- return np.nan
+ arr1d_parsed, overwrite_input = _remove_nan_1d(
+ arr1d, overwrite_input=overwrite_input,
+ )
- return np.median(arr1d, overwrite_input=overwrite_input)
+ if arr1d_parsed.size == 0:
+ # Ensure that a nan-esque scalar of the appropiate type (and unit)
+ # is returned for `timedelta64` and `complexfloating`
+ return arr1d[-1]
+
+ return np.median(arr1d_parsed, overwrite_input=overwrite_input)
def _nanmedian(a, axis=None, out=None, overwrite_input=False):
@@ -1008,10 +1012,12 @@ def _nanmedian_small(a, axis=None, out=None, overwrite_input=False):
for i in range(np.count_nonzero(m.mask.ravel())):
warnings.warn("All-NaN slice encountered", RuntimeWarning,
stacklevel=4)
+
+ fill_value = np.timedelta64("NaT") if m.dtype.kind == "m" else np.nan
if out is not None:
- out[...] = m.filled(np.nan)
+ out[...] = m.filled(fill_value)
return out
- return m.filled(np.nan)
+ return m.filled(fill_value)
def _nanmedian_dispatcher(
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index e0f723a3c..1f1f5601b 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -588,6 +588,15 @@ class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin):
assert_(len(w) == 0)
+_TIME_UNITS = (
+ "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as"
+)
+
+# All `inexact` + `timdelta64` type codes
+_TYPE_CODES = list(np.typecodes["AllFloat"])
+_TYPE_CODES += [f"m8[{unit}]" for unit in _TIME_UNITS]
+
+
class TestNanFunctions_Median:
def test_mutation(self):
@@ -662,23 +671,32 @@ class TestNanFunctions_Median:
res = np.nanmedian(_ndat, axis=1)
assert_almost_equal(res, tgt)
- def test_allnans(self):
- mat = np.array([np.nan]*9).reshape(3, 3)
- for axis in [None, 0, 1]:
- with suppress_warnings() as sup:
- sup.record(RuntimeWarning)
+ @pytest.mark.parametrize("axis", [None, 0, 1])
+ @pytest.mark.parametrize("dtype", _TYPE_CODES)
+ def test_allnans(self, dtype, axis):
+ mat = np.full((3, 3), np.nan).astype(dtype)
+ with suppress_warnings() as sup:
+ sup.record(RuntimeWarning)
- assert_(np.isnan(np.nanmedian(mat, axis=axis)).all())
- if axis is None:
- assert_(len(sup.log) == 1)
- else:
- assert_(len(sup.log) == 3)
- # Check scalar
- assert_(np.isnan(np.nanmedian(np.nan)))
- if axis is None:
- assert_(len(sup.log) == 2)
- else:
- assert_(len(sup.log) == 4)
+ output = np.nanmedian(mat, axis=axis)
+ assert output.dtype == mat.dtype
+ assert np.isnan(output).all()
+
+ if axis is None:
+ assert_(len(sup.log) == 1)
+ else:
+ assert_(len(sup.log) == 3)
+
+ # Check scalar
+ scalar = np.array(np.nan).astype(dtype)[()]
+ output_scalar = np.nanmedian(scalar)
+ assert output_scalar.dtype == scalar.dtype
+ assert np.isnan(output_scalar)
+
+ if axis is None:
+ assert_(len(sup.log) == 2)
+ else:
+ assert_(len(sup.log) == 4)
def test_empty(self):
mat = np.zeros((0, 3))