diff options
author | Jeremy Chen <convexset@gmail.com> | 2018-08-01 14:49:40 +0800 |
---|---|---|
committer | Jeremy Chen <convexset@gmail.com> | 2018-08-02 14:40:07 +0800 |
commit | cb0fc23d26461899eb37d1d13f6c0bf834695573 (patch) | |
tree | f3425904fb7eee3e500bdcdf876cf1a203fdf246 | |
parent | 977431a6355dfff2ade73f4ea1543598d26c7154 (diff) | |
download | numpy-cb0fc23d26461899eb37d1d13f6c0bf834695573.tar.gz |
ENH: support for empty matrices in linalg.lstsq
-rw-r--r-- | doc/release/1.16.0-notes.rst | 6 | ||||
-rw-r--r-- | numpy/linalg/linalg.py | 10 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 30 |
3 files changed, 41 insertions, 5 deletions
diff --git a/doc/release/1.16.0-notes.rst b/doc/release/1.16.0-notes.rst index ae21f4ffd..f56280171 100644 --- a/doc/release/1.16.0-notes.rst +++ b/doc/release/1.16.0-notes.rst @@ -40,6 +40,12 @@ New Features Improvements ============ +``linalg.lstsq`` now works with empty matrices +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Previously, a ``LinAlgError`` would be raised when an empty matrix/empty +matrices (with zero rows and/or columns) is passed in. Now outputs of +appropriate shapes are returned. + ``randint`` and ``choice`` now work on empty distributions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Even when no elements needed to be drawn, ``np.random.randint`` and diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 8e7ad70cd..b10442906 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -2110,7 +2110,6 @@ def lstsq(a, b, rcond="warn"): if is_1d: b = b[:, newaxis] _assertRank2(a, b) - _assertNoEmpty2d(a, b) # TODO: relax this constraint m, n = a.shape[-2:] m2, n_rhs = b.shape[-2:] if m != m2: @@ -2141,7 +2140,16 @@ def lstsq(a, b, rcond="warn"): signature = 'DDd->Ddid' if isComplexType(t) else 'ddd->ddid' extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) + if n_rhs == 0: + # lapack can't handle n_rhs = 0 - so allocate the array one larger in that axis + b = zeros(b.shape[:-2] + (m, n_rhs + 1), dtype=b.dtype) x, resids, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj) + if m == 0: + x[...] = 0 + if n_rhs == 0: + # remove the item we added + x = x[..., :n_rhs] + resids = resids[..., :n_rhs] # remove the axis we added if is_1d: diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 1c24f1e04..2326bedcb 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -875,14 +875,12 @@ class TestDet(DetCases): class LstsqCases(LinalgSquareTestCase, LinalgNonsquareTestCase): def do(self, a, b, tags): - if 'size-0' in tags: - assert_raises(LinAlgError, linalg.lstsq, a, b) - return - arr = np.asarray(a) m, n = arr.shape u, s, vt = linalg.svd(a, 0) x, residuals, rank, sv = linalg.lstsq(a, b, rcond=-1) + if m == 0: + assert_((x == 0).all()) if m <= n: assert_almost_equal(b, dot(a, x)) assert_equal(rank, m) @@ -923,6 +921,30 @@ class TestLstsq(LstsqCases): # Warning should be raised exactly once (first command) assert_(len(w) == 1) + @pytest.mark.parametrize(["m", "n", "n_rhs"], [ + (4, 2, 2), + (0, 4, 1), + (0, 4, 2), + (4, 0, 1), + (4, 0, 2), + (4, 2, 0), + (0, 0, 0) + ]) + def test_empty_a_b(self, m, n, n_rhs): + a = np.arange(m * n).reshape(m, n) + b = np.ones((m, n_rhs)) + x, residuals, rank, s = linalg.lstsq(a, b, rcond=None) + if m == 0: + assert_((x == 0).all()) + assert_equal(x.shape, (n, n_rhs)) + assert_equal(residuals.shape, ((n_rhs,) if m > n else (0,))) + if m > n and n_rhs > 0: + # residuals are exactly the squared norms of b's columns + r = b - np.dot(a, x) + assert_almost_equal(residuals, (r * r).sum(axis=-2)) + assert_equal(rank, min(m, n)) + assert_equal(s.shape, (min(m, n),)) + class TestMatrixPower(object): R90 = array([[0, 1], [-1, 0]]) |