diff options
author | Jeremy Chen <convexset@gmail.com> | 2018-08-01 01:00:07 +0800 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-07-31 10:00:07 -0700 |
commit | 8fdc4460b9065f8a6cd7e95e106bfeafdc64d613 (patch) | |
tree | 1cfde43473d15181536c60c8467077e2ba65b2ac /numpy/linalg/linalg.py | |
parent | 9bb569c4e0e1cf08128179d157bdab10c8706a97 (diff) | |
download | numpy-8fdc4460b9065f8a6cd7e95e106bfeafdc64d613.tar.gz |
ENH: handle empty matrices in qr decomposition (#11593)
Ensure LWORK and LDA respect the requirements of the lapack methods (zgeqrf, dgeqrf, zungqr, dorgqr)
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index f3a3b37a3..c3b76ada7 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -858,13 +858,13 @@ def qr(a, mode='reduced'): a, wrap = _makearray(a) _assertRank2(a) - _assertNoEmpty2d(a) m, n = a.shape t, result_t = _commonType(a) a = _fastCopyAndTranspose(t, a) a = _to_native_byte_order(a) mn = min(m, n) tau = zeros((mn,), t) + if isComplexType(t): lapack_routine = lapack_lite.zgeqrf routine_name = 'zgeqrf' @@ -875,14 +875,14 @@ def qr(a, mode='reduced'): # calculate optimal size of work data 'work' lwork = 1 work = zeros((lwork,), t) - results = lapack_routine(m, n, a, m, tau, work, -1, 0) + results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0) if results['info'] != 0: raise LinAlgError('%s returns %d' % (routine_name, results['info'])) # do qr decomposition - lwork = int(abs(work[0])) + lwork = max(1, n, int(abs(work[0]))) work = zeros((lwork,), t) - results = lapack_routine(m, n, a, m, tau, work, lwork, 0) + results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0) if results['info'] != 0: raise LinAlgError('%s returns %d' % (routine_name, results['info'])) @@ -918,14 +918,14 @@ def qr(a, mode='reduced'): # determine optimal lwork lwork = 1 work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, m, tau, work, -1, 0) + results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0) if results['info'] != 0: raise LinAlgError('%s returns %d' % (routine_name, results['info'])) # compute q - lwork = int(abs(work[0])) + lwork = max(1, n, int(abs(work[0]))) work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, m, tau, work, lwork, 0) + results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0) if results['info'] != 0: raise LinAlgError('%s returns %d' % (routine_name, results['info'])) |