diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2016-12-19 13:59:14 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2016-12-19 14:32:54 +0000 |
commit | b28ea1c57c9dba8feec7f616986f2e3e28dd6192 (patch) | |
tree | 9cf758da903e1a9ea1d4458173ac56ffc0226367 /numpy/linalg/linalg.py | |
parent | 835653064fc35c08bf001ef9a09c3127d7ea0be7 (diff) | |
download | numpy-b28ea1c57c9dba8feec7f616986f2e3e28dd6192.tar.gz |
BUG: Raise LinAlgError from lstsq rather than a math domain error from math.log
We could make the log conditional on its argument being non-zero, but that entire expression is wrong anyway.
We could omit that calculation entirely and have LAPACK calculate it for us, but the routine in LAPACK is wrong anyway
We could upgrade the version of lapack shipped in lapack_lite, but the tool to do that is wrong anyway.
Let's leave that can of worms for another time, and just improve the error message for now.
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 79ff68ff3..60a69b9c8 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1919,6 +1919,7 @@ def lstsq(a, b, rcond=-1): if is_1d: b = b[:, newaxis] _assertRank2(a, b) + _assertNoEmpty2d(a, b) # TODO: relax this constraint m = a.shape[0] n = a.shape[1] n_rhs = b.shape[1] @@ -1933,6 +1934,14 @@ def lstsq(a, b, rcond=-1): a, bstar = _fastCopyAndTranspose(t, a, bstar) a, bstar = _to_native_byte_order(a, bstar) s = zeros((min(m, n),), real_t) + # This line: + # * is incorrect, according to the LAPACK documentation + # * raises a ValueError if min(m,n) == 0 + # * should not be calculated here anyway, as LAPACK should calculate + # `liwork` for us. But that only works if our version of lapack does + # not have this bug: + # http://icl.cs.utk.edu/lapack-forum/archives/lapack/msg00899.html + # Lapack_lite does have that bug... nlvl = max( 0, int( math.log( float(min(m, n))/2. ) ) + 1 ) iwork = zeros((3*min(m, n)*nlvl+11*min(m, n),), fortran_int) if isComplexType(t): |