summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-10-26 12:18:41 -0500
committerGitHub <noreply@github.com>2018-10-26 12:18:41 -0500
commitc8709d66ea0b0a201d80e660e27b636ec76bf289 (patch)
tree055ccb7f0683cfc573ace253d309fb7f7d32b26a
parent3debe9772ea1b68d997dba3440929a467ad11c52 (diff)
parente9b1fb1c8f0abb6aa27b223d969b2f0797f77235 (diff)
downloadnumpy-c8709d66ea0b0a201d80e660e27b636ec76bf289.tar.gz
Merge pull request #12214 from tylerjereddy/test_tensorinv
TST: add test for tensorinv()
-rw-r--r--numpy/linalg/tests/test_linalg.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 0e94c2633..836681039 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -1915,3 +1915,44 @@ class TestMultiDot(object):
def test_too_few_input_arrays(self):
assert_raises(ValueError, multi_dot, [])
assert_raises(ValueError, multi_dot, [np.random.random((3, 3))])
+
+
+class TestTensorinv(object):
+
+ @pytest.mark.parametrize("arr, ind", [
+ (np.ones((4, 6, 8, 2)), 2),
+ (np.ones((3, 3, 2)), 1),
+ ])
+ def test_non_square_handling(self, arr, ind):
+ with assert_raises(LinAlgError):
+ linalg.tensorinv(arr, ind=ind)
+
+ @pytest.mark.parametrize("shape, ind", [
+ # examples from docstring
+ ((4, 6, 8, 3), 2),
+ ((24, 8, 3), 1),
+ ])
+ def test_tensorinv_shape(self, shape, ind):
+ a = np.eye(24)
+ a.shape = shape
+ ainv = linalg.tensorinv(a=a, ind=ind)
+ expected = a.shape[ind:] + a.shape[:ind]
+ actual = ainv.shape
+ assert_equal(actual, expected)
+
+ @pytest.mark.parametrize("ind", [
+ 0, -2,
+ ])
+ def test_tensorinv_ind_limit(self, ind):
+ a = np.eye(24)
+ a.shape = (4, 6, 8, 3)
+ with assert_raises(ValueError):
+ linalg.tensorinv(a=a, ind=ind)
+
+ def test_tensorinv_result(self):
+ # mimic a docstring example
+ a = np.eye(24)
+ a.shape = (24, 8, 3)
+ ainv = linalg.tensorinv(a, ind=1)
+ b = np.ones(24)
+ assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))