diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2008-04-06 02:34:19 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2008-04-06 02:34:19 +0000 |
commit | f339b6c31419e77f576e8b2364e186db546135e7 (patch) | |
tree | 896f3a44cf9253f4f105a151c8cba73599b80f16 | |
parent | c24510c81f54547dbc48f1c60b01d0109a967af1 (diff) | |
download | numpy-f339b6c31419e77f576e8b2364e186db546135e7.tar.gz |
Factor out matrix_multiply from defmatrix. Based on a patch by
Anne Archibald.
-rw-r--r-- | numpy/core/defmatrix.py | 112 | ||||
-rw-r--r-- | numpy/linalg/info.py | 1 | ||||
-rw-r--r-- | numpy/linalg/linalg.py | 6 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 70 |
4 files changed, 146 insertions, 43 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index d8a5a14dc..5582f2878 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -2,7 +2,8 @@ __all__ = ['matrix', 'bmat', 'mat', 'asmatrix'] import sys import numeric as N -from numeric import concatenate, isscalar, binary_repr +from numeric import concatenate, isscalar, binary_repr, identity +from numpy.lib.utils import issubdtype # make translation table _table = [None]*256 @@ -46,6 +47,81 @@ def asmatrix(data, dtype=None): """ return matrix(data, dtype=dtype, copy=False) +def matrix_power(M,n): + """Raise a square matrix to the (integer) power n. + + For positive integers n, the power is computed by repeated matrix + squarings and matrix multiplications. If n=0, the identity matrix + of the same type as M is returned. If n<0, the inverse is computed + and raised to the exponent. + + Parameters + ---------- + M : array-like + Must be a square array (that is, of dimension two and with + equal sizes). + n : integer + The exponent can be any integer or long integer, positive + negative or zero. + + Returns + ------- + M to the power n + The return value is a an array the same shape and size as M; + if the exponent was positive or zero then the type of the + elements is the same as those of M. If the exponent was negative + the elements are floating-point. + + Raises + ------ + LinAlgException + If the matrix is not numerically invertible, an exception is raised. + + See Also + -------- + The matrix() class provides an equivalent function as the exponentiation + operator. + + Examples + -------- + >>> matrix_power(array([[0,1],[-1,0]]),10) + array([[-1, 0], + [ 0, -1]]) + """ + if len(M.shape) != 2 or M.shape[0] != M.shape[1]: + raise ValueError("input must be a square array") + if not issubdtype(type(n),int): + raise TypeError("exponent must be an integer") + + from numpy.linalg import inv + + if n==0: + M = M.copy() + M[:] = identity(M.shape[0]) + return M + elif n<0: + M = inv(M) + n *= -1 + + result = M + if n <= 3: + for _ in range(n-1): + result=N.dot(result,M) + return result + + # binary decomposition to reduce the number of Matrix + # multiplications for n > 3. + beta = binary_repr(n) + Z,q,t = M,0,len(beta) + while beta[t-q-1] == '0': + Z = N.dot(Z,Z) + q += 1 + result = Z + for k in range(q+1,t): + Z = N.dot(Z,Z) + if beta[t-k-1] == '1': + result = N.dot(result,Z) + return result class matrix(N.ndarray): @@ -195,39 +271,7 @@ class matrix(N.ndarray): return self def __pow__(self, other): - shape = self.shape - if len(shape) != 2 or shape[0] != shape[1]: - raise TypeError, "matrix is not square" - if type(other) in (type(1), type(1L)): - if other==0: - return matrix(N.identity(shape[0])) - if other<0: - x = self.I - other=-other - else: - x=self - result = x - if other <= 3: - while(other>1): - result=result*x - other=other-1 - return result - # binary decomposition to reduce the number of Matrix - # Multiplies for other > 3. - beta = binary_repr(other) - t = len(beta) - Z,q = x.copy(),0 - while beta[t-q-1] == '0': - Z *= Z - q += 1 - result = Z.copy() - for k in range(q+1,t): - Z *= Z - if beta[t-k-1] == '1': - result *= Z - return result - else: - raise TypeError, "exponent must be an integer" + return matrix_power(self, other) def __rpow__(self, other): return NotImplemented diff --git a/numpy/linalg/info.py b/numpy/linalg/info.py index 25afdec1b..235822dfa 100644 --- a/numpy/linalg/info.py +++ b/numpy/linalg/info.py @@ -10,6 +10,7 @@ Linear algebra basics: - lstsq Solve linear least-squares problem - pinv Pseudo-inverse (Moore-Penrose) calculated using a singular value decomposition +- matrix_power Integer power of a square matrix Eigenvalues and decompositions: diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 1ee2b9bf8..c4575e377 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -9,7 +9,7 @@ dgeev, zgeev, dgesdd, zgesdd, dgelsd, zgelsd, dsyevd, zheevd, dgetrf, zgetrf, dpotrf, zpotrf, dgeqrf, zgeqrf, zungqr, dorgqr. """ -__all__ = ['solve', 'tensorsolve', 'tensorinv', +__all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv', 'cholesky', 'eigvals', 'eigvalsh', 'pinv', @@ -26,6 +26,7 @@ from numpy.core import array, asarray, zeros, empty, transpose, \ isfinite, size from numpy.lib import triu from numpy.linalg import lapack_lite +from numpy.core.defmatrix import matrix_power fortran_int = intc @@ -134,6 +135,7 @@ def _assertNonEmpty(*arrays): if size(a) == 0: raise LinAlgError("Arrays cannot be empty") + # Linear equations def tensorsolve(a, b, axes=None): @@ -326,6 +328,7 @@ def inv(a): a, wrap = _makearray(a) return wrap(solve(a, identity(a.shape[0], dtype=a.dtype))) + # Cholesky decomposition def cholesky(a): @@ -1053,6 +1056,7 @@ def det(a): sign = add.reduce(pivots != arange(1, n+1)) % 2 return (1.-2.*sign)*multiply.reduce(diagonal(a), axis=-1) + # Linear Least Squares def lstsq(a, b, rcond=-1): diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 1b380fa12..0ecd8e234 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -6,6 +6,7 @@ set_package_path() from numpy import array, single, double, csingle, cdouble, dot, identity, \ multiply, atleast_2d from numpy import linalg +from linalg import matrix_power restore_path() old_assert_almost_equal = assert_almost_equal @@ -46,38 +47,38 @@ class LinalgTestCase(NumpyTestCase): except linalg.LinAlgError, e: pass -class test_solve(LinalgTestCase): +class TestSolve(LinalgTestCase): def do(self, a, b): x = linalg.solve(a, b) assert_almost_equal(b, dot(a, x)) -class test_inv(LinalgTestCase): +class TestInv(LinalgTestCase): def do(self, a, b): a_inv = linalg.inv(a) assert_almost_equal(dot(a, a_inv), identity(a.shape[0])) -class test_eigvals(LinalgTestCase): +class TestEigvals(LinalgTestCase): def do(self, a, b): ev = linalg.eigvals(a) evalues, evectors = linalg.eig(a) assert_almost_equal(ev, evalues) -class test_eig(LinalgTestCase): +class TestEig(LinalgTestCase): def do(self, a, b): evalues, evectors = linalg.eig(a) assert_almost_equal(dot(a, evectors), evectors*evalues) -class test_svd(LinalgTestCase): +class TestSVD(LinalgTestCase): def do(self, a, b): u, s, vt = linalg.svd(a, 0) assert_almost_equal(a, dot(u*s, vt)) -class test_pinv(LinalgTestCase): +class TestPinv(LinalgTestCase): def do(self, a, b): a_ginv = linalg.pinv(a) assert_almost_equal(dot(a, a_ginv), identity(a.shape[0])) -class test_det(LinalgTestCase): +class TestDet(LinalgTestCase): def do(self, a, b): d = linalg.det(a) if a.dtype.type in (single, double): @@ -87,7 +88,7 @@ class test_det(LinalgTestCase): ev = linalg.eigvals(ad) assert_almost_equal(d, multiply.reduce(ev)) -class test_lstsq(LinalgTestCase): +class TestLstsq(LinalgTestCase): def do(self, a, b): u, s, vt = linalg.svd(a, 0) x, residuals, rank, sv = linalg.lstsq(a, b) @@ -95,5 +96,58 @@ class test_lstsq(LinalgTestCase): assert_equal(rank, a.shape[0]) assert_almost_equal(sv, s) +class TestMatrixPower(ParametricTestCase): + R90 = array([[0,1],[-1,0]]) + Arb22 = array([[4,-7],[-2,10]]) + noninv = array([[1,0],[0,0]]) + arbfloat = array([[0.1,3.2],[1.2,0.7]]) + + large = identity(10) + t = large[1,:].copy() + large[1,:] = large[0,:] + large[0,:] = t + + def test_large_power(self): + assert_equal(matrix_power(self.R90,2L**100+2**10+2**5+1),self.R90) + def test_large_power_trailing_zero(self): + assert_equal(matrix_power(self.R90,2L**100+2**10+2**5),identity(2)) + + def testip_zero(self): + def tz(M): + mz = matrix_power(M,0) + assert_equal(mz, identity(M.shape[0])) + assert_equal(mz.dtype, M.dtype) + for M in [self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def testip_one(self): + def tz(M): + mz = matrix_power(M,1) + assert_equal(mz, M) + assert_equal(mz.dtype, M.dtype) + for M in [self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def testip_two(self): + def tz(M): + mz = matrix_power(M,2) + assert_equal(mz, dot(M,M)) + assert_equal(mz.dtype, M.dtype) + for M in [self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def testip_invert(self): + def tz(M): + mz = matrix_power(M,-1) + assert_almost_equal(identity(M.shape[0]), dot(mz,M)) + for M in [self.R90, self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def test_invert_noninvertible(self): + import numpy.linalg + self.assertRaises(numpy.linalg.linalg.LinAlgError, + lambda: matrix_power(self.noninv,-1)) + + if __name__ == '__main__': NumpyTest().run() |