summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py21
1 files changed, 14 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index fbc479254..4040ac578 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -207,7 +207,7 @@ class concatenator(object):
self.col = 0
def __getitem__(self,key):
- trans1d = False
+ trans1d = -1
ndmin = 1
if isinstance(key, str):
frame = sys._getframe().f_back
@@ -234,8 +234,8 @@ class concatenator(object):
newobj = _nx.arange(start, stop, step)
if ndmin > 1:
newobj = array(newobj,copy=False,ndmin=ndmin)
- if trans1d:
- newobj = newobj.T
+ if trans1d != -1:
+ newobj = newobj.swapaxes(-1,trans1d)
elif isinstance(key[k],str):
if k != 0:
raise ValueError, "special directives must be the"\
@@ -250,8 +250,8 @@ class concatenator(object):
try:
self.axis, ndmin = \
[int(x) for x in vec[:2]]
- if len(vec) == 3 and vec[2] == 't':
- trans1d = True
+ if len(vec) == 3:
+ trans1d = int(vec[2])
continue
except:
raise ValueError, "unknown special directive"
@@ -270,8 +270,15 @@ class concatenator(object):
tempobj = array(newobj, copy=False, subok=True)
newobj = array(newobj, copy=False, subok=True,
ndmin=ndmin)
- if trans1d and tempobj.ndim == 1:
- newobj = newobj.T
+ if trans1d != -1 and tempobj.ndim < ndmin:
+ k2 = ndmin-tempobj.ndim
+ if (trans1d < 0):
+ trans1d += k2 + 1
+ defaxes = range(ndmin)
+ k1 = trans1d
+ axes = defaxes[:k1] + defaxes[k2:] + \
+ defaxes[k1:k2]
+ newobj = newobj.transpose(axes)
del tempobj
objs.append(newobj)
if isinstance(newobj, _nx.ndarray) and not scalar: