diff options
| author | Kinshuk Dua <kinshukdua@gmail.com> | 2022-01-27 13:03:12 +0530 |
|---|---|---|
| committer | Kinshuk Dua <kinshukdua@gmail.com> | 2022-01-27 13:03:12 +0530 |
| commit | 79019ae2a9e2fa1b640ae3eb6727dda1b64fcd53 (patch) | |
| tree | 9633dcef827938d359c8fcf13dc27ad454fa41c6 /numpy | |
| parent | df5c66c98144f8364f1a0c11a35419f3dfc0970d (diff) | |
| download | numpy-79019ae2a9e2fa1b640ae3eb6727dda1b64fcd53.tar.gz | |
BUG: change `ma.mean` dtype to be consistent with `np.mean`
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/ma/core.py | 11 | ||||
| -rw-r--r-- | numpy/ma/tests/test_core.py | 6 |
2 files changed, 16 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 2ff1667ba..7b099f69b 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -31,6 +31,7 @@ from functools import reduce import numpy as np import numpy.core.umath as umath import numpy.core.numerictypes as ntypes +from numpy.core import multiarray as mu from numpy import ndarray, amax, amin, iscomplexobj, bool_, _NoValue from numpy import array as narray from numpy.lib.function_base import angle @@ -5235,14 +5236,22 @@ class MaskedArray(ndarray): """ kwargs = {} if keepdims is np._NoValue else {'keepdims': keepdims} - if self._mask is nomask: result = super().mean(axis=axis, dtype=dtype, **kwargs)[()] else: + is_float16_result = False + if dtype is None: + if issubclass(self.dtype.type, (ntypes.integer, ntypes.bool_)): + dtype = mu.dtype('f8') + elif issubclass(self.dtype.type, ntypes.float16): + dtype = mu.dtype('f4') + is_float16_result = True dsum = self.sum(axis=axis, dtype=dtype, **kwargs) cnt = self.count(axis=axis, **kwargs) if cnt.shape == () and (cnt == 0): result = masked + elif is_float16_result: + result = self.dtype.type(dsum * 1. / cnt) else: result = dsum * 1. / cnt if out is not None: diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index bf95c999a..396b90d93 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4050,6 +4050,12 @@ class TestMaskedArrayMathMethods: assert_equal(a.max(-1), [3, 6]) assert_equal(a.max(1), [3, 6]) + def test_mean_overflow(self): + # Test overflow in masked arrays + # gh-20272 + a = masked_array(np.full((10000, 10000), 65535, dtype=np.uint16), + mask=np.zeros((10000, 10000))) + assert_equal(a.mean(), 65535.0) class TestMaskedArrayMathMethodsComplex: # Test class for miscellaneous MaskedArrays methods. |
