diff options
author | Hameer Abbasi <einstein.edison@gmail.com> | 2020-01-29 13:43:31 +0100 |
---|---|---|
committer | Hameer Abbasi <einstein.edison@gmail.com> | 2020-02-05 11:47:44 +0100 |
commit | b202aad8d8e138d4c4cb8ccc590b87e1173f45bc (patch) | |
tree | 8b43e03b36425582abf08e0bdb18b569ab2ea33a /numpy/linalg/linalg.py | |
parent | 49102a5da9684330849531589c81d1c5005e3bb9 (diff) | |
download | numpy-b202aad8d8e138d4c4cb8ccc590b87e1173f45bc.tar.gz |
BUG: Fix for SVD not always sorted with hermitian=True
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 0964e807b..1d572f876 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -24,7 +24,7 @@ from numpy.core import ( add, multiply, sqrt, fastCopyAndTranspose, sum, isfinite, finfo, errstate, geterrobj, moveaxis, amin, amax, product, abs, atleast_2d, intp, asanyarray, object_, matmul, - swapaxes, divide, count_nonzero, isnan, sign + swapaxes, divide, count_nonzero, isnan, sign, argsort, sort ) from numpy.core.multiarray import normalize_axis_index from numpy.core.overrides import set_module @@ -1608,24 +1608,29 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): True """ + import numpy as _nx a, wrap = _makearray(a) if hermitian: - # note: lapack returns eigenvalues in reverse order to our contract. - # reversing is cheap by design in numpy, so we do so to be consistent + # note: lapack svd returns eigenvalues with s ** 2 sorted descending, + # but eig returns s sorted ascending, so we re-order the eigenvalues + # and related arrays to have the correct order if compute_uv: s, u = eigh(a) - s = s[..., ::-1] - u = u[..., ::-1] - # singular values are unsigned, move the sign into v - vt = transpose(u * sign(s)[..., None, :]).conjugate() + sgn = sign(s) s = abs(s) + sidx = argsort(s)[..., ::-1] + sgn = _nx.take_along_axis(sgn, sidx, axis=-1) + s = _nx.take_along_axis(s, sidx, axis=-1) + u = _nx.take_along_axis(u, sidx[..., None, :], axis=-1) + # singular values are unsigned, move the sign into v + vt = transpose(u * sgn[..., None, :]).conjugate() return wrap(u), s, wrap(vt) else: s = eigvalsh(a) s = s[..., ::-1] s = abs(s) - return s + return sort(s)[..., ::-1] _assert_stacked_2d(a) t, result_t = _commonType(a) |