summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/linalg/linalg.py9
-rw-r--r--numpy/linalg/tests/test_linalg.py11
2 files changed, 18 insertions, 2 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 7d6d986e0..fce65e4e5 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -20,7 +20,7 @@ from numpy.core import array, asarray, zeros, empty, transpose, \
intc, single, double, csingle, cdouble, inexact, complexfloating, \
newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \
maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \
- isfinite
+ isfinite, size
from numpy.lib import triu
from numpy.linalg import lapack_lite
@@ -126,6 +126,11 @@ def _assertFinite(*arrays):
if not (isfinite(a).all()):
raise LinAlgError, "Array must not contain infs or NaNs"
+def _assertNonEmpty(*arrays):
+ for a in arrays:
+ if size(a) == 0:
+ raise LinAlgError("Arrays cannot be empty")
+
# Linear equations
def tensorsolve(a, b, axes=None):
@@ -718,6 +723,7 @@ def svd(a, full_matrices=1, compute_uv=1):
"""
a, wrap = _makearray(a)
_assertRank2(a)
+ _assertNonEmpty(a)
m, n = a.shape
t, result_t = _commonType(a)
real_t = _linalgRealType(t)
@@ -783,6 +789,7 @@ def pinv(a, rcond=1e-15 ):
rcond of the largest.
"""
a, wrap = _makearray(a)
+ _assertNonEmpty(a)
a = a.conjugate()
u, s, vt = svd(a, 0)
m = u.shape[0]
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index af7914200..1b380fa12 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -4,7 +4,7 @@
from numpy.testing import *
set_package_path()
from numpy import array, single, double, csingle, cdouble, dot, identity, \
- multiply
+ multiply, atleast_2d
from numpy import linalg
restore_path()
@@ -37,6 +37,15 @@ class LinalgTestCase(NumpyTestCase):
b = array([2.+1j, 1.+2j], dtype=cdouble)
self.do(a, b)
+ def check_empty(self):
+ a = atleast_2d(array([], dtype = double))
+ b = atleast_2d(array([], dtype = double))
+ try:
+ self.do(a, b)
+ raise AssertionError("%s should fail with empty matrices", self.__name__[5:])
+ except linalg.LinAlgError, e:
+ pass
+
class test_solve(LinalgTestCase):
def do(self, a, b):
x = linalg.solve(a, b)