diff options
author | seberg <sebastian@sipsolutions.net> | 2018-06-27 19:35:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-27 19:35:36 +0200 |
commit | a32c377102193b4bd6b7921e4497e9d1c284481f (patch) | |
tree | 28bbcd51839d651b2fdba4f79a781f4a52d8f4a0 /numpy/linalg/tests/test_linalg.py | |
parent | 72d2bc0b36478e06cc9d2f942f24308cabf1a22e (diff) | |
parent | b80d360e2b82cd52ad69548cc292c2bab95de6ce (diff) | |
download | numpy-a32c377102193b4bd6b7921e4497e9d1c284481f.tar.gz |
Merge pull request #11424 from eric-wieser/empty-svd
ENH: Allow use of svd on empty arrays
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 87dfe988a..1c24f1e04 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -644,10 +644,6 @@ class TestEig(EigCases): class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): def do(self, a, b, tags): - if 'size-0' in tags: - assert_raises(LinAlgError, linalg.svd, a, 0) - return - u, s, vt = linalg.svd(a, 0) assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :], np.asarray(vt)), @@ -670,15 +666,19 @@ class TestSVD(SVDCases): for dtype in [single, double, csingle, cdouble]: check(dtype) - def test_0_size(self): - # These raise errors currently - # (which does not mean that it may not make sense) - a = np.zeros((0, 0), dtype=np.complex64) - assert_raises(linalg.LinAlgError, linalg.svd, a) - a = np.zeros((0, 1), dtype=np.complex64) - assert_raises(linalg.LinAlgError, linalg.svd, a) - a = np.zeros((1, 0), dtype=np.complex64) - assert_raises(linalg.LinAlgError, linalg.svd, a) + def test_empty_identity(self): + """ Empty input should put an identity matrix in u or vh """ + x = np.empty((4, 0)) + u, s, vh = linalg.svd(x, compute_uv=True) + assert_equal(u.shape, (4, 4)) + assert_equal(vh.shape, (0, 0)) + assert_equal(u, np.eye(4)) + + x = np.empty((0, 4)) + u, s, vh = linalg.svd(x, compute_uv=True) + assert_equal(u.shape, (0, 0)) + assert_equal(vh.shape, (4, 4)) + assert_equal(vh, np.eye(4)) class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): |