summaryrefslogtreecommitdiff
path: root/numpy/matrixlib/defmatrix.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2011-08-27 21:46:08 -0600
committerCharles Harris <charlesr.harris@gmail.com>2011-08-27 21:46:08 -0600
commit9ecd91b7bf8c77d696ec9856ba10896d8f60309a (patch)
tree9884131ece5eada06212538c591965bf5928afa2 /numpy/matrixlib/defmatrix.py
parentaa55ba7437fbe6b8772a360a641b5aa7d3e669e0 (diff)
parent10fac981763e87f949bed15c66127fc380fa9b27 (diff)
downloadnumpy-9ecd91b7bf8c77d696ec9856ba10896d8f60309a.tar.gz
Merge branch 'pull-141'
* pull-141: (167 commits) ENH: missingdata: Make PyArray_Converter and PyArray_OutputConverter safer for legacy code DOC: missingdata: Add a mention of the design NEP, and masks vs bitpatterns DOC: missingdata: Updates from pull request feedback DOC: missingdata: Updates based on pull request feedback ENH: nditer: Change the Python nditer exposure to automatically add NPY_ITER_USE_MASKNA ENH: missingdata: Make comparisons with NA return NA(dtype='bool') BLD: core: onefile build fix and Python3 compatibility change DOC: Mention the update to np.all and np.any in the release notes TST: dtype: Adjust void dtype test to pass without raising a zero-size exception STY: Remove trailing whitespace TST: missingdata: Write some tests for the np.any and np.all NA behavior ENH: missingdata: Make numpy.all follow the NA && False == False rule ENH: missingdata: Make numpy.all follow the NA || True == True rule DOC: missingdata: Also show what assigning a non-NA value does in each case DOC: missingdata: Add introductory documentation for NA-masked arrays ENH: core: Rename PyArrayObject_fieldaccess to PyArrayObject_fields DOC: missingdata: Some tweaks to the NA mask documentation DOC: missingdata: Add example of a C-API function supporting NA masks DOC: missingdata: Documenting C API for NA-masked arrays ENH: missingdata: Finish adding C-API access to the NpyNA object ...
Diffstat (limited to 'numpy/matrixlib/defmatrix.py')
-rw-r--r--numpy/matrixlib/defmatrix.py27
1 files changed, 18 insertions, 9 deletions
diff --git a/numpy/matrixlib/defmatrix.py b/numpy/matrixlib/defmatrix.py
index aac123e59..dbf2909fd 100644
--- a/numpy/matrixlib/defmatrix.py
+++ b/numpy/matrixlib/defmatrix.py
@@ -375,6 +375,15 @@ class matrix(N.ndarray):
else:
raise ValueError("unsupported axis")
+ def _collapse(self, axis):
+ """A convenience function for operations that want to collapse
+ to a scalar like _align, but are using keepdims=True
+ """
+ if axis is None:
+ return self[0,0]
+ else:
+ return self
+
# Necessary because base-class tolist expects dimension
# reduction by x[0]
def tolist(self):
@@ -432,7 +441,7 @@ class matrix(N.ndarray):
[ 7.]])
"""
- return N.ndarray.sum(self, axis, dtype, out)._align(axis)
+ return N.ndarray.sum(self, axis, dtype, out, keepdims=True)._collapse(axis)
def mean(self, axis=None, dtype=None, out=None):
"""
@@ -466,7 +475,7 @@ class matrix(N.ndarray):
[ 9.5]])
"""
- return N.ndarray.mean(self, axis, dtype, out)._align(axis)
+ return N.ndarray.mean(self, axis, dtype, out, keepdims=True)._collapse(axis)
def std(self, axis=None, dtype=None, out=None, ddof=0):
"""
@@ -500,7 +509,7 @@ class matrix(N.ndarray):
[ 1.11803399]])
"""
- return N.ndarray.std(self, axis, dtype, out, ddof)._align(axis)
+ return N.ndarray.std(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis)
def var(self, axis=None, dtype=None, out=None, ddof=0):
"""
@@ -534,7 +543,7 @@ class matrix(N.ndarray):
[ 1.25]])
"""
- return N.ndarray.var(self, axis, dtype, out, ddof)._align(axis)
+ return N.ndarray.var(self, axis, dtype, out, ddof, keepdims=True)._collapse(axis)
def prod(self, axis=None, dtype=None, out=None):
"""
@@ -567,7 +576,7 @@ class matrix(N.ndarray):
[7920]])
"""
- return N.ndarray.prod(self, axis, dtype, out)._align(axis)
+ return N.ndarray.prod(self, axis, dtype, out, keepdims=True)._collapse(axis)
def any(self, axis=None, out=None):
"""
@@ -590,7 +599,7 @@ class matrix(N.ndarray):
returns `ndarray`
"""
- return N.ndarray.any(self, axis, out)._align(axis)
+ return N.ndarray.any(self, axis, out, keepdims=True)._collapse(axis)
def all(self, axis=None, out=None):
"""
@@ -630,7 +639,7 @@ class matrix(N.ndarray):
[False]], dtype=bool)
"""
- return N.ndarray.all(self, axis, out)._align(axis)
+ return N.ndarray.all(self, axis, out, keepdims=True)._collapse(axis)
def max(self, axis=None, out=None):
"""
@@ -665,7 +674,7 @@ class matrix(N.ndarray):
[11]])
"""
- return N.ndarray.max(self, axis, out)._align(axis)
+ return N.ndarray.max(self, axis, out, keepdims=True)._collapse(axis)
def argmax(self, axis=None, out=None):
"""
@@ -735,7 +744,7 @@ class matrix(N.ndarray):
[-11]])
"""
- return N.ndarray.min(self, axis, out)._align(axis)
+ return N.ndarray.min(self, axis, out, keepdims=True)._collapse(axis)
def argmin(self, axis=None, out=None):
"""