summaryrefslogtreecommitdiff
path: root/numpy/core/src/arrayobject.c
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-08-26 07:50:53 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-08-26 07:50:53 +0000
commit5dd549836b872e7c65f2d6b69e303f953de0f488 (patch)
tree1c389c7006f47b4497e404e3cb9d68d500d8d8b5 /numpy/core/src/arrayobject.c
parentc37cfa5b256d257a339b30334392c93ae8b7d78a (diff)
downloadnumpy-5dd549836b872e7c65f2d6b69e303f953de0f488.tar.gz
Fix broadcast-copy on fancy set-item.
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r--numpy/core/src/arrayobject.c87
1 files changed, 76 insertions, 11 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 360f801e3..abc01614a 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2399,8 +2399,11 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op)
}
}
-
- if ((it = (PyArrayIterObject *)PyArray_IterNew(arr))==NULL) {
+ /* Be sure values array is "broadcastable"
+ to shape of mit->dimensions, mit->nd */
+
+ if ((it = (PyArrayIterObject *)\
+ PyArray_BroadcastToShape(arr, mit->dimensions, mit->nd))==NULL) {
Py_DECREF(arr);
return -1;
}
@@ -2408,7 +2411,6 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op)
index = mit->size;
swap = (PyArray_ISNOTSWAPPED(mit->ait->ao) != \
(PyArray_ISNOTSWAPPED(arr)));
-
copyswap = PyArray_DESCR(arr)->f->copyswap;
PyArray_MapIterReset(mit);
/* Need to decref hasobject arrays */
@@ -2421,8 +2423,6 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op)
copyswap(mit->dataptr, NULL, swap, arr);
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
- if (it->index == it->size)
- PyArray_ITER_RESET(it);
}
Py_DECREF(arr);
Py_DECREF(it);
@@ -2434,8 +2434,6 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op)
copyswap(mit->dataptr, NULL, swap, arr);
PyArray_MapIterNext(mit);
PyArray_ITER_NEXT(it);
- if (it->index == it->size)
- PyArray_ITER_RESET(it);
}
Py_DECREF(arr);
Py_DECREF(it);
@@ -2703,7 +2701,8 @@ array_subscript(PyArrayObject *self, PyObject *op)
if (oned) {
PyArrayIterObject *it;
PyObject *rval;
- it = (PyArrayIterObject *)PyArray_IterNew((PyObject *)self);
+ it = (PyArrayIterObject *)\
+ PyArray_IterNew((PyObject *)self);
if (it == NULL) {Py_DECREF(mit); return NULL;}
rval = iter_subscript(it, mit->indexobj);
Py_DECREF(it);
@@ -8592,16 +8591,74 @@ PyArray_IterNew(PyObject *obj)
nd = ao->nd;
PyArray_UpdateFlags(ao, CONTIGUOUS);
- it->contiguous = 0;
if PyArray_ISCONTIGUOUS(ao) it->contiguous = 1;
+ else it->contiguous = 0;
Py_INCREF(ao);
it->ao = ao;
it->size = PyArray_SIZE(ao);
it->nd_m1 = nd - 1;
it->factors[nd-1] = 1;
for (i=0; i < nd; i++) {
- it->dims_m1[i] = it->ao->dimensions[i] - 1;
- it->strides[i] = it->ao->strides[i];
+ it->dims_m1[i] = ao->dimensions[i] - 1;
+ it->strides[i] = ao->strides[i];
+ it->backstrides[i] = it->strides[i] * \
+ it->dims_m1[i];
+ if (i > 0)
+ it->factors[nd-i-1] = it->factors[nd-i] * \
+ ao->dimensions[nd-i];
+ }
+ PyArray_ITER_RESET(it);
+
+ return (PyObject *)it;
+}
+
+/*MULTIARRAY_API
+ Get Iterator broadcast to a particular shape
+ */
+static PyObject *
+PyArray_BroadcastToShape(PyObject *obj, intp *dims, int nd)
+{
+ PyArrayIterObject *it;
+ int i, diff, j, compat, k;
+ PyArrayObject *ao = (PyArrayObject *)obj;
+
+ if (ao->nd > nd) goto err;
+ compat = 1;
+ diff = j = nd - ao->nd;
+ for (i=0; i<ao->nd; i++, j++) {
+ if (ao->dimensions[i] == 1) continue;
+ if (ao->dimensions[i] != dims[j]) {
+ compat = 0;
+ break;
+ }
+ }
+ if (!compat) goto err;
+
+ it = (PyArrayIterObject *)_pya_malloc(sizeof(PyArrayIterObject));
+ PyObject_Init((PyObject *)it, &PyArrayIter_Type);
+
+ if (it == NULL)
+ return NULL;
+
+ PyArray_UpdateFlags(ao, CONTIGUOUS);
+ if PyArray_ISCONTIGUOUS(ao) it->contiguous = 1;
+ else it->contiguous = 0;
+ Py_INCREF(ao);
+ it->ao = ao;
+ it->size = PyArray_MultiplyList(dims, nd);
+ it->nd_m1 = nd - 1;
+ it->factors[nd-1] = 1;
+ for (i=0; i < nd; i++) {
+ it->dims_m1[i] = dims[i] - 1;
+ k = i - diff;
+ if ((k < 0) ||
+ ao->dimensions[k] != dims[i]) {
+ it->contiguous = 0;
+ it->strides[i] = 0;
+ }
+ else {
+ it->strides[i] = ao->strides[i];
+ }
it->backstrides[i] = it->strides[i] * \
it->dims_m1[i];
if (i > 0)
@@ -8611,9 +8668,17 @@ PyArray_IterNew(PyObject *obj)
PyArray_ITER_RESET(it);
return (PyObject *)it;
+
+ err:
+ PyErr_SetString(PyExc_ValueError, "array is not broadcastable to "\
+ "correct shape");
+ return NULL;
}
+
+
+
/*OBJECT_API
Get Iterator that iterates over all but one axis (don't use this with
PyArray_ITER_GOTO1D). The axis will be over-written if negative.