summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests/test_linalg.py
diff options
context:
space:
mode:
authorJeremy Chen <convexset@gmail.com>2018-08-01 14:49:40 +0800
committerJeremy Chen <convexset@gmail.com>2018-08-03 00:18:34 +0800
commit45d8c5d1562007492c459f290e16cbbf99c72e1c (patch)
treede13615d09cdc1d76ba76161e7db160b636ba37d /numpy/linalg/tests/test_linalg.py
parent6105281cf245c5713660245a0c87ae00e85aec23 (diff)
downloadnumpy-45d8c5d1562007492c459f290e16cbbf99c72e1c.tar.gz
ENH: support for empty matrices in linalg.lstsq
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r--numpy/linalg/tests/test_linalg.py30
1 files changed, 26 insertions, 4 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 0df673884..36b677ac3 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -875,14 +875,12 @@ class TestDet(DetCases):
class LstsqCases(LinalgSquareTestCase, LinalgNonsquareTestCase):
def do(self, a, b, tags):
- if 'size-0' in tags:
- assert_raises(LinAlgError, linalg.lstsq, a, b)
- return
-
arr = np.asarray(a)
m, n = arr.shape
u, s, vt = linalg.svd(a, 0)
x, residuals, rank, sv = linalg.lstsq(a, b, rcond=-1)
+ if m == 0:
+ assert_((x == 0).all())
if m <= n:
assert_almost_equal(b, dot(a, x))
assert_equal(rank, m)
@@ -923,6 +921,30 @@ class TestLstsq(LstsqCases):
# Warning should be raised exactly once (first command)
assert_(len(w) == 1)
+ @pytest.mark.parametrize(["m", "n", "n_rhs"], [
+ (4, 2, 2),
+ (0, 4, 1),
+ (0, 4, 2),
+ (4, 0, 1),
+ (4, 0, 2),
+ (4, 2, 0),
+ (0, 0, 0)
+ ])
+ def test_empty_a_b(self, m, n, n_rhs):
+ a = np.arange(m * n).reshape(m, n)
+ b = np.ones((m, n_rhs))
+ x, residuals, rank, s = linalg.lstsq(a, b, rcond=None)
+ if m == 0:
+ assert_((x == 0).all())
+ assert_equal(x.shape, (n, n_rhs))
+ assert_equal(residuals.shape, ((n_rhs,) if m > n else (0,)))
+ if m > n and n_rhs > 0:
+ # residuals are exactly the squared norms of b's columns
+ r = b - np.dot(a, x)
+ assert_almost_equal(residuals, (r * r).sum(axis=-2))
+ assert_equal(rank, min(m, n))
+ assert_equal(s.shape, (min(m, n),))
+
class TestMatrixPower(object):
R90 = array([[0, 1], [-1, 0]])