diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 9 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 11 |
2 files changed, 18 insertions, 2 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 7d6d986e0..fce65e4e5 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -20,7 +20,7 @@ from numpy.core import array, asarray, zeros, empty, transpose, \ intc, single, double, csingle, cdouble, inexact, complexfloating, \ newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \ maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \ - isfinite + isfinite, size from numpy.lib import triu from numpy.linalg import lapack_lite @@ -126,6 +126,11 @@ def _assertFinite(*arrays): if not (isfinite(a).all()): raise LinAlgError, "Array must not contain infs or NaNs" +def _assertNonEmpty(*arrays): + for a in arrays: + if size(a) == 0: + raise LinAlgError("Arrays cannot be empty") + # Linear equations def tensorsolve(a, b, axes=None): @@ -718,6 +723,7 @@ def svd(a, full_matrices=1, compute_uv=1): """ a, wrap = _makearray(a) _assertRank2(a) + _assertNonEmpty(a) m, n = a.shape t, result_t = _commonType(a) real_t = _linalgRealType(t) @@ -783,6 +789,7 @@ def pinv(a, rcond=1e-15 ): rcond of the largest. """ a, wrap = _makearray(a) + _assertNonEmpty(a) a = a.conjugate() u, s, vt = svd(a, 0) m = u.shape[0] diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index af7914200..1b380fa12 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -4,7 +4,7 @@ from numpy.testing import * set_package_path() from numpy import array, single, double, csingle, cdouble, dot, identity, \ - multiply + multiply, atleast_2d from numpy import linalg restore_path() @@ -37,6 +37,15 @@ class LinalgTestCase(NumpyTestCase): b = array([2.+1j, 1.+2j], dtype=cdouble) self.do(a, b) + def check_empty(self): + a = atleast_2d(array([], dtype = double)) + b = atleast_2d(array([], dtype = double)) + try: + self.do(a, b) + raise AssertionError("%s should fail with empty matrices", self.__name__[5:]) + except linalg.LinAlgError, e: + pass + class test_solve(LinalgTestCase): def do(self, a, b): x = linalg.solve(a, b) |