diff options
author | Pauli Virtanen <pav@iki.fi> | 2013-04-09 23:58:01 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2013-04-10 22:47:45 +0300 |
commit | 9bfa19b11f38b5fe710d872db6a8628fc6a72359 (patch) | |
tree | d28c4c7ebdca39fa028168aab70f6faf7f755a06 | |
parent | 63a8aef815fdb2526493311b89b4d15afbb4a38d (diff) | |
download | numpy-9bfa19b11f38b5fe710d872db6a8628fc6a72359.tar.gz |
TST: linalg: add tests for generalized linalg functions
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 123 |
1 files changed, 89 insertions, 34 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index d31da3220..0a6f8f4ca 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -85,7 +85,6 @@ class LinalgTestCase(object): b = matrix([2., 1.]).T self.do(a, b) - class LinalgNonsquareTestCase(object): def test_single_nsq_1(self): a = array([[1.,2.,3.], [3.,4.,6.]], dtype=single) @@ -138,43 +137,94 @@ class LinalgNonsquareTestCase(object): self.do(a, b) -class TestSolve(LinalgTestCase, TestCase): +def _generalized_testcase(new_cls_name, old_cls): + def get_method(old_name, new_name): + def method(self): + base = old_cls() + def do(a, b): + a = np.array([a, a, a]) + b = np.array([b, b, b]) + self.do(a, b) + base.do = do + getattr(base, old_name)() + method.__name__ = new_name + return method + + dct = dict() + for old_name in dir(old_cls): + if old_name.startswith('test_'): + new_name = old_name + '_generalized' + dct[new_name] = get_method(old_name, new_name) + + return type(new_cls_name, (object,), dct) + +LinalgGeneralizedTestCase = _generalized_testcase( + 'LinalgGeneralizedTestCase', LinalgTestCase) +LinalgGeneralizedNonsquareTestCase = _generalized_testcase( + 'LinalgGeneralizedNonsquareTestCase', LinalgNonsquareTestCase) + + +def dot_generalized(a, b): + a = asarray(a) + if a.ndim == 3: + return np.array([dot(ax, bx) for ax, bx in zip(a, b)]) + elif a.ndim > 3: + raise ValueError("Not implemented...") + return dot(a, b) + +def identity_like_generalized(a): + a = asarray(a) + if a.ndim == 3: + return np.array([identity(a.shape[-2]) for ax in a]) + elif a.ndim > 3: + raise ValueError("Not implemented...") + return identity(a.shape[0]) + + +class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): x = linalg.solve(a, b) - assert_almost_equal(b, dot(a, x)) + assert_almost_equal(b, dot_generalized(a, x)) assert_(imply(isinstance(b, matrix), isinstance(x, matrix))) -class TestInv(LinalgTestCase, TestCase): +class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): a_inv = linalg.inv(a) - assert_almost_equal(dot(a, a_inv), identity(asarray(a).shape[0])) + assert_almost_equal(dot_generalized(a, a_inv), + identity_like_generalized(a)) assert_(imply(isinstance(a, matrix), isinstance(a_inv, matrix))) -class TestEigvals(LinalgTestCase, TestCase): +class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): ev = linalg.eigvals(a) evalues, evectors = linalg.eig(a) assert_almost_equal(ev, evalues) -class TestEig(LinalgTestCase, TestCase): +class TestEig(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): evalues, evectors = linalg.eig(a) - assert_almost_equal(dot(a, evectors), multiply(evectors, evalues)) + if evectors.ndim == 3: + assert_almost_equal(dot_generalized(a, evectors), evectors * evalues[:,None,:]) + else: + assert_almost_equal(dot(a, evectors), multiply(evectors, evalues)) assert_(imply(isinstance(a, matrix), isinstance(evectors, matrix))) -class TestSVD(LinalgTestCase, TestCase): +class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): u, s, vt = linalg.svd(a, 0) - assert_almost_equal(a, dot(multiply(u, s), vt)) + if u.ndim == 3: + assert_almost_equal(a, dot_generalized(u * s[:,None,:], vt)) + else: + assert_almost_equal(a, dot(multiply(u, s), vt)) assert_(imply(isinstance(a, matrix), isinstance(u, matrix))) assert_(imply(isinstance(a, matrix), isinstance(vt, matrix))) -class TestCondSVD(LinalgTestCase, TestCase): +class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): c = asarray(a) # a might be a matrix s = linalg.svd(c, compute_uv=False) @@ -201,7 +251,7 @@ class TestPinv(LinalgTestCase, TestCase): assert_(imply(isinstance(a, matrix), isinstance(a_ginv, matrix))) -class TestDet(LinalgTestCase, TestCase): +class TestDet(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): d = linalg.det(a) (s, ld) = linalg.slogdet(a) @@ -210,12 +260,14 @@ class TestDet(LinalgTestCase, TestCase): else: ad = asarray(a).astype(cdouble) ev = linalg.eigvals(ad) - assert_almost_equal(d, multiply.reduce(ev)) - assert_almost_equal(s * np.exp(ld), multiply.reduce(ev)) - if s != 0: - assert_almost_equal(np.abs(s), 1) - else: - assert_equal(ld, -inf) + assert_almost_equal(d, multiply.reduce(ev, axis=-1)) + assert_almost_equal(s * np.exp(ld), multiply.reduce(ev, axis=-1)) + + s = np.atleast_1d(s) + ld = np.atleast_1d(ld) + m = (s != 0) + assert_almost_equal(np.abs(s[m]), 1) + assert_equal(ld[~m], -inf) def test_zero(self): assert_equal(linalg.det([[0.0]]), 0.0) @@ -320,58 +372,61 @@ class TestBoolPower(TestCase): class HermitianTestCase(object): def test_single(self): a = array([[1.,2.], [2.,1.]], dtype=single) - self.do(a) + self.do(a, None) def test_double(self): a = array([[1.,2.], [2.,1.]], dtype=double) - self.do(a) + self.do(a, None) def test_csingle(self): a = array([[1.,2+3j], [2-3j,1]], dtype=csingle) - self.do(a) + self.do(a, None) def test_cdouble(self): a = array([[1.,2+3j], [2-3j,1]], dtype=cdouble) - self.do(a) + self.do(a, None) def test_empty(self): a = atleast_2d(array([], dtype = double)) - assert_raises(linalg.LinAlgError, self.do, a) + assert_raises(linalg.LinAlgError, self.do, a, None) def test_nonarray(self): a = [[1,2], [2,1]] - self.do(a) + self.do(a, None) def test_matrix_b_only(self): """Check that matrix type is preserved.""" a = array([[1.,2.], [2.,1.]]) - self.do(a) + self.do(a, None) def test_matrix_a_and_b(self): """Check that matrix type is preserved.""" a = matrix([[1.,2.], [2.,1.]]) - self.do(a) + self.do(a, None) + +HermitianGeneralizedTestCase = _generalized_testcase( + 'HermitianGeneralizedTestCase', HermitianTestCase) -class TestEigvalsh(HermitianTestCase, TestCase): - def do(self, a): +class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase): + def do(self, a, b): # note that eigenvalue arrays must be sorted since # their order isn't guaranteed. ev = linalg.eigvalsh(a) evalues, evectors = linalg.eig(a) - ev.sort() - evalues.sort() + ev.sort(axis=-1) + evalues.sort(axis=-1) assert_almost_equal(ev, evalues) -class TestEigh(HermitianTestCase, TestCase): - def do(self, a): +class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase): + def do(self, a, b): # note that eigenvalue arrays must be sorted since # their order isn't guaranteed. ev, evc = linalg.eigh(a) evalues, evectors = linalg.eig(a) - ev.sort() - evalues.sort() + ev.sort(axis=-1) + evalues.sort(axis=-1) assert_almost_equal(ev, evalues) |