summaryrefslogtreecommitdiff
path: root/scipy/base/matrix.py
diff options
context:
space:
mode:
authoredschofield <edschofield@localhost>2005-11-18 16:07:41 +0000
committeredschofield <edschofield@localhost>2005-11-18 16:07:41 +0000
commitdd815f35ebb406a70bb354f459a572cb069f6744 (patch)
tree44844b2d7110ef0d29e50e00110ae48622513eeb /scipy/base/matrix.py
parent4330c22b2fd3c1edb368af386364d397d0aaaceb (diff)
downloadnumpy-dd815f35ebb406a70bb354f459a572cb069f6744.tar.gz
Bug fixes and transparent upcasting for matrix objects
Diffstat (limited to 'scipy/base/matrix.py')
-rw-r--r--scipy/base/matrix.py224
1 files changed, 178 insertions, 46 deletions
diff --git a/scipy/base/matrix.py b/scipy/base/matrix.py
index 623ae1660..b5d8a07cd 100644
--- a/scipy/base/matrix.py
+++ b/scipy/base/matrix.py
@@ -39,26 +39,29 @@ def _convert_from_string(data):
if count == 0:
Ncols = len(newrow)
elif len(newrow) != Ncols:
- raise ValueError, "Rows not the same size."
+ raise ValueError, "rows not the same size"
count += 1
newdata.append(newrow)
return newdata
-class matrix(N.ndarray):
+class matrix(object):
__array_priority__ = 10.0
- def __new__(self, data, dtype=None, copy=True):
+ def __init__(self, data, dtype=None, copy=True):
if isinstance(data, matrix):
+ swapped = data.flags.swapped
dtype2 = data.dtype
if (dtype is None):
dtype = dtype2
if (dtype2 is dtype) and (not copy):
return data
return data.astype(dtype)
-
- if dtype is None:
- if isinstance(data, N.ndarray):
+ elif isinstance(data, N.ndarray):
+ swapped = data.flags.swapped
+ if dtype is None:
dtype = data.dtype
+ else:
+ swapped = False
intype = N.obj2dtype(dtype)
if isinstance(data, types.StringType):
@@ -66,37 +69,32 @@ class matrix(N.ndarray):
# now convert data to an array
arr = N.array(data, dtype=intype, copy=copy)
+ arr.flags.swapped = swapped
ndim = arr.ndim
shape = arr.shape
if (ndim > 2):
raise ValueError, "matrix must be 2-dimensional"
elif ndim == 0:
- shape = (1,1)
+ arr = arr.reshape((1,1))
elif ndim == 1:
shape = (1,shape[0])
+ arr.shape = shape
+ self.arr = arr
- fortran = False;
- if (ndim == 2) and arr.flags['FORTRAN']:
- fortran = True
-
- if not (fortran or arr.flags['CONTIGUOUS']):
- arr = arr.copy()
-
- ret = N.ndarray.__new__(matrix, shape, arr.dtype, buffer=arr,
- fortran=fortran,
- swap=arr.flags['S'])
- return ret
def __array_finalize__(self, obj):
- ndim = self.ndim
+ ndim = self.arr.ndim
if ndim == 0:
- self.shape = (1,1)
+ arr.shape = (1, 1)
elif ndim == 1:
- self.shape = (1,self.shape[0])
+ arr.shape = (1, self.arr.shape[0])
return
+ def __setitem__(self, index, value):
+ out = self.arr.__setitem__(index, value)
+
def __getitem__(self, index):
- out = N.ndarray.__getitem__(self, index)
+ out = self.arr.__getitem__(index)
# Need to swap if slice is on first index
retscal = False
try:
@@ -113,26 +111,159 @@ class matrix(N.ndarray):
pass
if retscal and out.shape == (1,1): # convert scalars
return out.A[0,0]
- return out
+ # Return array if the output is 1-d, or matrix if the output is 2-d
+ if out.ndim == 2:
+ return matrix(out)
+ else:
+ return out
+
+ def copy(self):
+ return matrix(self.arr.copy())
+
+ def __copy__(self):
+ return matrix(self.arr.copy())
+
+ def __add__(self, other):
+ return matrix(self.arr + other)
+
+ def __radd__(self, other):
+ return matrix(other + self.arr)
+
+ def __sub__(self, other):
+ return matrix(self.arr - other)
+
+ def __rsub__(self, other):
+ return matrix(other - self.arr)
def __mul__(self, other):
- if isinstance(other, N.ndarray) and other.ndim == 0:
- return N.multiply(self, other)
+ if (isinstance(other, N.ndarray) or isinstance(other, matrix)) \
+ and other.ndim == 0:
+ return matrix(N.multiply(self.arr, other))
else:
- return N.dot(self, other)
+ return matrix(N.dot(self.arr, other))
def __rmul__(self, other):
- if isinstance(other, N.ndarray) and other.ndim == 0:
- return N.multiply(other, self)
+ if (isinstance(other, N.ndarray) or isinstance(other, matrix)) \
+ and other.ndim == 0:
+ return matrix(N.multiply(other, self.arr))
else:
- return N.dot(other, self)
+ return matrix(N.dot(other, self.arr))
+
+ def __div__(self, other):
+ try:
+ if other.ndim == 0:
+ return matrix(N.divide(self.arr, other))
+ else:
+ raise NotImplementedError, "matrix division not yet implemented"
+ except AttributeError:
+ return matrix(N.divide(self.arr, other))
+
+ def __rdiv__(self, other):
+ try:
+ if other.ndim == 0:
+ return matrix(N.divide(other, self.arr))
+ else:
+ raise NotImplementedError, "matrix division not yet implemented"
+ except AttributeError:
+ return matrix(N.divide(other, self.arr))
+
+ def __iadd__(self, other):
+ new = self.arr + other
+ try:
+ self.arr[:] = new
+ except TypeError:
+ self.arr = new
+ return self
+
+ def __isub__(self, other):
+ new = self.arr - other
+ try:
+ self.arr[:] = new
+ except TypeError:
+ self.arr = new
+ return self
def __imul__(self, other):
- self[:] = self * other
+ new = (self * other).arr
+ try:
+ self.arr[:] = new
+ except TypeError:
+ self.arr = new
return self
+ def __idiv__(self, other):
+ new = (self / other).arr
+ try:
+ self.arr[:] = new
+ except TypeError:
+ self.arr = new
+ return self
+
+ def __int__(self):
+ return int(self.arr)
+
+ def __float__(self):
+ return float(self.arr)
+
+ def __long__(self):
+ return long(self.arr)
+
+ def __complex__(self):
+ return complex(self.arr)
+
+ def __oct__(self):
+ return oct(self.arr)
+
+ def __hex__(self):
+ return hex(self.arr)
+
+ def __len__(self):
+ return len(self.arr)
+
+ def __contains__(self, item):
+ return self.arr.__contains__(item)
+
+ def __nonzero__(self):
+ return self.arr.__nonzero__()
+
+ def __lt__(self, item):
+ return self.arr.__lt__(item)
+
+ def __le__(self, item):
+ return self.arr.__le__(item)
+
+ def __gt__(self, item):
+ return self.arr.__gt__(item)
+
+ def __ge__(self, item):
+ return self.arr.__ge__(item)
+
+ def __eq__(self, item):
+ return self.arr.__eq__(item)
+
+ def __ne__(self, item):
+ return self.arr.__ne__(item)
+
+ def __pos__(self):
+ return self.arr.__pos__()
+
+ def __neg__(self):
+ return self.arr.__neg__()
+
+ def __abs__(self):
+ return self.arr.__abs__()
+
+ def __getattr__(self, obj):
+ return self.arr.__getattribute__(obj)
+
+ def __setattr__(self, obj, value):
+ if obj in ('shape', 'arr'):
+ object.__setattr__(self, obj, value)
+ else:
+ self.arr.__setattr__(obj, value)
+
def __pow__(self, other):
- shape = self.shape
+ shape = self.arr.shape
if len(shape) != 2 or shape[0] != shape[1]:
raise TypeError, "matrix is not square"
if type(other) in (type(1), type(1L)):
@@ -143,17 +274,17 @@ class matrix(N.ndarray):
other=-other
else:
x=self
- result = x
if other <= 3:
+ result = x.copy()
while(other>1):
- result=result*x
- other=other-1
+ result *= x
+ other -= 1
return result
- # binary decomposition to reduce the number of Matrix
- # Multiplies for other > 3.
+ # binary decomposition to reduce the number of matrix
+ # multiplications for 'other' > 3.
beta = binary_repr(other)
t = len(beta)
- Z,q = x.copy(),0
+ Z, q = x.copy(), 0
while beta[t-q-1] == '0':
Z *= Z
q += 1
@@ -170,27 +301,27 @@ class matrix(N.ndarray):
raise NotImplementedError
def __repr__(self):
- return repr(self.__array__()).replace('array','matrix')
+ return repr(self.arr).replace('array','matrix')
def __str__(self):
- return str(self.__array__())
+ return str(self.arr)
# Needed becase tolist method expects a[i]
# to have dimension a.ndim-1
- def tolist(self):
- return self.__array__().tolist()
+ #def tolist(self):
+ # return self.__array__().tolist()
def getA(self):
- return self.__array__()
+ return self.arr
def getT(self):
- return self.transpose()
+ return matrix(self.arr.transpose())
def getH(self):
- if issubclass(self.dtype, N.complexfloating):
- return self.transpose().conjugate()
+ if issubclass(self.arr.dtype, N.complexfloating):
+ return matrix(self.arr.transpose().conjugate())
else:
- return self.transpose()
+ return matrix(self.arr.transpose())
def getI(self):
from scipy import linalg
@@ -264,3 +395,4 @@ def bmat(obj,ldict=None, gdict=None):
return matrix(obj)
mat = matrix
+