diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-09-14 02:33:55 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-09-14 02:33:55 +0000 |
commit | d6ce2d7dc3a62b45272779d771c86338cf4f2c56 (patch) | |
tree | 61ebc82a5816257d1cd90b2f35ce694e5e17be74 /numpy/lib/index_tricks.py | |
parent | 4e76e00cc5afceaf70fe8d655cf59d4a9fb85a0a (diff) | |
download | numpy-d6ce2d7dc3a62b45272779d771c86338cf4f2c56.tar.gz |
Fix up r_ so you can specify the minimum number of dimensions to force arrays to and allow alteration of the concatenation axis and whether or not to transpose 1d arrays
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 41 |
1 files changed, 34 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 6c864080f..fbc479254 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -10,7 +10,7 @@ __all__ = ['unravel_index', import sys import types import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, ScalarType +from numpy.core.numeric import asarray, ScalarType, array import function_base import numpy.core.defmatrix as matrix @@ -207,6 +207,8 @@ class concatenator(object): self.col = 0 def __getitem__(self,key): + trans1d = False + ndmin = 1 if isinstance(key, str): frame = sys._getframe().f_back mymat = matrix.bmat(key,frame.f_globals,frame.f_locals) @@ -230,22 +232,47 @@ class concatenator(object): newobj = function_base.linspace(start, stop, num=size) else: newobj = _nx.arange(start, stop, step) - elif type(key[k]) is str: - if (key[k] in 'rc'): + if ndmin > 1: + newobj = array(newobj,copy=False,ndmin=ndmin) + if trans1d: + newobj = newobj.T + elif isinstance(key[k],str): + if k != 0: + raise ValueError, "special directives must be the"\ + "first entry." + key0 = key[0] + if key0 in 'rc': self.matrix = True - self.col = (key[k] == 'c') + self.col = (key0 == 'c') continue + if ',' in key0: + vec = key0.split(',') + try: + self.axis, ndmin = \ + [int(x) for x in vec[:2]] + if len(vec) == 3 and vec[2] == 't': + trans1d = True + continue + except: + raise ValueError, "unknown special directive" try: self.axis = int(key[k]) continue except (ValueError, TypeError): raise ValueError, "unknown special directive" elif type(key[k]) in ScalarType: - newobj = asarray([key[k]]) + newobj = array(key[k],ndmin=ndmin) scalars.append(k) scalar = True else: newobj = key[k] + if ndmin > 1: + tempobj = array(newobj, copy=False, subok=True) + newobj = array(newobj, copy=False, subok=True, + ndmin=ndmin) + if trans1d and tempobj.ndim == 1: + newobj = newobj.T + del tempobj objs.append(newobj) if isinstance(newobj, _nx.ndarray) and not scalar: if final_dtypedescr is None: @@ -286,13 +313,13 @@ import warnings class c_class(concatenator): """Translates slice objects to concatenation along the second axis. - This is deprecated. Use r_[...,'-1'] + This is deprecated. Use r_['-1',...] """ def __init__(self): concatenator.__init__(self, -1) def __getitem__(self, obj): - warnings.warn("c_ is deprecated use r_[...,'-1']") + warnings.warn("c_ is deprecated use r_['-1',...]") return concatenator.__getitem__(self, obj) c_ = c_class() |