diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 56 |
1 files changed, 12 insertions, 44 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 97d130894..2031014ac 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -769,12 +769,12 @@ def cholesky(a): # QR decomposition -def _qr_dispatcher(a, mode=None, old=None): +def _qr_dispatcher(a, mode=None): return (a,) @array_function_dispatch(_qr_dispatcher) -def qr(a, mode='reduced', old=True): +def qr(a, mode='reduced'): """ Compute the qr factorization of a matrix. @@ -925,65 +925,33 @@ def qr(a, mode='reduced', old=True): # handle modes that don't return q if mode == 'r': - return wrap(triu(a.T)) + return wrap(triu(a)) if mode == 'raw': - return wrap(a.T), tau + return wrap(transpose(a)), tau if mode == 'economic': if t != result_t : a = a.astype(result_t, copy=False) return wrap(a) - if old: - # generate q from a - if mode == 'complete' and m > n: - mc = m - q = empty((m, m), t) - else: - mc = mn - q = empty((m, n), t) - q[:n] = a - - if isComplexType(t): - lapack_routine = lapack_lite.zungqr - routine_name = 'zungqr' + if mode == 'complete' and m > n: + mc = m + if m <= n: + gufunc = _umath_linalg.qr_complete_m else: - lapack_routine = lapack_lite.dorgqr - routine_name = 'dorgqr' - - # determine optimal lwork - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - # compute q - lwork = max(1, n, int(abs(work[0]))) - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - q = _fastCopyAndTranspose(result_t, q[:mc]) - r = _fastCopyAndTranspose(result_t, a[:, :mc]) - - return wrap(q), wrap(triu(r)) - - if mode == 'reduced': + gufunc = _umath_linalg.qr_complete_n + else: + mc = mn if m <= n: gufunc = _umath_linalg.qr_reduced_m else: gufunc = _umath_linalg.qr_reduced_n - print(gufunc) signature = 'DD->D' if isComplexType(t) else 'dd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) q = gufunc(a, tau, signature=signature, extobj=extobj) - - r = _fastCopyAndTranspose(result_t, a[:, :mn]) - return q, wrap(triu(r)) + return wrap(q), wrap(triu(a[:, :mc])) # Eigenvalues |