diff options
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 122 |
1 files changed, 52 insertions, 70 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 98a77d8f5..6bb2cab7a 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -462,12 +462,10 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestSolve(SolveCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(linalg.solve(x, x).dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.solve(x, x).dtype, dtype) def test_0_size(self): class ArraySubclass(np.ndarray): @@ -531,12 +529,10 @@ class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestInv(InvCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(linalg.inv(x).dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.inv(x).dtype, dtype) def test_0_size(self): # Check that all kinds of 0-sized arrays work @@ -564,14 +560,12 @@ class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestEigvals(EigvalsCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(linalg.eigvals(x).dtype, dtype) - x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) - assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype)) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.eigvals(x).dtype, dtype) + x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) + assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype)) def test_0_size(self): # Check that all kinds of 0-sized arrays work @@ -603,20 +597,17 @@ class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestEig(EigCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - w, v = np.linalg.eig(x) - assert_equal(w.dtype, dtype) - assert_equal(v.dtype, dtype) - - x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) - w, v = np.linalg.eig(x) - assert_equal(w.dtype, get_complex_dtype(dtype)) - assert_equal(v.dtype, get_complex_dtype(dtype)) - - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w, v = np.linalg.eig(x) + assert_equal(w.dtype, dtype) + assert_equal(v.dtype, dtype) + + x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) + w, v = np.linalg.eig(x) + assert_equal(w.dtype, get_complex_dtype(dtype)) + assert_equal(v.dtype, get_complex_dtype(dtype)) def test_0_size(self): # Check that all kinds of 0-sized arrays work @@ -653,18 +644,15 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestSVD(SVDCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - u, s, vh = linalg.svd(x) - assert_equal(u.dtype, dtype) - assert_equal(s.dtype, get_real_dtype(dtype)) - assert_equal(vh.dtype, dtype) - s = linalg.svd(x, compute_uv=False) - assert_equal(s.dtype, get_real_dtype(dtype)) - - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + u, s, vh = linalg.svd(x) + assert_equal(u.dtype, dtype) + assert_equal(s.dtype, get_real_dtype(dtype)) + assert_equal(vh.dtype, dtype) + s = linalg.svd(x, compute_uv=False) + assert_equal(s.dtype, get_real_dtype(dtype)) def test_empty_identity(self): """ Empty input should put an identity matrix in u or vh """ @@ -842,15 +830,13 @@ class TestDet(DetCases): assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble) assert_equal(type(linalg.slogdet([[0.0j]])[1]), double) - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(np.linalg.det(x).dtype, dtype) - ph, s = np.linalg.slogdet(x) - assert_equal(s.dtype, get_real_dtype(dtype)) - assert_equal(ph.dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(np.linalg.det(x).dtype, dtype) + ph, s = np.linalg.slogdet(x) + assert_equal(s.dtype, get_real_dtype(dtype)) + assert_equal(ph.dtype, dtype) def test_0_size(self): a = np.zeros((0, 0), dtype=np.complex64) @@ -1049,13 +1035,11 @@ class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase): class TestEigvalsh(object): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - w = np.linalg.eigvalsh(x) - assert_equal(w.dtype, get_real_dtype(dtype)) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w = np.linalg.eigvalsh(x) + assert_equal(w.dtype, get_real_dtype(dtype)) def test_invalid(self): x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) @@ -1127,14 +1111,12 @@ class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase): class TestEigh(object): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - w, v = np.linalg.eigh(x) - assert_equal(w.dtype, get_real_dtype(dtype)) - assert_equal(v.dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w, v = np.linalg.eigh(x) + assert_equal(w.dtype, get_real_dtype(dtype)) + assert_equal(v.dtype, dtype) def test_invalid(self): x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) |