diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-11-07 20:05:50 -0800 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-11-07 22:48:46 -0800 |
commit | a1af647b13b33963eca3f9504f1d4b2603c3e4a3 (patch) | |
tree | 39edeccf93bb690b3577ada04b63744464b22632 /numpy/linalg | |
parent | a2228e9f2cd09b4aea872594fc2ec738eba7a1a6 (diff) | |
download | numpy-a1af647b13b33963eca3f9504f1d4b2603c3e4a3.tar.gz |
MAINT: Remove similar branches from linalg.lstsq
Diffstat (limited to 'numpy/linalg')
-rw-r--r-- | numpy/linalg/linalg.py | 47 |
1 files changed, 27 insertions, 20 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 26254903d..041b34b91 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -2039,28 +2039,35 @@ def lstsq(a, b, rcond="warn"): 0, work, lwork, iwork, 0) if results['info'] > 0: raise LinAlgError('SVD did not converge in Linear Least Squares') - resids = array([], result_real_t) - if is_1d: - x = array(ravel(bstar)[:n], dtype=result_t, copy=True) - if results['rank'] == n and m > n: - if isComplexType(t): - resids = array([sum(abs(ravel(bstar)[n:])**2)], - dtype=result_real_t) - else: - resids = array([sum((ravel(bstar)[n:])**2)], - dtype=result_real_t) + + # undo transpose imposed by fortran-order arrays + b_out = bstar.T + + # b_out contains both the solution and the components of the residuals + x = b_out[:n,:] + r_parts = b_out[n:,:] + if isComplexType(t): + resids = sum(abs(r_parts)**2, axis=-2) else: - x = array(bstar.T[:n,:], dtype=result_t, copy=True) - if results['rank'] == n and m > n: - if isComplexType(t): - resids = sum(abs(bstar.T[n:,:])**2, axis=0).astype( - result_real_t, copy=False) - else: - resids = sum((bstar.T[n:,:])**2, axis=0).astype( - result_real_t, copy=False) + resids = sum(r_parts**2, axis=-2) - st = s[:min(n, m)].astype(result_real_t, copy=True) - return wrap(x), wrap(resids), results['rank'], st + rank = results['rank'] + + # remove the axis we added + if is_1d: + x = x.squeeze(axis=-1) + # we probably should squeeze resids too, but we can't + # without breaking compatibility. + + # as documented + if rank != n or m <= n: + resids = array([], result_real_t) + + # coerce output arrays + s = s.astype(result_real_t, copy=True) + resids = resids.astype(result_real_t, copy=False) # array is temporary + x = x.astype(result_t, copy=True) + return wrap(x), wrap(resids), rank, s def _multi_svd_norm(x, row_axis, col_axis, op): |