diff options
| -rw-r--r-- | numpy/core/src/arrayobject.c | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index cab2484d8..69109475c 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -897,7 +897,13 @@ _broadcast_copy(PyArrayObject *dest, PyArrayObject *src, multi = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, dest, src); if (multi == NULL) return -1; maxaxis = PyArray_RemoveLargest(multi); - if (maxaxis < 0) return -1; + if (maxaxis < 0) { /* copy 1 0-d array to another */ + PyArray_XDECREF(dest); + memcpy(dest->data, src->data, elsize); + if (swap) byte_swap_vector(dest->data, 1, elsize); + PyArray_INCREF(dest); + return 0; + } maxdim = multi->dimensions[maxaxis]; PyArray_XDECREF(dest); @@ -6866,7 +6872,7 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in, { int delsize, selsize, maxaxis, i, N; PyArrayMultiIterObject *multi; - intp maxdim; + intp maxdim, ostrides, istrides; char *buffers[2]; PyArray_CopySwapNFunc *ocopyfunc, *icopyfunc; char *obptr; @@ -6875,10 +6881,22 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in, selsize = PyArray_ITEMSIZE(in); multi = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, out, in); if (multi == NULL) return -1; + icopyfunc = in->descr->f->copyswapn; + ocopyfunc = out->descr->f->copyswapn; maxaxis = PyArray_RemoveLargest(multi); - if (maxaxis < 0) return -1; - maxdim = multi->dimensions[maxaxis]; - N = (int) (MIN(maxdim, PyArray_BUFSIZE)); + if (maxaxis < 0) { /* cast 1 0-d array to another */ + N = 1; + maxdim = 1; + ostrides = delsize; + istrides = selsize; + } + else { + maxdim = multi->dimensions[maxaxis]; + N = (int) (MIN(maxdim, PyArray_BUFSIZE)); + ostrides = multi->iters[0]->strides[maxaxis]; + istrides = multi->iters[1]->strides[maxaxis]; + + } buffers[0] = _pya_malloc(N*delsize); if (buffers[0] == NULL) { PyErr_NoMemory(); @@ -6886,7 +6904,7 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in, } buffers[1] = _pya_malloc(N*selsize); if (buffers[1] == NULL) { - free(buffers[0]); + _pya_free(buffers[0]); PyErr_NoMemory(); return -1; } @@ -6895,15 +6913,12 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in, if (in->descr->hasobject) memset(buffers[1], 0, N*selsize); - icopyfunc = in->descr->f->copyswapn; - ocopyfunc = out->descr->f->copyswapn; - while(multi->index < multi->size) { _strided_buffered_cast(multi->iters[0]->dataptr, - multi->iters[0]->strides[maxaxis], + ostrides, delsize, oswap, ocopyfunc, multi->iters[1]->dataptr, - multi->iters[1]->strides[maxaxis], + istrides, selsize, iswap, icopyfunc, maxdim, buffers, N, castfunc, out, in); |
