diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-26 07:50:53 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-26 07:50:53 +0000 |
commit | 5dd549836b872e7c65f2d6b69e303f953de0f488 (patch) | |
tree | 1c389c7006f47b4497e404e3cb9d68d500d8d8b5 /numpy/core/src/arrayobject.c | |
parent | c37cfa5b256d257a339b30334392c93ae8b7d78a (diff) | |
download | numpy-5dd549836b872e7c65f2d6b69e303f953de0f488.tar.gz |
Fix broadcast-copy on fancy set-item.
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r-- | numpy/core/src/arrayobject.c | 87 |
1 files changed, 76 insertions, 11 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 360f801e3..abc01614a 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2399,8 +2399,11 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op) } } - - if ((it = (PyArrayIterObject *)PyArray_IterNew(arr))==NULL) { + /* Be sure values array is "broadcastable" + to shape of mit->dimensions, mit->nd */ + + if ((it = (PyArrayIterObject *)\ + PyArray_BroadcastToShape(arr, mit->dimensions, mit->nd))==NULL) { Py_DECREF(arr); return -1; } @@ -2408,7 +2411,6 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op) index = mit->size; swap = (PyArray_ISNOTSWAPPED(mit->ait->ao) != \ (PyArray_ISNOTSWAPPED(arr))); - copyswap = PyArray_DESCR(arr)->f->copyswap; PyArray_MapIterReset(mit); /* Need to decref hasobject arrays */ @@ -2421,8 +2423,6 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op) copyswap(mit->dataptr, NULL, swap, arr); PyArray_MapIterNext(mit); PyArray_ITER_NEXT(it); - if (it->index == it->size) - PyArray_ITER_RESET(it); } Py_DECREF(arr); Py_DECREF(it); @@ -2434,8 +2434,6 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op) copyswap(mit->dataptr, NULL, swap, arr); PyArray_MapIterNext(mit); PyArray_ITER_NEXT(it); - if (it->index == it->size) - PyArray_ITER_RESET(it); } Py_DECREF(arr); Py_DECREF(it); @@ -2703,7 +2701,8 @@ array_subscript(PyArrayObject *self, PyObject *op) if (oned) { PyArrayIterObject *it; PyObject *rval; - it = (PyArrayIterObject *)PyArray_IterNew((PyObject *)self); + it = (PyArrayIterObject *)\ + PyArray_IterNew((PyObject *)self); if (it == NULL) {Py_DECREF(mit); return NULL;} rval = iter_subscript(it, mit->indexobj); Py_DECREF(it); @@ -8592,16 +8591,74 @@ PyArray_IterNew(PyObject *obj) nd = ao->nd; PyArray_UpdateFlags(ao, CONTIGUOUS); - it->contiguous = 0; if PyArray_ISCONTIGUOUS(ao) it->contiguous = 1; + else it->contiguous = 0; Py_INCREF(ao); it->ao = ao; it->size = PyArray_SIZE(ao); it->nd_m1 = nd - 1; it->factors[nd-1] = 1; for (i=0; i < nd; i++) { - it->dims_m1[i] = it->ao->dimensions[i] - 1; - it->strides[i] = it->ao->strides[i]; + it->dims_m1[i] = ao->dimensions[i] - 1; + it->strides[i] = ao->strides[i]; + it->backstrides[i] = it->strides[i] * \ + it->dims_m1[i]; + if (i > 0) + it->factors[nd-i-1] = it->factors[nd-i] * \ + ao->dimensions[nd-i]; + } + PyArray_ITER_RESET(it); + + return (PyObject *)it; +} + +/*MULTIARRAY_API + Get Iterator broadcast to a particular shape + */ +static PyObject * +PyArray_BroadcastToShape(PyObject *obj, intp *dims, int nd) +{ + PyArrayIterObject *it; + int i, diff, j, compat, k; + PyArrayObject *ao = (PyArrayObject *)obj; + + if (ao->nd > nd) goto err; + compat = 1; + diff = j = nd - ao->nd; + for (i=0; i<ao->nd; i++, j++) { + if (ao->dimensions[i] == 1) continue; + if (ao->dimensions[i] != dims[j]) { + compat = 0; + break; + } + } + if (!compat) goto err; + + it = (PyArrayIterObject *)_pya_malloc(sizeof(PyArrayIterObject)); + PyObject_Init((PyObject *)it, &PyArrayIter_Type); + + if (it == NULL) + return NULL; + + PyArray_UpdateFlags(ao, CONTIGUOUS); + if PyArray_ISCONTIGUOUS(ao) it->contiguous = 1; + else it->contiguous = 0; + Py_INCREF(ao); + it->ao = ao; + it->size = PyArray_MultiplyList(dims, nd); + it->nd_m1 = nd - 1; + it->factors[nd-1] = 1; + for (i=0; i < nd; i++) { + it->dims_m1[i] = dims[i] - 1; + k = i - diff; + if ((k < 0) || + ao->dimensions[k] != dims[i]) { + it->contiguous = 0; + it->strides[i] = 0; + } + else { + it->strides[i] = ao->strides[i]; + } it->backstrides[i] = it->strides[i] * \ it->dims_m1[i]; if (i > 0) @@ -8611,9 +8668,17 @@ PyArray_IterNew(PyObject *obj) PyArray_ITER_RESET(it); return (PyObject *)it; + + err: + PyErr_SetString(PyExc_ValueError, "array is not broadcastable to "\ + "correct shape"); + return NULL; } + + + /*OBJECT_API Get Iterator that iterates over all but one axis (don't use this with PyArray_ITER_GOTO1D). The axis will be over-written if negative. |