diff options
Diffstat (limited to 'numpy/linalg')
-rw-r--r-- | numpy/linalg/linalg.py | 92 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 136 |
2 files changed, 149 insertions, 79 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index ccc437663..59923f3c5 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -28,6 +28,7 @@ from numpy.core import ( swapaxes, divide, count_nonzero, isnan ) from numpy.core.multiarray import normalize_axis_index +from numpy.core.overrides import array_function_dispatch from numpy.lib.twodim_base import triu, eye from numpy.linalg import lapack_lite, _umath_linalg @@ -198,11 +199,6 @@ def _assertRankAtLeast2(*arrays): raise LinAlgError('%d-dimensional array given. Array must be ' 'at least two-dimensional' % a.ndim) -def _assertSquareness(*arrays): - for a in arrays: - if max(a.shape) != min(a.shape): - raise LinAlgError('Array must be square') - def _assertNdSquareness(*arrays): for a in arrays: m, n = a.shape[-2:] @@ -242,6 +238,11 @@ def transpose(a): # Linear equations +def _tensorsolve_dispatcher(a, b, axes=None): + return (a, b) + + +@array_function_dispatch(_tensorsolve_dispatcher) def tensorsolve(a, b, axes=None): """ Solve the tensor equation ``a x = b`` for x. @@ -311,6 +312,12 @@ def tensorsolve(a, b, axes=None): res.shape = oldshape return res + +def _solve_dispatcher(a, b): + return (a, b) + + +@array_function_dispatch(_solve_dispatcher) def solve(a, b): """ Solve a linear matrix equation, or system of linear scalar equations. @@ -391,6 +398,11 @@ def solve(a, b): return wrap(r.astype(result_t, copy=False)) +def _tensorinv_dispatcher(a, ind=None): + return (a,) + + +@array_function_dispatch(_tensorinv_dispatcher) def tensorinv(a, ind=2): """ Compute the 'inverse' of an N-dimensional array. @@ -460,6 +472,11 @@ def tensorinv(a, ind=2): # Matrix inversion +def _unary_dispatcher(a): + return (a,) + + +@array_function_dispatch(_unary_dispatcher) def inv(a): """ Compute the (multiplicative) inverse of a matrix. @@ -528,6 +545,11 @@ def inv(a): return wrap(ainv.astype(result_t, copy=False)) +def _matrix_power_dispatcher(a, n): + return (a,) + + +@array_function_dispatch(_matrix_power_dispatcher) def matrix_power(a, n): """ Raise a square matrix to the (integer) power `n`. @@ -645,6 +667,8 @@ def matrix_power(a, n): # Cholesky decomposition + +@array_function_dispatch(_unary_dispatcher) def cholesky(a): """ Cholesky decomposition. @@ -728,8 +752,14 @@ def cholesky(a): r = gufunc(a, signature=signature, extobj=extobj) return wrap(r.astype(result_t, copy=False)) + # QR decompostion +def _qr_dispatcher(a, mode=None): + return (a,) + + +@array_function_dispatch(_qr_dispatcher) def qr(a, mode='reduced'): """ Compute the qr factorization of a matrix. @@ -945,6 +975,7 @@ def qr(a, mode='reduced'): # Eigenvalues +@array_function_dispatch(_unary_dispatcher) def eigvals(a): """ Compute the eigenvalues of a general matrix. @@ -1034,6 +1065,12 @@ def eigvals(a): return w.astype(result_t, copy=False) + +def _eigvalsh_dispatcher(a, UPLO=None): + return (a,) + + +@array_function_dispatch(_eigvalsh_dispatcher) def eigvalsh(a, UPLO='L'): """ Compute the eigenvalues of a complex Hermitian or real symmetric matrix. @@ -1135,6 +1172,7 @@ def _convertarray(a): # Eigenvectors +@array_function_dispatch(_unary_dispatcher) def eig(a): """ Compute the eigenvalues and right eigenvectors of a square array. @@ -1276,6 +1314,7 @@ def eig(a): return w.astype(result_t, copy=False), wrap(vt) +@array_function_dispatch(_eigvalsh_dispatcher) def eigh(a, UPLO='L'): """ Return the eigenvalues and eigenvectors of a complex Hermitian @@ -1415,6 +1454,11 @@ def eigh(a, UPLO='L'): # Singular value decomposition +def _svd_dispatcher(a, full_matrices=None, compute_uv=None): + return (a,) + + +@array_function_dispatch(_svd_dispatcher) def svd(a, full_matrices=True, compute_uv=True): """ Singular Value Decomposition. @@ -1575,6 +1619,11 @@ def svd(a, full_matrices=True, compute_uv=True): return s +def _cond_dispatcher(x, p=None): + return (x,) + + +@array_function_dispatch(_cond_dispatcher) def cond(x, p=None): """ Compute the condition number of a matrix. @@ -1692,6 +1741,11 @@ def cond(x, p=None): return r +def _matrix_rank_dispatcher(M, tol=None, hermitian=None): + return (M,) + + +@array_function_dispatch(_matrix_rank_dispatcher) def matrix_rank(M, tol=None, hermitian=False): """ Return matrix rank of array using SVD method @@ -1796,7 +1850,12 @@ def matrix_rank(M, tol=None, hermitian=False): # Generalized inverse -def pinv(a, rcond=1e-15 ): +def _pinv_dispatcher(a, rcond=None): + return (a,) + + +@array_function_dispatch(_pinv_dispatcher) +def pinv(a, rcond=1e-15): """ Compute the (Moore-Penrose) pseudo-inverse of a matrix. @@ -1880,8 +1939,11 @@ def pinv(a, rcond=1e-15 ): res = matmul(transpose(vt), multiply(s[..., newaxis], transpose(u))) return wrap(res) + # Determinant + +@array_function_dispatch(_unary_dispatcher) def slogdet(a): """ Compute the sign and (natural) logarithm of the determinant of an array. @@ -1967,6 +2029,8 @@ def slogdet(a): logdet = logdet.astype(real_t, copy=False) return sign, logdet + +@array_function_dispatch(_unary_dispatcher) def det(a): """ Compute the determinant of an array. @@ -2023,8 +2087,14 @@ def det(a): r = r.astype(result_t, copy=False) return r + # Linear Least Squares +def _lstsq_dispatcher(a, b, rcond=None): + return (a, b) + + +@array_function_dispatch(_lstsq_dispatcher) def lstsq(a, b, rcond="warn"): """ Return the least-squares solution to a linear matrix equation. @@ -2208,6 +2278,11 @@ def _multi_svd_norm(x, row_axis, col_axis, op): return result +def _norm_dispatcher(x, ord=None, axis=None, keepdims=None): + return (x,) + + +@array_function_dispatch(_norm_dispatcher) def norm(x, ord=None, axis=None, keepdims=False): """ Matrix or vector norm. @@ -2450,6 +2525,11 @@ def norm(x, ord=None, axis=None, keepdims=False): # multi_dot +def _multidot_dispatcher(arrays): + return arrays + + +@array_function_dispatch(_multidot_dispatcher) def multi_dot(arrays): """ Compute the dot product of two or more arrays in a single function call, diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 98a77d8f5..0e94c2633 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -19,7 +19,7 @@ from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError from numpy.linalg.linalg import _multi_dot_matrix_chain_order from numpy.testing import ( assert_, assert_equal, assert_raises, assert_array_equal, - assert_almost_equal, assert_allclose, SkipTest, suppress_warnings + assert_almost_equal, assert_allclose, suppress_warnings ) @@ -462,12 +462,10 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestSolve(SolveCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(linalg.solve(x, x).dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.solve(x, x).dtype, dtype) def test_0_size(self): class ArraySubclass(np.ndarray): @@ -531,12 +529,10 @@ class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestInv(InvCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(linalg.inv(x).dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.inv(x).dtype, dtype) def test_0_size(self): # Check that all kinds of 0-sized arrays work @@ -564,14 +560,12 @@ class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestEigvals(EigvalsCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(linalg.eigvals(x).dtype, dtype) - x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) - assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype)) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.eigvals(x).dtype, dtype) + x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) + assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype)) def test_0_size(self): # Check that all kinds of 0-sized arrays work @@ -603,20 +597,17 @@ class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestEig(EigCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - w, v = np.linalg.eig(x) - assert_equal(w.dtype, dtype) - assert_equal(v.dtype, dtype) - - x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) - w, v = np.linalg.eig(x) - assert_equal(w.dtype, get_complex_dtype(dtype)) - assert_equal(v.dtype, get_complex_dtype(dtype)) - - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w, v = np.linalg.eig(x) + assert_equal(w.dtype, dtype) + assert_equal(v.dtype, dtype) + + x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) + w, v = np.linalg.eig(x) + assert_equal(w.dtype, get_complex_dtype(dtype)) + assert_equal(v.dtype, get_complex_dtype(dtype)) def test_0_size(self): # Check that all kinds of 0-sized arrays work @@ -653,18 +644,15 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): class TestSVD(SVDCases): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - u, s, vh = linalg.svd(x) - assert_equal(u.dtype, dtype) - assert_equal(s.dtype, get_real_dtype(dtype)) - assert_equal(vh.dtype, dtype) - s = linalg.svd(x, compute_uv=False) - assert_equal(s.dtype, get_real_dtype(dtype)) - - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + u, s, vh = linalg.svd(x) + assert_equal(u.dtype, dtype) + assert_equal(s.dtype, get_real_dtype(dtype)) + assert_equal(vh.dtype, dtype) + s = linalg.svd(x, compute_uv=False) + assert_equal(s.dtype, get_real_dtype(dtype)) def test_empty_identity(self): """ Empty input should put an identity matrix in u or vh """ @@ -842,15 +830,13 @@ class TestDet(DetCases): assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble) assert_equal(type(linalg.slogdet([[0.0j]])[1]), double) - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - assert_equal(np.linalg.det(x).dtype, dtype) - ph, s = np.linalg.slogdet(x) - assert_equal(s.dtype, get_real_dtype(dtype)) - assert_equal(ph.dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(np.linalg.det(x).dtype, dtype) + ph, s = np.linalg.slogdet(x) + assert_equal(s.dtype, get_real_dtype(dtype)) + assert_equal(ph.dtype, dtype) def test_0_size(self): a = np.zeros((0, 0), dtype=np.complex64) @@ -1049,13 +1035,11 @@ class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase): class TestEigvalsh(object): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - w = np.linalg.eigvalsh(x) - assert_equal(w.dtype, get_real_dtype(dtype)) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w = np.linalg.eigvalsh(x) + assert_equal(w.dtype, get_real_dtype(dtype)) def test_invalid(self): x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) @@ -1127,14 +1111,12 @@ class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase): class TestEigh(object): - def test_types(self): - def check(dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - w, v = np.linalg.eigh(x) - assert_equal(w.dtype, get_real_dtype(dtype)) - assert_equal(v.dtype, dtype) - for dtype in [single, double, csingle, cdouble]: - check(dtype) + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w, v = np.linalg.eigh(x) + assert_equal(w.dtype, get_real_dtype(dtype)) + assert_equal(v.dtype, dtype) def test_invalid(self): x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32) @@ -1769,7 +1751,7 @@ def test_xerbla_override(): pid = os.fork() except (OSError, AttributeError): # fork failed, or not running on POSIX - raise SkipTest("Not POSIX or fork failed.") + pytest.skip("Not POSIX or fork failed.") if pid == 0: # child; close i/o file handles @@ -1804,7 +1786,7 @@ def test_xerbla_override(): # parent pid, status = os.wait() if os.WEXITSTATUS(status) != XERBLA_OK: - raise SkipTest('Numpy xerbla not linked in.') + pytest.skip('Numpy xerbla not linked in.') def test_sdot_bug_8577(): @@ -1853,6 +1835,14 @@ class TestMultiDot(object): assert_almost_equal(multi_dot([A, B, C]), A.dot(B).dot(C)) assert_almost_equal(multi_dot([A, B, C]), np.dot(A, np.dot(B, C))) + def test_basic_function_with_two_arguments(self): + # separate code path with two arguments + A = np.random.random((6, 2)) + B = np.random.random((2, 6)) + + assert_almost_equal(multi_dot([A, B]), A.dot(B)) + assert_almost_equal(multi_dot([A, B]), np.dot(A, B)) + def test_basic_function_with_dynamic_programing_optimization(self): # multi_dot with four or more arguments uses the dynamic programing # optimization and therefore deserve a separate |