From 19a398ae78c3f35ce3d29d87b35da059558d72b4 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 7 Dec 2021 15:35:20 -0700 Subject: BUG: Fix handling of the dtype parameter to numpy.array_api.prod() --- numpy/array_api/_statistical_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/array_api/_statistical_functions.py') diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index 7bee3f4db..5bc831ac2 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -65,8 +65,8 @@ def prod( # Note: sum() and prod() always upcast float32 to float64 for dtype=None # We need to do so here before computing the product to avoid overflow if dtype is None and x.dtype == float32: - x = asarray(x, dtype=float64) - return Array._new(np.prod(x._array, axis=axis, keepdims=keepdims)) + dtype = float64 + return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) def std( -- cgit v1.2.1