summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 0964e807b..1d572f876 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -24,7 +24,7 @@ from numpy.core import (
add, multiply, sqrt, fastCopyAndTranspose, sum, isfinite,
finfo, errstate, geterrobj, moveaxis, amin, amax, product, abs,
atleast_2d, intp, asanyarray, object_, matmul,
- swapaxes, divide, count_nonzero, isnan, sign
+ swapaxes, divide, count_nonzero, isnan, sign, argsort, sort
)
from numpy.core.multiarray import normalize_axis_index
from numpy.core.overrides import set_module
@@ -1608,24 +1608,29 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
True
"""
+ import numpy as _nx
a, wrap = _makearray(a)
if hermitian:
- # note: lapack returns eigenvalues in reverse order to our contract.
- # reversing is cheap by design in numpy, so we do so to be consistent
+ # note: lapack svd returns eigenvalues with s ** 2 sorted descending,
+ # but eig returns s sorted ascending, so we re-order the eigenvalues
+ # and related arrays to have the correct order
if compute_uv:
s, u = eigh(a)
- s = s[..., ::-1]
- u = u[..., ::-1]
- # singular values are unsigned, move the sign into v
- vt = transpose(u * sign(s)[..., None, :]).conjugate()
+ sgn = sign(s)
s = abs(s)
+ sidx = argsort(s)[..., ::-1]
+ sgn = _nx.take_along_axis(sgn, sidx, axis=-1)
+ s = _nx.take_along_axis(s, sidx, axis=-1)
+ u = _nx.take_along_axis(u, sidx[..., None, :], axis=-1)
+ # singular values are unsigned, move the sign into v
+ vt = transpose(u * sgn[..., None, :]).conjugate()
return wrap(u), s, wrap(vt)
else:
s = eigvalsh(a)
s = s[..., ::-1]
s = abs(s)
- return s
+ return sort(s)[..., ::-1]
_assert_stacked_2d(a)
t, result_t = _commonType(a)