summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests/test_linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r--numpy/linalg/tests/test_linalg.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 881311c94..7f102634e 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -204,6 +204,39 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ class ArraySubclass(np.ndarray):
+ pass
+ # Test system of 0x0 matrices
+ a = np.arange(8).reshape(2, 2, 2)
+ b = np.arange(6).reshape(1, 2, 3).view(ArraySubclass)
+
+ expected = linalg.solve(a, b)[:,0:0,:]
+ result = linalg.solve(a[:,0:0,0:0], b[:,0:0,:])
+ assert_array_equal(result, expected)
+ assert_(isinstance(result, ArraySubclass))
+
+ # Test errors for non-square and only b's dimension being 0
+ assert_raises(linalg.LinAlgError, linalg.solve, a[:,0:0,0:1], b)
+ assert_raises(ValueError, linalg.solve, a, b[:,0:0,:])
+
+ # Test broadcasting error
+ b = np.arange(6).reshape(1, 3, 2) # broadcasting error
+ assert_raises(ValueError, linalg.solve, a, b)
+ assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
+
+ # Test zero "single equations" with 0x0 matrices.
+ b = np.arange(2).reshape(1, 2).view(ArraySubclass)
+ expected = linalg.solve(a, b)[:,0:0]
+ result = linalg.solve(a[:,0:0,0:0], b[:,0:0])
+ assert_array_equal(result, expected)
+ assert_(isinstance(result, ArraySubclass))
+
+ b = np.arange(3).reshape(1, 3)
+ assert_raises(ValueError, linalg.solve, a, b)
+ assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
+ assert_raises(ValueError, linalg.solve, a[:,0:0,0:0], b)
+
class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -219,6 +252,21 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0,1,1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.inv(a)
+ assert_(res.dtype.type is np.float64)
+ assert_equal(a.shape, res.shape)
+ assert_(isinstance(a, ArraySubclass))
+
+ a = np.zeros((0,0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.inv(a)
+ assert_(res.dtype.type is np.complex64)
+ assert_equal(a.shape, res.shape)
+
class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):