diff options
| author | Matti Picus <matti.picus@gmail.com> | 2022-06-12 16:47:47 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-12 16:47:47 +0300 |
| commit | c2db91cbc287ee0052b501f6f28bb88926b54db5 (patch) | |
| tree | 6a13450f1f3f48429ce5d1c98604f3a49c0570a9 /numpy | |
| parent | 705244444526804860e9f8d52459e2ca5c255366 (diff) | |
| parent | 79019ae2a9e2fa1b640ae3eb6727dda1b64fcd53 (diff) | |
| download | numpy-c2db91cbc287ee0052b501f6f28bb88926b54db5.tar.gz | |
Merge pull request #20914 from kinshukdua/fix_ma_mean
BUG: change `ma.mean` dtype to be consistent with `np.mean`
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/ma/core.py | 11 | ||||
| -rw-r--r-- | numpy/ma/tests/test_core.py | 6 |
2 files changed, 16 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 78333ed02..d8fd4f389 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -31,6 +31,7 @@ from functools import reduce import numpy as np import numpy.core.umath as umath import numpy.core.numerictypes as ntypes +from numpy.core import multiarray as mu from numpy import ndarray, amax, amin, iscomplexobj, bool_, _NoValue from numpy import array as narray from numpy.lib.function_base import angle @@ -5289,14 +5290,22 @@ class MaskedArray(ndarray): """ kwargs = {} if keepdims is np._NoValue else {'keepdims': keepdims} - if self._mask is nomask: result = super().mean(axis=axis, dtype=dtype, **kwargs)[()] else: + is_float16_result = False + if dtype is None: + if issubclass(self.dtype.type, (ntypes.integer, ntypes.bool_)): + dtype = mu.dtype('f8') + elif issubclass(self.dtype.type, ntypes.float16): + dtype = mu.dtype('f4') + is_float16_result = True dsum = self.sum(axis=axis, dtype=dtype, **kwargs) cnt = self.count(axis=axis, **kwargs) if cnt.shape == () and (cnt == 0): result = masked + elif is_float16_result: + result = self.dtype.type(dsum * 1. / cnt) else: result = dsum * 1. / cnt if out is not None: diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 0dada104d..679ed2090 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4036,6 +4036,12 @@ class TestMaskedArrayMathMethods: assert_equal(a.max(-1), [3, 6]) assert_equal(a.max(1), [3, 6]) + def test_mean_overflow(self): + # Test overflow in masked arrays + # gh-20272 + a = masked_array(np.full((10000, 10000), 65535, dtype=np.uint16), + mask=np.zeros((10000, 10000))) + assert_equal(a.mean(), 65535.0) class TestMaskedArrayMathMethodsComplex: # Test class for miscellaneous MaskedArrays methods. |
