diff options
author | Pauli Virtanen <pav@iki.fi> | 2016-09-11 14:58:40 +0200 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2017-01-19 22:12:47 +0100 |
commit | 0bff7b30466b26963cf4fc1b280eb207b74e9851 (patch) | |
tree | 180623c77ce0230c80a8acb9fe40a3248a3ce8f7 | |
parent | dae0b12d6a790543bee4002f434a1633f8923188 (diff) | |
download | numpy-0bff7b30466b26963cf4fc1b280eb207b74e9851.tar.gz |
ENH: core: handle memory overlap in ufunc.at
This adds a new method PyArray_MapIterArrayCopyIfOverlap to the API.
-rw-r--r-- | numpy/core/code_generators/numpy_api.py | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/mapping.c | 90 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 15 | ||||
-rw-r--r-- | numpy/core/tests/test_mem_overlap.py | 24 |
4 files changed, 120 insertions, 11 deletions
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py index 972966627..7d107f2ac 100644 --- a/numpy/core/code_generators/numpy_api.py +++ b/numpy/core/code_generators/numpy_api.py @@ -344,6 +344,8 @@ multiarray_funcs_api = { # End 1.9 API 'PyArray_CheckAnyScalarExact': (300, NonNull(1)), # End 1.10 API + # End 1.11 API + 'PyArray_MapIterArrayCopyIfOverlap': (301,), } ufunc_types_api = { diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c index 50f516a29..b97cc185e 100644 --- a/numpy/core/src/multiarray/mapping.c +++ b/numpy/core/src/multiarray/mapping.c @@ -19,6 +19,7 @@ #include "mapping.h" #include "lowlevel_strided_loops.h" #include "item_selection.h" +#include "mem_overlap.h" #define HAS_INTEGER 1 @@ -704,6 +705,38 @@ prepare_index(PyArrayObject *self, PyObject *index, /** + * Check if self has memory overlap with one of the index arrays, or with extra_op. + * + * @returns 1 if memory overlap found, 0 if not. + */ +NPY_NO_EXPORT int +index_has_memory_overlap(PyArrayObject *self, + int index_type, npy_index_info *indices, int num, + PyObject *extra_op) +{ + int i; + + if (index_type & (HAS_FANCY | HAS_BOOL)) { + for (i = 0; i < num; ++i) { + if (indices[i].object != NULL && PyArray_Check(indices[i].object) && + solve_may_share_memory(self, (PyArrayObject *)indices[i].object, + 1) != 0) { + + return 1; + } + } + } + + if (extra_op != NULL && PyArray_Check(extra_op) && + solve_may_share_memory(self, (PyArrayObject *)extra_op, 1) != 0) { + return 1; + } + + return 0; +} + + +/** * Get pointer for an integer index. * * For a purely integer index, set ptr to the memory address. @@ -3132,19 +3165,23 @@ PyArray_MapIterNew(npy_index_info *indices , int index_num, int index_type, /*NUMPY_API * - * Use advanced indexing to iterate an array. Please note - * that most of this public API is currently not guaranteed - * to stay the same between versions. If you plan on using - * it, please consider adding more utility functions here - * to accommodate new features. + * Same as PyArray_MapIterArray, but: + * + * If copy_if_overlap != 0, check if `a` has memory overlap with any of the + * arrays in `index` and with `extra_op`. If yes, make copies as appropriate + * to avoid problems if `a` is modified during the iteration. + * `iter->array` may contain a copied array (with UPDATEIFCOPY set). */ NPY_NO_EXPORT PyObject * -PyArray_MapIterArray(PyArrayObject * a, PyObject * index) +PyArray_MapIterArrayCopyIfOverlap(PyArrayObject * a, PyObject * index, + int copy_if_overlap, PyArrayObject *extra_op) { PyArrayMapIterObject * mit = NULL; PyArrayObject *subspace = NULL; npy_index_info indices[NPY_MAXDIMS * 2 + 1]; int i, index_num, ndim, fancy_ndim, index_type; + int need_copy = 0; + PyArrayObject *a_copy = NULL; index_type = prepare_index(a, index, indices, &index_num, &ndim, &fancy_ndim, 0); @@ -3153,6 +3190,30 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index) return NULL; } + if (copy_if_overlap && index_has_memory_overlap(a, index_type, indices, + index_num, + (PyObject *)extra_op)) { + /* Make a copy of the input array */ + a_copy = (PyArrayObject *)PyArray_NewLikeArray(a, NPY_ANYORDER, + NULL, 0); + if (a_copy == NULL) { + goto fail; + } + + if (PyArray_CopyInto(a_copy, a) != 0) { + Py_DECREF(a_copy); + goto fail; + } + + Py_INCREF(a); + if (PyArray_SetUpdateIfCopyBase(a_copy, a) < 0) { + Py_DECREF(a); + goto fail; + } + + a = a_copy; + } + /* If it is not a pure fancy index, need to get the subspace */ if (index_type != HAS_FANCY) { if (get_view_from_index(a, &subspace, indices, index_num, 1) < 0) { @@ -3180,6 +3241,7 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index) goto fail; } + Py_XDECREF(a_copy); Py_XDECREF(subspace); PyArray_MapIterReset(mit); @@ -3190,6 +3252,7 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index) return (PyObject *)mit; fail: + Py_XDECREF(a_copy); Py_XDECREF(subspace); Py_XDECREF((PyObject *)mit); for (i=0; i < index_num; i++) { @@ -3199,6 +3262,21 @@ PyArray_MapIterArray(PyArrayObject * a, PyObject * index) } +/*NUMPY_API + * + * Use advanced indexing to iterate an array. Please note + * that most of this public API is currently not guaranteed + * to stay the same between versions. If you plan on using + * it, please consider adding more utility functions here + * to accommodate new features. + */ +NPY_NO_EXPORT PyObject * +PyArray_MapIterArray(PyArrayObject * a, PyObject * index) +{ + return PyArray_MapIterArrayCopyIfOverlap(a, index, 0, NULL); +} + + #undef HAS_INTEGER #undef HAS_NEWAXIS #undef HAS_SLICE diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 0bef35cec..3c0c54222 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -5278,11 +5278,6 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) op1_array = (PyArrayObject *)op1; - iter = (PyArrayMapIterObject *)PyArray_MapIterArray(op1_array, idx); - if (iter == NULL) { - goto fail; - } - /* Create second operand from number array if needed. */ if (op2 != NULL) { op2_array = (PyArrayObject *)PyArray_FromAny(op2, NULL, @@ -5290,7 +5285,17 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) if (op2_array == NULL) { goto fail; } + } + /* Create map iterator */ + iter = (PyArrayMapIterObject *)PyArray_MapIterArrayCopyIfOverlap( + op1_array, idx, 1, op2_array); + if (iter == NULL) { + goto fail; + } + op1_array = iter->array; /* May be updateifcopied on overlap */ + + if (op2 != NULL) { /* * May need to swap axes so that second operand is * iterated over correctly diff --git a/numpy/core/tests/test_mem_overlap.py b/numpy/core/tests/test_mem_overlap.py index 07b1346cb..a5cb5a4f5 100644 --- a/numpy/core/tests/test_mem_overlap.py +++ b/numpy/core/tests/test_mem_overlap.py @@ -750,6 +750,30 @@ class TestUFunc(object): if (c != cx).any(): assert_equal(c, cx) + def test_ufunc_at_manual(self): + def check(ufunc, a, ind, b=None): + a0 = a.copy() + if b is None: + ufunc.at(a0, ind.copy()) + c1 = a0.copy() + ufunc.at(a, ind) + c2 = a.copy() + else: + ufunc.at(a0, ind.copy(), b.copy()) + c1 = a0.copy() + ufunc.at(a, ind, b) + c2 = a.copy() + assert_array_equal(c1, c2) + + # Overlap with index + a = np.arange(10000, dtype=np.int16) + check(np.invert, a[::-1], a) + + # Overlap with second data array + a = np.arange(100, dtype=np.int16) + ind = np.arange(0, 100, 2, dtype=np.int16) + check(np.add, a, ind, a[25:75]) + def test_unary_ufunc_1d_manual(self): # Exercise branches in PyArray_EQUIVALENTLY_ITERABLE |