summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-09-14 02:33:55 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-09-14 02:33:55 +0000
commitd6ce2d7dc3a62b45272779d771c86338cf4f2c56 (patch)
tree61ebc82a5816257d1cd90b2f35ce694e5e17be74 /numpy/lib/index_tricks.py
parent4e76e00cc5afceaf70fe8d655cf59d4a9fb85a0a (diff)
downloadnumpy-d6ce2d7dc3a62b45272779d771c86338cf4f2c56.tar.gz
Fix up r_ so you can specify the minimum number of dimensions to force arrays to and allow alteration of the concatenation axis and whether or not to transpose 1d arrays
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py41
1 files changed, 34 insertions, 7 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 6c864080f..fbc479254 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -10,7 +10,7 @@ __all__ = ['unravel_index',
import sys
import types
import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType
+from numpy.core.numeric import asarray, ScalarType, array
import function_base
import numpy.core.defmatrix as matrix
@@ -207,6 +207,8 @@ class concatenator(object):
self.col = 0
def __getitem__(self,key):
+ trans1d = False
+ ndmin = 1
if isinstance(key, str):
frame = sys._getframe().f_back
mymat = matrix.bmat(key,frame.f_globals,frame.f_locals)
@@ -230,22 +232,47 @@ class concatenator(object):
newobj = function_base.linspace(start, stop, num=size)
else:
newobj = _nx.arange(start, stop, step)
- elif type(key[k]) is str:
- if (key[k] in 'rc'):
+ if ndmin > 1:
+ newobj = array(newobj,copy=False,ndmin=ndmin)
+ if trans1d:
+ newobj = newobj.T
+ elif isinstance(key[k],str):
+ if k != 0:
+ raise ValueError, "special directives must be the"\
+ "first entry."
+ key0 = key[0]
+ if key0 in 'rc':
self.matrix = True
- self.col = (key[k] == 'c')
+ self.col = (key0 == 'c')
continue
+ if ',' in key0:
+ vec = key0.split(',')
+ try:
+ self.axis, ndmin = \
+ [int(x) for x in vec[:2]]
+ if len(vec) == 3 and vec[2] == 't':
+ trans1d = True
+ continue
+ except:
+ raise ValueError, "unknown special directive"
try:
self.axis = int(key[k])
continue
except (ValueError, TypeError):
raise ValueError, "unknown special directive"
elif type(key[k]) in ScalarType:
- newobj = asarray([key[k]])
+ newobj = array(key[k],ndmin=ndmin)
scalars.append(k)
scalar = True
else:
newobj = key[k]
+ if ndmin > 1:
+ tempobj = array(newobj, copy=False, subok=True)
+ newobj = array(newobj, copy=False, subok=True,
+ ndmin=ndmin)
+ if trans1d and tempobj.ndim == 1:
+ newobj = newobj.T
+ del tempobj
objs.append(newobj)
if isinstance(newobj, _nx.ndarray) and not scalar:
if final_dtypedescr is None:
@@ -286,13 +313,13 @@ import warnings
class c_class(concatenator):
"""Translates slice objects to concatenation along the second axis.
- This is deprecated. Use r_[...,'-1']
+ This is deprecated. Use r_['-1',...]
"""
def __init__(self):
concatenator.__init__(self, -1)
def __getitem__(self, obj):
- warnings.warn("c_ is deprecated use r_[...,'-1']")
+ warnings.warn("c_ is deprecated use r_['-1',...]")
return concatenator.__getitem__(self, obj)
c_ = c_class()