summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py31
1 files changed, 9 insertions, 22 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 46fb2502e..087156bd0 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -904,34 +904,21 @@ def qr(a, mode='reduced'):
raise ValueError(f"Unrecognized mode '{mode}'")
a, wrap = _makearray(a)
- _assert_2d(a)
- m, n = a.shape
+ _assert_stacked_2d(a)
+ m, n = a.shape[-2:]
t, result_t = _commonType(a)
- a = _fastCopyAndTranspose(t, a)
+ a = a.astype(t, copy=True)
a = _to_native_byte_order(a)
mn = min(m, n)
- tau = zeros((mn,), t)
- if isComplexType(t):
- lapack_routine = lapack_lite.zgeqrf
- routine_name = 'zgeqrf'
+ if m <= n:
+ gufunc = _umath_linalg.qr_r_raw_e_m
else:
- lapack_routine = lapack_lite.dgeqrf
- routine_name = 'dgeqrf'
+ gufunc = _umath_linalg.qr_r_raw_e_n
- # calculate optimal size of work data 'work'
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
-
- # do qr decomposition
- lwork = max(1, n, int(abs(work[0])))
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
+ signature = 'D->D' if isComplexType(t) else 'd->d'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
+ tau = gufunc(a, signature=signature, extobj=extobj)
# handle modes that don't return q
if mode == 'r':