diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-05-22 11:53:39 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-22 11:53:39 -0600 |
commit | d691f18e8ef8fd76d3ee257563d843f9ad5c7a71 (patch) | |
tree | 4917130d28579bb9f91b632d95c3d0d013f795cc | |
parent | ef6595517f451edc5c5710f048d1eaf6c82f4022 (diff) | |
parent | ae9314eff5d539122bf87800a1bc50a9f99762a8 (diff) | |
download | numpy-d691f18e8ef8fd76d3ee257563d843f9ad5c7a71.tar.gz |
Merge pull request #19066 from BvB93/nanmedian
BUG: Fixed an issue wherein `nanmedian` could return an array with the wrong dtype
-rw-r--r-- | numpy/lib/nanfunctions.py | 20 | ||||
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 50 |
2 files changed, 47 insertions, 23 deletions
diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index a02ad779f..2c2c3435b 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -962,12 +962,16 @@ def _nanmedian1d(arr1d, overwrite_input=False): Private function for rank 1 arrays. Compute the median ignoring NaNs. See nanmedian for parameter usage """ - arr1d, overwrite_input = _remove_nan_1d(arr1d, - overwrite_input=overwrite_input) - if arr1d.size == 0: - return np.nan + arr1d_parsed, overwrite_input = _remove_nan_1d( + arr1d, overwrite_input=overwrite_input, + ) - return np.median(arr1d, overwrite_input=overwrite_input) + if arr1d_parsed.size == 0: + # Ensure that a nan-esque scalar of the appropiate type (and unit) + # is returned for `timedelta64` and `complexfloating` + return arr1d[-1] + + return np.median(arr1d_parsed, overwrite_input=overwrite_input) def _nanmedian(a, axis=None, out=None, overwrite_input=False): @@ -1008,10 +1012,12 @@ def _nanmedian_small(a, axis=None, out=None, overwrite_input=False): for i in range(np.count_nonzero(m.mask.ravel())): warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=4) + + fill_value = np.timedelta64("NaT") if m.dtype.kind == "m" else np.nan if out is not None: - out[...] = m.filled(np.nan) + out[...] = m.filled(fill_value) return out - return m.filled(np.nan) + return m.filled(fill_value) def _nanmedian_dispatcher( diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index e0f723a3c..1f1f5601b 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -588,6 +588,15 @@ class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin): assert_(len(w) == 0) +_TIME_UNITS = ( + "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as" +) + +# All `inexact` + `timdelta64` type codes +_TYPE_CODES = list(np.typecodes["AllFloat"]) +_TYPE_CODES += [f"m8[{unit}]" for unit in _TIME_UNITS] + + class TestNanFunctions_Median: def test_mutation(self): @@ -662,23 +671,32 @@ class TestNanFunctions_Median: res = np.nanmedian(_ndat, axis=1) assert_almost_equal(res, tgt) - def test_allnans(self): - mat = np.array([np.nan]*9).reshape(3, 3) - for axis in [None, 0, 1]: - with suppress_warnings() as sup: - sup.record(RuntimeWarning) + @pytest.mark.parametrize("axis", [None, 0, 1]) + @pytest.mark.parametrize("dtype", _TYPE_CODES) + def test_allnans(self, dtype, axis): + mat = np.full((3, 3), np.nan).astype(dtype) + with suppress_warnings() as sup: + sup.record(RuntimeWarning) - assert_(np.isnan(np.nanmedian(mat, axis=axis)).all()) - if axis is None: - assert_(len(sup.log) == 1) - else: - assert_(len(sup.log) == 3) - # Check scalar - assert_(np.isnan(np.nanmedian(np.nan))) - if axis is None: - assert_(len(sup.log) == 2) - else: - assert_(len(sup.log) == 4) + output = np.nanmedian(mat, axis=axis) + assert output.dtype == mat.dtype + assert np.isnan(output).all() + + if axis is None: + assert_(len(sup.log) == 1) + else: + assert_(len(sup.log) == 3) + + # Check scalar + scalar = np.array(np.nan).astype(dtype)[()] + output_scalar = np.nanmedian(scalar) + assert output_scalar.dtype == scalar.dtype + assert np.isnan(output_scalar) + + if axis is None: + assert_(len(sup.log) == 2) + else: + assert_(len(sup.log) == 4) def test_empty(self): mat = np.zeros((0, 3)) |