summaryrefslogtreecommitdiff
path: root/numpy/array_api/_searching_functions.py
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
committerRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
commiteccb8dfbd9b07183e16a1144e8d5d76936671bfc (patch)
tree647a9477b4f3b8b7205f2f7f2feb99eaa482e806 /numpy/array_api/_searching_functions.py
parentd0d75f39f28ac26d4cc1aa3a4cbea63a6a027929 (diff)
parentff2e2a1e7eea29d925063b13922e096d14331222 (diff)
downloadnumpy-eccb8dfbd9b07183e16a1144e8d5d76936671bfc.tar.gz
Merge branch 'main' into never_copy
Diffstat (limited to 'numpy/array_api/_searching_functions.py')
-rw-r--r--numpy/array_api/_searching_functions.py1
1 files changed, 1 insertions, 0 deletions
diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py
index 3dcef61c3..40f5a4d2e 100644
--- a/numpy/array_api/_searching_functions.py
+++ b/numpy/array_api/_searching_functions.py
@@ -43,4 +43,5 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.where(condition._array, x1._array, x2._array))