diff options
author | Pauli Virtanen <pav@iki.fi> | 2013-04-09 22:02:45 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2013-04-10 22:47:45 +0300 |
commit | 1253d571a48f3100865e0b38895c0cceb8336aa3 (patch) | |
tree | f834fbe8e6cb57bd606da6ef7ebf8bcbd75bac08 | |
parent | 74e14775a1447942e9fede448958810d479c836b (diff) | |
download | numpy-1253d571a48f3100865e0b38895c0cceb8336aa3.tar.gz |
ENH: linalg: use _umath_linalg for eig()
-rw-r--r-- | numpy/linalg/linalg.py | 87 |
1 files changed, 29 insertions, 58 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index d36df8157..43caa3235 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -959,17 +959,20 @@ def eig(a): Parameters ---------- - a : (M, M) array_like - A square array of real or complex elements. + a : (..., M, M) array + Matrices for which the eigenvalues and right eigenvectors will + be computed Returns ------- - w : (M,) ndarray + w : (..., M) array The eigenvalues, each repeated according to its multiplicity. - The eigenvalues are not necessarily ordered, nor are they - necessarily real for real arrays (though for real arrays - complex-valued eigenvalues should occur in conjugate pairs). - v : (M, M) ndarray + The eigenvalues are not necessarily ordered. The resulting + array will be always be of complex type. When `a` is real + the resulting eigenvalues will be real (0 imaginary part) or + occur in conjugate pairs + + v : (..., M, M) array The normalized (unit "length") eigenvectors, such that the column ``v[:,i]`` is the eigenvector corresponding to the eigenvalue ``w[i]``. @@ -988,9 +991,11 @@ def eig(a): Notes ----- - This is a simple interface to the LAPACK routines dgeev and zgeev - which compute the eigenvalues and eigenvectors of, respectively, - general real- and complex-valued square arrays. + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + This is implemented using the _geev LAPACK routines which compute + the eigenvalues and eigenvectors of general square arrays. The number `w` is an eigenvalue of `a` if there exists a vector `v` such that ``dot(a,v) = w * v``. Thus, the arrays `a`, `w`, and @@ -1061,57 +1066,23 @@ def eig(a): """ a, wrap = _makearray(a) - _assertRank2(a) - _assertSquareness(a) + _assertRankAtLeast2(a) + _assertNdSquareness(a) _assertFinite(a) - a, t, result_t = _convertarray(a) # convert to double or cdouble type - a = _to_native_byte_order(a) - real_t = _linalgRealType(t) - n = a.shape[0] - dummy = zeros((1,), t) - if isComplexType(t): - # Complex routines take different arguments - lapack_routine = lapack_lite.zgeev - w = zeros((n,), t) - v = zeros((n, n), t) - lwork = 1 - work = zeros((lwork,), t) - rwork = zeros((2*n,), real_t) - results = lapack_routine(_N, _V, n, a, n, w, - dummy, 1, v, n, work, -1, rwork, 0) - lwork = int(abs(work[0])) - work = zeros((lwork,), t) - results = lapack_routine(_N, _V, n, a, n, w, - dummy, 1, v, n, work, lwork, rwork, 0) + t, result_t = _commonType(a) + + extobj = get_linalg_error_extobj( + _raise_linalgerror_eigenvalues_nonconvergence) + w, vt = _umath_linalg.eig(a.astype(t), extobj=extobj) + + if not isComplexType(t) and all(w.imag == 0.0): + w = w.real + vt = vt.real + result_t = _realType(result_t) else: - lapack_routine = lapack_lite.dgeev - wr = zeros((n,), t) - wi = zeros((n,), t) - vr = zeros((n, n), t) - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(_N, _V, n, a, n, wr, wi, - dummy, 1, vr, n, work, -1, 0) - lwork = int(work[0]) - work = zeros((lwork,), t) - results = lapack_routine(_N, _V, n, a, n, wr, wi, - dummy, 1, vr, n, work, lwork, 0) - if all(wi == 0.0): - w = wr - v = vr - result_t = _realType(result_t) - else: - w = wr+1j*wi - v = array(vr, w.dtype) - ind = flatnonzero(wi != 0.0) # indices of complex e-vals - for i in range(len(ind)//2): - v[ind[2*i]] = vr[ind[2*i]] + 1j*vr[ind[2*i+1]] - v[ind[2*i+1]] = vr[ind[2*i]] - 1j*vr[ind[2*i+1]] - result_t = _complexType(result_t) + result_t = _complexType(result_t) - if results['info'] > 0: - raise LinAlgError('Eigenvalues did not converge') - vt = v.transpose().astype(result_t) + vt = vt.astype(result_t) return w.astype(result_t), wrap(vt) |