From a0e082a087e9667e3805d3be859958a292e8f336 Mon Sep 17 00:00:00 2001 From: Travis Oliphant Date: Fri, 3 Oct 2008 15:55:52 +0000 Subject: Fix ticket #925 --- numpy/core/src/multiarraymodule.c | 63 ++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 37 deletions(-) (limited to 'numpy/core/src/multiarraymodule.c') diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 022d6f1c4..e3e84fda2 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -2326,50 +2326,40 @@ static PyObject * PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret, NPY_CLIPMODE clipmode) { - intp *sizes, offset; int n, elsize; intp i, m; char *ret_data; PyArrayObject **mps, *ap; - intp *self_data, mi; + PyArrayMultiIterObject *multi=NULL; + intp mi; int copyret=0; ap = NULL; /* Convert all inputs to arrays of a common type */ + /* Also makes them C-contiguous */ mps = PyArray_ConvertToCommonType(op, &n); if (mps == NULL) return NULL; - sizes = (intp *)_pya_malloc(n*sizeof(intp)); - if (sizes == NULL) goto fail; - - ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ip, - PyArray_INTP, - 0, 0); - if (ap == NULL) goto fail; - - /* Check the dimensions of the arrays */ for(i=0; ind < mps[i]->nd) { - PyErr_SetString(PyExc_ValueError, - "too many dimensions"); - goto fail; - } - if (!PyArray_CompareLists(ap->dimensions+(ap->nd-mps[i]->nd), - mps[i]->dimensions, mps[i]->nd)) { - PyErr_SetString(PyExc_ValueError, - "array dimensions must agree"); - goto fail; - } - sizes[i] = PyArray_NBYTES(mps[i]); } + ap = (PyArrayObject *)PyArray_FROM_OT((PyObject *)ip, NPY_INTP); + + if (ap == NULL) goto fail; + + /* Broadcast all arrays to each other, index array at the end. */ + multi = (PyArrayMultiIterObject *)\ + PyArray_MultiIterFromObjects((PyObject **)mps, n, 1, ap); + if (multi == NULL) goto fail; + + /* Set-up return array */ if (!ret) { Py_INCREF(mps[0]->descr); ret = (PyArrayObject *)PyArray_NewFromDescr(ap->ob_type, mps[0]->descr, - ap->nd, - ap->dimensions, + multi->nd, + multi->dimensions, NULL, NULL, 0, (PyObject *)ap); } @@ -2377,8 +2367,10 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret, PyArrayObject *obj; int flags = NPY_CARRAY | NPY_UPDATEIFCOPY | NPY_FORCECAST; - if (PyArray_SIZE(ret) != PyArray_SIZE(ap)) { - PyErr_SetString(PyExc_TypeError, + if ((PyArray_NDIM(ret) != multi->nd) || + !PyArray_CompareLists(PyArray_DIMS(ret), multi->dimensions, + multi->nd)) { + PyErr_SetString(PyExc_TypeError, "invalid shape for output array."); ret = NULL; goto fail; @@ -2399,12 +2391,10 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret, if (ret == NULL) goto fail; elsize = ret->descr->elsize; - m = PyArray_SIZE(ret); - self_data = (intp *)ap->data; ret_data = ret->data; - for (i=0; i= n) { switch(clipmode) { case NPY_RAISE: @@ -2426,17 +2416,16 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret, break; } } - offset = i*elsize; - if (offset >= sizes[mi]) {offset = offset % sizes[mi]; } - memmove(ret_data, mps[mi]->data+offset, elsize); - ret_data += elsize; self_data++; + memmove(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize); + ret_data += elsize; + PyArray_MultiIter_NEXT(multi); } PyArray_INCREF(ret); + Py_DECREF(multi); for(i=0; ibase; @@ -2447,10 +2436,10 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret, return (PyObject *)ret; fail: + Py_XDECREF(multi); for(i=0; i