diff options
-rw-r--r-- | numpy/linalg/linalg.py | 68 |
1 files changed, 24 insertions, 44 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 43caa3235..30f5fbeba 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1096,17 +1096,18 @@ def eigh(a, UPLO='L'): Parameters ---------- - a : (M, M) array_like - A complex Hermitian or real symmetric matrix. + A : (..., M, M) array + Hermitian/Symmetric matrices whose eigenvalues and + eigenvectors are to be computed. UPLO : {'L', 'U'}, optional Specifies whether the calculation is done with the lower triangular part of `a` ('L', default) or the upper triangular part ('U'). Returns ------- - w : (M,) ndarray + w : (..., M) ndarray The eigenvalues, not necessarily ordered. - v : {(M, M) ndarray, (M, M) matrix} + v : {(..., M, M) ndarray, (..., M, M) matrix} The column ``v[:, i]`` is the normalized eigenvector corresponding to the eigenvalue ``w[i]``. Will return a matrix object if `a` is a matrix object. @@ -1124,9 +1125,11 @@ def eigh(a, UPLO='L'): Notes ----- - This is a simple interface to the LAPACK routines dsyevd and zheevd, - which compute the eigenvalues and eigenvectors of real symmetric and - complex Hermitian arrays, respectively. + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The eigenvalues/eigenvectors are computed using LAPACK routines _ssyevd, + _heevd The eigenvalues of real symmetric or complex Hermitian matrices are always real. [1]_ The array `v` of (column) eigenvectors is unitary @@ -1168,46 +1171,23 @@ def eigh(a, UPLO='L'): """ UPLO = asbytes(UPLO) + a, wrap = _makearray(a) - _assertRank2(a) - _assertSquareness(a) + _assertRankAtLeast2(a) + _assertNdSquareness(a) t, result_t = _commonType(a) - real_t = _linalgRealType(t) - a = _fastCopyAndTranspose(t, a) - a = _to_native_byte_order(a) - n = a.shape[0] - liwork = 5*n+3 - iwork = zeros((liwork,), fortran_int) - if isComplexType(t): - lapack_routine = lapack_lite.zheevd - w = zeros((n,), real_t) - lwork = 1 - work = zeros((lwork,), t) - lrwork = 1 - rwork = zeros((lrwork,), real_t) - results = lapack_routine(_V, UPLO, n, a, n, w, work, -1, - rwork, -1, iwork, liwork, 0) - lwork = int(abs(work[0])) - work = zeros((lwork,), t) - lrwork = int(rwork[0]) - rwork = zeros((lrwork,), real_t) - results = lapack_routine(_V, UPLO, n, a, n, w, work, lwork, - rwork, lrwork, iwork, liwork, 0) + + extobj = get_linalg_error_extobj( + _raise_linalgerror_eigenvalues_nonconvergence) + if 'L' == UPLO: + gufunc = _umath_linalg.eigh_lo else: - lapack_routine = lapack_lite.dsyevd - w = zeros((n,), t) - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(_V, UPLO, n, a, n, w, work, -1, - iwork, liwork, 0) - lwork = int(work[0]) - work = zeros((lwork,), t) - results = lapack_routine(_V, UPLO, n, a, n, w, work, lwork, - iwork, liwork, 0) - if results['info'] > 0: - raise LinAlgError('Eigenvalues did not converge') - at = a.transpose().astype(result_t) - return w.astype(_realType(result_t)), wrap(at) + gufunc = _umath_linalg.eigh_up + + w, vt = gufunc(a.astype(t), extobj=extobj) + w = w.astype(_realType(result_t)) + vt = vt.astype(result_t) + return w, wrap(vt) # Singular value decomposition |