diff options
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r-- | numpy/core/defmatrix.py | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index 31572f1b4..d76eafea3 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -74,7 +74,7 @@ class matrix(N.ndarray): if copy: return new.copy() else: return new - if isinstance(data, types.StringType): + if isinstance(data, str): data = _convert_from_string(data) # now convert data to an array @@ -112,27 +112,33 @@ class matrix(N.ndarray): return elif (ndim > 2): raise ValueError, "shape too large to be a matrix." + else: + newshape = self.shape if ndim == 0: self.shape = (1,1) elif ndim == 1: - self.shape = (1,self.shape[0]) + self.shape = (1,newshape[0]) return def __getitem__(self, index): out = N.ndarray.__getitem__(self, index) # Need to swap if slice is on first index + # or there is an integer on the second retscal = False try: n = len(index) if (n==2): - if isinstance(index[0], types.SliceType): - if (isscalar(index[1])): + if isscalar(index[1]): + if isscalar(index[0]): + retscal = True + elif out.shape[0] == 1: sh = out.shape out.shape = (sh[1], sh[0]) - else: - if (isscalar(index[0])) and (isscalar(index[1])): - retscal = True - except TypeError: + elif isinstance(index[1], (slice, types.EllipsisType)): + if out.shape[0] == 1 and not isscalar(index[0]): + sh = out.shape + out.shape = (sh[1], sh[0]) + except (TypeError, IndexError): pass if retscal and out.shape == (1,1): # convert scalars return out.A[0,0] @@ -322,7 +328,7 @@ def bmat(obj, ldict=None, gdict=None): if A, B, C, and D are appropriately shaped 2-d arrays. """ - if isinstance(obj, types.StringType): + if isinstance(obj, str): if gdict is None: # get previous frame frame = sys._getframe().f_back @@ -334,7 +340,7 @@ def bmat(obj, ldict=None, gdict=None): return matrix(_from_string(obj, glob_dict, loc_dict)) - if isinstance(obj, (types.TupleType, types.ListType)): + if isinstance(obj, (tuple, list)): # [[A,B],[C,D]] arr_rows = [] for row in obj: |