diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2016-12-12 18:33:09 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2016-12-19 14:32:49 +0000 |
commit | c3442066006015dfa8be714686472434756bf83e (patch) | |
tree | 950ea3d66f3e6207503ecf14faab29a4ac18cb77 /numpy/linalg/tests | |
parent | 340779f53b56d89aa4044af3b1a382e3e1a15592 (diff) | |
download | numpy-c3442066006015dfa8be714686472434756bf83e.tar.gz |
TST: Refactor all the test case lists
Allows each individual function to inspect the flags of a certain test, and
decide whether an exception will be thrown
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 253 |
1 files changed, 176 insertions, 77 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index a8b4dcdf5..9b1724c9c 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -13,7 +13,7 @@ import numpy as np from numpy import array, single, double, csingle, cdouble, dot, identity from numpy import multiply, atleast_2d, inf, asarray, matrix from numpy import linalg -from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot +from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError from numpy.linalg.linalg import _multi_dot_matrix_chain_order from numpy.testing import ( assert_, assert_equal, assert_raises, assert_array_equal, @@ -61,22 +61,33 @@ def get_rtol(dtype): class LinalgCase(object): - def __init__(self, name, a, b, exception_cls=None): + def __init__(self, name, a, b, flags=frozenset()): assert_(isinstance(name, str)) self.name = name self.a = a self.b = b - self.exception_cls = exception_cls + self.flags = flags def check(self, do): - if self.exception_cls is None: - do(self.a, self.b) - else: - assert_raises(self.exception_cls, do, self.a, self.b) + do(self.a, self.b, flags=self.flags) def __repr__(self): return "<LinalgCase: %s>" % (self.name,) +def apply_flags(flags, cases): + for case in cases: + case.flags = case.flags | flags + return cases + +class CaseFlags: + """ A simple flags enum. Could be replaced with ints for speed""" + square = frozenset(['square']) + generalized = frozenset(['generalized']) + empty = frozenset(['empty']) + hermitian = frozenset(['hermitian']) # not a subset of square, because these have only one argument + + none = frozenset() + all = square | generalized | empty | hermitian # # Base test cases @@ -84,7 +95,10 @@ class LinalgCase(object): np.random.seed(1234) -SQUARE_CASES = [ +CASES = [] + +# square test cases +CASES += apply_flags(CaseFlags.square, [ LinalgCase("single", array([[1., 2.], [3., 4.]], dtype=single), array([2., 1.], dtype=single)), @@ -106,7 +120,7 @@ SQUARE_CASES = [ LinalgCase("0x0", np.empty((0, 0), dtype=double), np.empty((0, 0), dtype=double), - linalg.LinAlgError), + flags=CaseFlags.empty), LinalgCase("8x8", np.random.rand(8, 8), np.random.rand(8)), @@ -122,9 +136,10 @@ SQUARE_CASES = [ LinalgCase("matrix_a_and_b", matrix([[1., 2.], [3., 4.]]), matrix([2., 1.]).T), -] +]) -NONSQUARE_CASES = [ +# non-square test-cases +CASES += apply_flags(CaseFlags.none, [ LinalgCase("single_nsq_1", array([[1., 2., 3.], [3., 4., 6.]], dtype=single), array([2., 1.], dtype=single)), @@ -172,13 +187,16 @@ NONSQUARE_CASES = [ np.random.rand(1)), LinalgCase("0x4", np.random.rand(0, 4), - np.random.rand(4)), + np.random.rand(4), + flags=CaseFlags.empty), LinalgCase("4x0", np.random.rand(4, 0), - np.random.rand(0)), -] + np.random.rand(0), + flags=CaseFlags.empty), +]) -HERMITIAN_CASES = [ +# hermitian test-cases +CASES += apply_flags(CaseFlags.hermitian, [ LinalgCase("hsingle", array([[1., 2.], [2., 1.]], dtype=single), None), @@ -192,9 +210,9 @@ HERMITIAN_CASES = [ array([[1., 2 + 3j], [2 - 3j, 1]], dtype=cdouble), None), LinalgCase("hempty", - atleast_2d(array([], dtype=double)), + np.empty((0, 0), dtype=double), None, - linalg.LinAlgError), + flags=CaseFlags.empty), LinalgCase("hnonarray", [[1, 2], [2, 1]], None), @@ -207,21 +225,16 @@ HERMITIAN_CASES = [ LinalgCase("hmatrix_1x1", np.random.rand(1, 1), None), -] +]) # # Gufunc test cases # +def _make_generalized_cases(): + new_cases = [] -GENERALIZED_SQUARE_CASES = [] -GENERALIZED_NONSQUARE_CASES = [] -GENERALIZED_HERMITIAN_CASES = [] - -for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES), - (GENERALIZED_NONSQUARE_CASES, NONSQUARE_CASES), - (GENERALIZED_HERMITIAN_CASES, HERMITIAN_CASES)): - for case in src: + for case in CASES: if not isinstance(case.a, np.ndarray): continue @@ -231,8 +244,8 @@ for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES), else: b = np.array([case.b, 7 * case.b, 6 * case.b]) new_case = LinalgCase(case.name + "_tile3", a, b, - case.exception_cls) - tgt.append(new_case) + flags=case.flags | {'generalized'}) + new_cases.append(new_case) a = np.array([case.a] * 2 * 3).reshape((3, 2) + case.a.shape) if case.b is None: @@ -240,14 +253,17 @@ for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES), else: b = np.array([case.b] * 2 * 3).reshape((3, 2) + case.b.shape) new_case = LinalgCase(case.name + "_tile213", a, b, - case.exception_cls) - tgt.append(new_case) + flags=case.flags | {'generalized'}) + new_cases.append(new_case) + + return new_cases + +CASES += _make_generalized_cases() # # Generate stride combination variations of the above # - def _stride_comb_iter(x): """ Generate cartesian product of strides for all axes @@ -295,29 +311,35 @@ def _stride_comb_iter(x): xi = np.lib.stride_tricks.as_strided(x, strides=s) yield xi, "stride_xxx_0_0" -for src in (SQUARE_CASES, - NONSQUARE_CASES, - HERMITIAN_CASES, - GENERALIZED_SQUARE_CASES, - GENERALIZED_NONSQUARE_CASES, - GENERALIZED_HERMITIAN_CASES): - +def _make_strided_cases(): new_cases = [] - for case in src: + for case in CASES: for a, a_tag in _stride_comb_iter(case.a): for b, b_tag in _stride_comb_iter(case.b): new_case = LinalgCase(case.name + "_" + a_tag + "_" + b_tag, a, b, - exception_cls=case.exception_cls) + flags=case.flags | {'strided'}) new_cases.append(new_case) - src.extend(new_cases) + return new_cases + +CASES += _make_strided_cases() # # Test different routines against the above cases # -def _check_cases(func, cases): - for case in cases: +def _check_cases(func, require=CaseFlags.none, exclude=CaseFlags.none): + """ + Run func on each of the cases with all of the flags in require, and none + of the flags in exclude + """ + for case in CASES: + # filter by require and exclude + if case.flags & require != require: + continue + if case.flags & exclude: + continue + try: case.check(func) except Exception: @@ -326,43 +348,87 @@ def _check_cases(func, cases): raise AssertionError(msg) -class LinalgTestCase(object): +class LinalgSquareTestCase(object): def test_sq_cases(self): - _check_cases(self.do, SQUARE_CASES) + _check_cases(self.do, + require=CaseFlags.square, + exclude=CaseFlags.generalized | CaseFlags.empty) + + def test_empty_sq_cases(self): + _check_cases(self.do, + require=CaseFlags.square | CaseFlags.empty, + exclude=CaseFlags.generalized) class LinalgNonsquareTestCase(object): def test_nonsq_cases(self): - _check_cases(self.do, NONSQUARE_CASES) + _check_cases(self.do, + require=CaseFlags.none, + exclude=CaseFlags.generalized | CaseFlags.square | CaseFlags.hermitian | CaseFlags.empty) + + def test_empty_nonsq_cases(self): + _check_cases(self.do, + require=CaseFlags.empty, + exclude=CaseFlags.generalized | CaseFlags.square | CaseFlags.hermitian) + +class HermitianTestCase(object): + + def test_herm_cases(self): + _check_cases(self.do, + require=CaseFlags.hermitian, + exclude=CaseFlags.generalized | CaseFlags.empty) + + def test_empty_herm_cases(self): + _check_cases(self.do, + require=CaseFlags.hermitian | CaseFlags.empty, + exclude=CaseFlags.generalized) -class LinalgGeneralizedTestCase(object): +class LinalgGeneralizedSquareTestCase(object): @dec.slow def test_generalized_sq_cases(self): - _check_cases(self.do, GENERALIZED_SQUARE_CASES) + _check_cases(self.do, + require=CaseFlags.generalized | CaseFlags.square, + exclude=CaseFlags.hermitian | CaseFlags.empty) + + @dec.slow + def test_generalized_empty_sq_cases(self): + _check_cases(self.do, + require=CaseFlags.generalized | CaseFlags.square | CaseFlags.empty, + exclude=CaseFlags.hermitian) class LinalgGeneralizedNonsquareTestCase(object): @dec.slow def test_generalized_nonsq_cases(self): - _check_cases(self.do, GENERALIZED_NONSQUARE_CASES) - + _check_cases(self.do, + require=CaseFlags.generalized, + exclude=CaseFlags.square | CaseFlags.empty) -class HermitianTestCase(object): - - def test_herm_cases(self): - _check_cases(self.do, HERMITIAN_CASES) + @dec.slow + def test_generalized_empty_nonsq_cases(self): + _check_cases(self.do, + require=CaseFlags.generalized | CaseFlags.empty, + exclude=CaseFlags.square) class HermitianGeneralizedTestCase(object): @dec.slow def test_generalized_herm_cases(self): - _check_cases(self.do, GENERALIZED_HERMITIAN_CASES) + _check_cases(self.do, + require=CaseFlags.generalized | CaseFlags.hermitian, + exclude=CaseFlags.empty) + + @dec.slow + def test_generalized_empty_herm_cases(self): + _check_cases(self.do, + require=CaseFlags.generalized | CaseFlags.hermitian | CaseFlags.empty, + exclude=CaseFlags.none) def dot_generalized(a, b): @@ -395,9 +461,9 @@ def identity_like_generalized(a): return identity(a.shape[0]) -class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase): +class TestSolve(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): - def do(self, a, b): + def do(self, a, b, flags): x = linalg.solve(a, b) assert_almost_equal(b, dot_generalized(a, x)) assert_(imply(isinstance(b, matrix), isinstance(x, matrix))) @@ -461,9 +527,9 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase): assert_(isinstance(result, ArraySubclass)) -class TestInv(LinalgTestCase, LinalgGeneralizedTestCase): +class TestInv(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): - def do(self, a, b): + def do(self, a, b, flags): a_inv = linalg.inv(a) assert_almost_equal(dot_generalized(a, a_inv), identity_like_generalized(a)) @@ -492,9 +558,12 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase): assert_equal(a.shape, res.shape) -class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase): +class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): - def do(self, a, b): + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.eigvals, a) + return ev = linalg.eigvals(a) evalues, evectors = linalg.eig(a) assert_almost_equal(ev, evalues) @@ -509,9 +578,13 @@ class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase): yield check, dtype -class TestEig(LinalgTestCase, LinalgGeneralizedTestCase): +class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): + + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.eig, a) + return - def do(self, a, b): evalues, evectors = linalg.eig(a) assert_allclose(dot_generalized(a, evectors), np.asarray(evectors) * np.asarray(evalues)[..., None, :], @@ -534,9 +607,13 @@ class TestEig(LinalgTestCase, LinalgGeneralizedTestCase): yield check, dtype -class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase): +class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): + + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.svd, a, 0) + return - def do(self, a, b): u, s, vt = linalg.svd(a, 0) assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :], np.asarray(vt)), @@ -558,10 +635,13 @@ class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase): yield check, dtype -class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase): +class TestCondSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): - def do(self, a, b): + def do(self, a, b, flags): c = asarray(a) # a might be a matrix + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.svd, c, compute_uv=False) + return s = linalg.svd(c, compute_uv=False) assert_almost_equal( s[..., 0] / s[..., -1], linalg.cond(a), @@ -572,10 +652,13 @@ class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase): assert_equal(linalg.cond(A), linalg.cond(A[None, ...])[0]) -class TestCond2(LinalgTestCase): +class TestCond2(LinalgSquareTestCase): - def do(self, a, b): + def do(self, a, b, flags): c = asarray(a) # a might be a matrix + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.svd, c, compute_uv=False) + return s = linalg.svd(c, compute_uv=False) assert_almost_equal( s[..., 0] / s[..., -1], linalg.cond(a, 2), @@ -593,18 +676,24 @@ class TestCondInf(object): assert_almost_equal(linalg.cond(A, inf), 3.) -class TestPinv(LinalgTestCase, LinalgNonsquareTestCase): +class TestPinv(LinalgSquareTestCase, LinalgNonsquareTestCase): - def do(self, a, b): + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.pinv, a) + return a_ginv = linalg.pinv(a) # `a @ a_ginv == I` does not hold if a is singular assert_almost_equal(dot(a, a_ginv).dot(a), a, single_decimal=5, double_decimal=11) assert_(imply(isinstance(a, matrix), isinstance(a_ginv, matrix))) -class TestDet(LinalgTestCase, LinalgGeneralizedTestCase): +class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): - def do(self, a, b): + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.det, a) + return d = linalg.det(a) (s, ld) = linalg.slogdet(a) if asarray(a).dtype.type in (single, double): @@ -645,9 +734,13 @@ class TestDet(LinalgTestCase, LinalgGeneralizedTestCase): yield check, dtype -class TestLstsq(LinalgTestCase, LinalgNonsquareTestCase): +class TestLstsq(LinalgSquareTestCase, LinalgNonsquareTestCase): + + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.lstsq, a, b) + return - def do(self, a, b): arr = np.asarray(a) m, n = arr.shape u, s, vt = linalg.svd(a, 0) @@ -738,7 +831,10 @@ class TestBoolPower(object): class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase): - def do(self, a, b): + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.eigvalsh, a, 'L') + return # note that eigenvalue arrays returned by eig must be sorted since # their order isn't guaranteed. ev = linalg.eigvalsh(a, 'L') @@ -788,7 +884,10 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase): class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): - def do(self, a, b): + def do(self, a, b, flags): + if flags & CaseFlags.empty: + assert_raises(LinAlgError, linalg.eigh, a) + return # note that eigenvalue arrays returned by eig must be sorted since # their order isn't guaranteed. ev, evc = linalg.eigh(a) |