diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2017-03-26 14:19:18 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2017-04-29 18:37:45 +0200 |
commit | 668fb78ed7715c1ca513713cc08b7d6d68e8ddf3 (patch) | |
tree | c2b6285a88adfe8209827125b6674a278c9d1984 /numpy | |
parent | bfb41d6a93b92b6a0aa92075d6163690bb6ed71c (diff) | |
download | numpy-668fb78ed7715c1ca513713cc08b7d6d68e8ddf3.tar.gz |
MAINT: Remove python side empty array handling from linalg
The necessary fixup on the C-side of linalg has been done already
(i.e. the gufuncs correctly work for these empty arrays).
This also enables cholesky decomposition and fixes a small bug in pinv
handling.
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 37 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 141 |
2 files changed, 140 insertions, 38 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 8776d3c16..31147b9cc 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -366,21 +366,8 @@ def solve(a, b): # We use the b = (..., M,) logic, only if the number of extra dimensions # match exactly if b.ndim == a.ndim - 1: - if a.shape[-1] == 0 and b.shape[-1] == 0: - # Legal, but the ufunc cannot handle the 0-sized inner dims - # let the ufunc handle all wrong cases. - a = a.reshape(a.shape[:-1]) - bc = broadcast(a, b) - return wrap(empty(bc.shape, dtype=result_t)) - gufunc = _umath_linalg.solve1 else: - if b.size == 0: - if (a.shape[-1] == 0 and b.shape[-2] == 0) or b.shape[-1] == 0: - a = a[:,:1].reshape(a.shape[:-1] + (1,)) - bc = broadcast(a, b) - return wrap(empty(bc.shape, dtype=result_t)) - gufunc = _umath_linalg.solve signature = 'DD->D' if isComplexType(t) else 'dd->d' @@ -521,10 +508,6 @@ def inv(a): _assertNdSquareness(a) t, result_t = _commonType(a) - if a.shape[-1] == 0: - # The inner array is 0x0, the ufunc cannot handle this case - return wrap(empty_like(a, dtype=result_t)) - signature = 'D->D' if isComplexType(t) else 'd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_singular) ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj) @@ -905,8 +888,6 @@ def eigvals(a): _assertNdSquareness(a) _assertFinite(a) t, result_t = _commonType(a) - if _isEmpty2d(a): - return empty(a.shape[-1:], dtype=result_t) extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) @@ -1009,8 +990,6 @@ def eigvalsh(a, UPLO='L'): _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) - if _isEmpty2d(a): - return empty(a.shape[-1:], dtype=result_t) signature = 'D->d' if isComplexType(t) else 'd->d' w = gufunc(a, signature=signature, extobj=extobj) return w.astype(_realType(result_t), copy=False) @@ -1148,10 +1127,6 @@ def eig(a): _assertNdSquareness(a) _assertFinite(a) t, result_t = _commonType(a) - if _isEmpty2d(a): - w = empty(a.shape[-1:], dtype=result_t) - vt = empty(a.shape, dtype=result_t) - return w, wrap(vt) extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) @@ -1289,10 +1264,6 @@ def eigh(a, UPLO='L'): _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) - if _isEmpty2d(a): - w = empty(a.shape[-1:], dtype=result_t) - vt = empty(a.shape, dtype=result_t) - return w, wrap(vt) extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) @@ -1766,11 +1737,6 @@ def slogdet(a): _assertNdSquareness(a) t, result_t = _commonType(a) real_t = _realType(result_t) - if _isEmpty2d(a): - # determinant of empty matrix is 1 - sign = ones(a.shape[:-2], dtype=result_t) - logdet = zeros(a.shape[:-2], dtype=real_t) - return sign, logdet signature = 'D->Dd' if isComplexType(t) else 'd->dd' sign, logdet = _umath_linalg.slogdet(a, signature=signature) if isscalar(sign): @@ -1834,9 +1800,6 @@ def det(a): _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) - # 0x0 matrices have determinant 1 - if _isEmpty2d(a): - return ones(a.shape[:-2], dtype=result_t) signature = 'D->D' if isComplexType(t) else 'd->d' r = _umath_linalg.det(a, signature=signature) if isscalar(r): diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index baa195241..c612eb6bb 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -542,12 +542,13 @@ class TestInv(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): res = linalg.inv(a) assert_(res.dtype.type is np.float64) assert_equal(a.shape, res.shape) - assert_(isinstance(a, ArraySubclass)) + assert_(isinstance(res, ArraySubclass)) a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) res = linalg.inv(a) assert_(res.dtype.type is np.complex64) assert_equal(a.shape, res.shape) + assert_(isinstance(res, ArraySubclass)) class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -566,6 +567,24 @@ class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res = linalg.eigvals(a) + assert_(res.dtype.type is np.float64) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res = linalg.eigvals(a) + assert_(res.dtype.type is np.complex64) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -591,6 +610,28 @@ class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res, res_v = linalg.eig(a) + assert_(res_v.dtype.type is np.float64) + assert_(res.dtype.type is np.float64) + assert_equal(a.shape, res_v.shape) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res, res_v = linalg.eig(a) + assert_(res_v.dtype.type is np.complex64) + assert_(res.dtype.type is np.complex64) + assert_equal(a.shape, res_v.shape) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -619,6 +660,16 @@ class TestSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + # These raise errors currently + # (which does not mean that it may not make sense) + a = np.zeros((0, 0), dtype=np.complex64) + assert_raises(linalg.LinAlgError, linalg.svd, a) + a = np.zeros((0, 1), dtype=np.complex64) + assert_raises(linalg.LinAlgError, linalg.svd, a) + a = np.zeros((1, 0), dtype=np.complex64) + assert_raises(linalg.LinAlgError, linalg.svd, a) + class TestCondSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): @@ -712,6 +763,25 @@ class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): for dtype in [single, double, csingle, cdouble]: yield check, dtype + def test_0_size(self): + a = np.zeros((0, 0), dtype=np.complex64) + res = linalg.det(a) + assert_equal(res, 1.) + assert_(res.dtype.type is np.complex64) + res = linalg.slogdet(a) + assert_equal(res, (1, 0)) + assert_(res[0].dtype.type is np.complex64) + assert_(res[1].dtype.type is np.float32) + + a = np.zeros((0, 0), dtype=np.float64) + res = linalg.det(a) + assert_equal(res, 1.) + assert_(res.dtype.type is np.float64) + res = linalg.slogdet(a) + assert_equal(res, (1, 0)) + assert_(res[0].dtype.type is np.float64) + assert_(res[1].dtype.type is np.float64) + class TestLstsq(LinalgSquareTestCase, LinalgNonsquareTestCase): @@ -857,6 +927,24 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase): w = np.linalg.eigvalsh(Kup, UPLO='u') assert_allclose(w, tgt, rtol=rtol) + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res = linalg.eigvalsh(a) + assert_(res.dtype.type is np.float64) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res = linalg.eigvalsh(a) + assert_(res.dtype.type is np.float32) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(res, np.ndarray)) + class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): @@ -916,6 +1004,28 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase): w, v = np.linalg.eigh(Kup, UPLO='u') assert_allclose(w, tgt, rtol=rtol) + def test_0_size(self): + # Check that all kinds of 0-sized arrays work + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res, res_v = linalg.eigh(a) + assert_(res_v.dtype.type is np.float64) + assert_(res.dtype.type is np.float64) + assert_equal(a.shape, res_v.shape) + assert_equal((0, 1), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + + a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass) + res, res_v = linalg.eigh(a) + assert_(res_v.dtype.type is np.complex64) + assert_(res.dtype.type is np.float32) + assert_equal(a.shape, res_v.shape) + assert_equal((0,), res.shape) + # This is just for documentation, it might make sense to change: + assert_(isinstance(a, np.ndarray)) + class _TestNorm(object): @@ -1350,6 +1460,35 @@ class TestQR(object): self.check_qr(m2.T) self.check_qr(matrix(m1)) + def test_0_size(self): + # There may be good ways to do (some of this) reasonably: + a = np.zeros((0, 0)) + assert_raises(linalg.LinAlgError, linalg.qr, a) + a = np.zeros((0, 1)) + assert_raises(linalg.LinAlgError, linalg.qr, a) + a = np.zeros((1, 0)) + assert_raises(linalg.LinAlgError, linalg.qr, a) + + +class TestCholesky(object): + # TODO: are there no other tests for cholesky? + + def test_0_size(self): + class ArraySubclass(np.ndarray): + pass + a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass) + res = linalg.cholesky(a) + assert_equal(a.shape, res.shape) + assert_(res.dtype.type is np.float64) + # for documentation purpose: + assert_(isinstance(res, np.ndarray)) + + a = np.zeros((1, 0, 0), dtype=np.complex64).view(ArraySubclass) + res = linalg.cholesky(a) + assert_equal(a.shape, res.shape) + assert_(res.dtype.type is np.complex64) + assert_(isinstance(res, np.ndarray)) + def test_byteorder_check(): # Byte order check should pass for native order |