diff options
author | edschofield <edschofield@localhost> | 2005-11-18 16:07:41 +0000 |
---|---|---|
committer | edschofield <edschofield@localhost> | 2005-11-18 16:07:41 +0000 |
commit | dd815f35ebb406a70bb354f459a572cb069f6744 (patch) | |
tree | 44844b2d7110ef0d29e50e00110ae48622513eeb /scipy/base/matrix.py | |
parent | 4330c22b2fd3c1edb368af386364d397d0aaaceb (diff) | |
download | numpy-dd815f35ebb406a70bb354f459a572cb069f6744.tar.gz |
Bug fixes and transparent upcasting for matrix objects
Diffstat (limited to 'scipy/base/matrix.py')
-rw-r--r-- | scipy/base/matrix.py | 224 |
1 files changed, 178 insertions, 46 deletions
diff --git a/scipy/base/matrix.py b/scipy/base/matrix.py index 623ae1660..b5d8a07cd 100644 --- a/scipy/base/matrix.py +++ b/scipy/base/matrix.py @@ -39,26 +39,29 @@ def _convert_from_string(data): if count == 0: Ncols = len(newrow) elif len(newrow) != Ncols: - raise ValueError, "Rows not the same size." + raise ValueError, "rows not the same size" count += 1 newdata.append(newrow) return newdata -class matrix(N.ndarray): +class matrix(object): __array_priority__ = 10.0 - def __new__(self, data, dtype=None, copy=True): + def __init__(self, data, dtype=None, copy=True): if isinstance(data, matrix): + swapped = data.flags.swapped dtype2 = data.dtype if (dtype is None): dtype = dtype2 if (dtype2 is dtype) and (not copy): return data return data.astype(dtype) - - if dtype is None: - if isinstance(data, N.ndarray): + elif isinstance(data, N.ndarray): + swapped = data.flags.swapped + if dtype is None: dtype = data.dtype + else: + swapped = False intype = N.obj2dtype(dtype) if isinstance(data, types.StringType): @@ -66,37 +69,32 @@ class matrix(N.ndarray): # now convert data to an array arr = N.array(data, dtype=intype, copy=copy) + arr.flags.swapped = swapped ndim = arr.ndim shape = arr.shape if (ndim > 2): raise ValueError, "matrix must be 2-dimensional" elif ndim == 0: - shape = (1,1) + arr = arr.reshape((1,1)) elif ndim == 1: shape = (1,shape[0]) + arr.shape = shape + self.arr = arr - fortran = False; - if (ndim == 2) and arr.flags['FORTRAN']: - fortran = True - - if not (fortran or arr.flags['CONTIGUOUS']): - arr = arr.copy() - - ret = N.ndarray.__new__(matrix, shape, arr.dtype, buffer=arr, - fortran=fortran, - swap=arr.flags['S']) - return ret def __array_finalize__(self, obj): - ndim = self.ndim + ndim = self.arr.ndim if ndim == 0: - self.shape = (1,1) + arr.shape = (1, 1) elif ndim == 1: - self.shape = (1,self.shape[0]) + arr.shape = (1, self.arr.shape[0]) return + def __setitem__(self, index, value): + out = self.arr.__setitem__(index, value) + def __getitem__(self, index): - out = N.ndarray.__getitem__(self, index) + out = self.arr.__getitem__(index) # Need to swap if slice is on first index retscal = False try: @@ -113,26 +111,159 @@ class matrix(N.ndarray): pass if retscal and out.shape == (1,1): # convert scalars return out.A[0,0] - return out + # Return array if the output is 1-d, or matrix if the output is 2-d + if out.ndim == 2: + return matrix(out) + else: + return out + + def copy(self): + return matrix(self.arr.copy()) + + def __copy__(self): + return matrix(self.arr.copy()) + + def __add__(self, other): + return matrix(self.arr + other) + + def __radd__(self, other): + return matrix(other + self.arr) + + def __sub__(self, other): + return matrix(self.arr - other) + + def __rsub__(self, other): + return matrix(other - self.arr) def __mul__(self, other): - if isinstance(other, N.ndarray) and other.ndim == 0: - return N.multiply(self, other) + if (isinstance(other, N.ndarray) or isinstance(other, matrix)) \ + and other.ndim == 0: + return matrix(N.multiply(self.arr, other)) else: - return N.dot(self, other) + return matrix(N.dot(self.arr, other)) def __rmul__(self, other): - if isinstance(other, N.ndarray) and other.ndim == 0: - return N.multiply(other, self) + if (isinstance(other, N.ndarray) or isinstance(other, matrix)) \ + and other.ndim == 0: + return matrix(N.multiply(other, self.arr)) else: - return N.dot(other, self) + return matrix(N.dot(other, self.arr)) + + def __div__(self, other): + try: + if other.ndim == 0: + return matrix(N.divide(self.arr, other)) + else: + raise NotImplementedError, "matrix division not yet implemented" + except AttributeError: + return matrix(N.divide(self.arr, other)) + + def __rdiv__(self, other): + try: + if other.ndim == 0: + return matrix(N.divide(other, self.arr)) + else: + raise NotImplementedError, "matrix division not yet implemented" + except AttributeError: + return matrix(N.divide(other, self.arr)) + + def __iadd__(self, other): + new = self.arr + other + try: + self.arr[:] = new + except TypeError: + self.arr = new + return self + + def __isub__(self, other): + new = self.arr - other + try: + self.arr[:] = new + except TypeError: + self.arr = new + return self def __imul__(self, other): - self[:] = self * other + new = (self * other).arr + try: + self.arr[:] = new + except TypeError: + self.arr = new return self + def __idiv__(self, other): + new = (self / other).arr + try: + self.arr[:] = new + except TypeError: + self.arr = new + return self + + def __int__(self): + return int(self.arr) + + def __float__(self): + return float(self.arr) + + def __long__(self): + return long(self.arr) + + def __complex__(self): + return complex(self.arr) + + def __oct__(self): + return oct(self.arr) + + def __hex__(self): + return hex(self.arr) + + def __len__(self): + return len(self.arr) + + def __contains__(self, item): + return self.arr.__contains__(item) + + def __nonzero__(self): + return self.arr.__nonzero__() + + def __lt__(self, item): + return self.arr.__lt__(item) + + def __le__(self, item): + return self.arr.__le__(item) + + def __gt__(self, item): + return self.arr.__gt__(item) + + def __ge__(self, item): + return self.arr.__ge__(item) + + def __eq__(self, item): + return self.arr.__eq__(item) + + def __ne__(self, item): + return self.arr.__ne__(item) + + def __pos__(self): + return self.arr.__pos__() + + def __neg__(self): + return self.arr.__neg__() + + def __abs__(self): + return self.arr.__abs__() + + def __getattr__(self, obj): + return self.arr.__getattribute__(obj) + + def __setattr__(self, obj, value): + if obj in ('shape', 'arr'): + object.__setattr__(self, obj, value) + else: + self.arr.__setattr__(obj, value) + def __pow__(self, other): - shape = self.shape + shape = self.arr.shape if len(shape) != 2 or shape[0] != shape[1]: raise TypeError, "matrix is not square" if type(other) in (type(1), type(1L)): @@ -143,17 +274,17 @@ class matrix(N.ndarray): other=-other else: x=self - result = x if other <= 3: + result = x.copy() while(other>1): - result=result*x - other=other-1 + result *= x + other -= 1 return result - # binary decomposition to reduce the number of Matrix - # Multiplies for other > 3. + # binary decomposition to reduce the number of matrix + # multiplications for 'other' > 3. beta = binary_repr(other) t = len(beta) - Z,q = x.copy(),0 + Z, q = x.copy(), 0 while beta[t-q-1] == '0': Z *= Z q += 1 @@ -170,27 +301,27 @@ class matrix(N.ndarray): raise NotImplementedError def __repr__(self): - return repr(self.__array__()).replace('array','matrix') + return repr(self.arr).replace('array','matrix') def __str__(self): - return str(self.__array__()) + return str(self.arr) # Needed becase tolist method expects a[i] # to have dimension a.ndim-1 - def tolist(self): - return self.__array__().tolist() + #def tolist(self): + # return self.__array__().tolist() def getA(self): - return self.__array__() + return self.arr def getT(self): - return self.transpose() + return matrix(self.arr.transpose()) def getH(self): - if issubclass(self.dtype, N.complexfloating): - return self.transpose().conjugate() + if issubclass(self.arr.dtype, N.complexfloating): + return matrix(self.arr.transpose().conjugate()) else: - return self.transpose() + return matrix(self.arr.transpose()) def getI(self): from scipy import linalg @@ -264,3 +395,4 @@ def bmat(obj,ldict=None, gdict=None): return matrix(obj) mat = matrix + |