diff options
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 57 |
1 files changed, 52 insertions, 5 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 0e94c2633..235488c6e 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -13,13 +13,14 @@ import pytest import numpy as np from numpy import array, single, double, csingle, cdouble, dot, identity, matmul -from numpy import multiply, atleast_2d, inf, asarray, matrix +from numpy import multiply, atleast_2d, inf, asarray from numpy import linalg from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError from numpy.linalg.linalg import _multi_dot_matrix_chain_order from numpy.testing import ( assert_, assert_equal, assert_raises, assert_array_equal, - assert_almost_equal, assert_allclose, suppress_warnings + assert_almost_equal, assert_allclose, suppress_warnings, + assert_raises_regex, ) @@ -931,6 +932,14 @@ class TestLstsq(LstsqCases): assert_equal(rank, min(m, n)) assert_equal(s.shape, (min(m, n),)) + def test_incompatible_dims(self): + # use modified version of docstring example + x = np.array([0, 1, 2, 3]) + y = np.array([-1, 0.2, 0.9, 2.1, 3.3]) + A = np.vstack([x, np.ones(len(x))]).T + with assert_raises_regex(LinAlgError, "Incompatible dimensions"): + linalg.lstsq(A, y, rcond=None) + @pytest.mark.parametrize('dt', [np.dtype(c) for c in '?bBhHiIqQefdgFDGO']) class TestMatrixPower(object): @@ -946,7 +955,6 @@ class TestMatrixPower(object): dtnoinv = [object, np.dtype('e'), np.dtype('g'), np.dtype('G')] def test_large_power(self, dt): - power = matrix_power rshft = self.rshft_1.astype(dt) assert_equal( matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0) @@ -1610,8 +1618,6 @@ class TestQR(object): def test_qr_empty(self, m, n): k = min(m, n) a = np.empty((m, n)) - a_type = type(a) - a_dtype = a.dtype self.check_qr(a) @@ -1915,3 +1921,44 @@ class TestMultiDot(object): def test_too_few_input_arrays(self): assert_raises(ValueError, multi_dot, []) assert_raises(ValueError, multi_dot, [np.random.random((3, 3))]) + + +class TestTensorinv(object): + + @pytest.mark.parametrize("arr, ind", [ + (np.ones((4, 6, 8, 2)), 2), + (np.ones((3, 3, 2)), 1), + ]) + def test_non_square_handling(self, arr, ind): + with assert_raises(LinAlgError): + linalg.tensorinv(arr, ind=ind) + + @pytest.mark.parametrize("shape, ind", [ + # examples from docstring + ((4, 6, 8, 3), 2), + ((24, 8, 3), 1), + ]) + def test_tensorinv_shape(self, shape, ind): + a = np.eye(24) + a.shape = shape + ainv = linalg.tensorinv(a=a, ind=ind) + expected = a.shape[ind:] + a.shape[:ind] + actual = ainv.shape + assert_equal(actual, expected) + + @pytest.mark.parametrize("ind", [ + 0, -2, + ]) + def test_tensorinv_ind_limit(self, ind): + a = np.eye(24) + a.shape = (4, 6, 8, 3) + with assert_raises(ValueError): + linalg.tensorinv(a=a, ind=ind) + + def test_tensorinv_result(self): + # mimic a docstring example + a = np.eye(24) + a.shape = (24, 8, 3) + ainv = linalg.tensorinv(a, ind=1) + b = np.ones(24) + assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) |