summaryrefslogtreecommitdiff
path: root/numpy/core/src/multiarraymodule.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/multiarraymodule.c')
-rw-r--r--numpy/core/src/multiarraymodule.c63
1 files changed, 26 insertions, 37 deletions
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index 022d6f1c4..e3e84fda2 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -2326,50 +2326,40 @@ static PyObject *
PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
NPY_CLIPMODE clipmode)
{
- intp *sizes, offset;
int n, elsize;
intp i, m;
char *ret_data;
PyArrayObject **mps, *ap;
- intp *self_data, mi;
+ PyArrayMultiIterObject *multi=NULL;
+ intp mi;
int copyret=0;
ap = NULL;
/* Convert all inputs to arrays of a common type */
+ /* Also makes them C-contiguous */
mps = PyArray_ConvertToCommonType(op, &n);
if (mps == NULL) return NULL;
- sizes = (intp *)_pya_malloc(n*sizeof(intp));
- if (sizes == NULL) goto fail;
-
- ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ip,
- PyArray_INTP,
- 0, 0);
- if (ap == NULL) goto fail;
-
- /* Check the dimensions of the arrays */
for(i=0; i<n; i++) {
if (mps[i] == NULL) goto fail;
- if (ap->nd < mps[i]->nd) {
- PyErr_SetString(PyExc_ValueError,
- "too many dimensions");
- goto fail;
- }
- if (!PyArray_CompareLists(ap->dimensions+(ap->nd-mps[i]->nd),
- mps[i]->dimensions, mps[i]->nd)) {
- PyErr_SetString(PyExc_ValueError,
- "array dimensions must agree");
- goto fail;
- }
- sizes[i] = PyArray_NBYTES(mps[i]);
}
+ ap = (PyArrayObject *)PyArray_FROM_OT((PyObject *)ip, NPY_INTP);
+
+ if (ap == NULL) goto fail;
+
+ /* Broadcast all arrays to each other, index array at the end. */
+ multi = (PyArrayMultiIterObject *)\
+ PyArray_MultiIterFromObjects((PyObject **)mps, n, 1, ap);
+ if (multi == NULL) goto fail;
+
+ /* Set-up return array */
if (!ret) {
Py_INCREF(mps[0]->descr);
ret = (PyArrayObject *)PyArray_NewFromDescr(ap->ob_type,
mps[0]->descr,
- ap->nd,
- ap->dimensions,
+ multi->nd,
+ multi->dimensions,
NULL, NULL, 0,
(PyObject *)ap);
}
@@ -2377,8 +2367,10 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
PyArrayObject *obj;
int flags = NPY_CARRAY | NPY_UPDATEIFCOPY | NPY_FORCECAST;
- if (PyArray_SIZE(ret) != PyArray_SIZE(ap)) {
- PyErr_SetString(PyExc_TypeError,
+ if ((PyArray_NDIM(ret) != multi->nd) ||
+ !PyArray_CompareLists(PyArray_DIMS(ret), multi->dimensions,
+ multi->nd)) {
+ PyErr_SetString(PyExc_TypeError,
"invalid shape for output array.");
ret = NULL;
goto fail;
@@ -2399,12 +2391,10 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
if (ret == NULL) goto fail;
elsize = ret->descr->elsize;
- m = PyArray_SIZE(ret);
- self_data = (intp *)ap->data;
ret_data = ret->data;
- for (i=0; i<m; i++) {
- mi = *self_data;
+ while (PyArray_MultiIter_NOTDONE(multi)) {
+ mi = *((intp *)PyArray_MultiIter_DATA(multi, n));
if (mi < 0 || mi >= n) {
switch(clipmode) {
case NPY_RAISE:
@@ -2426,17 +2416,16 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
break;
}
}
- offset = i*elsize;
- if (offset >= sizes[mi]) {offset = offset % sizes[mi]; }
- memmove(ret_data, mps[mi]->data+offset, elsize);
- ret_data += elsize; self_data++;
+ memmove(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize);
+ ret_data += elsize;
+ PyArray_MultiIter_NEXT(multi);
}
PyArray_INCREF(ret);
+ Py_DECREF(multi);
for(i=0; i<n; i++) Py_XDECREF(mps[i]);
Py_DECREF(ap);
PyDataMem_FREE(mps);
- _pya_free(sizes);
if (copyret) {
PyObject *obj;
obj = ret->base;
@@ -2447,10 +2436,10 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
return (PyObject *)ret;
fail:
+ Py_XDECREF(multi);
for(i=0; i<n; i++) Py_XDECREF(mps[i]);
Py_XDECREF(ap);
PyDataMem_FREE(mps);
- _pya_free(sizes);
PyArray_XDECREF_ERR(ret);
return NULL;
}