diff options
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r-- | numpy/core/defmatrix.py | 98 |
1 files changed, 33 insertions, 65 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index 617fc811f..24556da43 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -50,6 +50,8 @@ def asmatrix(data, dtype=None): """ return matrix(data, dtype=dtype, copy=False) + + class matrix(N.ndarray): __array_priority__ = 10.0 def __new__(subtype, data, dtype=None, copy=True): @@ -142,6 +144,7 @@ class matrix(N.ndarray): for val in shp: if (val > 1): truend += 1 return truend + def __mul__(self, other): if isinstance(other, N.ndarray) or N.isscalar(other) or \ @@ -201,94 +204,59 @@ class matrix(N.ndarray): def __str__(self): return str(self.__array__()) + def _align(self, axis): + """A convenience function for operations that need to preserve axis + orientation. + """ + if axis is None: + return self[0,0] + elif axis==0: + return self + elif axis==1: + return self.transpose() + else: + raise ValueError, "unsupported axis" + # To preserve orientation of result... def sum(self, axis=None, dtype=None): """Sum the matrix over the given axis. If the axis is None, sum over all dimensions. This preserves the orientation of the result as a row or column. """ - s = N.ndarray.sum(self, axis, dtype) - if axis==1: - return s.transpose() - else: - return s + return N.ndarray.sum(self, axis, dtype)._align(axis) def mean(self, axis=None): - s = N.ndarray.mean(self, axis) - if axis==1: - return s.transpose() - else: - return s + return N.ndarray.mean(self, axis)._align(axis) def std(self, axis=None, dtype=None): - s = N.ndarray.std(self, axis, dtype) - if axis==1: - return s.transpose() - else: - return s + return N.ndarray.std(self, axis, dtype)._align(axis) def var(self, axis=None, dtype=None): - s = N.ndarray.var(self, axis, dtype) - if axis==1: - return s.transpose() - else: - return s - + return N.ndarray.var(self, axis, dtype)._align(axis) + def prod(self, axis=None, dtype=None): - s = N.ndarray.prod(self, axis, dtype) - if axis==1: - return s.transpose() - else: - return s - + return N.ndarray.prod(self, axis, dtype)._align(axis) + def any(self, axis=None): - s = N.ndarray.any(self, axis) - if axis==1: - return s.transpose() - else: - return s + return N.ndarray.any(self, axis)._align(axis) def all(self, axis=None): - s = N.ndarray.all(self, axis) - if axis==1: - return s.transpose() - else: - return s + return N.ndarray.all(self, axis)._align(axis) def max(self, axis=None): - s = N.ndarray.max(self, axis) - if axis==1: - return s.transpose() - else: - return s + return N.ndarray.max(self, axis)._align(axis) def argmax(self, axis=None): - s = N.ndarray.argmax(self, axis) - if axis==1: - return s.transpose() - else: - return s - + return N.ndarray.argmax(self, axis)._align(axis) + def min(self, axis=None): - s = N.ndarray.min(self, axis) - if axis==1: - return s.transpose() - else: - return s - + return N.ndarray.min(self, axis)._align(axis) + def argmin(self, axis=None): - s = N.ndarray.argmin(self, axis) - if axis==1: - return s.transpose() - else: - return s - + return N.ndarray.argmin(self, axis)._align(axis) + def ptp(self, axis=None): - s = N.ndarray.ptp(self, axis) - if axis==1: - return s.transpose() - else: - return s + return N.ndarray.ptp(self, axis)._align(axis) # Needed becase tolist method expects a[i] # to have dimension a.ndim-1 |