diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 100 |
1 files changed, 58 insertions, 42 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 087156bd0..97d130894 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -99,6 +99,9 @@ def _raise_linalgerror_svd_nonconvergence(err, flag): def _raise_linalgerror_lstsq(err, flag): raise LinAlgError("SVD did not converge in Linear Least Squares") +def _raise_linalgerror_qr_r_raw(err, flag): + raise LinAlgError("Illegal argument found while performing QR factorization") + def get_linalg_error_extobj(callback): extobj = list(_linalg_error_extobj) # make a copy extobj[2] = callback @@ -766,12 +769,12 @@ def cholesky(a): # QR decomposition -def _qr_dispatcher(a, mode=None): +def _qr_dispatcher(a, mode=None, old=None): return (a,) @array_function_dispatch(_qr_dispatcher) -def qr(a, mode='reduced'): +def qr(a, mode='reduced', old=True): """ Compute the qr factorization of a matrix. @@ -912,62 +915,75 @@ def qr(a, mode='reduced'): mn = min(m, n) if m <= n: - gufunc = _umath_linalg.qr_r_raw_e_m + gufunc = _umath_linalg.qr_r_raw_m else: - gufunc = _umath_linalg.qr_r_raw_e_n + gufunc = _umath_linalg.qr_r_raw_n signature = 'D->D' if isComplexType(t) else 'd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) + extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) tau = gufunc(a, signature=signature, extobj=extobj) # handle modes that don't return q if mode == 'r': - r = _fastCopyAndTranspose(result_t, a[:, :mn]) - return wrap(triu(r)) + return wrap(triu(a.T)) if mode == 'raw': - return a, tau + return wrap(a.T), tau if mode == 'economic': if t != result_t : a = a.astype(result_t, copy=False) - return wrap(a.T) - - # generate q from a - if mode == 'complete' and m > n: - mc = m - q = empty((m, m), t) - else: - mc = mn - q = empty((n, m), t) - q[:n] = a - - if isComplexType(t): - lapack_routine = lapack_lite.zungqr - routine_name = 'zungqr' - else: - lapack_routine = lapack_lite.dorgqr - routine_name = 'dorgqr' + return wrap(a) - # determine optimal lwork - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - # compute q - lwork = max(1, n, int(abs(work[0]))) - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - q = _fastCopyAndTranspose(result_t, q[:mc]) - r = _fastCopyAndTranspose(result_t, a[:, :mc]) + if old: + # generate q from a + if mode == 'complete' and m > n: + mc = m + q = empty((m, m), t) + else: + mc = mn + q = empty((m, n), t) + q[:n] = a - return wrap(q), wrap(triu(r)) + if isComplexType(t): + lapack_routine = lapack_lite.zungqr + routine_name = 'zungqr' + else: + lapack_routine = lapack_lite.dorgqr + routine_name = 'dorgqr' + + # determine optimal lwork + lwork = 1 + work = zeros((lwork,), t) + results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0) + if results['info'] != 0: + raise LinAlgError('%s returns %d' % (routine_name, results['info'])) + + # compute q + lwork = max(1, n, int(abs(work[0]))) + work = zeros((lwork,), t) + results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0) + if results['info'] != 0: + raise LinAlgError('%s returns %d' % (routine_name, results['info'])) + + q = _fastCopyAndTranspose(result_t, q[:mc]) + r = _fastCopyAndTranspose(result_t, a[:, :mc]) + + return wrap(q), wrap(triu(r)) + + if mode == 'reduced': + if m <= n: + gufunc = _umath_linalg.qr_reduced_m + else: + gufunc = _umath_linalg.qr_reduced_n + + print(gufunc) + signature = 'DD->D' if isComplexType(t) else 'dd->d' + extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) + q = gufunc(a, tau, signature=signature, extobj=extobj) + r = _fastCopyAndTranspose(result_t, a[:, :mn]) + return q, wrap(triu(r)) # Eigenvalues |