summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorJay Bourque <jay.bourque@continuum.io>2012-11-14 12:39:34 -0600
committerJay Bourque <jay.bourque@continuum.io>2013-08-16 16:39:30 -0500
commit7dde42ad510074aaa6273fe156dcd243b85b25fe (patch)
tree315320833c2a475f118aaa0de69ef5b461ddb06e /numpy
parentd94cd574cf1a31e621ba5384542671d1aa8056c4 (diff)
downloadnumpy-7dde42ad510074aaa6273fe156dcd243b85b25fe.tar.gz
Correct implementation of 'at' method that covers all corner cases
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/ufunc_object.c47
1 files changed, 22 insertions, 25 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index ae8c8c162..06707c805 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4924,34 +4924,35 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
}
PyArray_MapIterReset(iter);
- /* If second operand is a scalar, create 0 dim array from it */
- if (op2 != NULL && PyArray_IsAnyScalar(op2)) {
- op2_array = (PyArrayObject *)PyArray_FromAny(op2, NULL, 0, 0, 0, NULL);
+ /* If second operand exists, we need to broadcast it to
+ the shape of the indices applied to the first operand.
+ This will produce an iterator that will match up with
+ the iterator for the first operand. */
+ if (op2 != NULL) {
+ PyArray_Descr *descr = PyArray_DESCR(iter->ait->ao);
+ Py_INCREF(descr);
+ op2_array = (PyArrayObject *)PyArray_FromAny(op2, descr,
+ 0, 0, NPY_ARRAY_FORCECAST, NULL);
if (op2_array == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "could not convert scalar to array");
goto fail;
}
-
- iter2 = NULL;
- first_item[0] = 0;
- }
- /* If second operand is an array like object,
- create MapIter object for it */
- else if (op2 != NULL) {
- op2_array = (PyArrayObject *)PyArray_FromAny(op2, NULL, 0, 0, 0, NULL);
- if (op2_array == NULL) {
- goto fail;
+ if ((iter->subspace != NULL) && (iter->consec)) {
+ if (iter->iteraxes[0] > 0) { /* then we need to swap */
+ PyArray_MapIterSwapAxes(iter, &op2_array, 0);
+ if (op2_array == NULL) {
+ goto fail;
+ }
+ }
}
- int ndims = PyArray_NDIM(op1_array);
- npy_intp *dims = PyArray_DIMS(op1_array);
- iter2 = PyArray_BroadcastToShape(op2_array, iter->dimensions, iter->nd);
- if (iter2 == NULL) {
+ /* Be sure values array is "broadcastable"
+ to shape of mit->dimensions, mit->nd */
+
+ if ((iter2 = (PyArrayIterObject *)\
+ PyArray_BroadcastToShape((PyObject *)op2_array,
+ iter->dimensions, iter->nd))==NULL) {
goto fail;
}
-
- PyArray_ITER_RESET(iter2);
}
/* Create dtypes array for either one or two input operands.
@@ -4985,10 +4986,6 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
dataptr[1] = PyArray_ITER_DATA(iter2);
dataptr[2] = iter->dataptr;
}
- else if (op2_array != NULL) {
- dataptr[1] = PyArray_GetPtr(op2_array, first_item);
- dataptr[2] = iter->dataptr;
- }
else {
dataptr[1] = iter->dataptr;
dataptr[2] = NULL;