diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-01-15 05:52:41 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-01-15 05:52:41 +0000 |
commit | 03c804239307a3d71b3df0dc69e4538dcf80f1c8 (patch) | |
tree | ebbdea7da0b7de608e978dff5a2395a186647ace /numpy/core/defmatrix.py | |
parent | ded59130b65a2046f2aea9a4d3d255437b527141 (diff) | |
download | numpy-03c804239307a3d71b3df0dc69e4538dcf80f1c8.tar.gz |
Speed up matrix-vector calculations using matrix class.
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r-- | numpy/core/defmatrix.py | 47 |
1 files changed, 41 insertions, 6 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index 1a9264380..63217b1ba 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -127,17 +127,52 @@ class matrix(N.ndarray): return out.A[0,0] return out + def _get_truend(self): + shp = self.shape + truend = 0 + for val in shp: + if (val > 1): truend += 1 + return truend + def __mul__(self, other): - if isinstance(other, N.ndarray) and other.ndim == 0: - return N.multiply(self, other) + if (self.ndim != self._get_truend()): + myself = self.A.squeeze() else: - return N.dot(self, other) + myself = self + + myother = other + if isinstance(other, matrix): + if (other.ndim != other._get_truend()): + myother = other.A.squeeze() + + if isinstance(myother, N.ndarray) and myother.size == 1: + res = N.multiply(myself, myother.item()) + else: + res = N.dot(myself, myother) + + if not isinstance(res, matrix): + res = res.view(matrix) + return res def __rmul__(self, other): - if isinstance(other, N.ndarray) and other.ndim == 0: - return N.multiply(other, self) + if (self.ndim != self._get_truend()): + myself = self.A.squeeze() else: - return N.dot(other, self) + myself = self + + myother = other + if isinstance(other, matrix): + if (other.ndim != other._get_truend()): + myother = other.A.squeeze() + + if isinstance(myother, N.ndarray) and myother.size == 1: + res = N.multiply(myother.item(), myself) + else: + res = N.dot(myother, myself) + + if not isinstance(res, matrix): + res = res.view(matrix) + return res def __imul__(self, other): self[:] = self * other |