diff options
author | Ivan Yashchuk <ivan.yashchuk@aalto.fi> | 2021-11-29 10:45:15 +0000 |
---|---|---|
committer | Ivan Yashchuk <ivan.yashchuk@aalto.fi> | 2021-11-29 11:09:34 +0000 |
commit | 729b85e5d49257a20249e5c2d4f8f4072594f984 (patch) | |
tree | cf6e0b94b2ac416046ba2de85f4e7f329f856a97 | |
parent | c8fdb53ee2d1b1130a74ad7a3250366fd042da3f (diff) | |
download | numpy-729b85e5d49257a20249e5c2d4f8f4072594f984.tar.gz |
BUG: Fix tensorsolve for 0-sized input
`array.reshape(-1, size)` doesn't work with 0 in the dimensions.
The fix is to use explicit shape instead of `-1`. For "non-square"
tensors the `ValueError` would come from the reshape call, while previously
`LinAlgError` appeared from the solve call. To have the same error type
I added a check for squareness before the reshape.
-rw-r--r-- | numpy/linalg/linalg.py | 8 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 21 |
2 files changed, 28 insertions, 1 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 0c27e0631..d831886c0 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -299,7 +299,13 @@ def tensorsolve(a, b, axes=None): for k in oldshape: prod *= k - a = a.reshape(-1, prod) + if a.size != prod ** 2: + raise LinAlgError( + "Input arrays must satisfy the requirement \ + prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])" + ) + + a = a.reshape(prod, prod) b = b.ravel() res = wrap(solve(a, b)) res.shape = oldshape diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index c1ba84a8e..2462b3996 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -2103,6 +2103,27 @@ class TestTensorinv: assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) +class TestTensorsolve: + + @pytest.mark.parametrize("a, axes", [ + (np.ones((4, 6, 8, 2)), None), + (np.ones((3, 3, 2)), (0, 2)), + ]) + def test_non_square_handling(self, a, axes): + with assert_raises(LinAlgError): + b = np.ones(a.shape[:2]) + linalg.tensorsolve(a, b, axes=axes) + + @pytest.mark.parametrize("shape", + [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)], + ) + def test_tensorsolve_result(self, shape): + a = np.random.randn(*shape) + b = np.ones(a.shape[:2]) + x = np.linalg.tensorsolve(a, b) + assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b) + + def test_unsupported_commontype(): # linalg gracefully handles unsupported type arr = np.array([[1, -2], [2, 5]], dtype='float16') |