diff options
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r-- | numpy/core/src/arrayobject.c | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 0485e5ad3..ac85fcd23 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2280,10 +2280,10 @@ parse_index(PyArrayObject *self, PyObject *op, } static void -_swap_axes(PyArrayMapIterObject *mit, PyArrayObject **ret) +_swap_axes(PyArrayMapIterObject *mit, PyArrayObject **ret, int getmap) { PyObject *new; - int n1, n2, n3, val; + int n1, n2, n3, val, bnd; int i; PyArray_Dims permute; intp d[MAX_DIMS]; @@ -2308,24 +2308,38 @@ _swap_axes(PyArrayMapIterObject *mit, PyArrayObject **ret) if (new == NULL) return; } - /* tuple for transpose is - (n1,..,n1+n2-1,0,..,n1-1,n1+n2,...,n3-1) + /* Setting and getting need to have different permutations. + On the get we are permuting the returned object, but on + setting we are permuting the object-to-be-set. + The set permutation is the inverse of the get permutation. + */ + + /* For getting the array the tuple for transpose is + (n1,...,n1+n2-1,0,...,n1-1,n1+n2,...,n3-1) n1 is the number of dimensions of the broadcasted index array n2 is the number of dimensions skipped at the - start + start n3 is the number of dimensions of the result */ + + /* For setting the array the tuple for transpose is + (n2,...,n1+n2-1,0,...,n2-1,n1+n2,...n3-1) + */ n1 = mit->iters[0]->nd_m1 + 1; n2 = mit->iteraxes[0]; n3 = mit->nd; - val = n1; + + bnd = (getmap ? n1 : n2); /* use n1 as the boundary if getting + but n2 if setting */ + + val = bnd; i = 0; while(val < n1+n2) permute.ptr[i++] = val++; val = 0; - while(val < n1) + while(val < bnd) permute.ptr[i++] = val++; val = n1+n2; while(val < n3) @@ -2395,7 +2409,7 @@ PyArray_GetMap(PyArrayMapIterObject *mit) /* check for consecutive axes */ if ((mit->subspace != NULL) && (mit->consec)) { if (mit->iteraxes[0] > 0) { /* then we need to swap */ - _swap_axes(mit, &ret); + _swap_axes(mit, &ret, 1); } } return (PyObject *)ret; @@ -2421,7 +2435,7 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op) if ((mit->subspace != NULL) && (mit->consec)) { if (mit->iteraxes[0] > 0) { /* then we need to swap */ - _swap_axes(mit, (PyArrayObject **)&arr); + _swap_axes(mit, (PyArrayObject **)&arr, 0); if (arr == NULL) return -1; } } |