summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-07-03 10:12:07 -0600
committerGitHub <noreply@github.com>2018-07-03 10:12:07 -0600
commitfb0565e2e0e066e16c93f92e5b18b1e6fc4c6b57 (patch)
tree880b6c869d16c927be9955589eca66472dd1e255 /numpy/core
parent8909c09e0b4a4ec1d33ed53609a9613bf225c2ff (diff)
parent2b680c2311246f2d78f8b7786284336b92ef74a3 (diff)
downloadnumpy-fb0565e2e0e066e16c93f92e5b18b1e6fc4c6b57.tar.gz
Merge pull request #11468 from seberg/fancy-assignment-no-fastpath
BUG: Advanced indexing assignment incorrectly took 1-D fastpath
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/mapping.c2
-rw-r--r--numpy/core/src/private/lowlevel_strided_loops.h39
-rw-r--r--numpy/core/tests/test_indexing.py15
3 files changed, 36 insertions, 20 deletions
diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c
index 46ff78b9c..cdca1d606 100644
--- a/numpy/core/src/multiarray/mapping.c
+++ b/numpy/core/src/multiarray/mapping.c
@@ -2084,7 +2084,7 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op)
PyArray_TRIVIALLY_ITERABLE_OP_READ,
PyArray_TRIVIALLY_ITERABLE_OP_READ) ||
(PyArray_NDIM(tmp_arr) == 0 &&
- PyArray_TRIVIALLY_ITERABLE(tmp_arr))) &&
+ PyArray_TRIVIALLY_ITERABLE(ind))) &&
/* Check if the type is equivalent to INTP */
PyArray_ITEMSIZE(ind) == sizeof(npy_intp) &&
PyArray_DESCR(ind)->kind == 'i' &&
diff --git a/numpy/core/src/private/lowlevel_strided_loops.h b/numpy/core/src/private/lowlevel_strided_loops.h
index 094612b7d..f9c671f77 100644
--- a/numpy/core/src/private/lowlevel_strided_loops.h
+++ b/numpy/core/src/private/lowlevel_strided_loops.h
@@ -689,21 +689,16 @@ npy_bswap8_unaligned(char * x)
#define PyArray_TRIVIALLY_ITERABLE_OP_NOREAD 0
#define PyArray_TRIVIALLY_ITERABLE_OP_READ 1
-#define PyArray_EQUIVALENTLY_ITERABLE_BASE(arr1, arr2) ( \
- PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \
- PyArray_CompareLists(PyArray_DIMS(arr1), \
- PyArray_DIMS(arr2), \
- PyArray_NDIM(arr1)) && \
- (PyArray_FLAGS(arr1)&(NPY_ARRAY_C_CONTIGUOUS| \
- NPY_ARRAY_F_CONTIGUOUS)) & \
- (PyArray_FLAGS(arr2)&(NPY_ARRAY_C_CONTIGUOUS| \
- NPY_ARRAY_F_CONTIGUOUS)) \
- )
+#define PyArray_TRIVIALLY_ITERABLE(arr) ( \
+ PyArray_NDIM(arr) <= 1 || \
+ PyArray_CHKFLAGS(arr, NPY_ARRAY_C_CONTIGUOUS) || \
+ PyArray_CHKFLAGS(arr, NPY_ARRAY_F_CONTIGUOUS) \
+ )
#define PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size, arr) ( \
- size == 1 ? 0 : ((PyArray_NDIM(arr) == 1) ? \
- PyArray_STRIDE(arr, 0) : \
- PyArray_ITEMSIZE(arr)))
+ assert(PyArray_TRIVIALLY_ITERABLE(arr)), \
+ size == 1 ? 0 : ((PyArray_NDIM(arr) == 1) ? \
+ PyArray_STRIDE(arr, 0) : PyArray_ITEMSIZE(arr)))
static NPY_INLINE int
PyArray_EQUIVALENTLY_ITERABLE_OVERLAP_OK(PyArrayObject *arr1, PyArrayObject *arr2,
@@ -757,15 +752,22 @@ PyArray_EQUIVALENTLY_ITERABLE_OVERLAP_OK(PyArrayObject *arr1, PyArrayObject *arr
return (!arr1_read || arr1_ahead) && (!arr2_read || arr2_ahead);
}
+#define PyArray_EQUIVALENTLY_ITERABLE_BASE(arr1, arr2) ( \
+ PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \
+ PyArray_CompareLists(PyArray_DIMS(arr1), \
+ PyArray_DIMS(arr2), \
+ PyArray_NDIM(arr1)) && \
+ (PyArray_FLAGS(arr1)&(NPY_ARRAY_C_CONTIGUOUS| \
+ NPY_ARRAY_F_CONTIGUOUS)) & \
+ (PyArray_FLAGS(arr2)&(NPY_ARRAY_C_CONTIGUOUS| \
+ NPY_ARRAY_F_CONTIGUOUS)) \
+ )
+
#define PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2, arr1_read, arr2_read) ( \
PyArray_EQUIVALENTLY_ITERABLE_BASE(arr1, arr2) && \
PyArray_EQUIVALENTLY_ITERABLE_OVERLAP_OK( \
arr1, arr2, arr1_read, arr2_read))
-#define PyArray_TRIVIALLY_ITERABLE(arr) ( \
- PyArray_NDIM(arr) <= 1 || \
- PyArray_CHKFLAGS(arr, NPY_ARRAY_C_CONTIGUOUS) || \
- PyArray_CHKFLAGS(arr, NPY_ARRAY_F_CONTIGUOUS) \
- )
+
#define PyArray_PREPARE_TRIVIAL_ITERATION(arr, count, data, stride) \
count = PyArray_SIZE(arr); \
data = PyArray_BYTES(arr); \
@@ -774,7 +776,6 @@ PyArray_EQUIVALENTLY_ITERABLE_OVERLAP_OK(PyArrayObject *arr1, PyArrayObject *arr
PyArray_STRIDE(arr, 0) : \
PyArray_ITEMSIZE(arr)));
-
#define PyArray_TRIVIALLY_ITERABLE_PAIR(arr1, arr2, arr1_read, arr2_read) ( \
PyArray_TRIVIALLY_ITERABLE(arr1) && \
(PyArray_NDIM(arr2) == 0 || \
diff --git a/numpy/core/tests/test_indexing.py b/numpy/core/tests/test_indexing.py
index cbcd3e994..276cd9f93 100644
--- a/numpy/core/tests/test_indexing.py
+++ b/numpy/core/tests/test_indexing.py
@@ -329,6 +329,21 @@ class TestIndexing(object):
assert_raises(IndexError, a.__getitem__, ind)
assert_raises(IndexError, a.__setitem__, ind, 0)
+ def test_trivial_fancy_not_possible(self):
+ # Test that the fast path for trivial assignment is not incorrectly
+ # used when the index is not contiguous or 1D, see also gh-11467.
+ a = np.arange(6)
+ idx = np.arange(6, dtype=np.intp).reshape(2, 1, 3)[:, :, 0]
+ assert_array_equal(a[idx], idx)
+
+ # this case must not go into the fast path, note that idx is
+ # a non-contiuguous none 1D array here.
+ a[idx] = -1
+ res = np.arange(6)
+ res[0] = -1
+ res[3] = -1
+ assert_array_equal(a, res)
+
def test_nonbaseclass_values(self):
class SubClass(np.ndarray):
def __array_finalize__(self, old):