diff options
author | Ralf Gommers <ralf.gommers@gmail.com> | 2021-12-15 13:11:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-15 13:11:21 +0100 |
commit | 7d2bb368020c116433f8f8dfdbc63925efc23ddd (patch) | |
tree | 36d2092157e5b3dafd32489db03116907973d395 | |
parent | f8ebfa4a3182033503ac387ebbce9d4336f62a97 (diff) | |
parent | 729b85e5d49257a20249e5c2d4f8f4072594f984 (diff) | |
download | numpy-7d2bb368020c116433f8f8dfdbc63925efc23ddd.tar.gz |
Merge pull request #20482 from IvanYashchuk/fix-zerodim-tensorsolve
BUG: Fix tensorsolve for 0-sized input
-rw-r--r-- | numpy/linalg/linalg.py | 8 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 21 |
2 files changed, 28 insertions, 1 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 0c27e0631..d831886c0 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -299,7 +299,13 @@ def tensorsolve(a, b, axes=None): for k in oldshape: prod *= k - a = a.reshape(-1, prod) + if a.size != prod ** 2: + raise LinAlgError( + "Input arrays must satisfy the requirement \ + prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])" + ) + + a = a.reshape(prod, prod) b = b.ravel() res = wrap(solve(a, b)) res.shape = oldshape diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index c1ba84a8e..2462b3996 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -2103,6 +2103,27 @@ class TestTensorinv: assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) +class TestTensorsolve: + + @pytest.mark.parametrize("a, axes", [ + (np.ones((4, 6, 8, 2)), None), + (np.ones((3, 3, 2)), (0, 2)), + ]) + def test_non_square_handling(self, a, axes): + with assert_raises(LinAlgError): + b = np.ones(a.shape[:2]) + linalg.tensorsolve(a, b, axes=axes) + + @pytest.mark.parametrize("shape", + [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)], + ) + def test_tensorsolve_result(self, shape): + a = np.random.randn(*shape) + b = np.ones(a.shape[:2]) + x = np.linalg.tensorsolve(a, b) + assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b) + + def test_unsupported_commontype(): # linalg gracefully handles unsupported type arr = np.array([[1, -2], [2, 5]], dtype='float16') |