summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2022-06-12 16:47:47 +0300
committerGitHub <noreply@github.com>2022-06-12 16:47:47 +0300
commitc2db91cbc287ee0052b501f6f28bb88926b54db5 (patch)
tree6a13450f1f3f48429ce5d1c98604f3a49c0570a9 /numpy
parent705244444526804860e9f8d52459e2ca5c255366 (diff)
parent79019ae2a9e2fa1b640ae3eb6727dda1b64fcd53 (diff)
downloadnumpy-c2db91cbc287ee0052b501f6f28bb88926b54db5.tar.gz
Merge pull request #20914 from kinshukdua/fix_ma_mean
BUG: change `ma.mean` dtype to be consistent with `np.mean`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py11
-rw-r--r--numpy/ma/tests/test_core.py6
2 files changed, 16 insertions, 1 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 78333ed02..d8fd4f389 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -31,6 +31,7 @@ from functools import reduce
import numpy as np
import numpy.core.umath as umath
import numpy.core.numerictypes as ntypes
+from numpy.core import multiarray as mu
from numpy import ndarray, amax, amin, iscomplexobj, bool_, _NoValue
from numpy import array as narray
from numpy.lib.function_base import angle
@@ -5289,14 +5290,22 @@ class MaskedArray(ndarray):
"""
kwargs = {} if keepdims is np._NoValue else {'keepdims': keepdims}
-
if self._mask is nomask:
result = super().mean(axis=axis, dtype=dtype, **kwargs)[()]
else:
+ is_float16_result = False
+ if dtype is None:
+ if issubclass(self.dtype.type, (ntypes.integer, ntypes.bool_)):
+ dtype = mu.dtype('f8')
+ elif issubclass(self.dtype.type, ntypes.float16):
+ dtype = mu.dtype('f4')
+ is_float16_result = True
dsum = self.sum(axis=axis, dtype=dtype, **kwargs)
cnt = self.count(axis=axis, **kwargs)
if cnt.shape == () and (cnt == 0):
result = masked
+ elif is_float16_result:
+ result = self.dtype.type(dsum * 1. / cnt)
else:
result = dsum * 1. / cnt
if out is not None:
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 0dada104d..679ed2090 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4036,6 +4036,12 @@ class TestMaskedArrayMathMethods:
assert_equal(a.max(-1), [3, 6])
assert_equal(a.max(1), [3, 6])
+ def test_mean_overflow(self):
+ # Test overflow in masked arrays
+ # gh-20272
+ a = masked_array(np.full((10000, 10000), 65535, dtype=np.uint16),
+ mask=np.zeros((10000, 10000)))
+ assert_equal(a.mean(), 65535.0)
class TestMaskedArrayMathMethodsComplex:
# Test class for miscellaneous MaskedArrays methods.