summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-12-15 13:11:21 +0100
committerGitHub <noreply@github.com>2021-12-15 13:11:21 +0100
commit7d2bb368020c116433f8f8dfdbc63925efc23ddd (patch)
tree36d2092157e5b3dafd32489db03116907973d395
parentf8ebfa4a3182033503ac387ebbce9d4336f62a97 (diff)
parent729b85e5d49257a20249e5c2d4f8f4072594f984 (diff)
downloadnumpy-7d2bb368020c116433f8f8dfdbc63925efc23ddd.tar.gz
Merge pull request #20482 from IvanYashchuk/fix-zerodim-tensorsolve
BUG: Fix tensorsolve for 0-sized input
-rw-r--r--numpy/linalg/linalg.py8
-rw-r--r--numpy/linalg/tests/test_linalg.py21
2 files changed, 28 insertions, 1 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 0c27e0631..d831886c0 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -299,7 +299,13 @@ def tensorsolve(a, b, axes=None):
for k in oldshape:
prod *= k
- a = a.reshape(-1, prod)
+ if a.size != prod ** 2:
+ raise LinAlgError(
+ "Input arrays must satisfy the requirement \
+ prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
+ )
+
+ a = a.reshape(prod, prod)
b = b.ravel()
res = wrap(solve(a, b))
res.shape = oldshape
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index c1ba84a8e..2462b3996 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -2103,6 +2103,27 @@ class TestTensorinv:
assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
+class TestTensorsolve:
+
+ @pytest.mark.parametrize("a, axes", [
+ (np.ones((4, 6, 8, 2)), None),
+ (np.ones((3, 3, 2)), (0, 2)),
+ ])
+ def test_non_square_handling(self, a, axes):
+ with assert_raises(LinAlgError):
+ b = np.ones(a.shape[:2])
+ linalg.tensorsolve(a, b, axes=axes)
+
+ @pytest.mark.parametrize("shape",
+ [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)],
+ )
+ def test_tensorsolve_result(self, shape):
+ a = np.random.randn(*shape)
+ b = np.ones(a.shape[:2])
+ x = np.linalg.tensorsolve(a, b)
+ assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b)
+
+
def test_unsupported_commontype():
# linalg gracefully handles unsupported type
arr = np.array([[1, -2], [2, 5]], dtype='float16')