diff options
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 141 |
1 files changed, 140 insertions, 1 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index baa195241..c612eb6bb 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -542,12 +542,13 @@ class TestInv(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): res = linalg.inv(a) assert_(res.dtype.type is np.float64) assert_equal(a.shape, res.shape) - assert_(isinstance(a, ArraySubclass)) + assert_(isinstance(res, ArraySubclass)) a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) res = linalg.inv(a) assert_(res.dtype.type is np.complex64) assert_equal(a.shape, res.shape) + assert_(isinstance(res, ArraySubclass)) class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -566,6 +567,24 @@ class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res = linalg.eigvals(a) + assert_(res.dtype.type is np.float64) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res = linalg.eigvals(a) + assert_(res.dtype.type is np.complex64) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -591,6 +610,28 @@ class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res, res_v = linalg.eig(a) + assert_(res_v.dtype.type is np.float64) + assert_(res.dtype.type is np.float64) + assert_equal(a.shape, res_v.shape) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res, res_v = linalg.eig(a) + assert_(res_v.dtype.type is np.complex64) + assert_(res.dtype.type is np.complex64) + assert_equal(a.shape, res_v.shape) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -619,6 +660,16 @@ class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # These raise errors currently + # (which does not mean that it may not make sense) + a = np.zeros((0, 0), dtype=np.complex64) + assert_raises(linalg.LinAlgError, linalg.svd, a) + a = np.zeros((0, 1), dtype=np.complex64) + assert_raises(linalg.LinAlgError, linalg.svd, a) + a = np.zeros((1, 0), dtype=np.complex64) + assert_raises(linalg.LinAlgError, linalg.svd, a) + class TestCondSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -712,6 +763,25 @@ class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + a = np.zeros((0, 0), dtype=np.complex64) + res = linalg.det(a) + assert_equal(res, 1.) + assert_(res.dtype.type is np.complex64) + res = linalg.slogdet(a) + assert_equal(res, (1, 0)) + assert_(res[0].dtype.type is np.complex64) + assert_(res[1].dtype.type is np.float32) + + a = np.zeros((0, 0), dtype=np.float64) + res = linalg.det(a) + assert_equal(res, 1.) + assert_(res.dtype.type is np.float64) + res = linalg.slogdet(a) + assert_equal(res, (1, 0)) + assert_(res[0].dtype.type is np.float64) + assert_(res[1].dtype.type is np.float64) + class TestLstsq(LinalgSquareTestCase, LinalgNonsquareTestCase): @@ -857,6 +927,24 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase): w = np.linalg.eigvalsh(Kup, UPLO='u') assert_allclose(w, tgt, rtol=rtol) + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res = linalg.eigvalsh(a) + assert_(res.dtype.type is np.float64) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res = linalg.eigvalsh(a) + assert_(res.dtype.type is np.float32) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): @@ -916,6 +1004,28 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): w, v = np.linalg.eigh(Kup, UPLO='u') assert_allclose(w, tgt, rtol=rtol) + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res, res_v = linalg.eigh(a) + assert_(res_v.dtype.type is np.float64) + assert_(res.dtype.type is np.float64) + assert_equal(a.shape, res_v.shape) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res, res_v = linalg.eigh(a) + assert_(res_v.dtype.type is np.complex64) + assert_(res.dtype.type is np.float32) + assert_equal(a.shape, res_v.shape) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + class _TestNorm(object): @@ -1350,6 +1460,35 @@ class TestQR(object): self.check_qr(m2.T) self.check_qr(matrix(m1)) + def test_0_size(self): + # There may be good ways to do (some of this) reasonably: + a = np.zeros((0, 0)) + assert_raises(linalg.LinAlgError, linalg.qr, a) + a = np.zeros((0, 1)) + assert_raises(linalg.LinAlgError, linalg.qr, a) + a = np.zeros((1, 0)) + assert_raises(linalg.LinAlgError, linalg.qr, a) + + +class TestCholesky(object): + # TODO: are there no other tests for cholesky? + + def test_0_size(self): + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res = linalg.cholesky(a) + assert_equal(a.shape, res.shape) + assert_(res.dtype.type is np.float64) + # for documentation purpose: + assert_(isinstance(res, np.ndarray)) + + a = np.zeros((1, 0, 0), dtype=np.complex64).view(ArraySubclass) + res = linalg.cholesky(a) + assert_equal(a.shape, res.shape) + assert_(res.dtype.type is np.complex64) + assert_(isinstance(res, np.ndarray)) + def test_byteorder_check(): # Byte order check should pass for native order |