summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/arrayobject.c37
1 files changed, 26 insertions, 11 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index cab2484d8..69109475c 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -897,7 +897,13 @@ _broadcast_copy(PyArrayObject *dest, PyArrayObject *src,
multi = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, dest, src);
if (multi == NULL) return -1;
maxaxis = PyArray_RemoveLargest(multi);
- if (maxaxis < 0) return -1;
+ if (maxaxis < 0) { /* copy 1 0-d array to another */
+ PyArray_XDECREF(dest);
+ memcpy(dest->data, src->data, elsize);
+ if (swap) byte_swap_vector(dest->data, 1, elsize);
+ PyArray_INCREF(dest);
+ return 0;
+ }
maxdim = multi->dimensions[maxaxis];
PyArray_XDECREF(dest);
@@ -6866,7 +6872,7 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in,
{
int delsize, selsize, maxaxis, i, N;
PyArrayMultiIterObject *multi;
- intp maxdim;
+ intp maxdim, ostrides, istrides;
char *buffers[2];
PyArray_CopySwapNFunc *ocopyfunc, *icopyfunc;
char *obptr;
@@ -6875,10 +6881,22 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in,
selsize = PyArray_ITEMSIZE(in);
multi = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, out, in);
if (multi == NULL) return -1;
+ icopyfunc = in->descr->f->copyswapn;
+ ocopyfunc = out->descr->f->copyswapn;
maxaxis = PyArray_RemoveLargest(multi);
- if (maxaxis < 0) return -1;
- maxdim = multi->dimensions[maxaxis];
- N = (int) (MIN(maxdim, PyArray_BUFSIZE));
+ if (maxaxis < 0) { /* cast 1 0-d array to another */
+ N = 1;
+ maxdim = 1;
+ ostrides = delsize;
+ istrides = selsize;
+ }
+ else {
+ maxdim = multi->dimensions[maxaxis];
+ N = (int) (MIN(maxdim, PyArray_BUFSIZE));
+ ostrides = multi->iters[0]->strides[maxaxis];
+ istrides = multi->iters[1]->strides[maxaxis];
+
+ }
buffers[0] = _pya_malloc(N*delsize);
if (buffers[0] == NULL) {
PyErr_NoMemory();
@@ -6886,7 +6904,7 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in,
}
buffers[1] = _pya_malloc(N*selsize);
if (buffers[1] == NULL) {
- free(buffers[0]);
+ _pya_free(buffers[0]);
PyErr_NoMemory();
return -1;
}
@@ -6895,15 +6913,12 @@ _broadcast_cast(PyArrayObject *out, PyArrayObject *in,
if (in->descr->hasobject)
memset(buffers[1], 0, N*selsize);
- icopyfunc = in->descr->f->copyswapn;
- ocopyfunc = out->descr->f->copyswapn;
-
while(multi->index < multi->size) {
_strided_buffered_cast(multi->iters[0]->dataptr,
- multi->iters[0]->strides[maxaxis],
+ ostrides,
delsize, oswap, ocopyfunc,
multi->iters[1]->dataptr,
- multi->iters[1]->strides[maxaxis],
+ istrides,
selsize, iswap, icopyfunc,
maxdim, buffers, N,
castfunc, out, in);