diff options
-rw-r--r-- | numpy/linalg/linalg.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 1942d6358..6139e40c4 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -2007,10 +2007,9 @@ def lstsq(a, b, rcond="warn"): b = b[:, newaxis] _assertRank2(a, b) _assertNoEmpty2d(a, b) # TODO: relax this constraint - m = a.shape[0] - n = a.shape[1] - n_rhs = b.shape[1] - if m != b.shape[0]: + m, n = a.shape[-2:] + m2, n_rhs = b.shape[-2:] + if m != m2: raise LinAlgError('Incompatible dimensions') t, result_t = _commonType(a, b) |