diff options
Diffstat (limited to 'numpy/core/src/multiarraymodule.c')
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 81 |
1 files changed, 58 insertions, 23 deletions
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 7f57c3632..ac3df1e15 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -216,11 +216,9 @@ static PyObject * PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out) { PyObject *f, *ret=NULL, *tmp, *op1, *op2; - if (out && (!PyArray_SAMESHAPE(out, a) || - !PyArray_EquivTypes(a->descr, out->descr))) { + if (out && (PyArray_SIZE(out) != PyArray_SIZE(a))) { PyErr_SetString(PyExc_ValueError, - "output array must have the same shape" - "and type"); + "invalid output shape"); return NULL; } if (PyArray_ISCOMPLEX(a)) { @@ -266,7 +264,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out) if (decimals >= 0) { if (PyArray_ISINTEGER(a)) { if (out) { - if (PyArray_CopyInto(out, a) < 0) return NULL; + if (PyArray_CopyAnyInto(out, a) < 0) return NULL; Py_INCREF(out); return (PyObject *)out; } @@ -897,7 +895,7 @@ PyArray_Nonzero(PyArrayObject *self) Clip */ static PyObject * -PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max) +PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max, PyArrayObject *out) { PyObject *selector=NULL, *newtup=NULL, *ret=NULL; PyObject *res1=NULL, *res2=NULL, *res3=NULL; @@ -924,7 +922,7 @@ PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max) newtup = Py_BuildValue("(OOO)", (PyObject *)self, min, max); if (newtup == NULL) {Py_DECREF(selector); return NULL;} - ret = PyArray_Choose((PyAO *)selector, newtup, NULL, NPY_RAISE); + ret = PyArray_Choose((PyAO *)selector, newtup, out, NPY_RAISE); Py_DECREF(selector); Py_DECREF(newtup); return ret; @@ -934,14 +932,14 @@ PyArray_Clip(PyArrayObject *self, PyObject *min, PyObject *max) Conjugate */ static PyObject * -PyArray_Conjugate(PyArrayObject *self) +PyArray_Conjugate(PyArrayObject *self, PyArrayObject *out) { if (PyArray_ISCOMPLEX(self)) { PyObject *new; intp size, i; /* Make a copy */ - new = PyArray_NewCopy(self, -1); - if (new==NULL) return NULL; + new = PyArray_NewCopy(self, -1); + if (new==NULL) return NULL; size = PyArray_SIZE(new); if (self->descr->type_num == PyArray_CFLOAT) { cfloat *dptr = (cfloat *) PyArray_DATA(new); @@ -964,11 +962,25 @@ PyArray_Conjugate(PyArrayObject *self) dptr++; } } + if (out) { + if (PyArray_CopyAnyInto(out, (PyArrayObject *)new)<0) + return NULL; + Py_INCREF(out); + Py_DECREF(new); + return (PyObject *)out; + } return new; } else { - Py_INCREF(self); - return (PyObject *) self; + PyArrayObject *ret; + if (out) { + if (PyArray_CopyAnyInto(out, self)< 0) + return NULL; + ret = out; + } + else ret = self; + Py_INCREF(ret); + return (PyObject *)ret; } } @@ -1836,7 +1848,7 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret, PyArrayObject *obj; int flags = NPY_CARRAY | NPY_UPDATEIFCOPY; - if (!PyArray_SAMESHAPE(ap, ret)) { + if (PyArray_SIZE(ret) != PyArray_SIZE(ap)) { PyErr_SetString(PyExc_TypeError, "invalid shape for output array."); ret = NULL; @@ -3018,7 +3030,7 @@ PyArray_Correlate(PyObject *op1, PyObject *op2, int mode) ArgMin */ static PyObject * -PyArray_ArgMin(PyArrayObject *ap, int axis) +PyArray_ArgMin(PyArrayObject *ap, int axis, PyArrayObject *out) { PyObject *obj, *new, *ret; @@ -3039,7 +3051,7 @@ PyArray_ArgMin(PyArrayObject *ap, int axis) new = PyArray_EnsureAnyArray(PyNumber_Subtract(obj, (PyObject *)ap)); Py_DECREF(obj); if (new == NULL) return NULL; - ret = PyArray_ArgMax((PyArrayObject *)new, axis); + ret = PyArray_ArgMax((PyArrayObject *)new, axis, out); Py_DECREF(new); return ret; } @@ -3117,7 +3129,7 @@ PyArray_Ptp(PyArrayObject *ap, int axis, PyArrayObject *out) ArgMax */ static PyObject * -PyArray_ArgMax(PyArrayObject *op, int axis) +PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out) { PyArrayObject *ap=NULL, *rp=NULL; PyArray_ArgFunc* arg_func; @@ -3125,6 +3137,7 @@ PyArray_ArgMax(PyArrayObject *op, int axis) intp *rptr; intp i, n, m; int elsize; + int copyret=0; NPY_BEGIN_THREADS_DEF @@ -3163,13 +3176,6 @@ PyArray_ArgMax(PyArrayObject *op, int axis) goto fail; } - rp = (PyArrayObject *)PyArray_New(ap->ob_type, ap->nd-1, - ap->dimensions, PyArray_INTP, - NULL, NULL, 0, 0, - (PyObject *)ap); - if (rp == NULL) goto fail; - - elsize = ap->descr->elsize; m = ap->dimensions[ap->nd-1]; if (m == 0) { @@ -3178,6 +3184,28 @@ PyArray_ArgMax(PyArrayObject *op, int axis) "of an empty sequence??"); goto fail; } + + if (!out) { + rp = (PyArrayObject *)PyArray_New(ap->ob_type, ap->nd-1, + ap->dimensions, PyArray_INTP, + NULL, NULL, 0, 0, + (PyObject *)ap); + if (rp == NULL) goto fail; + } + else { + if (PyArray_SIZE(out) != \ + PyArray_MultiplyList(ap->dimensions, ap->nd-1)) { + PyErr_SetString(PyExc_TypeError, + "invalid shape for output array."); + } + rp = (PyArrayObject *)\ + PyArray_FromArray(out, + PyArray_DescrFromType(PyArray_INTP), + NPY_CARRAY | NPY_UPDATEIFCOPY); + if (rp == NULL) goto fail; + if (rp != out) copyret = 1; + } + NPY_BEGIN_THREADS_DESCR(ap->descr) n = PyArray_SIZE(ap)/m; rptr = (intp *)rp->data; @@ -3188,6 +3216,13 @@ PyArray_ArgMax(PyArrayObject *op, int axis) NPY_END_THREADS_DESCR(ap->descr) Py_DECREF(ap); + if (copyret) { + PyArrayObject *obj; + obj = (PyArrayObject *)rp->base; + Py_INCREF(obj); + Py_DECREF(rp); + rp = obj; + } return (PyObject *)rp; fail: |