diff options
Diffstat (limited to 'numpy/matrixlib')
-rw-r--r-- | numpy/matrixlib/defmatrix.py | 27 | ||||
-rw-r--r-- | numpy/matrixlib/tests/test_defmatrix.py | 36 |
2 files changed, 44 insertions, 19 deletions
diff --git a/numpy/matrixlib/defmatrix.py b/numpy/matrixlib/defmatrix.py index aac123e59..dbf2909fd 100644 --- a/numpy/matrixlib/defmatrix.py +++ b/numpy/matrixlib/defmatrix.py @@ -375,6 +375,15 @@ class matrix(N.ndarray): else: raise ValueError("unsupported axis") + def _collapse(self, axis): + """A convenience function for operations that want to collapse + to a scalar like _align, but are using keepdims=True + """ + if axis is None: + return self[0,0] + else: + return self + # Necessary because base-class tolist expects dimension # reduction by x[0] def tolist(self): @@ -432,7 +441,7 @@ class matrix(N.ndarray): [ 7.]]) """ - return N.ndarray.sum(self, axis, dtype, out)._align(axis) + return N.ndarray.sum(self, axis, dtype, out, keepdims=True)._collapse(axis) def mean(self, axis=None, dtype=None, out=None): """ @@ -466,7 +475,7 @@ class matrix(N.ndarray): [ 9.5]]) """ - return N.ndarray.mean(self, axis, dtype, out)._align(axis) + return N.ndarray.mean(self, axis, dtype, out, keepdims=True)._collapse(axis) def std(self, axis=None, dtype=None, out=None, ddof=0): """ @@ -500,7 +509,7 @@ class matrix(N.ndarray): [ 1.11803399]]) """ - return N.ndarray.std(self, axis, dtype, out, ddof)._align(axis) + return N.ndarray.std(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis) def var(self, axis=None, dtype=None, out=None, ddof=0): """ @@ -534,7 +543,7 @@ class matrix(N.ndarray): [ 1.25]]) """ - return N.ndarray.var(self, axis, dtype, out, ddof)._align(axis) + return N.ndarray.var(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis) def prod(self, axis=None, dtype=None, out=None): """ @@ -567,7 +576,7 @@ class matrix(N.ndarray): [7920]]) """ - return N.ndarray.prod(self, axis, dtype, out)._align(axis) + return N.ndarray.prod(self, axis, dtype, out, keepdims=True)._collapse(axis) def any(self, axis=None, out=None): """ @@ -590,7 +599,7 @@ class matrix(N.ndarray): returns `ndarray` """ - return N.ndarray.any(self, axis, out)._align(axis) + return N.ndarray.any(self, axis, out, keepdims=True)._collapse(axis) def all(self, axis=None, out=None): """ @@ -630,7 +639,7 @@ class matrix(N.ndarray): [False]], dtype=bool) """ - return N.ndarray.all(self, axis, out)._align(axis) + return N.ndarray.all(self, axis, out, keepdims=True)._collapse(axis) def max(self, axis=None, out=None): """ @@ -665,7 +674,7 @@ class matrix(N.ndarray): [11]]) """ - return N.ndarray.max(self, axis, out)._align(axis) + return N.ndarray.max(self, axis, out, keepdims=True)._collapse(axis) def argmax(self, axis=None, out=None): """ @@ -735,7 +744,7 @@ class matrix(N.ndarray): [-11]]) """ - return N.ndarray.min(self, axis, out)._align(axis) + return N.ndarray.min(self, axis, out, keepdims=True)._collapse(axis) def argmin(self, axis=None, out=None): """ diff --git a/numpy/matrixlib/tests/test_defmatrix.py b/numpy/matrixlib/tests/test_defmatrix.py index 09a4b4892..0a181fca3 100644 --- a/numpy/matrixlib/tests/test_defmatrix.py +++ b/numpy/matrixlib/tests/test_defmatrix.py @@ -65,29 +65,45 @@ class TestProperties(TestCase): sumall = 30 assert_array_equal(sum0, M.sum(axis=0)) assert_array_equal(sum1, M.sum(axis=1)) - assert_(sumall == M.sum()) + assert_equal(sumall, M.sum()) + + assert_array_equal(sum0, np.sum(M, axis=0)) + assert_array_equal(sum1, np.sum(M, axis=1)) + assert_equal(sumall, np.sum(M)) def test_prod(self): x = matrix([[1,2,3],[4,5,6]]) - assert_(x.prod() == 720) - assert_(all(x.prod(0) == matrix([[4,10,18]]))) - assert_(all(x.prod(1) == matrix([[6],[120]]))) + assert_equal(x.prod(), 720) + assert_equal(x.prod(0), matrix([[4,10,18]])) + assert_equal(x.prod(1), matrix([[6],[120]])) + + assert_equal(np.prod(x), 720) + assert_equal(np.prod(x, axis=0), matrix([[4,10,18]])) + assert_equal(np.prod(x, axis=1), matrix([[6],[120]])) y = matrix([0,1,3]) assert_(y.prod() == 0) def test_max(self): x = matrix([[1,2,3],[4,5,6]]) - assert_(x.max() == 6) - assert_(all(x.max(0) == matrix([[4,5,6]]))) - assert_(all(x.max(1) == matrix([[3],[6]]))) + assert_equal(x.max(), 6) + assert_equal(x.max(0), matrix([[4,5,6]])) + assert_equal(x.max(1), matrix([[3],[6]])) + + assert_equal(np.max(x), 6) + assert_equal(np.max(x, axis=0), matrix([[4,5,6]])) + assert_equal(np.max(x, axis=1), matrix([[3],[6]])) def test_min(self): x = matrix([[1,2,3],[4,5,6]]) - assert_(x.min() == 1) - assert_(all(x.min(0) == matrix([[1,2,3]]))) - assert_(all(x.min(1) == matrix([[1],[4]]))) + assert_equal(x.min(), 1) + assert_equal(x.min(0), matrix([[1,2,3]])) + assert_equal(x.min(1), matrix([[1],[4]])) + + assert_equal(np.min(x), 1) + assert_equal(np.min(x, axis=0), matrix([[1,2,3]])) + assert_equal(np.min(x, axis=1), matrix([[1],[4]])) def test_ptp(self): x = np.arange(4).reshape((2,2)) |