summaryrefslogtreecommitdiff
path: root/numpy/core/src/arrayobject.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r--numpy/core/src/arrayobject.c32
1 files changed, 23 insertions, 9 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 0485e5ad3..ac85fcd23 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2280,10 +2280,10 @@ parse_index(PyArrayObject *self, PyObject *op,
}
static void
-_swap_axes(PyArrayMapIterObject *mit, PyArrayObject **ret)
+_swap_axes(PyArrayMapIterObject *mit, PyArrayObject **ret, int getmap)
{
PyObject *new;
- int n1, n2, n3, val;
+ int n1, n2, n3, val, bnd;
int i;
PyArray_Dims permute;
intp d[MAX_DIMS];
@@ -2308,24 +2308,38 @@ _swap_axes(PyArrayMapIterObject *mit, PyArrayObject **ret)
if (new == NULL) return;
}
- /* tuple for transpose is
- (n1,..,n1+n2-1,0,..,n1-1,n1+n2,...,n3-1)
+ /* Setting and getting need to have different permutations.
+ On the get we are permuting the returned object, but on
+ setting we are permuting the object-to-be-set.
+ The set permutation is the inverse of the get permutation.
+ */
+
+ /* For getting the array the tuple for transpose is
+ (n1,...,n1+n2-1,0,...,n1-1,n1+n2,...,n3-1)
n1 is the number of dimensions of
the broadcasted index array
n2 is the number of dimensions skipped at the
- start
+ start
n3 is the number of dimensions of the
result
*/
+
+ /* For setting the array the tuple for transpose is
+ (n2,...,n1+n2-1,0,...,n2-1,n1+n2,...n3-1)
+ */
n1 = mit->iters[0]->nd_m1 + 1;
n2 = mit->iteraxes[0];
n3 = mit->nd;
- val = n1;
+
+ bnd = (getmap ? n1 : n2); /* use n1 as the boundary if getting
+ but n2 if setting */
+
+ val = bnd;
i = 0;
while(val < n1+n2)
permute.ptr[i++] = val++;
val = 0;
- while(val < n1)
+ while(val < bnd)
permute.ptr[i++] = val++;
val = n1+n2;
while(val < n3)
@@ -2395,7 +2409,7 @@ PyArray_GetMap(PyArrayMapIterObject *mit)
/* check for consecutive axes */
if ((mit->subspace != NULL) && (mit->consec)) {
if (mit->iteraxes[0] > 0) { /* then we need to swap */
- _swap_axes(mit, &ret);
+ _swap_axes(mit, &ret, 1);
}
}
return (PyObject *)ret;
@@ -2421,7 +2435,7 @@ PyArray_SetMap(PyArrayMapIterObject *mit, PyObject *op)
if ((mit->subspace != NULL) && (mit->consec)) {
if (mit->iteraxes[0] > 0) { /* then we need to swap */
- _swap_axes(mit, (PyArrayObject **)&arr);
+ _swap_axes(mit, (PyArrayObject **)&arr, 0);
if (arr == NULL) return -1;
}
}