summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2021-09-13 15:50:54 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2021-09-14 13:38:07 -0400
commit6ba48721e22622403a60b7f9d3ec5cae308ba3a9 (patch)
treeb85b31d174fbcf7d6c7723eebfd0d3eabdc042b4
parente4f85b08c252b576f3c589cc1202f040c22d64f2 (diff)
downloadnumpy-6ba48721e22622403a60b7f9d3ec5cae308ba3a9.tar.gz
BUG: ensure np.median does not drop subclass for NaN result.
Currently, np.median is almost completely safe for subclasses, except if the result is NaN. In that case, it assumes the result is a scalar and substitutes a NaN with the right dtype. This PR fixes that, since subclasses like astropy's Quantity generally use array scalars to preserve subclass information such as the unit.
-rw-r--r--numpy/lib/tests/test_function_base.py10
-rw-r--r--numpy/lib/utils.py14
2 files changed, 16 insertions, 8 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 829691b1c..5f27ea655 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -3432,6 +3432,16 @@ class TestMedian:
a = MySubClass([1, 2, 3])
assert_equal(np.median(a), -7)
+ @pytest.mark.parametrize('arr',
+ ([1., 2., 3.], [1., np.nan, 3.], np.nan, 0.))
+ def test_subclass2(self, arr):
+ """Check that we return subclasses, even if a NaN scalar."""
+ class MySubclass(np.ndarray):
+ pass
+
+ m = np.median(np.array(arr).view(MySubclass))
+ assert isinstance(m, MySubclass)
+
def test_out(self):
o = np.zeros((4,))
d = np.ones((3, 4))
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index 1f2cb66fa..931669fc1 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -1029,14 +1029,12 @@ def _median_nancheck(data, result, axis, out):
# masked NaN values are ok
if np.ma.isMaskedArray(n):
n = n.filled(False)
- if result.ndim == 0:
- if n == True:
- if out is not None:
- out[...] = data.dtype.type(np.nan)
- result = out
- else:
- result = data.dtype.type(np.nan)
- elif np.count_nonzero(n.ravel()) > 0:
+ if np.count_nonzero(n.ravel()) > 0:
+ # Without given output, it is possible that the current result is a
+ # numpy scalar, which is not writeable. If so, just return nan.
+ if isinstance(result, np.generic):
+ return data.dtype.type(np.nan)
+
result[n] = np.nan
return result