diff options
Diffstat (limited to 'scipy/base/matrix.py')
-rw-r--r-- | scipy/base/matrix.py | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/scipy/base/matrix.py b/scipy/base/matrix.py index dfe3cc0a3..59d844a4c 100644 --- a/scipy/base/matrix.py +++ b/scipy/base/matrix.py @@ -1,6 +1,7 @@ import numeric as N from numeric import ArrayType, concatenate, integer +from type_check import isscalar from function_base import binary_repr import types import string as str_ @@ -103,20 +104,24 @@ class matrix(N.ndarray): return def __getitem__(self, index): - out = (self.A)[index] - # Need to swap if slice is on first index + out = N.ndarray.__getitem__(self, index) + # Need to swap if slice is on first inde + retscal = False try: n = len(index) - if (n > 1) and isinstance(index[0], types.SliceType): - if (isinstance(index[1], types.IntType) or - isinstance(index[1], types.LongType) or - isinstance(index[1], integer)): - sh = out.shape - out.shape = (sh[1], sh[0]) - return matrix(out) - return out + if (n==2): + if isinstance(index[0], types.SliceType): + if (isscalar(index[1])): + sh = out.shape + out.shape = (sh[1], sh[0]) + else: + if (isscalar(index[0])) and (isscalar(index[1])): + retscal = True except TypeError: - return matrix(out) + pass + if retscal and out.shape == (1,1): # convert scalars + return out.A[0,0] + return out def __mul__(self, other): if isinstance(other, N.ndarray) and other.ndim == 0: |