summaryrefslogtreecommitdiff
path: root/numpy/array_api/_statistical_functions.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-12-07 17:25:43 -0700
committerGitHub <noreply@github.com>2021-12-07 17:25:43 -0700
commitab7a1927353ab9dd52e3f2f7a1a889ae790667b9 (patch)
treed84a6fe6a59c7d584792a388ded73bb2168dc651 /numpy/array_api/_statistical_functions.py
parent3b3998fd9a3c126ae18ba9fc33f5d585e251ee86 (diff)
parent19a398ae78c3f35ce3d29d87b35da059558d72b4 (diff)
downloadnumpy-ab7a1927353ab9dd52e3f2f7a1a889ae790667b9.tar.gz
Merge pull request #20533 from asmeurer/array_api-prod-fix
BUG: Fix handling of the dtype parameter to numpy.array_api.prod()
Diffstat (limited to 'numpy/array_api/_statistical_functions.py')
-rw-r--r--numpy/array_api/_statistical_functions.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py
index 7bee3f4db..5bc831ac2 100644
--- a/numpy/array_api/_statistical_functions.py
+++ b/numpy/array_api/_statistical_functions.py
@@ -65,8 +65,8 @@ def prod(
# Note: sum() and prod() always upcast float32 to float64 for dtype=None
# We need to do so here before computing the product to avoid overflow
if dtype is None and x.dtype == float32:
- x = asarray(x, dtype=float64)
- return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims))
+ dtype = float64
+ return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
def std(