summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-11-03 21:33:06 -0400
committerGitHub <noreply@github.com>2016-11-03 21:33:06 -0400
commit6bb5a22ac873a76f4e80988469fe84c1067a7d58 (patch)
tree3e9cbc9a69f8b6aa50b7d2862590e136012884c6
parent7d3e4488afb7bcdfcf0b5ff52fb1cdcd7dc3748b (diff)
parentaf818b08edc30d04cb56de0429a35ca82458e326 (diff)
downloadnumpy-6bb5a22ac873a76f4e80988469fe84c1067a7d58.tar.gz
Merge pull request #8222 from nouiz/mean_float16
ENH: Make numpy.mean() do more precise computation
-rw-r--r--numpy/core/_methods.py16
-rw-r--r--numpy/core/fromnumeric.py3
-rw-r--r--numpy/core/tests/test_multiarray.py5
3 files changed, 21 insertions, 3 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 54e267541..abfd0a3cc 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -53,21 +53,31 @@ def _count_reduce_items(arr, axis):
def _mean(a, axis=None, dtype=None, out=None, keepdims=False):
arr = asanyarray(a)
+ is_float16_result = False
rcount = _count_reduce_items(arr, axis)
# Make this warning show up first
if rcount == 0:
warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
# Cast bool, unsigned int, and int to float64 by default
- if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
- dtype = mu.dtype('f8')
+ if dtype is None:
+ if issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
+ dtype = mu.dtype('f8')
+ elif issubclass(arr.dtype.type, nt.float16):
+ dtype = mu.dtype('f4')
+ is_float16_result = True
ret = umr_sum(arr, axis, dtype, out, keepdims)
if isinstance(ret, mu.ndarray):
ret = um.true_divide(
ret, rcount, out=ret, casting='unsafe', subok=False)
+ if is_float16_result and out is None:
+ ret = a.dtype.type(ret)
elif hasattr(ret, 'dtype'):
- ret = ret.dtype.type(ret / rcount)
+ if is_float16_result:
+ ret = a.dtype.type(ret / rcount)
+ else:
+ ret = ret.dtype.type(ret / rcount)
else:
ret = ret / rcount
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 84fdab7cc..4cb78e3fc 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -2847,6 +2847,9 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue):
example below). Specifying a higher-precision accumulator using the
`dtype` keyword can alleviate this issue.
+ By default, `float16` results are computed using `float32` intermediates
+ for extra precision.
+
Examples
--------
>>> a = np.array([[1, 2], [3, 4]])
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index a6e039d50..b21e193b9 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4530,6 +4530,11 @@ class TestStats(TestCase):
res = _mean(mat, axis=axis) * np.prod(mat.shape)
assert_almost_equal(res, tgt)
+ def test_mean_float16(self):
+ # This fail if the sum inside mean is done in float16 instead
+ # of float32.
+ assert _mean(np.ones(100000, dtype='float16')) == 1
+
def test_var_values(self):
for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1, None]: