summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r--numpy/linalg/tests/test_linalg.py57
1 files changed, 52 insertions, 5 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 0e94c2633..235488c6e 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -13,13 +13,14 @@ import pytest
import numpy as np
from numpy import array, single, double, csingle, cdouble, dot, identity, matmul
-from numpy import multiply, atleast_2d, inf, asarray, matrix
+from numpy import multiply, atleast_2d, inf, asarray
from numpy import linalg
from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError
from numpy.linalg.linalg import _multi_dot_matrix_chain_order
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_array_equal,
- assert_almost_equal, assert_allclose, suppress_warnings
+ assert_almost_equal, assert_allclose, suppress_warnings,
+ assert_raises_regex,
)
@@ -931,6 +932,14 @@ class TestLstsq(LstsqCases):
assert_equal(rank, min(m, n))
assert_equal(s.shape, (min(m, n),))
+ def test_incompatible_dims(self):
+ # use modified version of docstring example
+ x = np.array([0, 1, 2, 3])
+ y = np.array([-1, 0.2, 0.9, 2.1, 3.3])
+ A = np.vstack([x, np.ones(len(x))]).T
+ with assert_raises_regex(LinAlgError, "Incompatible dimensions"):
+ linalg.lstsq(A, y, rcond=None)
+
@pytest.mark.parametrize('dt', [np.dtype(c) for c in '?bBhHiIqQefdgFDGO'])
class TestMatrixPower(object):
@@ -946,7 +955,6 @@ class TestMatrixPower(object):
dtnoinv = [object, np.dtype('e'), np.dtype('g'), np.dtype('G')]
def test_large_power(self, dt):
- power = matrix_power
rshft = self.rshft_1.astype(dt)
assert_equal(
matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0)
@@ -1610,8 +1618,6 @@ class TestQR(object):
def test_qr_empty(self, m, n):
k = min(m, n)
a = np.empty((m, n))
- a_type = type(a)
- a_dtype = a.dtype
self.check_qr(a)
@@ -1915,3 +1921,44 @@ class TestMultiDot(object):
def test_too_few_input_arrays(self):
assert_raises(ValueError, multi_dot, [])
assert_raises(ValueError, multi_dot, [np.random.random((3, 3))])
+
+
+class TestTensorinv(object):
+
+ @pytest.mark.parametrize("arr, ind", [
+ (np.ones((4, 6, 8, 2)), 2),
+ (np.ones((3, 3, 2)), 1),
+ ])
+ def test_non_square_handling(self, arr, ind):
+ with assert_raises(LinAlgError):
+ linalg.tensorinv(arr, ind=ind)
+
+ @pytest.mark.parametrize("shape, ind", [
+ # examples from docstring
+ ((4, 6, 8, 3), 2),
+ ((24, 8, 3), 1),
+ ])
+ def test_tensorinv_shape(self, shape, ind):
+ a = np.eye(24)
+ a.shape = shape
+ ainv = linalg.tensorinv(a=a, ind=ind)
+ expected = a.shape[ind:] + a.shape[:ind]
+ actual = ainv.shape
+ assert_equal(actual, expected)
+
+ @pytest.mark.parametrize("ind", [
+ 0, -2,
+ ])
+ def test_tensorinv_ind_limit(self, ind):
+ a = np.eye(24)
+ a.shape = (4, 6, 8, 3)
+ with assert_raises(ValueError):
+ linalg.tensorinv(a=a, ind=ind)
+
+ def test_tensorinv_result(self):
+ # mimic a docstring example
+ a = np.eye(24)
+ a.shape = (24, 8, 3)
+ ainv = linalg.tensorinv(a, ind=1)
+ b = np.ones(24)
+ assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))