diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-07-24 20:56:16 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-07-24 20:56:16 +0000 |
commit | 96d053c58aabcdb5e87d014e16b6f2ca6dccee91 (patch) | |
tree | a5f65740cf22e862a5990f78e6fe3274858e93b4 /numpy/core/defmatrix.py | |
parent | 19c85f13b9fcee3dc5d5632b77380a61108ca3e6 (diff) | |
download | numpy-96d053c58aabcdb5e87d014e16b6f2ca6dccee91.tar.gz |
Fix matrix getitem to return column matrices when appropriate.
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r-- | numpy/core/defmatrix.py | 47 |
1 files changed, 24 insertions, 23 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index 441500c90..f562cba06 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -101,6 +101,8 @@ class matrix(N.ndarray): return ret def __array_finalize__(self, obj): + self._getitem = False + if (isinstance(obj, matrix) and obj._getitem): return ndim = self.ndim if (ndim == 2): return @@ -121,31 +123,30 @@ class matrix(N.ndarray): return def __getitem__(self, index): - out = N.ndarray.__getitem__(self, index) - if not isinstance(out, matrix): - return out - # Need to swap if slice is on first index - # or there is an integer on the second - retscal = False + self._getitem = True try: - n = len(index) - if (n==2): - if isscalar(index[1]): - if isscalar(index[0]): - retscal = True - elif out.shape[0] == 1: - sh = out.shape - out.shape = (sh[1], sh[0]) - elif isinstance(index[1], (slice, types.EllipsisType)): - if out.shape[0] == 1 and not isscalar(index[0]): - sh = out.shape - out.shape = (sh[1], sh[0]) - except (TypeError, IndexError): - pass - if retscal and out.shape == (1,1): # convert scalars - return out.A[0,0] - return out + out = N.ndarray.__getitem__(self, index) + finally: + self._getitem = False + + if not isinstance(out, N.ndarray): + return out + if out.ndim == 0: + return out[()] + if out.ndim == 1: + sh = out.shape[0] + # Determine when we should have a column array + try: + n = len(index) + except: + n = 0 + if n > 1 and isscalar(index[1]): + out.shape = (sh,1) + else: + out.shape = (1,sh) + return out + def _get_truendim(self): shp = self.shape truend = 0 |