From fa55f4c463806599bccf145baf22e13ff79f9a68 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Wed, 10 Jul 2013 14:19:42 +0200 Subject: ENH: inv/solve work with empty inner and others empty outer array This makes the inverse of a 0x0 array simply be 0x0 again. It also modifies the no-empty array check in favor of a no-empty *inner* array, since the gufuncs seem to handle the other case fine. --- numpy/linalg/tests/test_linalg.py | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) (limited to 'numpy/linalg/tests') diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 881311c94..7f102634e 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -204,6 +204,39 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + class ArraySubclass(np.ndarray): + pass + # Test system of 0x0 matrices + a = np.arange(8).reshape(2, 2, 2) + b = np.arange(6).reshape(1, 2, 3).view(ArraySubclass) + + expected = linalg.solve(a, b)[:,0:0,:] + result = linalg.solve(a[:,0:0,0:0], b[:,0:0,:]) + assert_array_equal(result, expected) + assert_(isinstance(result, ArraySubclass)) + + # Test errors for non-square and only b's dimension being 0 + assert_raises(linalg.LinAlgError, linalg.solve, a[:,0:0,0:1], b) + assert_raises(ValueError, linalg.solve, a, b[:,0:0,:]) + + # Test broadcasting error + b = np.arange(6).reshape(1, 3, 2) # broadcasting error + assert_raises(ValueError, linalg.solve, a, b) + assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) + + # Test zero "single equations" with 0x0 matrices. + b = np.arange(2).reshape(1, 2).view(ArraySubclass) + expected = linalg.solve(a, b)[:,0:0] + result = linalg.solve(a[:,0:0,0:0], b[:,0:0]) + assert_array_equal(result, expected) + assert_(isinstance(result, ArraySubclass)) + + b = np.arange(3).reshape(1, 3) + assert_raises(ValueError, linalg.solve, a, b) + assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) + assert_raises(ValueError, linalg.solve, a[:,0:0,0:0], b) + class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): @@ -219,6 +252,21 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0,1,1), dtype=np.int_).view(ArraySubclass) + res = linalg.inv(a) + assert_(res.dtype.type is np.float64) + assert_equal(a.shape, res.shape) + assert_(isinstance(a, ArraySubclass)) + + a = np.zeros((0,0), dtype=np.complex64).view(ArraySubclass) + res = linalg.inv(a) + assert_(res.dtype.type is np.complex64) + assert_equal(a.shape, res.shape) + class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): -- cgit v1.2.1