summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2013-04-09 22:02:45 +0300
committerPauli Virtanen <pav@iki.fi>2013-04-10 22:47:45 +0300
commit1253d571a48f3100865e0b38895c0cceb8336aa3 (patch)
treef834fbe8e6cb57bd606da6ef7ebf8bcbd75bac08
parent74e14775a1447942e9fede448958810d479c836b (diff)
downloadnumpy-1253d571a48f3100865e0b38895c0cceb8336aa3.tar.gz
ENH: linalg: use _umath_linalg for eig()
-rw-r--r--numpy/linalg/linalg.py87
1 files changed, 29 insertions, 58 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index d36df8157..43caa3235 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -959,17 +959,20 @@ def eig(a):
Parameters
----------
- a : (M, M) array_like
- A square array of real or complex elements.
+ a : (..., M, M) array
+ Matrices for which the eigenvalues and right eigenvectors will
+ be computed
Returns
-------
- w : (M,) ndarray
+ w : (..., M) array
The eigenvalues, each repeated according to its multiplicity.
- The eigenvalues are not necessarily ordered, nor are they
- necessarily real for real arrays (though for real arrays
- complex-valued eigenvalues should occur in conjugate pairs).
- v : (M, M) ndarray
+ The eigenvalues are not necessarily ordered. The resulting
+ array will be always be of complex type. When `a` is real
+ the resulting eigenvalues will be real (0 imaginary part) or
+ occur in conjugate pairs
+
+ v : (..., M, M) array
The normalized (unit "length") eigenvectors, such that the
column ``v[:,i]`` is the eigenvector corresponding to the
eigenvalue ``w[i]``.
@@ -988,9 +991,11 @@ def eig(a):
Notes
-----
- This is a simple interface to the LAPACK routines dgeev and zgeev
- which compute the eigenvalues and eigenvectors of, respectively,
- general real- and complex-valued square arrays.
+ Broadcasting rules apply, see the `numpy.linalg` documentation for
+ details.
+
+ This is implemented using the _geev LAPACK routines which compute
+ the eigenvalues and eigenvectors of general square arrays.
The number `w` is an eigenvalue of `a` if there exists a vector
`v` such that ``dot(a,v) = w * v``. Thus, the arrays `a`, `w`, and
@@ -1061,57 +1066,23 @@ def eig(a):
"""
a, wrap = _makearray(a)
- _assertRank2(a)
- _assertSquareness(a)
+ _assertRankAtLeast2(a)
+ _assertNdSquareness(a)
_assertFinite(a)
- a, t, result_t = _convertarray(a) # convert to double or cdouble type
- a = _to_native_byte_order(a)
- real_t = _linalgRealType(t)
- n = a.shape[0]
- dummy = zeros((1,), t)
- if isComplexType(t):
- # Complex routines take different arguments
- lapack_routine = lapack_lite.zgeev
- w = zeros((n,), t)
- v = zeros((n, n), t)
- lwork = 1
- work = zeros((lwork,), t)
- rwork = zeros((2*n,), real_t)
- results = lapack_routine(_N, _V, n, a, n, w,
- dummy, 1, v, n, work, -1, rwork, 0)
- lwork = int(abs(work[0]))
- work = zeros((lwork,), t)
- results = lapack_routine(_N, _V, n, a, n, w,
- dummy, 1, v, n, work, lwork, rwork, 0)
+ t, result_t = _commonType(a)
+
+ extobj = get_linalg_error_extobj(
+ _raise_linalgerror_eigenvalues_nonconvergence)
+ w, vt = _umath_linalg.eig(a.astype(t), extobj=extobj)
+
+ if not isComplexType(t) and all(w.imag == 0.0):
+ w = w.real
+ vt = vt.real
+ result_t = _realType(result_t)
else:
- lapack_routine = lapack_lite.dgeev
- wr = zeros((n,), t)
- wi = zeros((n,), t)
- vr = zeros((n, n), t)
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(_N, _V, n, a, n, wr, wi,
- dummy, 1, vr, n, work, -1, 0)
- lwork = int(work[0])
- work = zeros((lwork,), t)
- results = lapack_routine(_N, _V, n, a, n, wr, wi,
- dummy, 1, vr, n, work, lwork, 0)
- if all(wi == 0.0):
- w = wr
- v = vr
- result_t = _realType(result_t)
- else:
- w = wr+1j*wi
- v = array(vr, w.dtype)
- ind = flatnonzero(wi != 0.0) # indices of complex e-vals
- for i in range(len(ind)//2):
- v[ind[2*i]] = vr[ind[2*i]] + 1j*vr[ind[2*i+1]]
- v[ind[2*i+1]] = vr[ind[2*i]] - 1j*vr[ind[2*i+1]]
- result_t = _complexType(result_t)
+ result_t = _complexType(result_t)
- if results['info'] > 0:
- raise LinAlgError('Eigenvalues did not converge')
- vt = v.transpose().astype(result_t)
+ vt = vt.astype(result_t)
return w.astype(result_t), wrap(vt)