summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
authorJeremy Chen <convexset@gmail.com>2018-08-01 01:00:07 +0800
committerEric Wieser <wieser.eric@gmail.com>2018-07-31 10:00:07 -0700
commit8fdc4460b9065f8a6cd7e95e106bfeafdc64d613 (patch)
tree1cfde43473d15181536c60c8467077e2ba65b2ac /numpy/linalg/linalg.py
parent9bb569c4e0e1cf08128179d157bdab10c8706a97 (diff)
downloadnumpy-8fdc4460b9065f8a6cd7e95e106bfeafdc64d613.tar.gz
ENH: handle empty matrices in qr decomposition (#11593)
Ensure LWORK and LDA respect the requirements of the lapack methods (zgeqrf, dgeqrf, zungqr, dorgqr)
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index f3a3b37a3..c3b76ada7 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -858,13 +858,13 @@ def qr(a, mode='reduced'):
a, wrap = _makearray(a)
_assertRank2(a)
- _assertNoEmpty2d(a)
m, n = a.shape
t, result_t = _commonType(a)
a = _fastCopyAndTranspose(t, a)
a = _to_native_byte_order(a)
mn = min(m, n)
tau = zeros((mn,), t)
+
if isComplexType(t):
lapack_routine = lapack_lite.zgeqrf
routine_name = 'zgeqrf'
@@ -875,14 +875,14 @@ def qr(a, mode='reduced'):
# calculate optimal size of work data 'work'
lwork = 1
work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, m, tau, work, -1, 0)
+ results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0)
if results['info'] != 0:
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
# do qr decomposition
- lwork = int(abs(work[0]))
+ lwork = max(1, n, int(abs(work[0])))
work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, m, tau, work, lwork, 0)
+ results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0)
if results['info'] != 0:
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
@@ -918,14 +918,14 @@ def qr(a, mode='reduced'):
# determine optimal lwork
lwork = 1
work = zeros((lwork,), t)
- results = lapack_routine(m, mc, mn, q, m, tau, work, -1, 0)
+ results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0)
if results['info'] != 0:
raise LinAlgError('%s returns %d' % (routine_name, results['info']))
# compute q
- lwork = int(abs(work[0]))
+ lwork = max(1, n, int(abs(work[0])))
work = zeros((lwork,), t)
- results = lapack_routine(m, mc, mn, q, m, tau, work, lwork, 0)
+ results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0)
if results['info'] != 0:
raise LinAlgError('%s returns %d' % (routine_name, results['info']))