diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 41 |
1 files changed, 34 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 6c864080f..fbc479254 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -10,7 +10,7 @@ __all__ = ['unravel_index', import sys import types import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, ScalarType +from numpy.core.numeric import asarray, ScalarType, array import function_base import numpy.core.defmatrix as matrix @@ -207,6 +207,8 @@ class concatenator(object): self.col = 0 def __getitem__(self,key): + trans1d = False + ndmin = 1 if isinstance(key, str): frame = sys._getframe().f_back mymat = matrix.bmat(key,frame.f_globals,frame.f_locals) @@ -230,22 +232,47 @@ class concatenator(object): newobj = function_base.linspace(start, stop, num=size) else: newobj = _nx.arange(start, stop, step) - elif type(key[k]) is str: - if (key[k] in 'rc'): + if ndmin > 1: + newobj = array(newobj,copy=False,ndmin=ndmin) + if trans1d: + newobj = newobj.T + elif isinstance(key[k],str): + if k != 0: + raise ValueError, "special directives must be the"\ + "first entry." + key0 = key[0] + if key0 in 'rc': self.matrix = True - self.col = (key[k] == 'c') + self.col = (key0 == 'c') continue + if ',' in key0: + vec = key0.split(',') + try: + self.axis, ndmin = \ + [int(x) for x in vec[:2]] + if len(vec) == 3 and vec[2] == 't': + trans1d = True + continue + except: + raise ValueError, "unknown special directive" try: self.axis = int(key[k]) continue except (ValueError, TypeError): raise ValueError, "unknown special directive" elif type(key[k]) in ScalarType: - newobj = asarray([key[k]]) + newobj = array(key[k],ndmin=ndmin) scalars.append(k) scalar = True else: newobj = key[k] + if ndmin > 1: + tempobj = array(newobj, copy=False, subok=True) + newobj = array(newobj, copy=False, subok=True, + ndmin=ndmin) + if trans1d and tempobj.ndim == 1: + newobj = newobj.T + del tempobj objs.append(newobj) if isinstance(newobj, _nx.ndarray) and not scalar: if final_dtypedescr is None: @@ -286,13 +313,13 @@ import warnings class c_class(concatenator): """Translates slice objects to concatenation along the second axis. - This is deprecated. Use r_[...,'-1'] + This is deprecated. Use r_['-1',...] """ def __init__(self): concatenator.__init__(self, -1) def __getitem__(self, obj): - warnings.warn("c_ is deprecated use r_[...,'-1']") + warnings.warn("c_ is deprecated use r_['-1',...]") return concatenator.__getitem__(self, obj) c_ = c_class() |