summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2013-04-09 22:02:54 +0300
committerPauli Virtanen <pav@iki.fi>2013-04-10 22:47:45 +0300
commitbbdca51cac02a9f9352f671229037afa139ac7b5 (patch)
treeac5a3627916fce8cafe624fc8f9a3c15c6170e72
parentcc7b048fafeb93150f118651438faf06aefe807b (diff)
downloadnumpy-bbdca51cac02a9f9352f671229037afa139ac7b5.tar.gz
ENH: linalg: use _umath_linalg for svd()
-rw-r--r--numpy/linalg/linalg.py98
1 files changed, 39 insertions, 59 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 30f5fbeba..370476fab 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -1201,7 +1201,7 @@ def svd(a, full_matrices=1, compute_uv=1):
Parameters
----------
- a : array_like
+ a : (..., M, N) array_like
A real or complex matrix of shape (`M`, `N`) .
full_matrices : bool, optional
If True (default), `u` and `v` have the shapes (`M`, `M`) and
@@ -1213,15 +1213,14 @@ def svd(a, full_matrices=1, compute_uv=1):
Returns
-------
- u : ndarray
- Unitary matrix. The shape of `u` is (`M`, `M`) or (`M`, `K`)
- depending on value of ``full_matrices``.
- s : ndarray
- The singular values, sorted so that ``s[i] >= s[i+1]``. `s` is
- a 1-d array of length min(`M`, `N`).
- v : ndarray
- Unitary matrix of shape (`N`, `N`) or (`K`, `N`), depending on
- ``full_matrices``.
+ u : { (..., M, M), (..., M, K) } array
+ Unitary matrices. The actual shape depends on the value of
+ ``full_matrices``. Only returned when ``compute_uv`` is True.
+ s : (..., K) array
+ The singular values for every matrix, sorted in descending order.
+ v : { (..., N, N), (..., K, N) } array
+ Unitary matrices. The actual shape depends on the value of
+ ``full_matrices``. Only returned when ``compute_uv`` is True.
Raises
------
@@ -1230,6 +1229,11 @@ def svd(a, full_matrices=1, compute_uv=1):
Notes
-----
+ Broadcasting rules apply, see the `numpy.linalg` documentation for
+ details.
+
+ The decomposition is performed using LAPACK routine _gesdd
+
The SVD is commonly written as ``a = U S V.H``. The `v` returned
by this function is ``V.H`` and ``u = U``.
@@ -1269,63 +1273,39 @@ def svd(a, full_matrices=1, compute_uv=1):
"""
a, wrap = _makearray(a)
- _assertRank2(a)
_assertNonEmpty(a)
- m, n = a.shape
+ _assertRankAtLeast2(a)
t, result_t = _commonType(a)
- real_t = _linalgRealType(t)
- a = _fastCopyAndTranspose(t, a)
- a = _to_native_byte_order(a)
- s = zeros((min(n, m),), real_t)
+
+ extobj = get_linalg_error_extobj(_raise_linalgerror_svd_nonconvergence)
+
+ m = a.shape[-2]
+ n = a.shape[-1]
if compute_uv:
if full_matrices:
- nu = m
- nvt = n
- option = _A
+ if m < n:
+ gufunc = _umath_linalg.svd_m_f
+ else:
+ gufunc = _umath_linalg.svd_n_f
else:
- nu = min(n, m)
- nvt = min(n, m)
- option = _S
- u = zeros((nu, m), t)
- vt = zeros((n, nvt), t)
- else:
- option = _N
- nu = 1
- nvt = 1
- u = empty((1, 1), t)
- vt = empty((1, 1), t)
+ if m < n:
+ gufunc = _umath_linalg.svd_m_s
+ else:
+ gufunc = _umath_linalg.svd_n_s
- iwork = zeros((8*min(m, n),), fortran_int)
- if isComplexType(t):
- lapack_routine = lapack_lite.zgesdd
- lrwork = min(m,n)*max(5*min(m,n)+7, 2*max(m,n)+2*min(m,n)+1)
- rwork = zeros((lrwork,), real_t)
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt,
- work, -1, rwork, iwork, 0)
- lwork = int(abs(work[0]))
- work = zeros((lwork,), t)
- results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt,
- work, lwork, rwork, iwork, 0)
- else:
- lapack_routine = lapack_lite.dgesdd
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt,
- work, -1, iwork, 0)
- lwork = int(work[0])
- work = zeros((lwork,), t)
- results = lapack_routine(option, m, n, a, m, s, u, m, vt, nvt,
- work, lwork, iwork, 0)
- if results['info'] > 0:
- raise LinAlgError('SVD did not converge')
- s = s.astype(_realType(result_t))
- if compute_uv:
- u = u.transpose().astype(result_t)
- vt = vt.transpose().astype(result_t)
+ u, s, vt = gufunc(a.astype(t), extobj=extobj)
+ u = u.astype(result_t)
+ s = s.astype(_realType(result_t))
+ vt = vt.astype(result_t)
return wrap(u), s, wrap(vt)
else:
+ if m < n:
+ gufunc = _umath_linalg.svd_m
+ else:
+ gufunc = _umath_linalg.svd_n
+
+ s = gufunc(a.astype(t), extobj=extobj)
+ s = s.astype(_realType(result_t))
return s
def cond(x, p=None):