summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2011-08-07 18:43:23 -0700
committerCharles Harris <charlesr.harris@gmail.com>2011-08-27 07:26:53 -0600
commit51c9d7ca4221570a907501e68f6449051d930742 (patch)
tree9101e1465a25398b71520721e4e96641e09679d6 /numpy
parent42b9c84cfcdd27057b902ea094247866e2d741da (diff)
downloadnumpy-51c9d7ca4221570a907501e68f6449051d930742.tar.gz
ENH: core: Make the array assignment routine handle overlapping arrays
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/array_assign.c47
-rw-r--r--numpy/core/src/multiarray/array_assign.h5
-rw-r--r--numpy/core/src/multiarray/array_assign_array.c175
-rw-r--r--numpy/core/src/multiarray/ctors.c53
-rw-r--r--numpy/core/src/multiarray/ctors.h4
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c51
6 files changed, 213 insertions, 122 deletions
diff --git a/numpy/core/src/multiarray/array_assign.c b/numpy/core/src/multiarray/array_assign.c
index 6136f3821..896f32ffd 100644
--- a/numpy/core/src/multiarray/array_assign.c
+++ b/numpy/core/src/multiarray/array_assign.c
@@ -97,3 +97,50 @@ raw_array_is_aligned(int ndim, char *data, npy_intp *strides, int alignment)
}
}
+
+/* Gets a half-open range [start, end) which contains the array data */
+NPY_NO_EXPORT void
+get_array_memory_extents(PyArrayObject *arr,
+ npy_uintp *out_start, npy_uintp *out_end)
+{
+ npy_uintp start, end;
+ npy_intp idim, ndim = PyArray_NDIM(arr);
+ npy_intp *dimensions = PyArray_DIMS(arr),
+ *strides = PyArray_STRIDES(arr);
+
+ /* Calculate with a closed range [start, end] */
+ start = end = (npy_uintp)PyArray_DATA(arr);
+ for (idim = 0; idim < ndim; ++idim) {
+ npy_intp stride = strides[idim], dim = dimensions[idim];
+ /* If the array size is zero, return an empty range */
+ if (dim == 0) {
+ *out_start = *out_end = (npy_uintp)PyArray_DATA(arr);
+ return;
+ }
+ /* Expand either upwards or downwards depending on stride */
+ else {
+ if (stride > 0) {
+ end += stride*(dim-1);
+ }
+ else if (stride < 0) {
+ start += stride*(dim-1);
+ }
+ }
+ }
+
+ /* Return a half-open range */
+ *out_start = start;
+ *out_end = end + PyArray_DESCR(arr)->elsize;
+}
+
+/* Returns 1 if the arrays have overlapping data, 0 otherwise */
+NPY_NO_EXPORT int
+arrays_overlap(PyArrayObject *arr1, PyArrayObject *arr2)
+{
+ npy_uintp start1 = 0, start2 = 0, end1 = 0, end2 = 0;
+
+ get_array_memory_extents(arr1, &start1, &end1);
+ get_array_memory_extents(arr2, &start2, &end2);
+
+ return (start1 < end2) && (start2 < end1);
+}
diff --git a/numpy/core/src/multiarray/array_assign.h b/numpy/core/src/multiarray/array_assign.h
index 35cb6f6a3..0f9f613b0 100644
--- a/numpy/core/src/multiarray/array_assign.h
+++ b/numpy/core/src/multiarray/array_assign.h
@@ -140,4 +140,9 @@ broadcast_strides(int ndim, npy_intp *shape,
NPY_NO_EXPORT int
raw_array_is_aligned(int ndim, char *data, npy_intp *strides, int alignment);
+/* Returns 1 if the arrays have overlapping data, 0 otherwise */
+NPY_NO_EXPORT int
+arrays_overlap(PyArrayObject *arr1, PyArrayObject *arr2);
+
+
#endif
diff --git a/numpy/core/src/multiarray/array_assign_array.c b/numpy/core/src/multiarray/array_assign_array.c
index 00fe31229..8e428dcff 100644
--- a/numpy/core/src/multiarray/array_assign_array.c
+++ b/numpy/core/src/multiarray/array_assign_array.c
@@ -66,6 +66,18 @@ raw_array_assign_array(int ndim, npy_intp *shape,
return -1;
}
+ /*
+ * Overlap check for the 1D case. Higher dimensional arrays cause
+ * a temporary copy before getting here.
+ */
+ if (ndim == 1 && src_data < dst_data &&
+ src_data + shape_it[0] * src_strides_it[0] > dst_data) {
+ src_data += (shape_it[0] - 1) * src_strides_it[0];
+ dst_data += (shape_it[0] - 1) * dst_strides_it[0];
+ src_strides_it[0] = -src_strides_it[0];
+ dst_strides_it[0] = -dst_strides_it[0];
+ }
+
/* Get the function to do the casting */
if (PyArray_GetDTypeTransferFunction(aligned,
src_strides_it[0], dst_strides_it[0],
@@ -142,6 +154,20 @@ raw_array_wheremasked_assign_array(int ndim, npy_intp *shape,
return -1;
}
+ /*
+ * Overlap check for the 1D case. Higher dimensional arrays cause
+ * a temporary copy before getting here.
+ */
+ if (ndim == 1 && src_data < dst_data &&
+ src_data + shape_it[0] * src_strides_it[0] > dst_data) {
+ src_data += (shape_it[0] - 1) * src_strides_it[0];
+ dst_data += (shape_it[0] - 1) * dst_strides_it[0];
+ wheremask_data += (shape_it[0] - 1) * wheremask_strides_it[0];
+ src_strides_it[0] = -src_strides_it[0];
+ dst_strides_it[0] = -dst_strides_it[0];
+ wheremask_strides_it[0] = -wheremask_strides_it[0];
+ }
+
/* Get the function to do the casting */
if (PyArray_GetMaskedDTypeTransferFunction(aligned,
src_strides_it[0],
@@ -243,6 +269,22 @@ raw_array_wheremasked_assign_array_preservena(int ndim, npy_intp *shape,
return -1;
}
+ /*
+ * Overlap check for the 1D case. Higher dimensional arrays cause
+ * a temporary copy before getting here.
+ */
+ if (ndim == 1 && src_data < dst_data &&
+ src_data + shape_it[0] * src_strides_it[0] > dst_data) {
+ src_data += (shape_it[0] - 1) * src_strides_it[0];
+ dst_data += (shape_it[0] - 1) * dst_strides_it[0];
+ maskna_data += (shape_it[0] - 1) * maskna_strides_it[0];
+ wheremask_data += (shape_it[0] - 1) * wheremask_strides_it[0];
+ src_strides_it[0] = -src_strides_it[0];
+ dst_strides_it[0] = -dst_strides_it[0];
+ maskna_strides_it[0] = -maskna_strides_it[0];
+ wheremask_strides_it[0] = -wheremask_strides_it[0];
+ }
+
/* Get the function to do the casting */
if (PyArray_GetMaskedDTypeTransferFunction(aligned,
src_strides[0], dst_strides_it[0], maskna_itemsize,
@@ -329,10 +371,10 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
{
int dst_has_maskna = PyArray_HASMASKNA(dst);
int src_has_maskna = PyArray_HASMASKNA(src);
+ int copied_src = 0;
npy_intp src_strides[NPY_MAXDIMS], src_maskna_strides[NPY_MAXDIMS];
-
/* Use array_assign_scalar if 'src' NDIM is 0 */
if (PyArray_NDIM(src) == 0) {
/* If the array is masked, assign to the NA mask */
@@ -351,11 +393,44 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
wheremask, casting, preservena, preservewhichna);
}
+ /*
+ * Performance fix for expresions like "a[1000:6000] += x". In this
+ * case, first an in-place add is done, followed by an assignment,
+ * equivalently expressed like this:
+ *
+ * tmp = a[1000:6000] # Calls array_subscript_nice in mapping.c
+ * np.add(tmp, x, tmp)
+ * a[1000:6000] = tmp # Calls array_ass_sub in mapping.c
+ *
+ * In the assignment the underlying data type, shape, strides, and
+ * data pointers are identical, but src != dst because they are separately
+ * generated slices. By detecting this and skipping the redundant
+ * copy of values to themselves, we potentially give a big speed boost.
+ *
+ * Note that we don't call EquivTypes, because usually the exact same
+ * dtype object will appear, and we don't want to slow things down
+ * with a complicated comparison. The comparisons are ordered to
+ * try and reject this with as little work as possible.
+ */
+ if (PyArray_DATA(src) == PyArray_DATA(dst) &&
+ PyArray_MASKNA_DATA(src) == PyArray_MASKNA_DATA(dst) &&
+ PyArray_DESCR(src) == PyArray_DESCR(dst) &&
+ PyArray_NDIM(src) == PyArray_NDIM(dst) &&
+ PyArray_CompareLists(PyArray_DIMS(src),
+ PyArray_DIMS(dst),
+ PyArray_NDIM(src)) &&
+ PyArray_CompareLists(PyArray_STRIDES(src),
+ PyArray_STRIDES(dst),
+ PyArray_NDIM(src))) {
+ /*printf("Redundant copy operation detected\n");*/
+ return 0;
+ }
+
/* Check that 'dst' is writeable */
if (!PyArray_ISWRITEABLE(dst)) {
PyErr_SetString(PyExc_RuntimeError,
"cannot assign to a read-only array");
- return -1;
+ goto fail;
}
/* Check the casting rule */
@@ -373,13 +448,13 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyUString_FromFormat(" according to the rule %s",
npy_casting_to_string(casting)));
PyErr_SetObject(PyExc_TypeError, errmsg);
- return -1;
+ goto fail;
}
if (preservewhichna != NULL) {
PyErr_SetString(PyExc_RuntimeError,
"multi-NA support is not yet implemented");
- return -1;
+ goto fail;
}
if (src_has_maskna && !dst_has_maskna) {
@@ -388,19 +463,54 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyErr_SetString(PyExc_ValueError,
"Cannot assign NA value to an array which "
"does not support NAs");
- return -1;
+ goto fail;
}
else {
src_has_maskna = 0;
}
}
+ /*
+ * When ndim is 1, the lower-level inner loop handles copying
+ * of overlapping data. For bigger ndim, we make a temporary
+ * copy of 'src' if 'src' and 'dst' overlap.'
+ */
+ if (PyArray_NDIM(dst) > 1 && arrays_overlap(src, dst)) {
+ PyArrayObject *tmp;
+
+ /*
+ * Allocate a temporary copy array.
+ */
+ tmp = (PyArrayObject *)PyArray_NewLikeArray(dst,
+ NPY_KEEPORDER, NULL, 0);
+ if (tmp == NULL) {
+ goto fail;
+ }
+
+ /* Make the temporary copy have an NA mask if necessary */
+ if (PyArray_HASMASKNA(src)) {
+ if (PyArray_AllocateMaskNA(tmp, 1, 0, 1) < 0) {
+ Py_DECREF(tmp);
+ goto fail;
+ }
+ }
+
+ if (array_assign_array(tmp, src,
+ NULL, NPY_UNSAFE_CASTING, 0, NULL) < 0) {
+ Py_DECREF(tmp);
+ goto fail;
+ }
+
+ src = tmp;
+ copied_src = 1;
+ }
+
/* Broadcast 'src' to 'dst' for raw iteration */
if (broadcast_strides(PyArray_NDIM(dst), PyArray_DIMS(dst),
PyArray_NDIM(src), PyArray_DIMS(src),
PyArray_STRIDES(src), "input array",
src_strides) < 0) {
- return -1;
+ goto fail;
}
if (src_has_maskna) {
@@ -408,7 +518,7 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_NDIM(src), PyArray_DIMS(src),
PyArray_MASKNA_STRIDES(src), "input array",
src_maskna_strides) < 0) {
- return -1;
+ goto fail;
}
}
@@ -427,11 +537,11 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_MASKNA_DTYPE(src),
PyArray_MASKNA_DATA(src),
src_maskna_strides) < 0) {
- return -1;
+ goto fail;
}
/* Assign the values based on the 'src' NA mask */
- return raw_array_wheremasked_assign_array(
+ if (raw_array_wheremasked_assign_array(
PyArray_NDIM(dst), PyArray_DIMS(dst),
PyArray_DESCR(dst), PyArray_DATA(dst),
PyArray_STRIDES(dst),
@@ -439,12 +549,15 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
src_strides,
PyArray_MASKNA_DTYPE(src),
PyArray_MASKNA_DATA(src),
- src_maskna_strides);
+ src_maskna_strides) < 0) {
+ goto fail;
+ }
+ goto finish;
}
else {
if (PyArray_AssignMaskNA(dst, NULL, 1) < 0) {
- return -1;
+ goto fail;
}
}
}
@@ -453,7 +566,7 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
if (raw_array_assign_array(PyArray_NDIM(dst), PyArray_DIMS(dst),
PyArray_DESCR(dst), PyArray_DATA(dst), PyArray_STRIDES(dst),
PyArray_DESCR(src), PyArray_DATA(src), src_strides) < 0) {
- return -1;
+ goto fail;
}
}
/* A value assignment without overwriting NA values */
@@ -471,7 +584,7 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_MASKNA_DTYPE(dst),
PyArray_MASKNA_DATA(dst),
PyArray_MASKNA_STRIDES(dst)) < 0) {
- return -1;
+ goto fail;
}
}
@@ -487,7 +600,7 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_STRIDES(src),
PyArray_MASKNA_DTYPE(dst), PyArray_MASKNA_DATA(dst),
PyArray_MASKNA_STRIDES(dst)) < 0) {
- return -1;
+ goto fail;
}
}
}
@@ -499,14 +612,14 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyErr_SetString(PyExc_ValueError,
"Cannot assign NA value to an array which "
"does not support NAs");
- return -1;
+ goto fail;
}
else {
/* TODO: add support for this */
PyErr_SetString(PyExc_ValueError,
"A where mask with NA values is not supported "
"yet");
- return -1;
+ goto fail;
}
}
@@ -515,7 +628,7 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_NDIM(wheremask), PyArray_DIMS(wheremask),
PyArray_STRIDES(wheremask), "where mask",
wheremask_strides) < 0) {
- return -1;
+ goto fail;
}
/* A straightforward where-masked assignment */
@@ -539,14 +652,14 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_DESCR(wheremask),
PyArray_DATA(wheremask),
wheremask_strides) < 0) {
- return -1;
+ goto fail;
}
/*
* Assign the values based on the wheremask, not
* overwriting values also masked by the 'src' NA mask
*/
- return raw_array_wheremasked_assign_array_preservena(
+ if (raw_array_wheremasked_assign_array_preservena(
PyArray_NDIM(dst), PyArray_DIMS(dst),
PyArray_DESCR(dst), PyArray_DATA(dst),
PyArray_STRIDES(dst),
@@ -557,12 +670,15 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
src_maskna_strides,
PyArray_DESCR(wheremask),
PyArray_DATA(wheremask),
- wheremask_strides);
+ wheremask_strides)) {
+ goto fail;
+ }
+ goto finish;
}
else {
if (PyArray_AssignMaskNA(dst, wheremask, 1) < 0) {
- return -1;
+ goto fail;
}
}
}
@@ -574,7 +690,7 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_DESCR(src), PyArray_DATA(src), src_strides,
PyArray_DESCR(wheremask), PyArray_DATA(wheremask),
wheremask_strides) < 0) {
- return -1;
+ goto fail;
}
}
/* A masked value assignment without overwriting NA values */
@@ -598,7 +714,7 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_DESCR(wheremask),
PyArray_DATA(wheremask),
wheremask_strides) < 0) {
- return -1;
+ goto fail;
}
}
@@ -617,10 +733,21 @@ array_assign_array(PyArrayObject *dst, PyArrayObject *src,
PyArray_MASKNA_STRIDES(dst),
PyArray_DESCR(wheremask), PyArray_DATA(wheremask),
wheremask_strides) < 0) {
- return -1;
+ goto fail;
}
}
}
+finish:
+ if (copied_src) {
+ Py_DECREF(src);
+ }
return 0;
+
+fail:
+ if (copied_src) {
+ Py_DECREF(src);
+ }
+ return -1;
}
+
diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c
index 5e50166cd..cc7036cbe 100644
--- a/numpy/core/src/multiarray/ctors.c
+++ b/numpy/core/src/multiarray/ctors.c
@@ -25,6 +25,7 @@
#include "_datetime.h"
#include "datetime_strings.h"
#include "na_singleton.h"
+#include "array_assign.h"
/*
* Reading from a file or a string.
@@ -378,53 +379,6 @@ copy_and_swap(void *dst, void *src, int itemsize, npy_intp numitems,
}
}
-/* Gets a half-open range [start, end) which contains the array data */
-NPY_NO_EXPORT void
-_get_array_memory_extents(PyArrayObject *arr,
- npy_uintp *out_start, npy_uintp *out_end)
-{
- npy_uintp start, end;
- npy_intp idim, ndim = PyArray_NDIM(arr);
- npy_intp *dimensions = PyArray_DIMS(arr),
- *strides = PyArray_STRIDES(arr);
-
- /* Calculate with a closed range [start, end] */
- start = end = (npy_uintp)PyArray_DATA(arr);
- for (idim = 0; idim < ndim; ++idim) {
- npy_intp stride = strides[idim], dim = dimensions[idim];
- /* If the array size is zero, return an empty range */
- if (dim == 0) {
- *out_start = *out_end = (npy_uintp)PyArray_DATA(arr);
- return;
- }
- /* Expand either upwards or downwards depending on stride */
- else {
- if (stride > 0) {
- end += stride*(dim-1);
- }
- else if (stride < 0) {
- start += stride*(dim-1);
- }
- }
- }
-
- /* Return a half-open range */
- *out_start = start;
- *out_end = end + PyArray_DESCR(arr)->elsize;
-}
-
-/* Returns 1 if the arrays have overlapping data, 0 otherwise */
-NPY_NO_EXPORT int
-_arrays_overlap(PyArrayObject *arr1, PyArrayObject *arr2)
-{
- npy_uintp start1 = 0, start2 = 0, end1 = 0, end2 = 0;
-
- _get_array_memory_extents(arr1, &start1, &end1);
- _get_array_memory_extents(arr2, &start2, &end2);
-
- return (start1 < end2) && (start2 < end1);
-}
-
/*NUMPY_API
* Move the memory of one array into another, allowing for overlapping data.
*
@@ -460,6 +414,7 @@ PyArray_MoveInto(PyArrayObject *dst, PyArrayObject *src)
* try and reject this with as little work as possible.
*/
if (PyArray_DATA(src) == PyArray_DATA(dst) &&
+ PyArray_MASKNA_DATA(src) == PyArray_MASKNA_DATA(dst) &&
PyArray_DESCR(src) == PyArray_DESCR(dst) &&
PyArray_NDIM(src) == PyArray_NDIM(dst) &&
PyArray_CompareLists(PyArray_DIMS(src),
@@ -486,7 +441,7 @@ PyArray_MoveInto(PyArrayObject *dst, PyArrayObject *src)
PyArray_NDIM(src) == 1 &&
PyArray_STRIDE(dst, 0) > 0 &&
PyArray_STRIDE(src, 0) > 0) ||
- !_arrays_overlap(dst, src)) {
+ !arrays_overlap(dst, src)) {
return PyArray_CopyInto(dst, src);
}
else {
@@ -580,7 +535,7 @@ PyArray_MaskedMoveInto(PyArrayObject *dst, PyArrayObject *src,
PyArray_NDIM(src) == 1 &&
PyArray_STRIDE(dst, 0) > 0 &&
PyArray_STRIDE(src, 0) > 0) ||
- !_arrays_overlap(dst, src)) {
+ !arrays_overlap(dst, src)) {
return PyArray_MaskedCopyInto(dst, src, mask, casting);
}
else {
diff --git a/numpy/core/src/multiarray/ctors.h b/numpy/core/src/multiarray/ctors.h
index 3745c439c..d0a1ba687 100644
--- a/numpy/core/src/multiarray/ctors.h
+++ b/numpy/core/src/multiarray/ctors.h
@@ -88,4 +88,8 @@ PyArray_GetArrayParamsFromObjectEx(PyObject *op,
int *out_contains_na,
PyArrayObject **out_arr, PyObject *context);
+/* Returns 1 if the arrays have overlapping data, 0 otherwise */
+NPY_NO_EXPORT int
+_arrays_overlap(PyArrayObject *arr1, PyArrayObject *arr2);
+
#endif
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 018c7c342..97248ca38 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -1730,58 +1730,11 @@ array_copyto(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
}
}
- /* Special case scalar assignment */
- if (PyArray_NDIM(src) == 0 && !PyArray_ContainsNA(src)) {
- if (array_assign_scalar(dst, PyArray_DESCR(src), PyArray_DATA(src),
- wheremask, casting,
- preservena, NULL) < 0) {
- goto fail;
- }
- else {
- goto finish;
- }
- }
-
- if (preservena) {
- PyErr_SetString(PyExc_RuntimeError,
- "This case of copyto doesn't support preservena=True yet");
+ if (array_assign_array(dst, src,
+ wheremask, casting, preservena, NULL) < 0) {
goto fail;
}
- if (wheremask != NULL) {
- /* Use the 'move' function which handles overlapping */
- if (PyArray_MaskedMoveInto(dst, src, wheremask, casting) < 0) {
- goto fail;
- }
- }
- else {
- /*
- * MoveInto doesn't accept a casting rule, must check it
- * ourselves.
- */
- if (!PyArray_CanCastArrayTo(src, PyArray_DESCR(dst), casting)) {
- PyObject *errmsg;
- errmsg = PyUString_FromString("Cannot cast array data from ");
- PyUString_ConcatAndDel(&errmsg,
- PyObject_Repr((PyObject *)PyArray_DESCR(src)));
- PyUString_ConcatAndDel(&errmsg,
- PyUString_FromString(" to "));
- PyUString_ConcatAndDel(&errmsg,
- PyObject_Repr((PyObject *)PyArray_DESCR(dst)));
- PyUString_ConcatAndDel(&errmsg,
- PyUString_FromFormat(" according to the rule %s",
- npy_casting_to_string(casting)));
- PyErr_SetObject(PyExc_TypeError, errmsg);
- goto fail;
- }
-
- /* Use the 'move' function which handles overlapping */
- if (PyArray_MoveInto(dst, src) < 0) {
- goto fail;
- }
- }
-
-finish:
Py_XDECREF(src);
Py_XDECREF(wheremask);