summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-07 01:04:22 +0000
committerEric Wieser <wieser.eric@gmail.com>2018-06-27 09:49:59 -0700
commitb80d360e2b82cd52ad69548cc292c2bab95de6ce (patch)
tree4bfd5904997294a7741fc30d627ebe1a64a90762 /numpy/linalg/tests
parent65f15a5e881817d8646571d36ab9a0bc39a6667e (diff)
downloadnumpy-b80d360e2b82cd52ad69548cc292c2bab95de6ce.tar.gz
ENH: Allow use of svd on empty arrays
part of #8654
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r--numpy/linalg/tests/test_linalg.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 87dfe988a..1c24f1e04 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -644,10 +644,6 @@ class TestEig(EigCases):
class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
def do(self, a, b, tags):
- if 'size-0' in tags:
- assert_raises(LinAlgError, linalg.svd, a, 0)
- return
-
u, s, vt = linalg.svd(a, 0)
assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
np.asarray(vt)),
@@ -670,15 +666,19 @@ class TestSVD(SVDCases):
for dtype in [single, double, csingle, cdouble]:
check(dtype)
- def test_0_size(self):
- # These raise errors currently
- # (which does not mean that it may not make sense)
- a = np.zeros((0, 0), dtype=np.complex64)
- assert_raises(linalg.LinAlgError, linalg.svd, a)
- a = np.zeros((0, 1), dtype=np.complex64)
- assert_raises(linalg.LinAlgError, linalg.svd, a)
- a = np.zeros((1, 0), dtype=np.complex64)
- assert_raises(linalg.LinAlgError, linalg.svd, a)
+ def test_empty_identity(self):
+ """ Empty input should put an identity matrix in u or vh """
+ x = np.empty((4, 0))
+ u, s, vh = linalg.svd(x, compute_uv=True)
+ assert_equal(u.shape, (4, 4))
+ assert_equal(vh.shape, (0, 0))
+ assert_equal(u, np.eye(4))
+
+ x = np.empty((0, 4))
+ u, s, vh = linalg.svd(x, compute_uv=True)
+ assert_equal(u.shape, (0, 0))
+ assert_equal(vh.shape, (4, 4))
+ assert_equal(vh, np.eye(4))
class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):