summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2016-12-12 18:33:09 +0000
committerEric Wieser <wieser.eric@gmail.com>2016-12-19 14:32:49 +0000
commitc3442066006015dfa8be714686472434756bf83e (patch)
tree950ea3d66f3e6207503ecf14faab29a4ac18cb77 /numpy/linalg/tests
parent340779f53b56d89aa4044af3b1a382e3e1a15592 (diff)
downloadnumpy-c3442066006015dfa8be714686472434756bf83e.tar.gz
TST: Refactor all the test case lists
Allows each individual function to inspect the flags of a certain test, and decide whether an exception will be thrown
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r--numpy/linalg/tests/test_linalg.py253
1 files changed, 176 insertions, 77 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index a8b4dcdf5..9b1724c9c 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -13,7 +13,7 @@ import numpy as np
from numpy import array, single, double, csingle, cdouble, dot, identity
from numpy import multiply, atleast_2d, inf, asarray, matrix
from numpy import linalg
-from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot
+from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError
from numpy.linalg.linalg import _multi_dot_matrix_chain_order
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_array_equal,
@@ -61,22 +61,33 @@ def get_rtol(dtype):
class LinalgCase(object):
- def __init__(self, name, a, b, exception_cls=None):
+ def __init__(self, name, a, b, flags=frozenset()):
assert_(isinstance(name, str))
self.name = name
self.a = a
self.b = b
- self.exception_cls = exception_cls
+ self.flags = flags
def check(self, do):
- if self.exception_cls is None:
- do(self.a, self.b)
- else:
- assert_raises(self.exception_cls, do, self.a, self.b)
+ do(self.a, self.b, flags=self.flags)
def __repr__(self):
return "<LinalgCase: %s>" % (self.name,)
+def apply_flags(flags, cases):
+ for case in cases:
+ case.flags = case.flags | flags
+ return cases
+
+class CaseFlags:
+ """ A simple flags enum. Could be replaced with ints for speed"""
+ square = frozenset(['square'])
+ generalized = frozenset(['generalized'])
+ empty = frozenset(['empty'])
+ hermitian = frozenset(['hermitian']) # not a subset of square, because these have only one argument
+
+ none = frozenset()
+ all = square | generalized | empty | hermitian
#
# Base test cases
@@ -84,7 +95,10 @@ class LinalgCase(object):
np.random.seed(1234)
-SQUARE_CASES = [
+CASES = []
+
+# square test cases
+CASES += apply_flags(CaseFlags.square, [
LinalgCase("single",
array([[1., 2.], [3., 4.]], dtype=single),
array([2., 1.], dtype=single)),
@@ -106,7 +120,7 @@ SQUARE_CASES = [
LinalgCase("0x0",
np.empty((0, 0), dtype=double),
np.empty((0, 0), dtype=double),
- linalg.LinAlgError),
+ flags=CaseFlags.empty),
LinalgCase("8x8",
np.random.rand(8, 8),
np.random.rand(8)),
@@ -122,9 +136,10 @@ SQUARE_CASES = [
LinalgCase("matrix_a_and_b",
matrix([[1., 2.], [3., 4.]]),
matrix([2., 1.]).T),
-]
+])
-NONSQUARE_CASES = [
+# non-square test-cases
+CASES += apply_flags(CaseFlags.none, [
LinalgCase("single_nsq_1",
array([[1., 2., 3.], [3., 4., 6.]], dtype=single),
array([2., 1.], dtype=single)),
@@ -172,13 +187,16 @@ NONSQUARE_CASES = [
np.random.rand(1)),
LinalgCase("0x4",
np.random.rand(0, 4),
- np.random.rand(4)),
+ np.random.rand(4),
+ flags=CaseFlags.empty),
LinalgCase("4x0",
np.random.rand(4, 0),
- np.random.rand(0)),
-]
+ np.random.rand(0),
+ flags=CaseFlags.empty),
+])
-HERMITIAN_CASES = [
+# hermitian test-cases
+CASES += apply_flags(CaseFlags.hermitian, [
LinalgCase("hsingle",
array([[1., 2.], [2., 1.]], dtype=single),
None),
@@ -192,9 +210,9 @@ HERMITIAN_CASES = [
array([[1., 2 + 3j], [2 - 3j, 1]], dtype=cdouble),
None),
LinalgCase("hempty",
- atleast_2d(array([], dtype=double)),
+ np.empty((0, 0), dtype=double),
None,
- linalg.LinAlgError),
+ flags=CaseFlags.empty),
LinalgCase("hnonarray",
[[1, 2], [2, 1]],
None),
@@ -207,21 +225,16 @@ HERMITIAN_CASES = [
LinalgCase("hmatrix_1x1",
np.random.rand(1, 1),
None),
-]
+])
#
# Gufunc test cases
#
+def _make_generalized_cases():
+ new_cases = []
-GENERALIZED_SQUARE_CASES = []
-GENERALIZED_NONSQUARE_CASES = []
-GENERALIZED_HERMITIAN_CASES = []
-
-for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES),
- (GENERALIZED_NONSQUARE_CASES, NONSQUARE_CASES),
- (GENERALIZED_HERMITIAN_CASES, HERMITIAN_CASES)):
- for case in src:
+ for case in CASES:
if not isinstance(case.a, np.ndarray):
continue
@@ -231,8 +244,8 @@ for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES),
else:
b = np.array([case.b, 7 * case.b, 6 * case.b])
new_case = LinalgCase(case.name + "_tile3", a, b,
- case.exception_cls)
- tgt.append(new_case)
+ flags=case.flags | {'generalized'})
+ new_cases.append(new_case)
a = np.array([case.a] * 2 * 3).reshape((3, 2) + case.a.shape)
if case.b is None:
@@ -240,14 +253,17 @@ for tgt, src in ((GENERALIZED_SQUARE_CASES, SQUARE_CASES),
else:
b = np.array([case.b] * 2 * 3).reshape((3, 2) + case.b.shape)
new_case = LinalgCase(case.name + "_tile213", a, b,
- case.exception_cls)
- tgt.append(new_case)
+ flags=case.flags | {'generalized'})
+ new_cases.append(new_case)
+
+ return new_cases
+
+CASES += _make_generalized_cases()
#
# Generate stride combination variations of the above
#
-
def _stride_comb_iter(x):
"""
Generate cartesian product of strides for all axes
@@ -295,29 +311,35 @@ def _stride_comb_iter(x):
xi = np.lib.stride_tricks.as_strided(x, strides=s)
yield xi, "stride_xxx_0_0"
-for src in (SQUARE_CASES,
- NONSQUARE_CASES,
- HERMITIAN_CASES,
- GENERALIZED_SQUARE_CASES,
- GENERALIZED_NONSQUARE_CASES,
- GENERALIZED_HERMITIAN_CASES):
-
+def _make_strided_cases():
new_cases = []
- for case in src:
+ for case in CASES:
for a, a_tag in _stride_comb_iter(case.a):
for b, b_tag in _stride_comb_iter(case.b):
new_case = LinalgCase(case.name + "_" + a_tag + "_" + b_tag, a, b,
- exception_cls=case.exception_cls)
+ flags=case.flags | {'strided'})
new_cases.append(new_case)
- src.extend(new_cases)
+ return new_cases
+
+CASES += _make_strided_cases()
#
# Test different routines against the above cases
#
-def _check_cases(func, cases):
- for case in cases:
+def _check_cases(func, require=CaseFlags.none, exclude=CaseFlags.none):
+ """
+ Run func on each of the cases with all of the flags in require, and none
+ of the flags in exclude
+ """
+ for case in CASES:
+ # filter by require and exclude
+ if case.flags & require != require:
+ continue
+ if case.flags & exclude:
+ continue
+
try:
case.check(func)
except Exception:
@@ -326,43 +348,87 @@ def _check_cases(func, cases):
raise AssertionError(msg)
-class LinalgTestCase(object):
+class LinalgSquareTestCase(object):
def test_sq_cases(self):
- _check_cases(self.do, SQUARE_CASES)
+ _check_cases(self.do,
+ require=CaseFlags.square,
+ exclude=CaseFlags.generalized | CaseFlags.empty)
+
+ def test_empty_sq_cases(self):
+ _check_cases(self.do,
+ require=CaseFlags.square | CaseFlags.empty,
+ exclude=CaseFlags.generalized)
class LinalgNonsquareTestCase(object):
def test_nonsq_cases(self):
- _check_cases(self.do, NONSQUARE_CASES)
+ _check_cases(self.do,
+ require=CaseFlags.none,
+ exclude=CaseFlags.generalized | CaseFlags.square | CaseFlags.hermitian | CaseFlags.empty)
+
+ def test_empty_nonsq_cases(self):
+ _check_cases(self.do,
+ require=CaseFlags.empty,
+ exclude=CaseFlags.generalized | CaseFlags.square | CaseFlags.hermitian)
+
+class HermitianTestCase(object):
+
+ def test_herm_cases(self):
+ _check_cases(self.do,
+ require=CaseFlags.hermitian,
+ exclude=CaseFlags.generalized | CaseFlags.empty)
+
+ def test_empty_herm_cases(self):
+ _check_cases(self.do,
+ require=CaseFlags.hermitian | CaseFlags.empty,
+ exclude=CaseFlags.generalized)
-class LinalgGeneralizedTestCase(object):
+class LinalgGeneralizedSquareTestCase(object):
@dec.slow
def test_generalized_sq_cases(self):
- _check_cases(self.do, GENERALIZED_SQUARE_CASES)
+ _check_cases(self.do,
+ require=CaseFlags.generalized | CaseFlags.square,
+ exclude=CaseFlags.hermitian | CaseFlags.empty)
+
+ @dec.slow
+ def test_generalized_empty_sq_cases(self):
+ _check_cases(self.do,
+ require=CaseFlags.generalized | CaseFlags.square | CaseFlags.empty,
+ exclude=CaseFlags.hermitian)
class LinalgGeneralizedNonsquareTestCase(object):
@dec.slow
def test_generalized_nonsq_cases(self):
- _check_cases(self.do, GENERALIZED_NONSQUARE_CASES)
-
+ _check_cases(self.do,
+ require=CaseFlags.generalized,
+ exclude=CaseFlags.square | CaseFlags.empty)
-class HermitianTestCase(object):
-
- def test_herm_cases(self):
- _check_cases(self.do, HERMITIAN_CASES)
+ @dec.slow
+ def test_generalized_empty_nonsq_cases(self):
+ _check_cases(self.do,
+ require=CaseFlags.generalized | CaseFlags.empty,
+ exclude=CaseFlags.square)
class HermitianGeneralizedTestCase(object):
@dec.slow
def test_generalized_herm_cases(self):
- _check_cases(self.do, GENERALIZED_HERMITIAN_CASES)
+ _check_cases(self.do,
+ require=CaseFlags.generalized | CaseFlags.hermitian,
+ exclude=CaseFlags.empty)
+
+ @dec.slow
+ def test_generalized_empty_herm_cases(self):
+ _check_cases(self.do,
+ require=CaseFlags.generalized | CaseFlags.hermitian | CaseFlags.empty,
+ exclude=CaseFlags.none)
def dot_generalized(a, b):
@@ -395,9 +461,9 @@ def identity_like_generalized(a):
return identity(a.shape[0])
-class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase):
+class TestSolve(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
x = linalg.solve(a, b)
assert_almost_equal(b, dot_generalized(a, x))
assert_(imply(isinstance(b, matrix), isinstance(x, matrix)))
@@ -461,9 +527,9 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase):
assert_(isinstance(result, ArraySubclass))
-class TestInv(LinalgTestCase, LinalgGeneralizedTestCase):
+class TestInv(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
a_inv = linalg.inv(a)
assert_almost_equal(dot_generalized(a, a_inv),
identity_like_generalized(a))
@@ -492,9 +558,12 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase):
assert_equal(a.shape, res.shape)
-class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase):
+class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.eigvals, a)
+ return
ev = linalg.eigvals(a)
evalues, evectors = linalg.eig(a)
assert_almost_equal(ev, evalues)
@@ -509,9 +578,13 @@ class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase):
yield check, dtype
-class TestEig(LinalgTestCase, LinalgGeneralizedTestCase):
+class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
+
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.eig, a)
+ return
- def do(self, a, b):
evalues, evectors = linalg.eig(a)
assert_allclose(dot_generalized(a, evectors),
np.asarray(evectors) * np.asarray(evalues)[..., None, :],
@@ -534,9 +607,13 @@ class TestEig(LinalgTestCase, LinalgGeneralizedTestCase):
yield check, dtype
-class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase):
+class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
+
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.svd, a, 0)
+ return
- def do(self, a, b):
u, s, vt = linalg.svd(a, 0)
assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
np.asarray(vt)),
@@ -558,10 +635,13 @@ class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase):
yield check, dtype
-class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase):
+class TestCondSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
c = asarray(a) # a might be a matrix
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.svd, c, compute_uv=False)
+ return
s = linalg.svd(c, compute_uv=False)
assert_almost_equal(
s[..., 0] / s[..., -1], linalg.cond(a),
@@ -572,10 +652,13 @@ class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase):
assert_equal(linalg.cond(A), linalg.cond(A[None, ...])[0])
-class TestCond2(LinalgTestCase):
+class TestCond2(LinalgSquareTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
c = asarray(a) # a might be a matrix
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.svd, c, compute_uv=False)
+ return
s = linalg.svd(c, compute_uv=False)
assert_almost_equal(
s[..., 0] / s[..., -1], linalg.cond(a, 2),
@@ -593,18 +676,24 @@ class TestCondInf(object):
assert_almost_equal(linalg.cond(A, inf), 3.)
-class TestPinv(LinalgTestCase, LinalgNonsquareTestCase):
+class TestPinv(LinalgSquareTestCase, LinalgNonsquareTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.pinv, a)
+ return
a_ginv = linalg.pinv(a)
# `a @ a_ginv == I` does not hold if a is singular
assert_almost_equal(dot(a, a_ginv).dot(a), a, single_decimal=5, double_decimal=11)
assert_(imply(isinstance(a, matrix), isinstance(a_ginv, matrix)))
-class TestDet(LinalgTestCase, LinalgGeneralizedTestCase):
+class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.det, a)
+ return
d = linalg.det(a)
(s, ld) = linalg.slogdet(a)
if asarray(a).dtype.type in (single, double):
@@ -645,9 +734,13 @@ class TestDet(LinalgTestCase, LinalgGeneralizedTestCase):
yield check, dtype
-class TestLstsq(LinalgTestCase, LinalgNonsquareTestCase):
+class TestLstsq(LinalgSquareTestCase, LinalgNonsquareTestCase):
+
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.lstsq, a, b)
+ return
- def do(self, a, b):
arr = np.asarray(a)
m, n = arr.shape
u, s, vt = linalg.svd(a, 0)
@@ -738,7 +831,10 @@ class TestBoolPower(object):
class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.eigvalsh, a, 'L')
+ return
# note that eigenvalue arrays returned by eig must be sorted since
# their order isn't guaranteed.
ev = linalg.eigvalsh(a, 'L')
@@ -788,7 +884,10 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase):
class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
- def do(self, a, b):
+ def do(self, a, b, flags):
+ if flags & CaseFlags.empty:
+ assert_raises(LinAlgError, linalg.eigh, a)
+ return
# note that eigenvalue arrays returned by eig must be sorted since
# their order isn't guaranteed.
ev, evc = linalg.eigh(a)