diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index fbc479254..4040ac578 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -207,7 +207,7 @@ class concatenator(object): self.col = 0 def __getitem__(self,key): - trans1d = False + trans1d = -1 ndmin = 1 if isinstance(key, str): frame = sys._getframe().f_back @@ -234,8 +234,8 @@ class concatenator(object): newobj = _nx.arange(start, stop, step) if ndmin > 1: newobj = array(newobj,copy=False,ndmin=ndmin) - if trans1d: - newobj = newobj.T + if trans1d != -1: + newobj = newobj.swapaxes(-1,trans1d) elif isinstance(key[k],str): if k != 0: raise ValueError, "special directives must be the"\ @@ -250,8 +250,8 @@ class concatenator(object): try: self.axis, ndmin = \ [int(x) for x in vec[:2]] - if len(vec) == 3 and vec[2] == 't': - trans1d = True + if len(vec) == 3: + trans1d = int(vec[2]) continue except: raise ValueError, "unknown special directive" @@ -270,8 +270,15 @@ class concatenator(object): tempobj = array(newobj, copy=False, subok=True) newobj = array(newobj, copy=False, subok=True, ndmin=ndmin) - if trans1d and tempobj.ndim == 1: - newobj = newobj.T + if trans1d != -1 and tempobj.ndim < ndmin: + k2 = ndmin-tempobj.ndim + if (trans1d < 0): + trans1d += k2 + 1 + defaxes = range(ndmin) + k1 = trans1d + axes = defaxes[:k1] + defaxes[k2:] + \ + defaxes[k1:k2] + newobj = newobj.transpose(axes) del tempobj objs.append(newobj) if isinstance(newobj, _nx.ndarray) and not scalar: |