From 8968bc31944eb2af214b591d38eca3b0cea56f4b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Mar 2021 15:20:24 -0600 Subject: bitwise_left_shift and bitwise_right_shift should return the dtype of the first argument --- numpy/_array_api/_elementwise_functions.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'numpy/_array_api') diff --git a/numpy/_array_api/_elementwise_functions.py b/numpy/_array_api/_elementwise_functions.py index cd2e8661f..a8af04c62 100644 --- a/numpy/_array_api/_elementwise_functions.py +++ b/numpy/_array_api/_elementwise_functions.py @@ -123,10 +123,13 @@ def bitwise_left_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + # Note: the function name is different here if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_left_shift') - # Note: the function name is different here - return ndarray._new(np.left_shift(x1._array, x2._array)) + # Note: The spec requires the return dtype of bitwise_left_shift to be the + # same as the first argument. np.left_shift() returns a type that is the + # type promotion of the two input types. + return ndarray._new(np.left_shift(x1._array, x2._array).astype(x1.dtype)) def bitwise_invert(x: array, /) -> array: """ @@ -155,10 +158,13 @@ def bitwise_right_shift(x1: array, x2: array, /) -> array: See its docstring for more information. """ + # Note: the function name is different here if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError('Only integer dtypes are allowed in bitwise_right_shift') - # Note: the function name is different here - return ndarray._new(np.right_shift(x1._array, x2._array)) + # Note: The spec requires the return dtype of bitwise_left_shift to be the + # same as the first argument. np.left_shift() returns a type that is the + # type promotion of the two input types. + return ndarray._new(np.right_shift(x1._array, x2._array).astype(x1.dtype)) def bitwise_xor(x1: array, x2: array, /) -> array: """ -- cgit v1.2.1