diff options
author | Pauli Virtanen <pav@iki.fi> | 2013-04-09 22:02:54 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2013-04-10 22:47:45 +0300 |
commit | bbdca51cac02a9f9352f671229037afa139ac7b5 (patch) | |
tree | ac5a3627916fce8cafe624fc8f9a3c15c6170e72 | |
parent | cc7b048fafeb93150f118651438faf06aefe807b (diff) | |
download | numpy-bbdca51cac02a9f9352f671229037afa139ac7b5.tar.gz |
ENH: linalg: use _umath_linalg for svd()
-rw-r--r-- | numpy/linalg/linalg.py | 98 |
1 files changed, 39 insertions, 59 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 30f5fbeba..370476fab 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1201,7 +1201,7 @@ def svd(a, full_matrices=1, compute_uv=1): Parameters ---------- - a : array_like + a : (..., M, N) array_like A real or complex matrix of shape (`M`, `N`) . full_matrices : bool, optional If True (default), `u` and `v` have the shapes (`M`, `M`) and @@ -1213,15 +1213,14 @@ def svd(a, full_matrices=1, compute_uv=1): Returns ------- - u : ndarray - Unitary matrix. The shape of `u` is (`M`, `M`) or (`M`, `K`) - depending on value of ``full_matrices``. - s : ndarray - The singular values, sorted so that ``s[i] >= s[i+1]``. `s` is - a 1-d array of length min(`M`, `N`). - v : ndarray - Unitary matrix of shape (`N`, `N`) or (`K`, `N`), depending on - ``full_matrices``. + u : { (..., M, M), (..., M, K) } array + Unitary matrices. The actual shape depends on the value of + ``full_matrices``. Only returned when ``compute_uv`` is True. + s : (..., K) array + The singular values for every matrix, sorted in descending order. + v : { (..., N, N), (..., K, N) } array + Unitary matrices. The actual shape depends on the value of + ``full_matrices``. Only returned when ``compute_uv`` is True. Raises ------ @@ -1230,6 +1229,11 @@ def svd(a, full_matrices=1, compute_uv=1): Notes ----- + Broadcasting rules apply, see the `numpy.linalg` documentation for + details. + + The decomposition is performed using LAPACK routine _gesdd + The SVD is commonly written as ``a = U S V.H``. The `v` returned by this function is ``V.H`` and ``u = U``. @@ -1269,63 +1273,39 @@ def svd(a, full_matrices=1, compute_uv=1): """ a, wrap = _makearray(a) - _assertRank2(a) _assertNonEmpty(a) - m, n = a.shape + _assertRankAtLeast2(a) t, result_t = _commonType(a) - real_t = _linalgRealType(t) - a = _fastCopyAndTranspose(t, a) - a = _to_native_byte_order(a) - s = zeros((min(n, m),), real_t) + + extobj = get_linalg_error_extobj(_raise_linalgerror_svd_nonconvergence) + + m = a.shape[-2] + n = a.shape[-1] if compute_uv: if full_matrices: - nu = m - nvt = n - option = _A + if m < n: + gufunc = _umath_linalg.svd_m_f + else: + gufunc = _umath_linalg.svd_n_f else: - nu = min(n, m) - nvt = min(n, m) - option = _S - u = zeros((nu, m), t) - vt = zeros((n, nvt), t) - else: - option = _N - nu = 1 - nvt = 1 - u = empty((1, 1), t) - vt = empty((1, 1), t) + if m < n: + gufunc = _umath_linalg.svd_m_s + else: + gufunc = _umath_linalg.svd_n_s - iwork = zeros((8*min(m, n),), fortran_int) - if isComplexType(t): - lapack_routine = lapack_lite.zgesdd - lrwork = min(m,n)*max(5*min(m,n)+7, 2*max(m,n)+2*min(m,n)+1) - rwork = zeros((lrwork,), real_t) - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt, - work, -1, rwork, iwork, 0) - lwork = int(abs(work[0])) - work = zeros((lwork,), t) - results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt, - work, lwork, rwork, iwork, 0) - else: - lapack_routine = lapack_lite.dgesdd - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt, - work, -1, iwork, 0) - lwork = int(work[0]) - work = zeros((lwork,), t) - results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt, - work, lwork, iwork, 0) - if results['info'] > 0: - raise LinAlgError('SVD did not converge') - s = s.astype(_realType(result_t)) - if compute_uv: - u = u.transpose().astype(result_t) - vt = vt.transpose().astype(result_t) + u, s, vt = gufunc(a.astype(t), extobj=extobj) + u = u.astype(result_t) + s = s.astype(_realType(result_t)) + vt = vt.astype(result_t) return wrap(u), s, wrap(vt) else: + if m < n: + gufunc = _umath_linalg.svd_m + else: + gufunc = _umath_linalg.svd_n + + s = gufunc(a.astype(t), extobj=extobj) + s = s.astype(_realType(result_t)) return s def cond(x, p=None): |