summaryrefslogtreecommitdiff
path: root/numpy/array_api/_statistical_functions.py
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-11-14 14:35:06 -0700
committerGitHub <noreply@github.com>2021-11-14 22:35:06 +0100
commita1813504ad44b70fb139181a9df8465bcb22e24d (patch)
tree1c638079fe4e675976d73a0fb120f2f05d2ec522 /numpy/array_api/_statistical_functions.py
parentb8a0f339dcd90c134e1cc3e19d06348069af685b (diff)
downloadnumpy-a1813504ad44b70fb139181a9df8465bcb22e24d.tar.gz
ENH: Add the linalg extension to the array_api submodule (#19980)
Diffstat (limited to 'numpy/array_api/_statistical_functions.py')
-rw-r--r--numpy/array_api/_statistical_functions.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py
index c5abf9468..7bee3f4db 100644
--- a/numpy/array_api/_statistical_functions.py
+++ b/numpy/array_api/_statistical_functions.py
@@ -93,11 +93,12 @@ def sum(
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
- # Note: sum() and prod() always upcast float32 to float64 for dtype=None
- # We need to do so here before summing to avoid overflow
+ # Note: sum() and prod() always upcast integers to (u)int64 and float32 to
+ # float64 for dtype=None. `np.sum` does that too for integers, but not for
+ # float32, so we need to special-case it here
if dtype is None and x.dtype == float32:
- x = asarray(x, dtype=float64)
- return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims))
+ dtype = float64
+ return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
def var(