diff options
author | Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> | 2022-12-21 11:30:48 -0800 |
---|---|---|
committer | Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> | 2022-12-21 11:30:48 -0800 |
commit | f2fa2e533b0fee50aee2db8cfe2906fd072774c7 (patch) | |
tree | 5dec4905416c417b644370ebafdbb2525c07a8ec | |
parent | 2f64274d360b25d509636897d5d39157080a42cc (diff) | |
download | numpy-f2fa2e533b0fee50aee2db8cfe2906fd072774c7.tar.gz |
BUG, SIMD: Restore behavior converting non bool input to 0x00/0xff
This resolves https://github.com/numpy/numpy/issues/22845 by restoring prior behavior to convert non bool input
-rw-r--r-- | numpy/core/src/umath/loops_logical.dispatch.c.src | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/numpy/core/src/umath/loops_logical.dispatch.c.src b/numpy/core/src/umath/loops_logical.dispatch.c.src index 793a2af19..4a021d50b 100644 --- a/numpy/core/src/umath/loops_logical.dispatch.c.src +++ b/numpy/core/src/umath/loops_logical.dispatch.c.src @@ -63,6 +63,10 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp const int vstep = npyv_nlanes_u8; const int wstep = vstep * UNROLL; +#if @and@ + const npyv_u8 zero = npyv_zero_u8(); +#endif + // Unrolled vectors loop for (; len >= wstep; len -= wstep, ip1 += wstep, ip2 += wstep, op += wstep) { /**begin repeat1 @@ -71,6 +75,10 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp #if UNROLL > @unroll@ npyv_u8 a@unroll@ = npyv_load_u8(ip1 + vstep * @unroll@); npyv_u8 b@unroll@ = npyv_load_u8(ip2 + vstep * @unroll@); +#if @and@ + // a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly. + a@unroll@ = npyv_cvt_u8_b8(npyv_cmpgt_u8(a@unroll@, zero)); +#endif npyv_u8 r@unroll@ = npyv_@intrin@_u8(a@unroll@, b@unroll@); npyv_store_u8(op + vstep * @unroll@, byte_to_true(r@unroll@)); #endif @@ -82,6 +90,10 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp for (; len >= vstep; len -= vstep, ip1 += vstep, ip2 += vstep, op += vstep) { npyv_u8 a = npyv_load_u8(ip1); npyv_u8 b = npyv_load_u8(ip2); +#if @and@ + // a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly. + a = npyv_cvt_u8_b8(npyv_cmpgt_u8(a, zero)); +#endif npyv_u8 r = npyv_@intrin@_u8(a, b); npyv_store_u8(op, byte_to_true(r)); } |