diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2016-11-03 21:33:06 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-03 21:33:06 -0400 |
commit | 6bb5a22ac873a76f4e80988469fe84c1067a7d58 (patch) | |
tree | 3e9cbc9a69f8b6aa50b7d2862590e136012884c6 | |
parent | 7d3e4488afb7bcdfcf0b5ff52fb1cdcd7dc3748b (diff) | |
parent | af818b08edc30d04cb56de0429a35ca82458e326 (diff) | |
download | numpy-6bb5a22ac873a76f4e80988469fe84c1067a7d58.tar.gz |
Merge pull request #8222 from nouiz/mean_float16
ENH: Make numpy.mean() do more precise computation
-rw-r--r-- | numpy/core/_methods.py | 16 | ||||
-rw-r--r-- | numpy/core/fromnumeric.py | 3 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 5 |
3 files changed, 21 insertions, 3 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 54e267541..abfd0a3cc 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -53,21 +53,31 @@ def _count_reduce_items(arr, axis): def _mean(a, axis=None, dtype=None, out=None, keepdims=False): arr = asanyarray(a) + is_float16_result = False rcount = _count_reduce_items(arr, axis) # Make this warning show up first if rcount == 0: warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2) # Cast bool, unsigned int, and int to float64 by default - if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool_)): - dtype = mu.dtype('f8') + if dtype is None: + if issubclass(arr.dtype.type, (nt.integer, nt.bool_)): + dtype = mu.dtype('f8') + elif issubclass(arr.dtype.type, nt.float16): + dtype = mu.dtype('f4') + is_float16_result = True ret = umr_sum(arr, axis, dtype, out, keepdims) if isinstance(ret, mu.ndarray): ret = um.true_divide( ret, rcount, out=ret, casting='unsafe', subok=False) + if is_float16_result and out is None: + ret = a.dtype.type(ret) elif hasattr(ret, 'dtype'): - ret = ret.dtype.type(ret / rcount) + if is_float16_result: + ret = a.dtype.type(ret / rcount) + else: + ret = ret.dtype.type(ret / rcount) else: ret = ret / rcount diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 84fdab7cc..4cb78e3fc 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -2847,6 +2847,9 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): example below). Specifying a higher-precision accumulator using the `dtype` keyword can alleviate this issue. + By default, `float16` results are computed using `float32` intermediates + for extra precision. + Examples -------- >>> a = np.array([[1, 2], [3, 4]]) diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index a6e039d50..b21e193b9 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4530,6 +4530,11 @@ class TestStats(TestCase): res = _mean(mat, axis=axis) * np.prod(mat.shape) assert_almost_equal(res, tgt) + def test_mean_float16(self): + # This fail if the sum inside mean is done in float16 instead + # of float32. + assert _mean(np.ones(100000, dtype='float16')) == 1 + def test_var_values(self): for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: |