diff options
author | Jay Bourque <jay.bourque@continuum.io> | 2012-11-14 12:39:34 -0600 |
---|---|---|
committer | Jay Bourque <jay.bourque@continuum.io> | 2013-08-16 16:39:30 -0500 |
commit | 7dde42ad510074aaa6273fe156dcd243b85b25fe (patch) | |
tree | 315320833c2a475f118aaa0de69ef5b461ddb06e /numpy | |
parent | d94cd574cf1a31e621ba5384542671d1aa8056c4 (diff) | |
download | numpy-7dde42ad510074aaa6273fe156dcd243b85b25fe.tar.gz |
Correct implementation of 'at' method that covers all corner cases
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 47 |
1 files changed, 22 insertions, 25 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index ae8c8c162..06707c805 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -4924,34 +4924,35 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) } PyArray_MapIterReset(iter); - /* If second operand is a scalar, create 0 dim array from it */ - if (op2 != NULL && PyArray_IsAnyScalar(op2)) { - op2_array = (PyArrayObject *)PyArray_FromAny(op2, NULL, 0, 0, 0, NULL); + /* If second operand exists, we need to broadcast it to + the shape of the indices applied to the first operand. + This will produce an iterator that will match up with + the iterator for the first operand. */ + if (op2 != NULL) { + PyArray_Descr *descr = PyArray_DESCR(iter->ait->ao); + Py_INCREF(descr); + op2_array = (PyArrayObject *)PyArray_FromAny(op2, descr, + 0, 0, NPY_ARRAY_FORCECAST, NULL); if (op2_array == NULL) { - PyErr_SetString(PyExc_TypeError, - "could not convert scalar to array"); goto fail; } - - iter2 = NULL; - first_item[0] = 0; - } - /* If second operand is an array like object, - create MapIter object for it */ - else if (op2 != NULL) { - op2_array = (PyArrayObject *)PyArray_FromAny(op2, NULL, 0, 0, 0, NULL); - if (op2_array == NULL) { - goto fail; + if ((iter->subspace != NULL) && (iter->consec)) { + if (iter->iteraxes[0] > 0) { /* then we need to swap */ + PyArray_MapIterSwapAxes(iter, &op2_array, 0); + if (op2_array == NULL) { + goto fail; + } + } } - int ndims = PyArray_NDIM(op1_array); - npy_intp *dims = PyArray_DIMS(op1_array); - iter2 = PyArray_BroadcastToShape(op2_array, iter->dimensions, iter->nd); - if (iter2 == NULL) { + /* Be sure values array is "broadcastable" + to shape of mit->dimensions, mit->nd */ + + if ((iter2 = (PyArrayIterObject *)\ + PyArray_BroadcastToShape((PyObject *)op2_array, + iter->dimensions, iter->nd))==NULL) { goto fail; } - - PyArray_ITER_RESET(iter2); } /* Create dtypes array for either one or two input operands. @@ -4985,10 +4986,6 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) dataptr[1] = PyArray_ITER_DATA(iter2); dataptr[2] = iter->dataptr; } - else if (op2_array != NULL) { - dataptr[1] = PyArray_GetPtr(op2_array, first_item); - dataptr[2] = iter->dataptr; - } else { dataptr[1] = iter->dataptr; dataptr[2] = NULL; |