summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests/test_linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r--numpy/linalg/tests/test_linalg.py141
1 files changed, 140 insertions, 1 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index baa195241..c612eb6bb 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -542,12 +542,13 @@ class TestInv(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
res = linalg.inv(a)
assert_(res.dtype.type is np.float64)
assert_equal(a.shape, res.shape)
- assert_(isinstance(a, ArraySubclass))
+ assert_(isinstance(res, ArraySubclass))
a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
res = linalg.inv(a)
assert_(res.dtype.type is np.complex64)
assert_equal(a.shape, res.shape)
+ assert_(isinstance(res, ArraySubclass))
class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -566,6 +567,24 @@ class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.eigvals(a)
+ assert_(res.dtype.type is np.float64)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.eigvals(a)
+ assert_(res.dtype.type is np.complex64)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -591,6 +610,28 @@ class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res, res_v = linalg.eig(a)
+ assert_(res_v.dtype.type is np.float64)
+ assert_(res.dtype.type is np.float64)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res, res_v = linalg.eig(a)
+ assert_(res_v.dtype.type is np.complex64)
+ assert_(res.dtype.type is np.complex64)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -619,6 +660,16 @@ class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # These raise errors currently
+ # (which does not mean that it may not make sense)
+ a = np.zeros((0, 0), dtype=np.complex64)
+ assert_raises(linalg.LinAlgError, linalg.svd, a)
+ a = np.zeros((0, 1), dtype=np.complex64)
+ assert_raises(linalg.LinAlgError, linalg.svd, a)
+ a = np.zeros((1, 0), dtype=np.complex64)
+ assert_raises(linalg.LinAlgError, linalg.svd, a)
+
class TestCondSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -712,6 +763,25 @@ class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ a = np.zeros((0, 0), dtype=np.complex64)
+ res = linalg.det(a)
+ assert_equal(res, 1.)
+ assert_(res.dtype.type is np.complex64)
+ res = linalg.slogdet(a)
+ assert_equal(res, (1, 0))
+ assert_(res[0].dtype.type is np.complex64)
+ assert_(res[1].dtype.type is np.float32)
+
+ a = np.zeros((0, 0), dtype=np.float64)
+ res = linalg.det(a)
+ assert_equal(res, 1.)
+ assert_(res.dtype.type is np.float64)
+ res = linalg.slogdet(a)
+ assert_equal(res, (1, 0))
+ assert_(res[0].dtype.type is np.float64)
+ assert_(res[1].dtype.type is np.float64)
+
class TestLstsq(LinalgSquareTestCase, LinalgNonsquareTestCase):
@@ -857,6 +927,24 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase):
w = np.linalg.eigvalsh(Kup, UPLO='u')
assert_allclose(w, tgt, rtol=rtol)
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.eigvalsh(a)
+ assert_(res.dtype.type is np.float64)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.eigvalsh(a)
+ assert_(res.dtype.type is np.float32)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
@@ -916,6 +1004,28 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
w, v = np.linalg.eigh(Kup, UPLO='u')
assert_allclose(w, tgt, rtol=rtol)
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res, res_v = linalg.eigh(a)
+ assert_(res_v.dtype.type is np.float64)
+ assert_(res.dtype.type is np.float64)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res, res_v = linalg.eigh(a)
+ assert_(res_v.dtype.type is np.complex64)
+ assert_(res.dtype.type is np.float32)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
class _TestNorm(object):
@@ -1350,6 +1460,35 @@ class TestQR(object):
self.check_qr(m2.T)
self.check_qr(matrix(m1))
+ def test_0_size(self):
+ # There may be good ways to do (some of this) reasonably:
+ a = np.zeros((0, 0))
+ assert_raises(linalg.LinAlgError, linalg.qr, a)
+ a = np.zeros((0, 1))
+ assert_raises(linalg.LinAlgError, linalg.qr, a)
+ a = np.zeros((1, 0))
+ assert_raises(linalg.LinAlgError, linalg.qr, a)
+
+
+class TestCholesky(object):
+ # TODO: are there no other tests for cholesky?
+
+ def test_0_size(self):
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.cholesky(a)
+ assert_equal(a.shape, res.shape)
+ assert_(res.dtype.type is np.float64)
+ # for documentation purpose:
+ assert_(isinstance(res, np.ndarray))
+
+ a = np.zeros((1, 0, 0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.cholesky(a)
+ assert_equal(a.shape, res.shape)
+ assert_(res.dtype.type is np.complex64)
+ assert_(isinstance(res, np.ndarray))
+
def test_byteorder_check():
# Byte order check should pass for native order