diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 31 |
1 files changed, 9 insertions, 22 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 46fb2502e..087156bd0 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -904,34 +904,21 @@ def qr(a, mode='reduced'): raise ValueError(f"Unrecognized mode '{mode}'") a, wrap = _makearray(a) - _assert_2d(a) - m, n = a.shape + _assert_stacked_2d(a) + m, n = a.shape[-2:] t, result_t = _commonType(a) - a = _fastCopyAndTranspose(t, a) + a = a.astype(t, copy=True) a = _to_native_byte_order(a) mn = min(m, n) - tau = zeros((mn,), t) - if isComplexType(t): - lapack_routine = lapack_lite.zgeqrf - routine_name = 'zgeqrf' + if m <= n: + gufunc = _umath_linalg.qr_r_raw_e_m else: - lapack_routine = lapack_lite.dgeqrf - routine_name = 'dgeqrf' + gufunc = _umath_linalg.qr_r_raw_e_n - # calculate optimal size of work data 'work' - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - # do qr decomposition - lwork = max(1, n, int(abs(work[0]))) - work = zeros((lwork,), t) - results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) + signature = 'D->D' if isComplexType(t) else 'd->d' + extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) + tau = gufunc(a, signature=signature, extobj=extobj) # handle modes that don't return q if mode == 'r': |