diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2011-08-27 21:46:08 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2011-08-27 21:46:08 -0600 |
commit | 9ecd91b7bf8c77d696ec9856ba10896d8f60309a (patch) | |
tree | 9884131ece5eada06212538c591965bf5928afa2 /numpy/matrixlib/defmatrix.py | |
parent | aa55ba7437fbe6b8772a360a641b5aa7d3e669e0 (diff) | |
parent | 10fac981763e87f949bed15c66127fc380fa9b27 (diff) | |
download | numpy-9ecd91b7bf8c77d696ec9856ba10896d8f60309a.tar.gz |
Merge branch 'pull-141'
* pull-141: (167 commits)
ENH: missingdata: Make PyArray_Converter and PyArray_OutputConverter safer for legacy code
DOC: missingdata: Add a mention of the design NEP, and masks vs bitpatterns
DOC: missingdata: Updates from pull request feedback
DOC: missingdata: Updates based on pull request feedback
ENH: nditer: Change the Python nditer exposure to automatically add NPY_ITER_USE_MASKNA
ENH: missingdata: Make comparisons with NA return NA(dtype='bool')
BLD: core: onefile build fix and Python3 compatibility change
DOC: Mention the update to np.all and np.any in the release notes
TST: dtype: Adjust void dtype test to pass without raising a zero-size exception
STY: Remove trailing whitespace
TST: missingdata: Write some tests for the np.any and np.all NA behavior
ENH: missingdata: Make numpy.all follow the NA && False == False rule
ENH: missingdata: Make numpy.all follow the NA || True == True rule
DOC: missingdata: Also show what assigning a non-NA value does in each case
DOC: missingdata: Add introductory documentation for NA-masked arrays
ENH: core: Rename PyArrayObject_fieldaccess to PyArrayObject_fields
DOC: missingdata: Some tweaks to the NA mask documentation
DOC: missingdata: Add example of a C-API function supporting NA masks
DOC: missingdata: Documenting C API for NA-masked arrays
ENH: missingdata: Finish adding C-API access to the NpyNA object
...
Diffstat (limited to 'numpy/matrixlib/defmatrix.py')
-rw-r--r-- | numpy/matrixlib/defmatrix.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/numpy/matrixlib/defmatrix.py b/numpy/matrixlib/defmatrix.py index aac123e59..dbf2909fd 100644 --- a/numpy/matrixlib/defmatrix.py +++ b/numpy/matrixlib/defmatrix.py @@ -375,6 +375,15 @@ class matrix(N.ndarray): else: raise ValueError("unsupported axis") + def _collapse(self, axis): + """A convenience function for operations that want to collapse + to a scalar like _align, but are using keepdims=True + """ + if axis is None: + return self[0,0] + else: + return self + # Necessary because base-class tolist expects dimension # reduction by x[0] def tolist(self): @@ -432,7 +441,7 @@ class matrix(N.ndarray): [ 7.]]) """ - return N.ndarray.sum(self, axis, dtype, out)._align(axis) + return N.ndarray.sum(self, axis, dtype, out, keepdims=True)._collapse(axis) def mean(self, axis=None, dtype=None, out=None): """ @@ -466,7 +475,7 @@ class matrix(N.ndarray): [ 9.5]]) """ - return N.ndarray.mean(self, axis, dtype, out)._align(axis) + return N.ndarray.mean(self, axis, dtype, out, keepdims=True)._collapse(axis) def std(self, axis=None, dtype=None, out=None, ddof=0): """ @@ -500,7 +509,7 @@ class matrix(N.ndarray): [ 1.11803399]]) """ - return N.ndarray.std(self, axis, dtype, out, ddof)._align(axis) + return N.ndarray.std(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis) def var(self, axis=None, dtype=None, out=None, ddof=0): """ @@ -534,7 +543,7 @@ class matrix(N.ndarray): [ 1.25]]) """ - return N.ndarray.var(self, axis, dtype, out, ddof)._align(axis) + return N.ndarray.var(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis) def prod(self, axis=None, dtype=None, out=None): """ @@ -567,7 +576,7 @@ class matrix(N.ndarray): [7920]]) """ - return N.ndarray.prod(self, axis, dtype, out)._align(axis) + return N.ndarray.prod(self, axis, dtype, out, keepdims=True)._collapse(axis) def any(self, axis=None, out=None): """ @@ -590,7 +599,7 @@ class matrix(N.ndarray): returns `ndarray` """ - return N.ndarray.any(self, axis, out)._align(axis) + return N.ndarray.any(self, axis, out, keepdims=True)._collapse(axis) def all(self, axis=None, out=None): """ @@ -630,7 +639,7 @@ class matrix(N.ndarray): [False]], dtype=bool) """ - return N.ndarray.all(self, axis, out)._align(axis) + return N.ndarray.all(self, axis, out, keepdims=True)._collapse(axis) def max(self, axis=None, out=None): """ @@ -665,7 +674,7 @@ class matrix(N.ndarray): [11]]) """ - return N.ndarray.max(self, axis, out)._align(axis) + return N.ndarray.max(self, axis, out, keepdims=True)._collapse(axis) def argmax(self, axis=None, out=None): """ @@ -735,7 +744,7 @@ class matrix(N.ndarray): [-11]]) """ - return N.ndarray.min(self, axis, out)._align(axis) + return N.ndarray.min(self, axis, out, keepdims=True)._collapse(axis) def argmin(self, axis=None, out=None): """ |