summaryrefslogtreecommitdiff
path: root/numpy/core/defmatrix.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r--numpy/core/defmatrix.py98
1 files changed, 33 insertions, 65 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py
index 617fc811f..24556da43 100644
--- a/numpy/core/defmatrix.py
+++ b/numpy/core/defmatrix.py
@@ -50,6 +50,8 @@ def asmatrix(data, dtype=None):
"""
return matrix(data, dtype=dtype, copy=False)
+
+
class matrix(N.ndarray):
__array_priority__ = 10.0
def __new__(subtype, data, dtype=None, copy=True):
@@ -142,6 +144,7 @@ class matrix(N.ndarray):
for val in shp:
if (val > 1): truend += 1
return truend
+
def __mul__(self, other):
if isinstance(other, N.ndarray) or N.isscalar(other) or \
@@ -201,94 +204,59 @@ class matrix(N.ndarray):
def __str__(self):
return str(self.__array__())
+ def _align(self, axis):
+ """A convenience function for operations that need to preserve axis
+ orientation.
+ """
+ if axis is None:
+ return self[0,0]
+ elif axis==0:
+ return self
+ elif axis==1:
+ return self.transpose()
+ else:
+ raise ValueError, "unsupported axis"
+
# To preserve orientation of result...
def sum(self, axis=None, dtype=None):
"""Sum the matrix over the given axis. If the axis is None, sum
over all dimensions. This preserves the orientation of the
result as a row or column.
"""
- s = N.ndarray.sum(self, axis, dtype)
- if axis==1:
- return s.transpose()
- else:
- return s
+ return N.ndarray.sum(self, axis, dtype)._align(axis)
def mean(self, axis=None):
- s = N.ndarray.mean(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
+ return N.ndarray.mean(self, axis)._align(axis)
def std(self, axis=None, dtype=None):
- s = N.ndarray.std(self, axis, dtype)
- if axis==1:
- return s.transpose()
- else:
- return s
+ return N.ndarray.std(self, axis, dtype)._align(axis)
def var(self, axis=None, dtype=None):
- s = N.ndarray.var(self, axis, dtype)
- if axis==1:
- return s.transpose()
- else:
- return s
-
+ return N.ndarray.var(self, axis, dtype)._align(axis)
+
def prod(self, axis=None, dtype=None):
- s = N.ndarray.prod(self, axis, dtype)
- if axis==1:
- return s.transpose()
- else:
- return s
-
+ return N.ndarray.prod(self, axis, dtype)._align(axis)
+
def any(self, axis=None):
- s = N.ndarray.any(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
+ return N.ndarray.any(self, axis)._align(axis)
def all(self, axis=None):
- s = N.ndarray.all(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
+ return N.ndarray.all(self, axis)._align(axis)
def max(self, axis=None):
- s = N.ndarray.max(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
+ return N.ndarray.max(self, axis)._align(axis)
def argmax(self, axis=None):
- s = N.ndarray.argmax(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
-
+ return N.ndarray.argmax(self, axis)._align(axis)
+
def min(self, axis=None):
- s = N.ndarray.min(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
-
+ return N.ndarray.min(self, axis)._align(axis)
+
def argmin(self, axis=None):
- s = N.ndarray.argmin(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
-
+ return N.ndarray.argmin(self, axis)._align(axis)
+
def ptp(self, axis=None):
- s = N.ndarray.ptp(self, axis)
- if axis==1:
- return s.transpose()
- else:
- return s
+ return N.ndarray.ptp(self, axis)._align(axis)
# Needed becase tolist method expects a[i]
# to have dimension a.ndim-1