summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Yashchuk <ivan.yashchuk@aalto.fi>2021-11-29 10:45:15 +0000
committerIvan Yashchuk <ivan.yashchuk@aalto.fi>2021-11-29 11:09:34 +0000
commit729b85e5d49257a20249e5c2d4f8f4072594f984 (patch)
treecf6e0b94b2ac416046ba2de85f4e7f329f856a97
parentc8fdb53ee2d1b1130a74ad7a3250366fd042da3f (diff)
downloadnumpy-729b85e5d49257a20249e5c2d4f8f4072594f984.tar.gz
BUG: Fix tensorsolve for 0-sized input
`array.reshape(-1, size)` doesn't work with 0 in the dimensions. The fix is to use explicit shape instead of `-1`. For "non-square" tensors the `ValueError` would come from the reshape call, while previously `LinAlgError` appeared from the solve call. To have the same error type I added a check for squareness before the reshape.
-rw-r--r--numpy/linalg/linalg.py8
-rw-r--r--numpy/linalg/tests/test_linalg.py21
2 files changed, 28 insertions, 1 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 0c27e0631..d831886c0 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -299,7 +299,13 @@ def tensorsolve(a, b, axes=None):
for k in oldshape:
prod *= k
- a = a.reshape(-1, prod)
+ if a.size != prod ** 2:
+ raise LinAlgError(
+ "Input arrays must satisfy the requirement \
+ prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
+ )
+
+ a = a.reshape(prod, prod)
b = b.ravel()
res = wrap(solve(a, b))
res.shape = oldshape
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index c1ba84a8e..2462b3996 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -2103,6 +2103,27 @@ class TestTensorinv:
assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
+class TestTensorsolve:
+
+ @pytest.mark.parametrize("a, axes", [
+ (np.ones((4, 6, 8, 2)), None),
+ (np.ones((3, 3, 2)), (0, 2)),
+ ])
+ def test_non_square_handling(self, a, axes):
+ with assert_raises(LinAlgError):
+ b = np.ones(a.shape[:2])
+ linalg.tensorsolve(a, b, axes=axes)
+
+ @pytest.mark.parametrize("shape",
+ [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)],
+ )
+ def test_tensorsolve_result(self, shape):
+ a = np.random.randn(*shape)
+ b = np.ones(a.shape[:2])
+ x = np.linalg.tensorsolve(a, b)
+ assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b)
+
+
def test_unsupported_commontype():
# linalg gracefully handles unsupported type
arr = np.array([[1, -2], [2, 5]], dtype='float16')