summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py100
1 files changed, 58 insertions, 42 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 087156bd0..97d130894 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -99,6 +99,9 @@ def _raise_linalgerror_svd_nonconvergence(err, flag):
def _raise_linalgerror_lstsq(err, flag):
raise LinAlgError("SVD did not converge in Linear Least Squares")
+def _raise_linalgerror_qr_r_raw(err, flag):
+ raise LinAlgError("Illegal argument found while performing QR factorization")
+
def get_linalg_error_extobj(callback):
extobj = list(_linalg_error_extobj) # make a copy
extobj[2] = callback
@@ -766,12 +769,12 @@ def cholesky(a):
# QR decomposition
-def _qr_dispatcher(a, mode=None):
+def _qr_dispatcher(a, mode=None, old=None):
return (a,)
@array_function_dispatch(_qr_dispatcher)
-def qr(a, mode='reduced'):
+def qr(a, mode='reduced', old=True):
"""
Compute the qr factorization of a matrix.
@@ -912,62 +915,75 @@ def qr(a, mode='reduced'):
mn = min(m, n)
if m <= n:
- gufunc = _umath_linalg.qr_r_raw_e_m
+ gufunc = _umath_linalg.qr_r_raw_m
else:
- gufunc = _umath_linalg.qr_r_raw_e_n
+ gufunc = _umath_linalg.qr_r_raw_n
signature = 'D->D' if isComplexType(t) else 'd->d'
- extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
+ extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw)
tau = gufunc(a, signature=signature, extobj=extobj)
# handle modes that don't return q
if mode == 'r':
- r = _fastCopyAndTranspose(result_t, a[:, :mn])
- return wrap(triu(r))
+ return wrap(triu(a.T))
if mode == 'raw':
- return a, tau
+ return wrap(a.T), tau
if mode == 'economic':
if t != result_t :
a = a.astype(result_t, copy=False)
- return wrap(a.T)
-
- # generate q from a
- if mode == 'complete' and m > n:
- mc = m
- q = empty((m, m), t)
- else:
- mc = mn
- q = empty((n, m), t)
- q[:n] = a
-
- if isComplexType(t):
- lapack_routine = lapack_lite.zungqr
- routine_name = 'zungqr'
- else:
- lapack_routine = lapack_lite.dorgqr
- routine_name = 'dorgqr'
+ return wrap(a)
- # determine optimal lwork
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
-
- # compute q
- lwork = max(1, n, int(abs(work[0])))
- work = zeros((lwork,), t)
- results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
-
- q = _fastCopyAndTranspose(result_t, q[:mc])
- r = _fastCopyAndTranspose(result_t, a[:, :mc])
+ if old:
+ # generate q from a
+ if mode == 'complete' and m > n:
+ mc = m
+ q = empty((m, m), t)
+ else:
+ mc = mn
+ q = empty((m, n), t)
+ q[:n] = a
- return wrap(q), wrap(triu(r))
+ if isComplexType(t):
+ lapack_routine = lapack_lite.zungqr
+ routine_name = 'zungqr'
+ else:
+ lapack_routine = lapack_lite.dorgqr
+ routine_name = 'dorgqr'
+
+ # determine optimal lwork
+ lwork = 1
+ work = zeros((lwork,), t)
+ results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0)
+ if results['info'] != 0:
+ raise LinAlgError('%s returns %d' % (routine_name, results['info']))
+
+ # compute q
+ lwork = max(1, n, int(abs(work[0])))
+ work = zeros((lwork,), t)
+ results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0)
+ if results['info'] != 0:
+ raise LinAlgError('%s returns %d' % (routine_name, results['info']))
+
+ q = _fastCopyAndTranspose(result_t, q[:mc])
+ r = _fastCopyAndTranspose(result_t, a[:, :mc])
+
+ return wrap(q), wrap(triu(r))
+
+ if mode == 'reduced':
+ if m <= n:
+ gufunc = _umath_linalg.qr_reduced_m
+ else:
+ gufunc = _umath_linalg.qr_reduced_n
+
+ print(gufunc)
+ signature = 'DD->D' if isComplexType(t) else 'dd->d'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw)
+ q = gufunc(a, tau, signature=signature, extobj=extobj)
+ r = _fastCopyAndTranspose(result_t, a[:, :mn])
+ return q, wrap(triu(r))
# Eigenvalues