summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2016-09-11 14:58:40 +0200
committerPauli Virtanen <pav@iki.fi>2017-01-19 22:12:47 +0100
commit0bff7b30466b26963cf4fc1b280eb207b74e9851 (patch)
tree180623c77ce0230c80a8acb9fe40a3248a3ce8f7
parentdae0b12d6a790543bee4002f434a1633f8923188 (diff)
downloadnumpy-0bff7b30466b26963cf4fc1b280eb207b74e9851.tar.gz
ENH: core: handle memory overlap in ufunc.at
This adds a new method PyArray_MapIterArrayCopyIfOverlap to the API.
-rw-r--r--numpy/core/code_generators/numpy_api.py2
-rw-r--r--numpy/core/src/multiarray/mapping.c90
-rw-r--r--numpy/core/src/umath/ufunc_object.c15
-rw-r--r--numpy/core/tests/test_mem_overlap.py24
4 files changed, 120 insertions, 11 deletions
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py
index 972966627..7d107f2ac 100644
--- a/numpy/core/code_generators/numpy_api.py
+++ b/numpy/core/code_generators/numpy_api.py
@@ -344,6 +344,8 @@ multiarray_funcs_api = {
# End 1.9 API
'PyArray_CheckAnyScalarExact': (300, NonNull(1)),
# End 1.10 API
+ # End 1.11 API
+ 'PyArray_MapIterArrayCopyIfOverlap': (301,),
}
ufunc_types_api = {
diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c
index 50f516a29..b97cc185e 100644
--- a/numpy/core/src/multiarray/mapping.c
+++ b/numpy/core/src/multiarray/mapping.c
@@ -19,6 +19,7 @@
#include "mapping.h"
#include "lowlevel_strided_loops.h"
#include "item_selection.h"
+#include "mem_overlap.h"
#define HAS_INTEGER 1
@@ -704,6 +705,38 @@ prepare_index(PyArrayObject *self, PyObject *index,
/**
+ * Check if self has memory overlap with one of the index arrays, or with extra_op.
+ *
+ * @returns 1 if memory overlap found, 0 if not.
+ */
+NPY_NO_EXPORT int
+index_has_memory_overlap(PyArrayObject *self,
+ int index_type, npy_index_info *indices, int num,
+ PyObject *extra_op)
+{
+ int i;
+
+ if (index_type & (HAS_FANCY | HAS_BOOL)) {
+ for (i = 0; i < num; ++i) {
+ if (indices[i].object != NULL && PyArray_Check(indices[i].object) &&
+ solve_may_share_memory(self, (PyArrayObject *)indices[i].object,
+ 1) != 0) {
+
+ return 1;
+ }
+ }
+ }
+
+ if (extra_op != NULL && PyArray_Check(extra_op) &&
+ solve_may_share_memory(self, (PyArrayObject *)extra_op, 1) != 0) {
+ return 1;
+ }
+
+ return 0;
+}
+
+
+/**
* Get pointer for an integer index.
*
* For a purely integer index, set ptr to the memory address.
@@ -3132,19 +3165,23 @@ PyArray_MapIterNew(npy_index_info *indices , int index_num, int index_type,
/*NUMPY_API
*
- * Use advanced indexing to iterate an array. Please note
- * that most of this public API is currently not guaranteed
- * to stay the same between versions. If you plan on using
- * it, please consider adding more utility functions here
- * to accommodate new features.
+ * Same as PyArray_MapIterArray, but:
+ *
+ * If copy_if_overlap != 0, check if `a` has memory overlap with any of the
+ * arrays in `index` and with `extra_op`. If yes, make copies as appropriate
+ * to avoid problems if `a` is modified during the iteration.
+ * `iter->array` may contain a copied array (with UPDATEIFCOPY set).
*/
NPY_NO_EXPORT PyObject *
-PyArray_MapIterArray(PyArrayObject * a, PyObject * index)
+PyArray_MapIterArrayCopyIfOverlap(PyArrayObject * a, PyObject * index,
+ int copy_if_overlap, PyArrayObject *extra_op)
{
PyArrayMapIterObject * mit = NULL;
PyArrayObject *subspace = NULL;
npy_index_info indices[NPY_MAXDIMS * 2 + 1];
int i, index_num, ndim, fancy_ndim, index_type;
+ int need_copy = 0;
+ PyArrayObject *a_copy = NULL;
index_type = prepare_index(a, index, indices, &index_num,
&ndim, &fancy_ndim, 0);
@@ -3153,6 +3190,30 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index)
return NULL;
}
+ if (copy_if_overlap && index_has_memory_overlap(a, index_type, indices,
+ index_num,
+ (PyObject *)extra_op)) {
+ /* Make a copy of the input array */
+ a_copy = (PyArrayObject *)PyArray_NewLikeArray(a, NPY_ANYORDER,
+ NULL, 0);
+ if (a_copy == NULL) {
+ goto fail;
+ }
+
+ if (PyArray_CopyInto(a_copy, a) != 0) {
+ Py_DECREF(a_copy);
+ goto fail;
+ }
+
+ Py_INCREF(a);
+ if (PyArray_SetUpdateIfCopyBase(a_copy, a) < 0) {
+ Py_DECREF(a);
+ goto fail;
+ }
+
+ a = a_copy;
+ }
+
/* If it is not a pure fancy index, need to get the subspace */
if (index_type != HAS_FANCY) {
if (get_view_from_index(a, &subspace, indices, index_num, 1) < 0) {
@@ -3180,6 +3241,7 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index)
goto fail;
}
+ Py_XDECREF(a_copy);
Py_XDECREF(subspace);
PyArray_MapIterReset(mit);
@@ -3190,6 +3252,7 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index)
return (PyObject *)mit;
fail:
+ Py_XDECREF(a_copy);
Py_XDECREF(subspace);
Py_XDECREF((PyObject *)mit);
for (i=0; i < index_num; i++) {
@@ -3199,6 +3262,21 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index)
}
+/*NUMPY_API
+ *
+ * Use advanced indexing to iterate an array. Please note
+ * that most of this public API is currently not guaranteed
+ * to stay the same between versions. If you plan on using
+ * it, please consider adding more utility functions here
+ * to accommodate new features.
+ */
+NPY_NO_EXPORT PyObject *
+PyArray_MapIterArray(PyArrayObject * a, PyObject * index)
+{
+ return PyArray_MapIterArrayCopyIfOverlap(a, index, 0, NULL);
+}
+
+
#undef HAS_INTEGER
#undef HAS_NEWAXIS
#undef HAS_SLICE
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 0bef35cec..3c0c54222 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -5278,11 +5278,6 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
op1_array = (PyArrayObject *)op1;
- iter = (PyArrayMapIterObject *)PyArray_MapIterArray(op1_array, idx);
- if (iter == NULL) {
- goto fail;
- }
-
/* Create second operand from number array if needed. */
if (op2 != NULL) {
op2_array = (PyArrayObject *)PyArray_FromAny(op2, NULL,
@@ -5290,7 +5285,17 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
if (op2_array == NULL) {
goto fail;
}
+ }
+ /* Create map iterator */
+ iter = (PyArrayMapIterObject *)PyArray_MapIterArrayCopyIfOverlap(
+ op1_array, idx, 1, op2_array);
+ if (iter == NULL) {
+ goto fail;
+ }
+ op1_array = iter->array; /* May be updateifcopied on overlap */
+
+ if (op2 != NULL) {
/*
* May need to swap axes so that second operand is
* iterated over correctly
diff --git a/numpy/core/tests/test_mem_overlap.py b/numpy/core/tests/test_mem_overlap.py
index 07b1346cb..a5cb5a4f5 100644
--- a/numpy/core/tests/test_mem_overlap.py
+++ b/numpy/core/tests/test_mem_overlap.py
@@ -750,6 +750,30 @@ class TestUFunc(object):
if (c != cx).any():
assert_equal(c, cx)
+ def test_ufunc_at_manual(self):
+ def check(ufunc, a, ind, b=None):
+ a0 = a.copy()
+ if b is None:
+ ufunc.at(a0, ind.copy())
+ c1 = a0.copy()
+ ufunc.at(a, ind)
+ c2 = a.copy()
+ else:
+ ufunc.at(a0, ind.copy(), b.copy())
+ c1 = a0.copy()
+ ufunc.at(a, ind, b)
+ c2 = a.copy()
+ assert_array_equal(c1, c2)
+
+ # Overlap with index
+ a = np.arange(10000, dtype=np.int16)
+ check(np.invert, a[::-1], a)
+
+ # Overlap with second data array
+ a = np.arange(100, dtype=np.int16)
+ ind = np.arange(0, 100, 2, dtype=np.int16)
+ check(np.add, a, ind, a[25:75])
+
def test_unary_ufunc_1d_manual(self):
# Exercise branches in PyArray_EQUIVALENTLY_ITERABLE