summaryrefslogtreecommitdiff
path: root/scipy/base/matrix.py
diff options
context:
space:
mode:
Diffstat (limited to 'scipy/base/matrix.py')
-rw-r--r--scipy/base/matrix.py27
1 files changed, 16 insertions, 11 deletions
diff --git a/scipy/base/matrix.py b/scipy/base/matrix.py
index dfe3cc0a3..59d844a4c 100644
--- a/scipy/base/matrix.py
+++ b/scipy/base/matrix.py
@@ -1,6 +1,7 @@
import numeric as N
from numeric import ArrayType, concatenate, integer
+from type_check import isscalar
from function_base import binary_repr
import types
import string as str_
@@ -103,20 +104,24 @@ class matrix(N.ndarray):
return
def __getitem__(self, index):
- out = (self.A)[index]
- # Need to swap if slice is on first index
+ out = N.ndarray.__getitem__(self, index)
+ # Need to swap if slice is on first inde
+ retscal = False
try:
n = len(index)
- if (n > 1) and isinstance(index[0], types.SliceType):
- if (isinstance(index[1], types.IntType) or
- isinstance(index[1], types.LongType) or
- isinstance(index[1], integer)):
- sh = out.shape
- out.shape = (sh[1], sh[0])
- return matrix(out)
- return out
+ if (n==2):
+ if isinstance(index[0], types.SliceType):
+ if (isscalar(index[1])):
+ sh = out.shape
+ out.shape = (sh[1], sh[0])
+ else:
+ if (isscalar(index[0])) and (isscalar(index[1])):
+ retscal = True
except TypeError:
- return matrix(out)
+ pass
+ if retscal and out.shape == (1,1): # convert scalars
+ return out.A[0,0]
+ return out
def __mul__(self, other):
if isinstance(other, N.ndarray) and other.ndim == 0: