summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2008-04-23 20:32:02 +0000
committerStefan van der Walt <stefan@sun.ac.za>2008-04-23 20:32:02 +0000
commit7b2239fbdfc7c5fdefdfd6eb59dc0ed500aa2291 (patch)
treefb2b16eb320a704f60214b1ce71cfa62b696a27c /numpy/core
parent1da799e3541679475c63026dbe82f9840d33188b (diff)
downloadnumpy-7b2239fbdfc7c5fdefdfd6eb59dc0ed500aa2291.tar.gz
Hack to let x[0][0] return a scalar for matrices.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/defmatrix.py9
-rw-r--r--numpy/core/tests/test_defmatrix.py11
2 files changed, 19 insertions, 1 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py
index de37a2686..77e90acbd 100644
--- a/numpy/core/defmatrix.py
+++ b/numpy/core/defmatrix.py
@@ -224,6 +224,14 @@ class matrix(N.ndarray):
def __getitem__(self, index):
self._getitem = True
+
+ # If indexing by scalar, check whether we are indexing into
+ # a vector, and then return the corresponding element
+ if N.isscalar(index) and (1 in self.shape):
+ index = [index,index]
+ index[list(self.shape).index(1)] = 0
+ index = tuple(index)
+
try:
out = N.ndarray.__getitem__(self, index)
finally:
@@ -254,7 +262,6 @@ class matrix(N.ndarray):
if (val > 1): truend += 1
return truend
-
def __mul__(self, other):
if isinstance(other,(N.ndarray, list, tuple)) :
# This promotes 1-D vectors to row vectors
diff --git a/numpy/core/tests/test_defmatrix.py b/numpy/core/tests/test_defmatrix.py
index 659ff4280..7d73bbd02 100644
--- a/numpy/core/tests/test_defmatrix.py
+++ b/numpy/core/tests/test_defmatrix.py
@@ -179,6 +179,17 @@ class TestIndexing(NumpyTestCase):
x[:,1] = y>0.5
assert_equal(x, [[0,1],[0,0],[0,0]])
+ def check_vector_element(self):
+ x = matrix([[1,2,3],[4,5,6]])
+ assert_equal(x[0][0],1)
+ assert_equal(x[0].shape,(1,3))
+ assert_equal(x[:,0].shape,(2,1))
+
+ x = matrix(0)
+ assert_equal(x[0,0],0)
+ assert_equal(x[0],0)
+ assert_equal(x[:,0].shape,x.shape)
+
if __name__ == "__main__":
NumpyTest().run()