summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/linalg/linalg.py68
1 files changed, 24 insertions, 44 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 43caa3235..30f5fbeba 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -1096,17 +1096,18 @@ def eigh(a, UPLO='L'):
Parameters
----------
- a : (M, M) array_like
- A complex Hermitian or real symmetric matrix.
+ A : (..., M, M) array
+ Hermitian/Symmetric matrices whose eigenvalues and
+ eigenvectors are to be computed.
UPLO : {'L', 'U'}, optional
Specifies whether the calculation is done with the lower triangular
part of `a` ('L', default) or the upper triangular part ('U').
Returns
-------
- w : (M,) ndarray
+ w : (..., M) ndarray
The eigenvalues, not necessarily ordered.
- v : {(M, M) ndarray, (M, M) matrix}
+ v : {(..., M, M) ndarray, (..., M, M) matrix}
The column ``v[:, i]`` is the normalized eigenvector corresponding
to the eigenvalue ``w[i]``. Will return a matrix object if `a` is
a matrix object.
@@ -1124,9 +1125,11 @@ def eigh(a, UPLO='L'):
Notes
-----
- This is a simple interface to the LAPACK routines dsyevd and zheevd,
- which compute the eigenvalues and eigenvectors of real symmetric and
- complex Hermitian arrays, respectively.
+ Broadcasting rules apply, see the `numpy.linalg` documentation for
+ details.
+
+ The eigenvalues/eigenvectors are computed using LAPACK routines _ssyevd,
+ _heevd
The eigenvalues of real symmetric or complex Hermitian matrices are
always real. [1]_ The array `v` of (column) eigenvectors is unitary
@@ -1168,46 +1171,23 @@ def eigh(a, UPLO='L'):
"""
UPLO = asbytes(UPLO)
+
a, wrap = _makearray(a)
- _assertRank2(a)
- _assertSquareness(a)
+ _assertRankAtLeast2(a)
+ _assertNdSquareness(a)
t, result_t = _commonType(a)
- real_t = _linalgRealType(t)
- a = _fastCopyAndTranspose(t, a)
- a = _to_native_byte_order(a)
- n = a.shape[0]
- liwork = 5*n+3
- iwork = zeros((liwork,), fortran_int)
- if isComplexType(t):
- lapack_routine = lapack_lite.zheevd
- w = zeros((n,), real_t)
- lwork = 1
- work = zeros((lwork,), t)
- lrwork = 1
- rwork = zeros((lrwork,), real_t)
- results = lapack_routine(_V, UPLO, n, a, n, w, work, -1,
- rwork, -1, iwork, liwork, 0)
- lwork = int(abs(work[0]))
- work = zeros((lwork,), t)
- lrwork = int(rwork[0])
- rwork = zeros((lrwork,), real_t)
- results = lapack_routine(_V, UPLO, n, a, n, w, work, lwork,
- rwork, lrwork, iwork, liwork, 0)
+
+ extobj = get_linalg_error_extobj(
+ _raise_linalgerror_eigenvalues_nonconvergence)
+ if 'L' == UPLO:
+ gufunc = _umath_linalg.eigh_lo
else:
- lapack_routine = lapack_lite.dsyevd
- w = zeros((n,), t)
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(_V, UPLO, n, a, n, w, work, -1,
- iwork, liwork, 0)
- lwork = int(work[0])
- work = zeros((lwork,), t)
- results = lapack_routine(_V, UPLO, n, a, n, w, work, lwork,
- iwork, liwork, 0)
- if results['info'] > 0:
- raise LinAlgError('Eigenvalues did not converge')
- at = a.transpose().astype(result_t)
- return w.astype(_realType(result_t)), wrap(at)
+ gufunc = _umath_linalg.eigh_up
+
+ w, vt = gufunc(a.astype(t), extobj=extobj)
+ w = w.astype(_realType(result_t))
+ vt = vt.astype(result_t)
+ return w, wrap(vt)
# Singular value decomposition