diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 9 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 18 |
2 files changed, 23 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 78e487a25..a1bb842ee 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -368,10 +368,11 @@ def solve(a, b): gufunc = _umath_linalg.solve1 else: - if a.shape[-1] == 0 and b.shape[-2] == 0: - a = a.reshape(a.shape[:-1] + (1,)) - bc = broadcast(a, b) - return wrap(empty(bc.shape, dtype=result_t)) + if b.size == 0: + if (a.shape[-1] == 0 and b.shape[-2] == 0) or b.shape[-1] == 0: + a = a[:,:1].reshape(a.shape[:-1] + (1,)) + bc = broadcast(a, b) + return wrap(empty(bc.shape, dtype=result_t)) gufunc = _umath_linalg.solve diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index af9c778a3..2c003a072 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -246,6 +246,24 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): assert_raises(ValueError, linalg.solve, a[0:0], b[0:0]) assert_raises(ValueError, linalg.solve, a[:, 0:0, 0:0], b) + def test_0_size_k(self): + # test zero multiple equation (K=0) case. + class ArraySubclass(np.ndarray): + pass + a = np.arange(4).reshape(1, 2, 2) + b = np.arange(6).reshape(3, 2, 1).view(ArraySubclass) + + expected = linalg.solve(a, b)[:,:, 0:0] + result = linalg.solve(a, b[:,:, 0:0]) + assert_array_equal(result, expected) + assert_(isinstance(result, ArraySubclass)) + + # test both zero. + expected = linalg.solve(a, b)[:, 0:0, 0:0] + result = linalg.solve(a[:, 0:0, 0:0], b[:,0:0, 0:0]) + assert_array_equal(result, expected) + assert_(isinstance(result, ArraySubclass)) + class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): |