diff options
author | Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> | 2021-11-18 14:31:58 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-18 14:31:58 -0800 |
commit | 5e9ce0c0529e3085498ac892941a020a65c7369a (patch) | |
tree | a70d9e941549b4a51b493f1b170ef33ce0d5a217 /numpy/array_api/_statistical_functions.py | |
parent | 2ff7ab64d4e7d5928e96ca95b85350aa9caa2b63 (diff) | |
parent | 056abda14dab7fa8daf7a1ab44144aeb2250c216 (diff) | |
download | numpy-5e9ce0c0529e3085498ac892941a020a65c7369a.tar.gz |
Merge branch 'numpy:main' into as_min_max
Diffstat (limited to 'numpy/array_api/_statistical_functions.py')
-rw-r--r-- | numpy/array_api/_statistical_functions.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index c5abf9468..7bee3f4db 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -93,11 +93,12 @@ def sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - # Note: sum() and prod() always upcast float32 to float64 for dtype=None - # We need to do so here before summing to avoid overflow + # Note: sum() and prod() always upcast integers to (u)int64 and float32 to + # float64 for dtype=None. `np.sum` does that too for integers, but not for + # float32, so we need to special-case it here if dtype is None and x.dtype == float32: - x = asarray(x, dtype=float64) - return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims)) + dtype = float64 + return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) def var( |