diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-10-26 12:18:41 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-26 12:18:41 -0500 |
commit | c8709d66ea0b0a201d80e660e27b636ec76bf289 (patch) | |
tree | 055ccb7f0683cfc573ace253d309fb7f7d32b26a | |
parent | 3debe9772ea1b68d997dba3440929a467ad11c52 (diff) | |
parent | e9b1fb1c8f0abb6aa27b223d969b2f0797f77235 (diff) | |
download | numpy-c8709d66ea0b0a201d80e660e27b636ec76bf289.tar.gz |
Merge pull request #12214 from tylerjereddy/test_tensorinv
TST: add test for tensorinv()
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 0e94c2633..836681039 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1915,3 +1915,44 @@ class TestMultiDot(object): def test_too_few_input_arrays(self): assert_raises(ValueError, multi_dot, []) assert_raises(ValueError, multi_dot, [np.random.random((3, 3))]) + + +class TestTensorinv(object): + + @pytest.mark.parametrize("arr, ind", [ + (np.ones((4, 6, 8, 2)), 2), + (np.ones((3, 3, 2)), 1), + ]) + def test_non_square_handling(self, arr, ind): + with assert_raises(LinAlgError): + linalg.tensorinv(arr, ind=ind) + + @pytest.mark.parametrize("shape, ind", [ + # examples from docstring + ((4, 6, 8, 3), 2), + ((24, 8, 3), 1), + ]) + def test_tensorinv_shape(self, shape, ind): + a = np.eye(24) + a.shape = shape + ainv = linalg.tensorinv(a=a, ind=ind) + expected = a.shape[ind:] + a.shape[:ind] + actual = ainv.shape + assert_equal(actual, expected) + + @pytest.mark.parametrize("ind", [ + 0, -2, + ]) + def test_tensorinv_ind_limit(self, ind): + a = np.eye(24) + a.shape = (4, 6, 8, 3) + with assert_raises(ValueError): + linalg.tensorinv(a=a, ind=ind) + + def test_tensorinv_result(self): + # mimic a docstring example + a = np.eye(24) + a.shape = (24, 8, 3) + ainv = linalg.tensorinv(a, ind=1) + b = np.ones(24) + assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) |