summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/linalg/tests/test_linalg.py122
1 files changed, 52 insertions, 70 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 98a77d8f5..6bb2cab7a 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -462,12 +462,10 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
class TestSolve(SolveCases):
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- assert_equal(linalg.solve(x, x).dtype, dtype)
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(linalg.solve(x, x).dtype, dtype)
def test_0_size(self):
class ArraySubclass(np.ndarray):
@@ -531,12 +529,10 @@ class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
class TestInv(InvCases):
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- assert_equal(linalg.inv(x).dtype, dtype)
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(linalg.inv(x).dtype, dtype)
def test_0_size(self):
# Check that all kinds of 0-sized arrays work
@@ -564,14 +560,12 @@ class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
class TestEigvals(EigvalsCases):
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- assert_equal(linalg.eigvals(x).dtype, dtype)
- x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
- assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(linalg.eigvals(x).dtype, dtype)
+ x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
+ assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
def test_0_size(self):
# Check that all kinds of 0-sized arrays work
@@ -603,20 +597,17 @@ class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
class TestEig(EigCases):
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- w, v = np.linalg.eig(x)
- assert_equal(w.dtype, dtype)
- assert_equal(v.dtype, dtype)
-
- x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
- w, v = np.linalg.eig(x)
- assert_equal(w.dtype, get_complex_dtype(dtype))
- assert_equal(v.dtype, get_complex_dtype(dtype))
-
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ w, v = np.linalg.eig(x)
+ assert_equal(w.dtype, dtype)
+ assert_equal(v.dtype, dtype)
+
+ x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
+ w, v = np.linalg.eig(x)
+ assert_equal(w.dtype, get_complex_dtype(dtype))
+ assert_equal(v.dtype, get_complex_dtype(dtype))
def test_0_size(self):
# Check that all kinds of 0-sized arrays work
@@ -653,18 +644,15 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
class TestSVD(SVDCases):
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- u, s, vh = linalg.svd(x)
- assert_equal(u.dtype, dtype)
- assert_equal(s.dtype, get_real_dtype(dtype))
- assert_equal(vh.dtype, dtype)
- s = linalg.svd(x, compute_uv=False)
- assert_equal(s.dtype, get_real_dtype(dtype))
-
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ u, s, vh = linalg.svd(x)
+ assert_equal(u.dtype, dtype)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+ assert_equal(vh.dtype, dtype)
+ s = linalg.svd(x, compute_uv=False)
+ assert_equal(s.dtype, get_real_dtype(dtype))
def test_empty_identity(self):
""" Empty input should put an identity matrix in u or vh """
@@ -842,15 +830,13 @@ class TestDet(DetCases):
assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble)
assert_equal(type(linalg.slogdet([[0.0j]])[1]), double)
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- assert_equal(np.linalg.det(x).dtype, dtype)
- ph, s = np.linalg.slogdet(x)
- assert_equal(s.dtype, get_real_dtype(dtype))
- assert_equal(ph.dtype, dtype)
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(np.linalg.det(x).dtype, dtype)
+ ph, s = np.linalg.slogdet(x)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+ assert_equal(ph.dtype, dtype)
def test_0_size(self):
a = np.zeros((0, 0), dtype=np.complex64)
@@ -1049,13 +1035,11 @@ class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase):
class TestEigvalsh(object):
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- w = np.linalg.eigvalsh(x)
- assert_equal(w.dtype, get_real_dtype(dtype))
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ w = np.linalg.eigvalsh(x)
+ assert_equal(w.dtype, get_real_dtype(dtype))
def test_invalid(self):
x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
@@ -1127,14 +1111,12 @@ class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase):
class TestEigh(object):
- def test_types(self):
- def check(dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- w, v = np.linalg.eigh(x)
- assert_equal(w.dtype, get_real_dtype(dtype))
- assert_equal(v.dtype, dtype)
- for dtype in [single, double, csingle, cdouble]:
- check(dtype)
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ w, v = np.linalg.eigh(x)
+ assert_equal(w.dtype, get_real_dtype(dtype))
+ assert_equal(v.dtype, dtype)
def test_invalid(self):
x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)