summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederic Bastien <nouiz@nouiz.org>2016-10-28 20:09:45 -0400
committerFrederic Bastien <nouiz@nouiz.org>2016-10-28 21:33:04 -0400
commit530d67191287b8dc625bfc49664ad08b22e9e4d3 (patch)
tree66853b1486805c1f205f287fd207295d02697ed8
parent6ae842001332f532e0c76815d49336ecc2b88dde (diff)
downloadnumpy-530d67191287b8dc625bfc49664ad08b22e9e4d3.tar.gz
[ENH]Make numpy.mean() do more precise computation without changing the output dtype that stay in float16.
-rw-r--r--numpy/core/_methods.py15
-rw-r--r--numpy/core/fromnumeric.py2
-rw-r--r--numpy/core/tests/test_multiarray.py5
3 files changed, 19 insertions, 3 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 54e267541..b53c5ca00 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -54,20 +54,29 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False):
arr = asanyarray(a)
rcount = _count_reduce_items(arr, axis)
+ orig_dtype = dtype
# Make this warning show up first
if rcount == 0:
warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
# Cast bool, unsigned int, and int to float64 by default
- if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
- dtype = mu.dtype('f8')
+ if dtype is None:
+ if issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
+ dtype = mu.dtype('f8')
+ elif issubclass(arr.dtype.type, nt.float16):
+ dtype = mu.dtype('f4')
ret = umr_sum(arr, axis, dtype, out, keepdims)
if isinstance(ret, mu.ndarray):
ret = um.true_divide(
ret, rcount, out=ret, casting='unsafe', subok=False)
+ if orig_dtype is None and issubclass(arr.dtype.type, nt.float16):
+ ret = a.dtype.type(ret)
elif hasattr(ret, 'dtype'):
- ret = ret.dtype.type(ret / rcount)
+ if orig_dtype is None and issubclass(arr.dtype.type, nt.float16):
+ ret = a.dtype.type(ret / rcount)
+ else:
+ ret = ret.dtype.type(ret / rcount)
else:
ret = ret / rcount
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 84fdab7cc..a5c63ba84 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -2790,6 +2790,8 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue):
Returns the average of the array elements. The average is taken over
the flattened array by default, otherwise over the specified axis.
`float64` intermediate and return values are used for integer inputs.
+ `float32` intermediate are used for `float16` inputs for extra
+ precision.
Parameters
----------
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 7e3ffe28a..61fcc8f6c 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4530,6 +4530,11 @@ class TestStats(TestCase):
res = _mean(mat, axis=axis) * np.prod(mat.shape)
assert_almost_equal(res, tgt)
+ def test_mean_float16(self):
+ # This fail if the sum inside mean is done in float16 instead
+ # of float32.
+ assert _mean(np.ones(100000, dtype='float16')) == 1
+
def test_var_values(self):
for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1, None]: