summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2017-03-26 14:19:18 +0200
committerSebastian Berg <sebastian@sipsolutions.net>2017-04-29 18:37:45 +0200
commit668fb78ed7715c1ca513713cc08b7d6d68e8ddf3 (patch)
treec2b6285a88adfe8209827125b6674a278c9d1984 /numpy
parentbfb41d6a93b92b6a0aa92075d6163690bb6ed71c (diff)
downloadnumpy-668fb78ed7715c1ca513713cc08b7d6d68e8ddf3.tar.gz
MAINT: Remove python side empty array handling from linalg
The necessary fixup on the C-side of linalg has been done already (i.e. the gufuncs correctly work for these empty arrays). This also enables cholesky decomposition and fixes a small bug in pinv handling. Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py37
-rw-r--r--numpy/linalg/tests/test_linalg.py141
2 files changed, 140 insertions, 38 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 8776d3c16..31147b9cc 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -366,21 +366,8 @@ def solve(a, b):
# We use the b = (..., M,) logic, only if the number of extra dimensions
# match exactly
if b.ndim == a.ndim - 1:
- if a.shape[-1] == 0 and b.shape[-1] == 0:
- # Legal, but the ufunc cannot handle the 0-sized inner dims
- # let the ufunc handle all wrong cases.
- a = a.reshape(a.shape[:-1])
- bc = broadcast(a, b)
- return wrap(empty(bc.shape, dtype=result_t))
-
gufunc = _umath_linalg.solve1
else:
- if b.size == 0:
- if (a.shape[-1] == 0 and b.shape[-2] == 0) or b.shape[-1] == 0:
- a = a[:,:1].reshape(a.shape[:-1] + (1,))
- bc = broadcast(a, b)
- return wrap(empty(bc.shape, dtype=result_t))
-
gufunc = _umath_linalg.solve
signature = 'DD->D' if isComplexType(t) else 'dd->d'
@@ -521,10 +508,6 @@ def inv(a):
_assertNdSquareness(a)
t, result_t = _commonType(a)
- if a.shape[-1] == 0:
- # The inner array is 0x0, the ufunc cannot handle this case
- return wrap(empty_like(a, dtype=result_t))
-
signature = 'D->D' if isComplexType(t) else 'd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
@@ -905,8 +888,6 @@ def eigvals(a):
_assertNdSquareness(a)
_assertFinite(a)
t, result_t = _commonType(a)
- if _isEmpty2d(a):
- return empty(a.shape[-1:], dtype=result_t)
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1009,8 +990,6 @@ def eigvalsh(a, UPLO='L'):
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
- if _isEmpty2d(a):
- return empty(a.shape[-1:], dtype=result_t)
signature = 'D->d' if isComplexType(t) else 'd->d'
w = gufunc(a, signature=signature, extobj=extobj)
return w.astype(_realType(result_t), copy=False)
@@ -1148,10 +1127,6 @@ def eig(a):
_assertNdSquareness(a)
_assertFinite(a)
t, result_t = _commonType(a)
- if _isEmpty2d(a):
- w = empty(a.shape[-1:], dtype=result_t)
- vt = empty(a.shape, dtype=result_t)
- return w, wrap(vt)
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1289,10 +1264,6 @@ def eigh(a, UPLO='L'):
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
- if _isEmpty2d(a):
- w = empty(a.shape[-1:], dtype=result_t)
- vt = empty(a.shape, dtype=result_t)
- return w, wrap(vt)
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
@@ -1766,11 +1737,6 @@ def slogdet(a):
_assertNdSquareness(a)
t, result_t = _commonType(a)
real_t = _realType(result_t)
- if _isEmpty2d(a):
- # determinant of empty matrix is 1
- sign = ones(a.shape[:-2], dtype=result_t)
- logdet = zeros(a.shape[:-2], dtype=real_t)
- return sign, logdet
signature = 'D->Dd' if isComplexType(t) else 'd->dd'
sign, logdet = _umath_linalg.slogdet(a, signature=signature)
if isscalar(sign):
@@ -1834,9 +1800,6 @@ def det(a):
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
- # 0x0 matrices have determinant 1
- if _isEmpty2d(a):
- return ones(a.shape[:-2], dtype=result_t)
signature = 'D->D' if isComplexType(t) else 'd->d'
r = _umath_linalg.det(a, signature=signature)
if isscalar(r):
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index baa195241..c612eb6bb 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -542,12 +542,13 @@ class TestInv(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
res = linalg.inv(a)
assert_(res.dtype.type is np.float64)
assert_equal(a.shape, res.shape)
- assert_(isinstance(a, ArraySubclass))
+ assert_(isinstance(res, ArraySubclass))
a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
res = linalg.inv(a)
assert_(res.dtype.type is np.complex64)
assert_equal(a.shape, res.shape)
+ assert_(isinstance(res, ArraySubclass))
class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -566,6 +567,24 @@ class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.eigvals(a)
+ assert_(res.dtype.type is np.float64)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.eigvals(a)
+ assert_(res.dtype.type is np.complex64)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -591,6 +610,28 @@ class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res, res_v = linalg.eig(a)
+ assert_(res_v.dtype.type is np.float64)
+ assert_(res.dtype.type is np.float64)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res, res_v = linalg.eig(a)
+ assert_(res_v.dtype.type is np.complex64)
+ assert_(res.dtype.type is np.complex64)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -619,6 +660,16 @@ class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ # These raise errors currently
+ # (which does not mean that it may not make sense)
+ a = np.zeros((0, 0), dtype=np.complex64)
+ assert_raises(linalg.LinAlgError, linalg.svd, a)
+ a = np.zeros((0, 1), dtype=np.complex64)
+ assert_raises(linalg.LinAlgError, linalg.svd, a)
+ a = np.zeros((1, 0), dtype=np.complex64)
+ assert_raises(linalg.LinAlgError, linalg.svd, a)
+
class TestCondSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@@ -712,6 +763,25 @@ class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
for dtype in [single, double, csingle, cdouble]:
yield check, dtype
+ def test_0_size(self):
+ a = np.zeros((0, 0), dtype=np.complex64)
+ res = linalg.det(a)
+ assert_equal(res, 1.)
+ assert_(res.dtype.type is np.complex64)
+ res = linalg.slogdet(a)
+ assert_equal(res, (1, 0))
+ assert_(res[0].dtype.type is np.complex64)
+ assert_(res[1].dtype.type is np.float32)
+
+ a = np.zeros((0, 0), dtype=np.float64)
+ res = linalg.det(a)
+ assert_equal(res, 1.)
+ assert_(res.dtype.type is np.float64)
+ res = linalg.slogdet(a)
+ assert_equal(res, (1, 0))
+ assert_(res[0].dtype.type is np.float64)
+ assert_(res[1].dtype.type is np.float64)
+
class TestLstsq(LinalgSquareTestCase, LinalgNonsquareTestCase):
@@ -857,6 +927,24 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase):
w = np.linalg.eigvalsh(Kup, UPLO='u')
assert_allclose(w, tgt, rtol=rtol)
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.eigvalsh(a)
+ assert_(res.dtype.type is np.float64)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.eigvalsh(a)
+ assert_(res.dtype.type is np.float32)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(res, np.ndarray))
+
class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
@@ -916,6 +1004,28 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):
w, v = np.linalg.eigh(Kup, UPLO='u')
assert_allclose(w, tgt, rtol=rtol)
+ def test_0_size(self):
+ # Check that all kinds of 0-sized arrays work
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res, res_v = linalg.eigh(a)
+ assert_(res_v.dtype.type is np.float64)
+ assert_(res.dtype.type is np.float64)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0, 1), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
+ res, res_v = linalg.eigh(a)
+ assert_(res_v.dtype.type is np.complex64)
+ assert_(res.dtype.type is np.float32)
+ assert_equal(a.shape, res_v.shape)
+ assert_equal((0,), res.shape)
+ # This is just for documentation, it might make sense to change:
+ assert_(isinstance(a, np.ndarray))
+
class _TestNorm(object):
@@ -1350,6 +1460,35 @@ class TestQR(object):
self.check_qr(m2.T)
self.check_qr(matrix(m1))
+ def test_0_size(self):
+ # There may be good ways to do (some of this) reasonably:
+ a = np.zeros((0, 0))
+ assert_raises(linalg.LinAlgError, linalg.qr, a)
+ a = np.zeros((0, 1))
+ assert_raises(linalg.LinAlgError, linalg.qr, a)
+ a = np.zeros((1, 0))
+ assert_raises(linalg.LinAlgError, linalg.qr, a)
+
+
+class TestCholesky(object):
+ # TODO: are there no other tests for cholesky?
+
+ def test_0_size(self):
+ class ArraySubclass(np.ndarray):
+ pass
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
+ res = linalg.cholesky(a)
+ assert_equal(a.shape, res.shape)
+ assert_(res.dtype.type is np.float64)
+ # for documentation purpose:
+ assert_(isinstance(res, np.ndarray))
+
+ a = np.zeros((1, 0, 0), dtype=np.complex64).view(ArraySubclass)
+ res = linalg.cholesky(a)
+ assert_equal(a.shape, res.shape)
+ assert_(res.dtype.type is np.complex64)
+ assert_(isinstance(res, np.ndarray))
+
def test_byteorder_check():
# Byte order check should pass for native order