diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2008-04-06 02:34:19 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2008-04-06 02:34:19 +0000 |
commit | f339b6c31419e77f576e8b2364e186db546135e7 (patch) | |
tree | 896f3a44cf9253f4f105a151c8cba73599b80f16 /numpy/linalg/tests | |
parent | c24510c81f54547dbc48f1c60b01d0109a967af1 (diff) | |
download | numpy-f339b6c31419e77f576e8b2364e186db546135e7.tar.gz |
Factor out matrix_multiply from defmatrix. Based on a patch by
Anne Archibald.
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 70 |
1 files changed, 62 insertions, 8 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 1b380fa12..0ecd8e234 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -6,6 +6,7 @@ set_package_path() from numpy import array, single, double, csingle, cdouble, dot, identity, \ multiply, atleast_2d from numpy import linalg +from linalg import matrix_power restore_path() old_assert_almost_equal = assert_almost_equal @@ -46,38 +47,38 @@ class LinalgTestCase(NumpyTestCase): except linalg.LinAlgError, e: pass -class test_solve(LinalgTestCase): +class TestSolve(LinalgTestCase): def do(self, a, b): x = linalg.solve(a, b) assert_almost_equal(b, dot(a, x)) -class test_inv(LinalgTestCase): +class TestInv(LinalgTestCase): def do(self, a, b): a_inv = linalg.inv(a) assert_almost_equal(dot(a, a_inv), identity(a.shape[0])) -class test_eigvals(LinalgTestCase): +class TestEigvals(LinalgTestCase): def do(self, a, b): ev = linalg.eigvals(a) evalues, evectors = linalg.eig(a) assert_almost_equal(ev, evalues) -class test_eig(LinalgTestCase): +class TestEig(LinalgTestCase): def do(self, a, b): evalues, evectors = linalg.eig(a) assert_almost_equal(dot(a, evectors), evectors*evalues) -class test_svd(LinalgTestCase): +class TestSVD(LinalgTestCase): def do(self, a, b): u, s, vt = linalg.svd(a, 0) assert_almost_equal(a, dot(u*s, vt)) -class test_pinv(LinalgTestCase): +class TestPinv(LinalgTestCase): def do(self, a, b): a_ginv = linalg.pinv(a) assert_almost_equal(dot(a, a_ginv), identity(a.shape[0])) -class test_det(LinalgTestCase): +class TestDet(LinalgTestCase): def do(self, a, b): d = linalg.det(a) if a.dtype.type in (single, double): @@ -87,7 +88,7 @@ class test_det(LinalgTestCase): ev = linalg.eigvals(ad) assert_almost_equal(d, multiply.reduce(ev)) -class test_lstsq(LinalgTestCase): +class TestLstsq(LinalgTestCase): def do(self, a, b): u, s, vt = linalg.svd(a, 0) x, residuals, rank, sv = linalg.lstsq(a, b) @@ -95,5 +96,58 @@ class test_lstsq(LinalgTestCase): assert_equal(rank, a.shape[0]) assert_almost_equal(sv, s) +class TestMatrixPower(ParametricTestCase): + R90 = array([[0,1],[-1,0]]) + Arb22 = array([[4,-7],[-2,10]]) + noninv = array([[1,0],[0,0]]) + arbfloat = array([[0.1,3.2],[1.2,0.7]]) + + large = identity(10) + t = large[1,:].copy() + large[1,:] = large[0,:] + large[0,:] = t + + def test_large_power(self): + assert_equal(matrix_power(self.R90,2L**100+2**10+2**5+1),self.R90) + def test_large_power_trailing_zero(self): + assert_equal(matrix_power(self.R90,2L**100+2**10+2**5),identity(2)) + + def testip_zero(self): + def tz(M): + mz = matrix_power(M,0) + assert_equal(mz, identity(M.shape[0])) + assert_equal(mz.dtype, M.dtype) + for M in [self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def testip_one(self): + def tz(M): + mz = matrix_power(M,1) + assert_equal(mz, M) + assert_equal(mz.dtype, M.dtype) + for M in [self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def testip_two(self): + def tz(M): + mz = matrix_power(M,2) + assert_equal(mz, dot(M,M)) + assert_equal(mz.dtype, M.dtype) + for M in [self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def testip_invert(self): + def tz(M): + mz = matrix_power(M,-1) + assert_almost_equal(identity(M.shape[0]), dot(mz,M)) + for M in [self.R90, self.Arb22, self.arbfloat, self.large]: + yield tz, M + + def test_invert_noninvertible(self): + import numpy.linalg + self.assertRaises(numpy.linalg.linalg.LinAlgError, + lambda: matrix_power(self.noninv,-1)) + + if __name__ == '__main__': NumpyTest().run() |