diff options
author | czgdp1807 <gdp.1807@gmail.com> | 2021-06-03 15:02:26 +0530 |
---|---|---|
committer | czgdp1807 <gdp.1807@gmail.com> | 2021-06-03 15:02:26 +0530 |
commit | c25c09b0688bcda1da148fbcac47adab697c409c (patch) | |
tree | 5ba18bcd6215bf09bc2740d359e00d8c9f0418e5 /numpy/linalg/linalg.py | |
parent | 6790873334b143117f4e8d1f515def8c7fdeb9fb (diff) | |
download | numpy-c25c09b0688bcda1da148fbcac47adab697c409c.tar.gz |
ENH: r, raw, economic now accept stacked matrices
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 31 |
1 files changed, 9 insertions, 22 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 46fb2502e..087156bd0 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -904,34 +904,21 @@ def qr(a, mode='reduced'): raise ValueError(f"Unrecognized mode '{mode}'") a, wrap = _makearray(a) - _assert_2d(a) - m, n = a.shape + _assert_stacked_2d(a) + m, n = a.shape[-2:] t, result_t = _commonType(a) - a = _fastCopyAndTranspose(t, a) + a = a.astype(t, copy=True) a = _to_native_byte_order(a) mn = min(m, n) - tau = zeros((mn,), t) - if isComplexType(t): - lapack_routine = lapack_lite.zgeqrf - routine_name = 'zgeqrf' + if m <= n: + gufunc = _umath_linalg.qr_r_raw_e_m else: - lapack_routine = lapack_lite.dgeqrf - routine_name = 'dgeqrf' + gufunc = _umath_linalg.qr_r_raw_e_n - # calculate optimal size of work data 'work' - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - # do qr decomposition - lwork = max(1, n, int(abs(work[0]))) - work = zeros((lwork,), t) - results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) + signature = 'D->D' if isComplexType(t) else 'd->d' + extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) + tau = gufunc(a, signature=signature, extobj=extobj) # handle modes that don't return q if mode == 'r': |