summaryrefslogtreecommitdiff
path: root/numpy/core/defmatrix.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-01-15 05:52:41 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-01-15 05:52:41 +0000
commit03c804239307a3d71b3df0dc69e4538dcf80f1c8 (patch)
treeebbdea7da0b7de608e978dff5a2395a186647ace /numpy/core/defmatrix.py
parentded59130b65a2046f2aea9a4d3d255437b527141 (diff)
downloadnumpy-03c804239307a3d71b3df0dc69e4538dcf80f1c8.tar.gz
Speed up matrix-vector calculations using matrix class.
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r--numpy/core/defmatrix.py47
1 files changed, 41 insertions, 6 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py
index 1a9264380..63217b1ba 100644
--- a/numpy/core/defmatrix.py
+++ b/numpy/core/defmatrix.py
@@ -127,17 +127,52 @@ class matrix(N.ndarray):
return out.A[0,0]
return out
+ def _get_truend(self):
+ shp = self.shape
+ truend = 0
+ for val in shp:
+ if (val > 1): truend += 1
+ return truend
+
def __mul__(self, other):
- if isinstance(other, N.ndarray) and other.ndim == 0:
- return N.multiply(self, other)
+ if (self.ndim != self._get_truend()):
+ myself = self.A.squeeze()
else:
- return N.dot(self, other)
+ myself = self
+
+ myother = other
+ if isinstance(other, matrix):
+ if (other.ndim != other._get_truend()):
+ myother = other.A.squeeze()
+
+ if isinstance(myother, N.ndarray) and myother.size == 1:
+ res = N.multiply(myself, myother.item())
+ else:
+ res = N.dot(myself, myother)
+
+ if not isinstance(res, matrix):
+ res = res.view(matrix)
+ return res
def __rmul__(self, other):
- if isinstance(other, N.ndarray) and other.ndim == 0:
- return N.multiply(other, self)
+ if (self.ndim != self._get_truend()):
+ myself = self.A.squeeze()
else:
- return N.dot(other, self)
+ myself = self
+
+ myother = other
+ if isinstance(other, matrix):
+ if (other.ndim != other._get_truend()):
+ myother = other.A.squeeze()
+
+ if isinstance(myother, N.ndarray) and myother.size == 1:
+ res = N.multiply(myother.item(), myself)
+ else:
+ res = N.dot(myother, myself)
+
+ if not isinstance(res, matrix):
+ res = res.view(matrix)
+ return res
def __imul__(self, other):
self[:] = self * other