summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
authorczgdp1807 <gdp.1807@gmail.com>2021-06-03 15:02:26 +0530
committerczgdp1807 <gdp.1807@gmail.com>2021-06-03 15:02:26 +0530
commitc25c09b0688bcda1da148fbcac47adab697c409c (patch)
tree5ba18bcd6215bf09bc2740d359e00d8c9f0418e5 /numpy/linalg/linalg.py
parent6790873334b143117f4e8d1f515def8c7fdeb9fb (diff)
downloadnumpy-c25c09b0688bcda1da148fbcac47adab697c409c.tar.gz
ENH: r, raw, economic now accept stacked matrices
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py31
1 files changed, 9 insertions, 22 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 46fb2502e..087156bd0 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -904,34 +904,21 @@ def qr(a, mode='reduced'):
raise ValueError(f"Unrecognized mode '{mode}'")
a, wrap = _makearray(a)
- _assert_2d(a)
- m, n = a.shape
+ _assert_stacked_2d(a)
+ m, n = a.shape[-2:]
t, result_t = _commonType(a)
- a = _fastCopyAndTranspose(t, a)
+ a = a.astype(t, copy=True)
a = _to_native_byte_order(a)
mn = min(m, n)
- tau = zeros((mn,), t)
- if isComplexType(t):
- lapack_routine = lapack_lite.zgeqrf
- routine_name = 'zgeqrf'
+ if m <= n:
+ gufunc = _umath_linalg.qr_r_raw_e_m
else:
- lapack_routine = lapack_lite.dgeqrf
- routine_name = 'dgeqrf'
+ gufunc = _umath_linalg.qr_r_raw_e_n
- # calculate optimal size of work data 'work'
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
-
- # do qr decomposition
- lwork = max(1, n, int(abs(work[0])))
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
+ signature = 'D->D' if isComplexType(t) else 'd->d'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
+ tau = gufunc(a, signature=signature, extobj=extobj)
# handle modes that don't return q
if mode == 'r':