diff options
Diffstat (limited to 'numpy/core/defmatrix.py')
-rw-r--r-- | numpy/core/defmatrix.py | 112 |
1 files changed, 78 insertions, 34 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index d8a5a14dc..5582f2878 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -2,7 +2,8 @@ __all__ = ['matrix', 'bmat', 'mat', 'asmatrix'] import sys import numeric as N -from numeric import concatenate, isscalar, binary_repr +from numeric import concatenate, isscalar, binary_repr, identity +from numpy.lib.utils import issubdtype # make translation table _table = [None]*256 @@ -46,6 +47,81 @@ def asmatrix(data, dtype=None): """ return matrix(data, dtype=dtype, copy=False) +def matrix_power(M,n): + """Raise a square matrix to the (integer) power n. + + For positive integers n, the power is computed by repeated matrix + squarings and matrix multiplications. If n=0, the identity matrix + of the same type as M is returned. If n<0, the inverse is computed + and raised to the exponent. + + Parameters + ---------- + M : array-like + Must be a square array (that is, of dimension two and with + equal sizes). + n : integer + The exponent can be any integer or long integer, positive + negative or zero. + + Returns + ------- + M to the power n + The return value is a an array the same shape and size as M; + if the exponent was positive or zero then the type of the + elements is the same as those of M. If the exponent was negative + the elements are floating-point. + + Raises + ------ + LinAlgException + If the matrix is not numerically invertible, an exception is raised. + + See Also + -------- + The matrix() class provides an equivalent function as the exponentiation + operator. + + Examples + -------- + >>> matrix_power(array([[0,1],[-1,0]]),10) + array([[-1, 0], + [ 0, -1]]) + """ + if len(M.shape) != 2 or M.shape[0] != M.shape[1]: + raise ValueError("input must be a square array") + if not issubdtype(type(n),int): + raise TypeError("exponent must be an integer") + + from numpy.linalg import inv + + if n==0: + M = M.copy() + M[:] = identity(M.shape[0]) + return M + elif n<0: + M = inv(M) + n *= -1 + + result = M + if n <= 3: + for _ in range(n-1): + result=N.dot(result,M) + return result + + # binary decomposition to reduce the number of Matrix + # multiplications for n > 3. + beta = binary_repr(n) + Z,q,t = M,0,len(beta) + while beta[t-q-1] == '0': + Z = N.dot(Z,Z) + q += 1 + result = Z + for k in range(q+1,t): + Z = N.dot(Z,Z) + if beta[t-k-1] == '1': + result = N.dot(result,Z) + return result class matrix(N.ndarray): @@ -195,39 +271,7 @@ class matrix(N.ndarray): return self def __pow__(self, other): - shape = self.shape - if len(shape) != 2 or shape[0] != shape[1]: - raise TypeError, "matrix is not square" - if type(other) in (type(1), type(1L)): - if other==0: - return matrix(N.identity(shape[0])) - if other<0: - x = self.I - other=-other - else: - x=self - result = x - if other <= 3: - while(other>1): - result=result*x - other=other-1 - return result - # binary decomposition to reduce the number of Matrix - # Multiplies for other > 3. - beta = binary_repr(other) - t = len(beta) - Z,q = x.copy(),0 - while beta[t-q-1] == '0': - Z *= Z - q += 1 - result = Z.copy() - for k in range(q+1,t): - Z *= Z - if beta[t-k-1] == '1': - result *= Z - return result - else: - raise TypeError, "exponent must be an integer" + return matrix_power(self, other) def __rpow__(self, other): return NotImplemented |