diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2021-05-22 11:53:39 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-22 11:53:39 -0600 |
commit | d691f18e8ef8fd76d3ee257563d843f9ad5c7a71 (patch) | |
tree | 4917130d28579bb9f91b632d95c3d0d013f795cc /numpy/lib/tests | |
parent | ef6595517f451edc5c5710f048d1eaf6c82f4022 (diff) | |
parent | ae9314eff5d539122bf87800a1bc50a9f99762a8 (diff) | |
download | numpy-d691f18e8ef8fd76d3ee257563d843f9ad5c7a71.tar.gz |
Merge pull request #19066 from BvB93/nanmedian
BUG: Fixed an issue wherein `nanmedian` could return an array with the wrong dtype
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 50 |
1 files changed, 34 insertions, 16 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index e0f723a3c..1f1f5601b 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -588,6 +588,15 @@ class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin): assert_(len(w) == 0) +_TIME_UNITS = ( + "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as" +) + +# All `inexact` + `timdelta64` type codes +_TYPE_CODES = list(np.typecodes["AllFloat"]) +_TYPE_CODES += [f"m8[{unit}]" for unit in _TIME_UNITS] + + class TestNanFunctions_Median: def test_mutation(self): @@ -662,23 +671,32 @@ class TestNanFunctions_Median: res = np.nanmedian(_ndat, axis=1) assert_almost_equal(res, tgt) - def test_allnans(self): - mat = np.array([np.nan]*9).reshape(3, 3) - for axis in [None, 0, 1]: - with suppress_warnings() as sup: - sup.record(RuntimeWarning) + @pytest.mark.parametrize("axis", [None, 0, 1]) + @pytest.mark.parametrize("dtype", _TYPE_CODES) + def test_allnans(self, dtype, axis): + mat = np.full((3, 3), np.nan).astype(dtype) + with suppress_warnings() as sup: + sup.record(RuntimeWarning) - assert_(np.isnan(np.nanmedian(mat, axis=axis)).all()) - if axis is None: - assert_(len(sup.log) == 1) - else: - assert_(len(sup.log) == 3) - # Check scalar - assert_(np.isnan(np.nanmedian(np.nan))) - if axis is None: - assert_(len(sup.log) == 2) - else: - assert_(len(sup.log) == 4) + output = np.nanmedian(mat, axis=axis) + assert output.dtype == mat.dtype + assert np.isnan(output).all() + + if axis is None: + assert_(len(sup.log) == 1) + else: + assert_(len(sup.log) == 3) + + # Check scalar + scalar = np.array(np.nan).astype(dtype)[()] + output_scalar = np.nanmedian(scalar) + assert output_scalar.dtype == scalar.dtype + assert np.isnan(output_scalar) + + if axis is None: + assert_(len(sup.log) == 2) + else: + assert_(len(sup.log) == 4) def test_empty(self): mat = np.zeros((0, 3)) |