summaryrefslogtreecommitdiff
path: root/numpy/matrixlib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/matrixlib')
-rw-r--r--numpy/matrixlib/defmatrix.py27
-rw-r--r--numpy/matrixlib/tests/test_defmatrix.py36
2 files changed, 44 insertions, 19 deletions
diff --git a/numpy/matrixlib/defmatrix.py b/numpy/matrixlib/defmatrix.py
index aac123e59..dbf2909fd 100644
--- a/numpy/matrixlib/defmatrix.py
+++ b/numpy/matrixlib/defmatrix.py
@@ -375,6 +375,15 @@ class matrix(N.ndarray):
else:
raise ValueError("unsupported axis")
+ def _collapse(self, axis):
+ """A convenience function for operations that want to collapse
+ to a scalar like _align, but are using keepdims=True
+ """
+ if axis is None:
+ return self[0,0]
+ else:
+ return self
+
# Necessary because base-class tolist expects dimension
# reduction by x[0]
def tolist(self):
@@ -432,7 +441,7 @@ class matrix(N.ndarray):
[ 7.]])
"""
- return N.ndarray.sum(self, axis, dtype, out)._align(axis)
+ return N.ndarray.sum(self, axis, dtype, out, keepdims=True)._collapse(axis)
def mean(self, axis=None, dtype=None, out=None):
"""
@@ -466,7 +475,7 @@ class matrix(N.ndarray):
[ 9.5]])
"""
- return N.ndarray.mean(self, axis, dtype, out)._align(axis)
+ return N.ndarray.mean(self, axis, dtype, out, keepdims=True)._collapse(axis)
def std(self, axis=None, dtype=None, out=None, ddof=0):
"""
@@ -500,7 +509,7 @@ class matrix(N.ndarray):
[ 1.11803399]])
"""
- return N.ndarray.std(self, axis, dtype, out, ddof)._align(axis)
+ return N.ndarray.std(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis)
def var(self, axis=None, dtype=None, out=None, ddof=0):
"""
@@ -534,7 +543,7 @@ class matrix(N.ndarray):
[ 1.25]])
"""
- return N.ndarray.var(self, axis, dtype, out, ddof)._align(axis)
+ return N.ndarray.var(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis)
def prod(self, axis=None, dtype=None, out=None):
"""
@@ -567,7 +576,7 @@ class matrix(N.ndarray):
[7920]])
"""
- return N.ndarray.prod(self, axis, dtype, out)._align(axis)
+ return N.ndarray.prod(self, axis, dtype, out, keepdims=True)._collapse(axis)
def any(self, axis=None, out=None):
"""
@@ -590,7 +599,7 @@ class matrix(N.ndarray):
returns `ndarray`
"""
- return N.ndarray.any(self, axis, out)._align(axis)
+ return N.ndarray.any(self, axis, out, keepdims=True)._collapse(axis)
def all(self, axis=None, out=None):
"""
@@ -630,7 +639,7 @@ class matrix(N.ndarray):
[False]], dtype=bool)
"""
- return N.ndarray.all(self, axis, out)._align(axis)
+ return N.ndarray.all(self, axis, out, keepdims=True)._collapse(axis)
def max(self, axis=None, out=None):
"""
@@ -665,7 +674,7 @@ class matrix(N.ndarray):
[11]])
"""
- return N.ndarray.max(self, axis, out)._align(axis)
+ return N.ndarray.max(self, axis, out, keepdims=True)._collapse(axis)
def argmax(self, axis=None, out=None):
"""
@@ -735,7 +744,7 @@ class matrix(N.ndarray):
[-11]])
"""
- return N.ndarray.min(self, axis, out)._align(axis)
+ return N.ndarray.min(self, axis, out, keepdims=True)._collapse(axis)
def argmin(self, axis=None, out=None):
"""
diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py
index 09a4b4892..0a181fca3 100644
--- a/numpy/matrixlib/tests/test_defmatrix.py
+++ b/numpy/matrixlib/tests/test_defmatrix.py
@@ -65,29 +65,45 @@ class TestProperties(TestCase):
sumall = 30
assert_array_equal(sum0, M.sum(axis=0))
assert_array_equal(sum1, M.sum(axis=1))
- assert_(sumall == M.sum())
+ assert_equal(sumall, M.sum())
+
+ assert_array_equal(sum0, np.sum(M, axis=0))
+ assert_array_equal(sum1, np.sum(M, axis=1))
+ assert_equal(sumall, np.sum(M))
def test_prod(self):
x = matrix([[1,2,3],[4,5,6]])
- assert_(x.prod() == 720)
- assert_(all(x.prod(0) == matrix([[4,10,18]])))
- assert_(all(x.prod(1) == matrix([[6],[120]])))
+ assert_equal(x.prod(), 720)
+ assert_equal(x.prod(0), matrix([[4,10,18]]))
+ assert_equal(x.prod(1), matrix([[6],[120]]))
+
+ assert_equal(np.prod(x), 720)
+ assert_equal(np.prod(x, axis=0), matrix([[4,10,18]]))
+ assert_equal(np.prod(x, axis=1), matrix([[6],[120]]))
y = matrix([0,1,3])
assert_(y.prod() == 0)
def test_max(self):
x = matrix([[1,2,3],[4,5,6]])
- assert_(x.max() == 6)
- assert_(all(x.max(0) == matrix([[4,5,6]])))
- assert_(all(x.max(1) == matrix([[3],[6]])))
+ assert_equal(x.max(), 6)
+ assert_equal(x.max(0), matrix([[4,5,6]]))
+ assert_equal(x.max(1), matrix([[3],[6]]))
+
+ assert_equal(np.max(x), 6)
+ assert_equal(np.max(x, axis=0), matrix([[4,5,6]]))
+ assert_equal(np.max(x, axis=1), matrix([[3],[6]]))
def test_min(self):
x = matrix([[1,2,3],[4,5,6]])
- assert_(x.min() == 1)
- assert_(all(x.min(0) == matrix([[1,2,3]])))
- assert_(all(x.min(1) == matrix([[1],[4]])))
+ assert_equal(x.min(), 1)
+ assert_equal(x.min(0), matrix([[1,2,3]]))
+ assert_equal(x.min(1), matrix([[1],[4]]))
+
+ assert_equal(np.min(x), 1)
+ assert_equal(np.min(x, axis=0), matrix([[1,2,3]]))
+ assert_equal(np.min(x, axis=1), matrix([[1],[4]]))
def test_ptp(self):
x = np.arange(4).reshape((2,2))