summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py9
-rw-r--r--numpy/linalg/tests/test_linalg.py18
2 files changed, 23 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 78e487a25..a1bb842ee 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -368,10 +368,11 @@ def solve(a, b):
gufunc = _umath_linalg.solve1
else:
- if a.shape[-1] == 0 and b.shape[-2] == 0:
- a = a.reshape(a.shape[:-1] + (1,))
- bc = broadcast(a, b)
- return wrap(empty(bc.shape, dtype=result_t))
+ if b.size == 0:
+ if (a.shape[-1] == 0 and b.shape[-2] == 0) or b.shape[-1] == 0:
+ a = a[:,:1].reshape(a.shape[:-1] + (1,))
+ bc = broadcast(a, b)
+ return wrap(empty(bc.shape, dtype=result_t))
gufunc = _umath_linalg.solve
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index af9c778a3..2c003a072 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -246,6 +246,24 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
assert_raises(ValueError, linalg.solve, a[:, 0:0, 0:0], b)
+ def test_0_size_k(self):
+ # test zero multiple equation (K=0) case.
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.arange(4).reshape(1, 2, 2)
+ b = np.arange(6).reshape(3, 2, 1).view(ArraySubclass)
+
+ expected = linalg.solve(a, b)[:,:, 0:0]
+ result = linalg.solve(a, b[:,:, 0:0])
+ assert_array_equal(result, expected)
+ assert_(isinstance(result, ArraySubclass))
+
+ # test both zero.
+ expected = linalg.solve(a, b)[:, 0:0, 0:0]
+ result = linalg.solve(a[:, 0:0, 0:0], b[:,0:0, 0:0])
+ assert_array_equal(result, expected)
+ assert_(isinstance(result, ArraySubclass))
+
class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):