diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2021-09-13 15:50:54 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2021-09-14 13:38:07 -0400 |
commit | 6ba48721e22622403a60b7f9d3ec5cae308ba3a9 (patch) | |
tree | b85b31d174fbcf7d6c7723eebfd0d3eabdc042b4 | |
parent | e4f85b08c252b576f3c589cc1202f040c22d64f2 (diff) | |
download | numpy-6ba48721e22622403a60b7f9d3ec5cae308ba3a9.tar.gz |
BUG: ensure np.median does not drop subclass for NaN result.
Currently, np.median is almost completely safe for subclasses, except
if the result is NaN. In that case, it assumes the result is a scalar
and substitutes a NaN with the right dtype. This PR fixes that, since
subclasses like astropy's Quantity generally use array scalars to
preserve subclass information such as the unit.
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 10 | ||||
-rw-r--r-- | numpy/lib/utils.py | 14 |
2 files changed, 16 insertions, 8 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 829691b1c..5f27ea655 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -3432,6 +3432,16 @@ class TestMedian: a = MySubClass([1, 2, 3]) assert_equal(np.median(a), -7) + @pytest.mark.parametrize('arr', + ([1., 2., 3.], [1., np.nan, 3.], np.nan, 0.)) + def test_subclass2(self, arr): + """Check that we return subclasses, even if a NaN scalar.""" + class MySubclass(np.ndarray): + pass + + m = np.median(np.array(arr).view(MySubclass)) + assert isinstance(m, MySubclass) + def test_out(self): o = np.zeros((4,)) d = np.ones((3, 4)) diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 1f2cb66fa..931669fc1 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -1029,14 +1029,12 @@ def _median_nancheck(data, result, axis, out): # masked NaN values are ok if np.ma.isMaskedArray(n): n = n.filled(False) - if result.ndim == 0: - if n == True: - if out is not None: - out[...] = data.dtype.type(np.nan) - result = out - else: - result = data.dtype.type(np.nan) - elif np.count_nonzero(n.ravel()) > 0: + if np.count_nonzero(n.ravel()) > 0: + # Without given output, it is possible that the current result is a + # numpy scalar, which is not writeable. If so, just return nan. + if isinstance(result, np.generic): + return data.dtype.type(np.nan) + result[n] = np.nan return result |