diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2011-08-25 17:48:26 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-08-27 07:27:01 -0600 |
commit | a43d255dbc243ea7910e1b92ba83704e1a880c70 (patch) | |
tree | de7fc08422d835218cfa45a034c190558afd04e3 /numpy/matrixlib/defmatrix.py | |
parent | 6706908c90668fffc729456c35bc8d3a3c8243c6 (diff) | |
download | numpy-a43d255dbc243ea7910e1b92ba83704e1a880c70.tar.gz |
ENH: missingdata: Make numpy.all follow the NA || True == True rule
Diffstat (limited to 'numpy/matrixlib/defmatrix.py')
-rw-r--r-- | numpy/matrixlib/defmatrix.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/numpy/matrixlib/defmatrix.py b/numpy/matrixlib/defmatrix.py index aac123e59..dbf2909fd 100644 --- a/numpy/matrixlib/defmatrix.py +++ b/numpy/matrixlib/defmatrix.py @@ -375,6 +375,15 @@ class matrix(N.ndarray): else: raise ValueError("unsupported axis") + def _collapse(self, axis): + """A convenience function for operations that want to collapse + to a scalar like _align, but are using keepdims=True + """ + if axis is None: + return self[0,0] + else: + return self + # Necessary because base-class tolist expects dimension # reduction by x[0] def tolist(self): @@ -432,7 +441,7 @@ class matrix(N.ndarray): [ 7.]]) """ - return N.ndarray.sum(self, axis, dtype, out)._align(axis) + return N.ndarray.sum(self, axis, dtype, out, keepdims=True)._collapse(axis) def mean(self, axis=None, dtype=None, out=None): """ @@ -466,7 +475,7 @@ class matrix(N.ndarray): [ 9.5]]) """ - return N.ndarray.mean(self, axis, dtype, out)._align(axis) + return N.ndarray.mean(self, axis, dtype, out, keepdims=True)._collapse(axis) def std(self, axis=None, dtype=None, out=None, ddof=0): """ @@ -500,7 +509,7 @@ class matrix(N.ndarray): [ 1.11803399]]) """ - return N.ndarray.std(self, axis, dtype, out, ddof)._align(axis) + return N.ndarray.std(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis) def var(self, axis=None, dtype=None, out=None, ddof=0): """ @@ -534,7 +543,7 @@ class matrix(N.ndarray): [ 1.25]]) """ - return N.ndarray.var(self, axis, dtype, out, ddof)._align(axis) + return N.ndarray.var(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis) def prod(self, axis=None, dtype=None, out=None): """ @@ -567,7 +576,7 @@ class matrix(N.ndarray): [7920]]) """ - return N.ndarray.prod(self, axis, dtype, out)._align(axis) + return N.ndarray.prod(self, axis, dtype, out, keepdims=True)._collapse(axis) def any(self, axis=None, out=None): """ @@ -590,7 +599,7 @@ class matrix(N.ndarray): returns `ndarray` """ - return N.ndarray.any(self, axis, out)._align(axis) + return N.ndarray.any(self, axis, out, keepdims=True)._collapse(axis) def all(self, axis=None, out=None): """ @@ -630,7 +639,7 @@ class matrix(N.ndarray): [False]], dtype=bool) """ - return N.ndarray.all(self, axis, out)._align(axis) + return N.ndarray.all(self, axis, out, keepdims=True)._collapse(axis) def max(self, axis=None, out=None): """ @@ -665,7 +674,7 @@ class matrix(N.ndarray): [11]]) """ - return N.ndarray.max(self, axis, out)._align(axis) + return N.ndarray.max(self, axis, out, keepdims=True)._collapse(axis) def argmax(self, axis=None, out=None): """ @@ -735,7 +744,7 @@ class matrix(N.ndarray): [-11]]) """ - return N.ndarray.min(self, axis, out)._align(axis) + return N.ndarray.min(self, axis, out, keepdims=True)._collapse(axis) def argmin(self, axis=None, out=None): """ |