summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py41
1 files changed, 34 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 6c864080f..fbc479254 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -10,7 +10,7 @@ __all__ = ['unravel_index',
import sys
import types
import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType
+from numpy.core.numeric import asarray, ScalarType, array
import function_base
import numpy.core.defmatrix as matrix
@@ -207,6 +207,8 @@ class concatenator(object):
self.col = 0
def __getitem__(self,key):
+ trans1d = False
+ ndmin = 1
if isinstance(key, str):
frame = sys._getframe().f_back
mymat = matrix.bmat(key,frame.f_globals,frame.f_locals)
@@ -230,22 +232,47 @@ class concatenator(object):
newobj = function_base.linspace(start, stop, num=size)
else:
newobj = _nx.arange(start, stop, step)
- elif type(key[k]) is str:
- if (key[k] in 'rc'):
+ if ndmin > 1:
+ newobj = array(newobj,copy=False,ndmin=ndmin)
+ if trans1d:
+ newobj = newobj.T
+ elif isinstance(key[k],str):
+ if k != 0:
+ raise ValueError, "special directives must be the"\
+ "first entry."
+ key0 = key[0]
+ if key0 in 'rc':
self.matrix = True
- self.col = (key[k] == 'c')
+ self.col = (key0 == 'c')
continue
+ if ',' in key0:
+ vec = key0.split(',')
+ try:
+ self.axis, ndmin = \
+ [int(x) for x in vec[:2]]
+ if len(vec) == 3 and vec[2] == 't':
+ trans1d = True
+ continue
+ except:
+ raise ValueError, "unknown special directive"
try:
self.axis = int(key[k])
continue
except (ValueError, TypeError):
raise ValueError, "unknown special directive"
elif type(key[k]) in ScalarType:
- newobj = asarray([key[k]])
+ newobj = array(key[k],ndmin=ndmin)
scalars.append(k)
scalar = True
else:
newobj = key[k]
+ if ndmin > 1:
+ tempobj = array(newobj, copy=False, subok=True)
+ newobj = array(newobj, copy=False, subok=True,
+ ndmin=ndmin)
+ if trans1d and tempobj.ndim == 1:
+ newobj = newobj.T
+ del tempobj
objs.append(newobj)
if isinstance(newobj, _nx.ndarray) and not scalar:
if final_dtypedescr is None:
@@ -286,13 +313,13 @@ import warnings
class c_class(concatenator):
"""Translates slice objects to concatenation along the second axis.
- This is deprecated. Use r_[...,'-1']
+ This is deprecated. Use r_['-1',...]
"""
def __init__(self):
concatenator.__init__(self, -1)
def __getitem__(self, obj):
- warnings.warn("c_ is deprecated use r_[...,'-1']")
+ warnings.warn("c_ is deprecated use r_['-1',...]")
return concatenator.__getitem__(self, obj)
c_ = c_class()