summaryrefslogtreecommitdiff
path: root/numpy/linalg
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-11-07 20:05:50 -0800
committerEric Wieser <wieser.eric@gmail.com>2017-11-07 22:48:46 -0800
commita1af647b13b33963eca3f9504f1d4b2603c3e4a3 (patch)
tree39edeccf93bb690b3577ada04b63744464b22632 /numpy/linalg
parenta2228e9f2cd09b4aea872594fc2ec738eba7a1a6 (diff)
downloadnumpy-a1af647b13b33963eca3f9504f1d4b2603c3e4a3.tar.gz
MAINT: Remove similar branches from linalg.lstsq
Diffstat (limited to 'numpy/linalg')
-rw-r--r--numpy/linalg/linalg.py47
1 files changed, 27 insertions, 20 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 26254903d..041b34b91 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -2039,28 +2039,35 @@ def lstsq(a, b, rcond="warn"):
0, work, lwork, iwork, 0)
if results['info'] > 0:
raise LinAlgError('SVD did not converge in Linear Least Squares')
- resids = array([], result_real_t)
- if is_1d:
- x = array(ravel(bstar)[:n], dtype=result_t, copy=True)
- if results['rank'] == n and m > n:
- if isComplexType(t):
- resids = array([sum(abs(ravel(bstar)[n:])**2)],
- dtype=result_real_t)
- else:
- resids = array([sum((ravel(bstar)[n:])**2)],
- dtype=result_real_t)
+
+ # undo transpose imposed by fortran-order arrays
+ b_out = bstar.T
+
+ # b_out contains both the solution and the components of the residuals
+ x = b_out[:n,:]
+ r_parts = b_out[n:,:]
+ if isComplexType(t):
+ resids = sum(abs(r_parts)**2, axis=-2)
else:
- x = array(bstar.T[:n,:], dtype=result_t, copy=True)
- if results['rank'] == n and m > n:
- if isComplexType(t):
- resids = sum(abs(bstar.T[n:,:])**2, axis=0).astype(
- result_real_t, copy=False)
- else:
- resids = sum((bstar.T[n:,:])**2, axis=0).astype(
- result_real_t, copy=False)
+ resids = sum(r_parts**2, axis=-2)
- st = s[:min(n, m)].astype(result_real_t, copy=True)
- return wrap(x), wrap(resids), results['rank'], st
+ rank = results['rank']
+
+ # remove the axis we added
+ if is_1d:
+ x = x.squeeze(axis=-1)
+ # we probably should squeeze resids too, but we can't
+ # without breaking compatibility.
+
+ # as documented
+ if rank != n or m <= n:
+ resids = array([], result_real_t)
+
+ # coerce output arrays
+ s = s.astype(result_real_t, copy=True)
+ resids = resids.astype(result_real_t, copy=False) # array is temporary
+ x = x.astype(result_t, copy=True)
+ return wrap(x), wrap(resids), rank, s
def _multi_svd_norm(x, row_axis, col_axis, op):