diff options
-rw-r--r-- | numpy/linalg/linalg.py | 64 |
1 files changed, 14 insertions, 50 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index de25d25e9..f073abadf 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -97,6 +97,9 @@ def _raise_linalgerror_eigenvalues_nonconvergence(err, flag): def _raise_linalgerror_svd_nonconvergence(err, flag): raise LinAlgError("SVD did not converge") +def _raise_linalgerror_lstsq(err, flag): + raise LinAlgError("SVD did not converge in Linear Least Squares") + def get_linalg_error_extobj(callback): extobj = list(_linalg_error_extobj) # make a copy extobj[2] = callback @@ -1997,7 +2000,6 @@ def lstsq(a, b, rcond="warn"): >>> plt.show() """ - import math a, _ = _makearray(a) b, wrap = _makearray(b) is_1d = b.ndim == 1 @@ -2008,7 +2010,6 @@ def lstsq(a, b, rcond="warn"): m = a.shape[0] n = a.shape[1] n_rhs = b.shape[1] - ldb = max(n, m) if m != b.shape[0]: raise LinAlgError('Incompatible dimensions') @@ -2028,62 +2029,25 @@ def lstsq(a, b, rcond="warn"): FutureWarning, stacklevel=2) rcond = -1 if rcond is None: - rcond = finfo(t).eps * ldb - - bstar = zeros((ldb, n_rhs), t) - bstar[:m, :n_rhs] = b - a, bstar = _fastCopyAndTranspose(t, a, bstar) - a, bstar = _to_native_byte_order(a, bstar) - s = zeros((min(m, n),), real_t) - # This line: - # * is incorrect, according to the LAPACK documentation - # * raises a ValueError if min(m,n) == 0 - # * should not be calculated here anyway, as LAPACK should calculate - # `liwork` for us. But that only works if our version of lapack does - # not have this bug: - # http://icl.cs.utk.edu/lapack-forum/archives/lapack/msg00899.html - # Lapack_lite does have that bug... - nlvl = max( 0, int( math.log( float(min(m, n))/2. ) ) + 1 ) - iwork = zeros((3*min(m, n)*nlvl+11*min(m, n),), fortran_int) - if isComplexType(t): - lapack_routine = lapack_lite.zgelsd - lwork = 1 - rwork = zeros((lwork,), real_t) - work = zeros((lwork,), t) - results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond, - 0, work, -1, rwork, iwork, 0) - lrwork = int(rwork[0]) - lwork = int(work[0].real) - work = zeros((lwork,), t) - rwork = zeros((lrwork,), real_t) - results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond, - 0, work, lwork, rwork, iwork, 0) + rcond = finfo(t).eps * max(n, m) + + if m <= n: + gufunc = _umath_linalg.lstsq_m else: - lapack_routine = lapack_lite.dgelsd - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond, - 0, work, -1, iwork, 0) - lwork = int(work[0]) - work = zeros((lwork,), t) - results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond, - 0, work, lwork, iwork, 0) - if results['info'] > 0: - raise LinAlgError('SVD did not converge in Linear Least Squares') - - # undo transpose imposed by fortran-order arrays - b_out = bstar.T + gufunc = _umath_linalg.lstsq_n + + signature = 'DDd->Did' if isComplexType(t) else 'ddd->did' + extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) + b_out, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj) # b_out contains both the solution and the components of the residuals - x = b_out[:n,:] - r_parts = b_out[n:,:] + x = b_out[...,:n,:] + r_parts = b_out[...,n:,:] if isComplexType(t): resids = sum(abs(r_parts)**2, axis=-2) else: resids = sum(r_parts**2, axis=-2) - rank = results['rank'] - # remove the axis we added if is_1d: x = x.squeeze(axis=-1) |