diff options
-rw-r--r-- | numpy/core/src/private/lowlevel_strided_loops.h | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/numpy/core/src/private/lowlevel_strided_loops.h b/numpy/core/src/private/lowlevel_strided_loops.h index 8a1312c35..691c6bcd5 100644 --- a/numpy/core/src/private/lowlevel_strided_loops.h +++ b/numpy/core/src/private/lowlevel_strided_loops.h @@ -720,20 +720,23 @@ PyArray_EQUIVALENTLY_ITERABLE_OVERLAP_OK(PyArrayObject *arr1, PyArrayObject *arr stride1 = PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size1, arr1); stride2 = PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size2, arr2); - if (stride1 >= 0) { + /* Arrays with zero stride are never "ahead" since the element is reused + (at this point we know the arrays do overlap). */ + + if (stride1 > 0) { arr1_ahead = (stride1 >= stride2 && (npy_uintp)PyArray_BYTES(arr1) >= (npy_uintp)PyArray_BYTES(arr2)); } - else { + else if (stride1 < 0) { arr1_ahead = (stride1 <= stride2 && (npy_uintp)PyArray_BYTES(arr1) <= (npy_uintp)PyArray_BYTES(arr2)); } - if (stride2 >= 0) { + if (stride2 > 0) { arr2_ahead = (stride2 >= stride1 && (npy_uintp)PyArray_BYTES(arr2) >= (npy_uintp)PyArray_BYTES(arr1)); } - else { + else if (stride2 < 0) { arr2_ahead = (stride2 <= stride1 && (npy_uintp)PyArray_BYTES(arr2) <= (npy_uintp)PyArray_BYTES(arr1)); } |