summaryrefslogtreecommitdiff
path: root/numpy/core/defmatrix.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-07-24 20:56:16 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-07-24 20:56:16 +0000
commit96d053c58aabcdb5e87d014e16b6f2ca6dccee91 (patch)
treea5f65740cf22e862a5990f78e6fe3274858e93b4 /numpy/core/defmatrix.py
parent19c85f13b9fcee3dc5d5632b77380a61108ca3e6 (diff)
downloadnumpy-96d053c58aabcdb5e87d014e16b6f2ca6dccee91.tar.gz
Fix matrix getitem to return column matrices when appropriate.
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r--numpy/core/defmatrix.py47
1 files changed, 24 insertions, 23 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py
index 441500c90..f562cba06 100644
--- a/numpy/core/defmatrix.py
+++ b/numpy/core/defmatrix.py
@@ -101,6 +101,8 @@ class matrix(N.ndarray):
return ret
def __array_finalize__(self, obj):
+ self._getitem = False
+ if (isinstance(obj, matrix) and obj._getitem): return
ndim = self.ndim
if (ndim == 2):
return
@@ -121,31 +123,30 @@ class matrix(N.ndarray):
return
def __getitem__(self, index):
- out = N.ndarray.__getitem__(self, index)
- if not isinstance(out, matrix):
- return out
- # Need to swap if slice is on first index
- # or there is an integer on the second
- retscal = False
+ self._getitem = True
try:
- n = len(index)
- if (n==2):
- if isscalar(index[1]):
- if isscalar(index[0]):
- retscal = True
- elif out.shape[0] == 1:
- sh = out.shape
- out.shape = (sh[1], sh[0])
- elif isinstance(index[1], (slice, types.EllipsisType)):
- if out.shape[0] == 1 and not isscalar(index[0]):
- sh = out.shape
- out.shape = (sh[1], sh[0])
- except (TypeError, IndexError):
- pass
- if retscal and out.shape == (1,1): # convert scalars
- return out.A[0,0]
- return out
+ out = N.ndarray.__getitem__(self, index)
+ finally:
+ self._getitem = False
+
+ if not isinstance(out, N.ndarray):
+ return out
+ if out.ndim == 0:
+ return out[()]
+ if out.ndim == 1:
+ sh = out.shape[0]
+ # Determine when we should have a column array
+ try:
+ n = len(index)
+ except:
+ n = 0
+ if n > 1 and isscalar(index[1]):
+ out.shape = (sh,1)
+ else:
+ out.shape = (1,sh)
+ return out
+
def _get_truendim(self):
shp = self.shape
truend = 0