summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorKinshuk Dua <kinshukdua@gmail.com>2022-01-27 13:03:12 +0530
committerKinshuk Dua <kinshukdua@gmail.com>2022-01-27 13:03:12 +0530
commit79019ae2a9e2fa1b640ae3eb6727dda1b64fcd53 (patch)
tree9633dcef827938d359c8fcf13dc27ad454fa41c6 /numpy
parentdf5c66c98144f8364f1a0c11a35419f3dfc0970d (diff)
downloadnumpy-79019ae2a9e2fa1b640ae3eb6727dda1b64fcd53.tar.gz
BUG: change `ma.mean` dtype to be consistent with `np.mean`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py11
-rw-r--r--numpy/ma/tests/test_core.py6
2 files changed, 16 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 2ff1667ba..7b099f69b 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -31,6 +31,7 @@ from functools import reduce
import numpy as np
import numpy.core.umath as umath
import numpy.core.numerictypes as ntypes
+from numpy.core import multiarray as mu
from numpy import ndarray, amax, amin, iscomplexobj, bool_, _NoValue
from numpy import array as narray
from numpy.lib.function_base import angle
@@ -5235,14 +5236,22 @@ class MaskedArray(ndarray):
"""
kwargs = {} if keepdims is np._NoValue else {'keepdims': keepdims}
-
if self._mask is nomask:
result = super().mean(axis=axis, dtype=dtype, **kwargs)[()]
else:
+ is_float16_result = False
+ if dtype is None:
+ if issubclass(self.dtype.type, (ntypes.integer, ntypes.bool_)):
+ dtype = mu.dtype('f8')
+ elif issubclass(self.dtype.type, ntypes.float16):
+ dtype = mu.dtype('f4')
+ is_float16_result = True
dsum = self.sum(axis=axis, dtype=dtype, **kwargs)
cnt = self.count(axis=axis, **kwargs)
if cnt.shape == () and (cnt == 0):
result = masked
+ elif is_float16_result:
+ result = self.dtype.type(dsum * 1. / cnt)
else:
result = dsum * 1. / cnt
if out is not None:
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index bf95c999a..396b90d93 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4050,6 +4050,12 @@ class TestMaskedArrayMathMethods:
assert_equal(a.max(-1), [3, 6])
assert_equal(a.max(1), [3, 6])
+ def test_mean_overflow(self):
+ # Test overflow in masked arrays
+ # gh-20272
+ a = masked_array(np.full((10000, 10000), 65535, dtype=np.uint16),
+ mask=np.zeros((10000, 10000)))
+ assert_equal(a.mean(), 65535.0)
class TestMaskedArrayMathMethodsComplex:
# Test class for miscellaneous MaskedArrays methods.