summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2013-04-09 23:58:01 +0300
committerPauli Virtanen <pav@iki.fi>2013-04-10 22:47:45 +0300
commit9bfa19b11f38b5fe710d872db6a8628fc6a72359 (patch)
treed28c4c7ebdca39fa028168aab70f6faf7f755a06
parent63a8aef815fdb2526493311b89b4d15afbb4a38d (diff)
downloadnumpy-9bfa19b11f38b5fe710d872db6a8628fc6a72359.tar.gz
TST: linalg: add tests for generalized linalg functions
-rw-r--r--numpy/linalg/tests/test_linalg.py123
1 files changed, 89 insertions, 34 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index d31da3220..0a6f8f4ca 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -85,7 +85,6 @@ class LinalgTestCase(object):
b = matrix([2., 1.]).T
self.do(a, b)
-
class LinalgNonsquareTestCase(object):
def test_single_nsq_1(self):
a = array([[1.,2.,3.], [3.,4.,6.]], dtype=single)
@@ -138,43 +137,94 @@ class LinalgNonsquareTestCase(object):
self.do(a, b)
-class TestSolve(LinalgTestCase, TestCase):
+def _generalized_testcase(new_cls_name, old_cls):
+ def get_method(old_name, new_name):
+ def method(self):
+ base = old_cls()
+ def do(a, b):
+ a = np.array([a, a, a])
+ b = np.array([b, b, b])
+ self.do(a, b)
+ base.do = do
+ getattr(base, old_name)()
+ method.__name__ = new_name
+ return method
+
+ dct = dict()
+ for old_name in dir(old_cls):
+ if old_name.startswith('test_'):
+ new_name = old_name + '_generalized'
+ dct[new_name] = get_method(old_name, new_name)
+
+ return type(new_cls_name, (object,), dct)
+
+LinalgGeneralizedTestCase = _generalized_testcase(
+ 'LinalgGeneralizedTestCase', LinalgTestCase)
+LinalgGeneralizedNonsquareTestCase = _generalized_testcase(
+ 'LinalgGeneralizedNonsquareTestCase', LinalgNonsquareTestCase)
+
+
+def dot_generalized(a, b):
+ a = asarray(a)
+ if a.ndim == 3:
+ return np.array([dot(ax, bx) for ax, bx in zip(a, b)])
+ elif a.ndim > 3:
+ raise ValueError("Not implemented...")
+ return dot(a, b)
+
+def identity_like_generalized(a):
+ a = asarray(a)
+ if a.ndim == 3:
+ return np.array([identity(a.shape[-2]) for ax in a])
+ elif a.ndim > 3:
+ raise ValueError("Not implemented...")
+ return identity(a.shape[0])
+
+
+class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
x = linalg.solve(a, b)
- assert_almost_equal(b, dot(a, x))
+ assert_almost_equal(b, dot_generalized(a, x))
assert_(imply(isinstance(b, matrix), isinstance(x, matrix)))
-class TestInv(LinalgTestCase, TestCase):
+class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
a_inv = linalg.inv(a)
- assert_almost_equal(dot(a, a_inv), identity(asarray(a).shape[0]))
+ assert_almost_equal(dot_generalized(a, a_inv),
+ identity_like_generalized(a))
assert_(imply(isinstance(a, matrix), isinstance(a_inv, matrix)))
-class TestEigvals(LinalgTestCase, TestCase):
+class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
ev = linalg.eigvals(a)
evalues, evectors = linalg.eig(a)
assert_almost_equal(ev, evalues)
-class TestEig(LinalgTestCase, TestCase):
+class TestEig(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
evalues, evectors = linalg.eig(a)
- assert_almost_equal(dot(a, evectors), multiply(evectors, evalues))
+ if evectors.ndim == 3:
+ assert_almost_equal(dot_generalized(a, evectors), evectors * evalues[:,None,:])
+ else:
+ assert_almost_equal(dot(a, evectors), multiply(evectors, evalues))
assert_(imply(isinstance(a, matrix), isinstance(evectors, matrix)))
-class TestSVD(LinalgTestCase, TestCase):
+class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
u, s, vt = linalg.svd(a, 0)
- assert_almost_equal(a, dot(multiply(u, s), vt))
+ if u.ndim == 3:
+ assert_almost_equal(a, dot_generalized(u * s[:,None,:], vt))
+ else:
+ assert_almost_equal(a, dot(multiply(u, s), vt))
assert_(imply(isinstance(a, matrix), isinstance(u, matrix)))
assert_(imply(isinstance(a, matrix), isinstance(vt, matrix)))
-class TestCondSVD(LinalgTestCase, TestCase):
+class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
c = asarray(a) # a might be a matrix
s = linalg.svd(c, compute_uv=False)
@@ -201,7 +251,7 @@ class TestPinv(LinalgTestCase, TestCase):
assert_(imply(isinstance(a, matrix), isinstance(a_ginv, matrix)))
-class TestDet(LinalgTestCase, TestCase):
+class TestDet(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
d = linalg.det(a)
(s, ld) = linalg.slogdet(a)
@@ -210,12 +260,14 @@ class TestDet(LinalgTestCase, TestCase):
else:
ad = asarray(a).astype(cdouble)
ev = linalg.eigvals(ad)
- assert_almost_equal(d, multiply.reduce(ev))
- assert_almost_equal(s * np.exp(ld), multiply.reduce(ev))
- if s != 0:
- assert_almost_equal(np.abs(s), 1)
- else:
- assert_equal(ld, -inf)
+ assert_almost_equal(d, multiply.reduce(ev, axis=-1))
+ assert_almost_equal(s * np.exp(ld), multiply.reduce(ev, axis=-1))
+
+ s = np.atleast_1d(s)
+ ld = np.atleast_1d(ld)
+ m = (s != 0)
+ assert_almost_equal(np.abs(s[m]), 1)
+ assert_equal(ld[~m], -inf)
def test_zero(self):
assert_equal(linalg.det([[0.0]]), 0.0)
@@ -320,58 +372,61 @@ class TestBoolPower(TestCase):
class HermitianTestCase(object):
def test_single(self):
a = array([[1.,2.], [2.,1.]], dtype=single)
- self.do(a)
+ self.do(a, None)
def test_double(self):
a = array([[1.,2.], [2.,1.]], dtype=double)
- self.do(a)
+ self.do(a, None)
def test_csingle(self):
a = array([[1.,2+3j], [2-3j,1]], dtype=csingle)
- self.do(a)
+ self.do(a, None)
def test_cdouble(self):
a = array([[1.,2+3j], [2-3j,1]], dtype=cdouble)
- self.do(a)
+ self.do(a, None)
def test_empty(self):
a = atleast_2d(array([], dtype = double))
- assert_raises(linalg.LinAlgError, self.do, a)
+ assert_raises(linalg.LinAlgError, self.do, a, None)
def test_nonarray(self):
a = [[1,2], [2,1]]
- self.do(a)
+ self.do(a, None)
def test_matrix_b_only(self):
"""Check that matrix type is preserved."""
a = array([[1.,2.], [2.,1.]])
- self.do(a)
+ self.do(a, None)
def test_matrix_a_and_b(self):
"""Check that matrix type is preserved."""
a = matrix([[1.,2.], [2.,1.]])
- self.do(a)
+ self.do(a, None)
+
+HermitianGeneralizedTestCase = _generalized_testcase(
+ 'HermitianGeneralizedTestCase', HermitianTestCase)
-class TestEigvalsh(HermitianTestCase, TestCase):
- def do(self, a):
+class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase):
+ def do(self, a, b):
# note that eigenvalue arrays must be sorted since
# their order isn't guaranteed.
ev = linalg.eigvalsh(a)
evalues, evectors = linalg.eig(a)
- ev.sort()
- evalues.sort()
+ ev.sort(axis=-1)
+ evalues.sort(axis=-1)
assert_almost_equal(ev, evalues)
-class TestEigh(HermitianTestCase, TestCase):
- def do(self, a):
+class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase):
+ def do(self, a, b):
# note that eigenvalue arrays must be sorted since
# their order isn't guaranteed.
ev, evc = linalg.eigh(a)
evalues, evectors = linalg.eig(a)
- ev.sort()
- evalues.sort()
+ ev.sort(axis=-1)
+ evalues.sort(axis=-1)
assert_almost_equal(ev, evalues)