diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2008-04-23 20:32:02 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2008-04-23 20:32:02 +0000 |
commit | 7b2239fbdfc7c5fdefdfd6eb59dc0ed500aa2291 (patch) | |
tree | fb2b16eb320a704f60214b1ce71cfa62b696a27c /numpy/core | |
parent | 1da799e3541679475c63026dbe82f9840d33188b (diff) | |
download | numpy-7b2239fbdfc7c5fdefdfd6eb59dc0ed500aa2291.tar.gz |
Hack to let x[0][0] return a scalar for matrices.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/defmatrix.py | 9 | ||||
-rw-r--r-- | numpy/core/tests/test_defmatrix.py | 11 |
2 files changed, 19 insertions, 1 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index de37a2686..77e90acbd 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -224,6 +224,14 @@ class matrix(N.ndarray): def __getitem__(self, index): self._getitem = True + + # If indexing by scalar, check whether we are indexing into + # a vector, and then return the corresponding element + if N.isscalar(index) and (1 in self.shape): + index = [index,index] + index[list(self.shape).index(1)] = 0 + index = tuple(index) + try: out = N.ndarray.__getitem__(self, index) finally: @@ -254,7 +262,6 @@ class matrix(N.ndarray): if (val > 1): truend += 1 return truend - def __mul__(self, other): if isinstance(other,(N.ndarray, list, tuple)) : # This promotes 1-D vectors to row vectors diff --git a/numpy/core/tests/test_defmatrix.py b/numpy/core/tests/test_defmatrix.py index 659ff4280..7d73bbd02 100644 --- a/numpy/core/tests/test_defmatrix.py +++ b/numpy/core/tests/test_defmatrix.py @@ -179,6 +179,17 @@ class TestIndexing(NumpyTestCase): x[:,1] = y>0.5 assert_equal(x, [[0,1],[0,0],[0,0]]) + def check_vector_element(self): + x = matrix([[1,2,3],[4,5,6]]) + assert_equal(x[0][0],1) + assert_equal(x[0].shape,(1,3)) + assert_equal(x[:,0].shape,(2,1)) + + x = matrix(0) + assert_equal(x[0,0],0) + assert_equal(x[0],0) + assert_equal(x[:,0].shape,x.shape) + if __name__ == "__main__": NumpyTest().run() |