summaryrefslogtreecommitdiff
path: root/numpy/linalg
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-11-09 09:04:42 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-04-10 23:33:13 -0700
commit3ef55be846bc8e6e21515787b29608fec7e1fad0 (patch)
treeea7123d0a2d36cca039bd356e15b2b85b82a16d3 /numpy/linalg
parentefc254ed833a5f7a68c2dbb69868f5bee693e8e9 (diff)
downloadnumpy-3ef55be846bc8e6e21515787b29608fec7e1fad0.tar.gz
MAINT: Move lstsq to umath_linalg
This does not yet enable any broadcasting, but makes doing so in future far easier.
Diffstat (limited to 'numpy/linalg')
-rw-r--r--numpy/linalg/linalg.py64
1 files changed, 14 insertions, 50 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index de25d25e9..f073abadf 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -97,6 +97,9 @@ def _raise_linalgerror_eigenvalues_nonconvergence(err, flag):
def _raise_linalgerror_svd_nonconvergence(err, flag):
raise LinAlgError("SVD did not converge")
+def _raise_linalgerror_lstsq(err, flag):
+ raise LinAlgError("SVD did not converge in Linear Least Squares")
+
def get_linalg_error_extobj(callback):
extobj = list(_linalg_error_extobj) # make a copy
extobj[2] = callback
@@ -1997,7 +2000,6 @@ def lstsq(a, b, rcond="warn"):
>>> plt.show()
"""
- import math
a, _ = _makearray(a)
b, wrap = _makearray(b)
is_1d = b.ndim == 1
@@ -2008,7 +2010,6 @@ def lstsq(a, b, rcond="warn"):
m = a.shape[0]
n = a.shape[1]
n_rhs = b.shape[1]
- ldb = max(n, m)
if m != b.shape[0]:
raise LinAlgError('Incompatible dimensions')
@@ -2028,62 +2029,25 @@ def lstsq(a, b, rcond="warn"):
FutureWarning, stacklevel=2)
rcond = -1
if rcond is None:
- rcond = finfo(t).eps * ldb
-
- bstar = zeros((ldb, n_rhs), t)
- bstar[:m, :n_rhs] = b
- a, bstar = _fastCopyAndTranspose(t, a, bstar)
- a, bstar = _to_native_byte_order(a, bstar)
- s = zeros((min(m, n),), real_t)
- # This line:
- # * is incorrect, according to the LAPACK documentation
- # * raises a ValueError if min(m,n) == 0
- # * should not be calculated here anyway, as LAPACK should calculate
- # `liwork` for us. But that only works if our version of lapack does
- # not have this bug:
- # http://icl.cs.utk.edu/lapack-forum/archives/lapack/msg00899.html
- # Lapack_lite does have that bug...
- nlvl = max( 0, int( math.log( float(min(m, n))/2. ) ) + 1 )
- iwork = zeros((3*min(m, n)*nlvl+11*min(m, n),), fortran_int)
- if isComplexType(t):
- lapack_routine = lapack_lite.zgelsd
- lwork = 1
- rwork = zeros((lwork,), real_t)
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
- 0, work, -1, rwork, iwork, 0)
- lrwork = int(rwork[0])
- lwork = int(work[0].real)
- work = zeros((lwork,), t)
- rwork = zeros((lrwork,), real_t)
- results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
- 0, work, lwork, rwork, iwork, 0)
+ rcond = finfo(t).eps * max(n, m)
+
+ if m <= n:
+ gufunc = _umath_linalg.lstsq_m
else:
- lapack_routine = lapack_lite.dgelsd
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
- 0, work, -1, iwork, 0)
- lwork = int(work[0])
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
- 0, work, lwork, iwork, 0)
- if results['info'] > 0:
- raise LinAlgError('SVD did not converge in Linear Least Squares')
-
- # undo transpose imposed by fortran-order arrays
- b_out = bstar.T
+ gufunc = _umath_linalg.lstsq_n
+
+ signature = 'DDd->Did' if isComplexType(t) else 'ddd->did'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
+ b_out, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj)
# b_out contains both the solution and the components of the residuals
- x = b_out[:n,:]
- r_parts = b_out[n:,:]
+ x = b_out[...,:n,:]
+ r_parts = b_out[...,n:,:]
if isComplexType(t):
resids = sum(abs(r_parts)**2, axis=-2)
else:
resids = sum(r_parts**2, axis=-2)
- rank = results['rank']
-
# remove the axis we added
if is_1d:
x = x.squeeze(axis=-1)