diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 92 |
1 files changed, 86 insertions, 6 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index ccc437663..59923f3c5 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -28,6 +28,7 @@ from numpy.core import ( swapaxes, divide, count_nonzero, isnan ) from numpy.core.multiarray import normalize_axis_index +from numpy.core.overrides import array_function_dispatch from numpy.lib.twodim_base import triu, eye from numpy.linalg import lapack_lite, _umath_linalg @@ -198,11 +199,6 @@ def _assertRankAtLeast2(*arrays): raise LinAlgError('%d-dimensional array given. Array must be ' 'at least two-dimensional' % a.ndim) -def _assertSquareness(*arrays): - for a in arrays: - if max(a.shape) != min(a.shape): - raise LinAlgError('Array must be square') - def _assertNdSquareness(*arrays): for a in arrays: m, n = a.shape[-2:] @@ -242,6 +238,11 @@ def transpose(a): # Linear equations +def _tensorsolve_dispatcher(a, b, axes=None): + return (a, b) + + +@array_function_dispatch(_tensorsolve_dispatcher) def tensorsolve(a, b, axes=None): """ Solve the tensor equation ``a x = b`` for x. @@ -311,6 +312,12 @@ def tensorsolve(a, b, axes=None): res.shape = oldshape return res + +def _solve_dispatcher(a, b): + return (a, b) + + +@array_function_dispatch(_solve_dispatcher) def solve(a, b): """ Solve a linear matrix equation, or system of linear scalar equations. @@ -391,6 +398,11 @@ def solve(a, b): return wrap(r.astype(result_t, copy=False)) +def _tensorinv_dispatcher(a, ind=None): + return (a,) + + +@array_function_dispatch(_tensorinv_dispatcher) def tensorinv(a, ind=2): """ Compute the 'inverse' of an N-dimensional array. @@ -460,6 +472,11 @@ def tensorinv(a, ind=2): # Matrix inversion +def _unary_dispatcher(a): + return (a,) + + +@array_function_dispatch(_unary_dispatcher) def inv(a): """ Compute the (multiplicative) inverse of a matrix. @@ -528,6 +545,11 @@ def inv(a): return wrap(ainv.astype(result_t, copy=False)) +def _matrix_power_dispatcher(a, n): + return (a,) + + +@array_function_dispatch(_matrix_power_dispatcher) def matrix_power(a, n): """ Raise a square matrix to the (integer) power `n`. @@ -645,6 +667,8 @@ def matrix_power(a, n): # Cholesky decomposition + +@array_function_dispatch(_unary_dispatcher) def cholesky(a): """ Cholesky decomposition. @@ -728,8 +752,14 @@ def cholesky(a): r = gufunc(a, signature=signature, extobj=extobj) return wrap(r.astype(result_t, copy=False)) + # QR decompostion +def _qr_dispatcher(a, mode=None): + return (a,) + + +@array_function_dispatch(_qr_dispatcher) def qr(a, mode='reduced'): """ Compute the qr factorization of a matrix. @@ -945,6 +975,7 @@ def qr(a, mode='reduced'): # Eigenvalues +@array_function_dispatch(_unary_dispatcher) def eigvals(a): """ Compute the eigenvalues of a general matrix. @@ -1034,6 +1065,12 @@ def eigvals(a): return w.astype(result_t, copy=False) + +def _eigvalsh_dispatcher(a, UPLO=None): + return (a,) + + +@array_function_dispatch(_eigvalsh_dispatcher) def eigvalsh(a, UPLO='L'): """ Compute the eigenvalues of a complex Hermitian or real symmetric matrix. @@ -1135,6 +1172,7 @@ def _convertarray(a): # Eigenvectors +@array_function_dispatch(_unary_dispatcher) def eig(a): """ Compute the eigenvalues and right eigenvectors of a square array. @@ -1276,6 +1314,7 @@ def eig(a): return w.astype(result_t, copy=False), wrap(vt) +@array_function_dispatch(_eigvalsh_dispatcher) def eigh(a, UPLO='L'): """ Return the eigenvalues and eigenvectors of a complex Hermitian @@ -1415,6 +1454,11 @@ def eigh(a, UPLO='L'): # Singular value decomposition +def _svd_dispatcher(a, full_matrices=None, compute_uv=None): + return (a,) + + +@array_function_dispatch(_svd_dispatcher) def svd(a, full_matrices=True, compute_uv=True): """ Singular Value Decomposition. @@ -1575,6 +1619,11 @@ def svd(a, full_matrices=True, compute_uv=True): return s +def _cond_dispatcher(x, p=None): + return (x,) + + +@array_function_dispatch(_cond_dispatcher) def cond(x, p=None): """ Compute the condition number of a matrix. @@ -1692,6 +1741,11 @@ def cond(x, p=None): return r +def _matrix_rank_dispatcher(M, tol=None, hermitian=None): + return (M,) + + +@array_function_dispatch(_matrix_rank_dispatcher) def matrix_rank(M, tol=None, hermitian=False): """ Return matrix rank of array using SVD method @@ -1796,7 +1850,12 @@ def matrix_rank(M, tol=None, hermitian=False): # Generalized inverse -def pinv(a, rcond=1e-15 ): +def _pinv_dispatcher(a, rcond=None): + return (a,) + + +@array_function_dispatch(_pinv_dispatcher) +def pinv(a, rcond=1e-15): """ Compute the (Moore-Penrose) pseudo-inverse of a matrix. @@ -1880,8 +1939,11 @@ def pinv(a, rcond=1e-15 ): res = matmul(transpose(vt), multiply(s[..., newaxis], transpose(u))) return wrap(res) + # Determinant + +@array_function_dispatch(_unary_dispatcher) def slogdet(a): """ Compute the sign and (natural) logarithm of the determinant of an array. @@ -1967,6 +2029,8 @@ def slogdet(a): logdet = logdet.astype(real_t, copy=False) return sign, logdet + +@array_function_dispatch(_unary_dispatcher) def det(a): """ Compute the determinant of an array. @@ -2023,8 +2087,14 @@ def det(a): r = r.astype(result_t, copy=False) return r + # Linear Least Squares +def _lstsq_dispatcher(a, b, rcond=None): + return (a, b) + + +@array_function_dispatch(_lstsq_dispatcher) def lstsq(a, b, rcond="warn"): """ Return the least-squares solution to a linear matrix equation. @@ -2208,6 +2278,11 @@ def _multi_svd_norm(x, row_axis, col_axis, op): return result +def _norm_dispatcher(x, ord=None, axis=None, keepdims=None): + return (x,) + + +@array_function_dispatch(_norm_dispatcher) def norm(x, ord=None, axis=None, keepdims=False): """ Matrix or vector norm. @@ -2450,6 +2525,11 @@ def norm(x, ord=None, axis=None, keepdims=False): # multi_dot +def _multidot_dispatcher(arrays): + return arrays + + +@array_function_dispatch(_multidot_dispatcher) def multi_dot(arrays): """ Compute the dot product of two or more arrays in a single function call, |