diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2013-07-10 14:19:42 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2013-08-04 19:14:25 +0200 |
commit | fa55f4c463806599bccf145baf22e13ff79f9a68 (patch) | |
tree | a7842184034542b4b9f7c9eb51871f7f41c1d457 /numpy/linalg/linalg.py | |
parent | 90ececac5755b0b54d7c8c2c5d71caaeb5c0b45c (diff) | |
download | numpy-fa55f4c463806599bccf145baf22e13ff79f9a68.tar.gz |
ENH: inv/solve work with empty inner and others empty outer array
This makes the inverse of a 0x0 array simply be 0x0 again. It
also modifies the no-empty array check in favor of a no-empty
*inner* array, since the gufuncs seem to handle the other case
fine.
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 50 |
1 files changed, 34 insertions, 16 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 760ca31b0..4b0d3d86d 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -19,10 +19,11 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv', import warnings from numpy.core import ( - array, asarray, zeros, empty, transpose, intc, single, double, csingle, - cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot, add, - multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size, - finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax + array, asarray, zeros, empty, empty_like, transpose, intc, single, double, + csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot, + add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size, + finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, + broadcast ) from numpy.lib import triu, asfarray from numpy.linalg import lapack_lite, _umath_linalg @@ -215,9 +216,9 @@ def _assertFinite(*arrays): if not (isfinite(a).all()): raise LinAlgError("Array must not contain infs or NaNs") -def _assertNonEmpty(*arrays): +def _assertNoEmpty2d(*arrays): for a in arrays: - if size(a) == 0: + if a.size == 0 and product(a.shape[-2:]) == 0: raise LinAlgError("Arrays cannot be empty") @@ -350,15 +351,28 @@ def solve(a, b): """ a, _ = _makearray(a) - _assertNonEmpty(a) _assertRankAtLeast2(a) _assertNdSquareness(a) b, wrap = _makearray(b) t, result_t = _commonType(a, b) - if len(b.shape) == len(a.shape) - 1: + # We use the b = (..., M,) logic, only if the number of extra dimensions + # match exactly + if b.ndim == a.ndim - 1: + if a.shape[-1] == 0 and b.shape[-1] == 0: + # Legal, but the ufunc cannot handle the 0-sized inner dims + # let the ufunc handle all wrong cases. + a = a.reshape(a.shape[:-1]) + bc = broadcast(a, b) + return wrap(empty(bc.shape, dtype=result_t)) + gufunc = _umath_linalg.solve1 else: + if a.shape[-1] == 0 and b.shape[-2] == 0: + a = a.reshape(a.shape[:-1] + (1,)) + bc = broadcast(a, b) + return wrap(empty(bc.shape, dtype=result_t)) + gufunc = _umath_linalg.solve signature = 'DD->D' if isComplexType(t) else 'dd->d' @@ -492,10 +506,14 @@ def inv(a): """ a, wrap = _makearray(a) - _assertNonEmpty(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) + + if a.shape[-1] == 0: + # The inner array is 0x0, the ufunc cannot handle this case + return wrap(empty_like(a, dtype=result_t)) + signature = 'D->D' if isComplexType(t) else 'd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_singular) ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj) @@ -718,7 +736,7 @@ def qr(a, mode='reduced'): a, wrap = _makearray(a) _assertRank2(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) m, n = a.shape t, result_t = _commonType(a) a = _fastCopyAndTranspose(t, a) @@ -863,7 +881,7 @@ def eigvals(a): """ a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) _assertFinite(a) @@ -940,7 +958,7 @@ def eigvalsh(a, UPLO='L'): gufunc = _umath_linalg.eigvalsh_up a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) @@ -1279,7 +1297,7 @@ def svd(a, full_matrices=1, compute_uv=1): """ a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) t, result_t = _commonType(a) @@ -1556,7 +1574,7 @@ def pinv(a, rcond=1e-15 ): """ a, wrap = _makearray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) a = a.conjugate() u, s, vt = svd(a, 0) m = u.shape[0] @@ -1643,7 +1661,7 @@ def slogdet(a): """ a = asarray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) @@ -1697,7 +1715,7 @@ def det(a): """ a = asarray(a) - _assertNonEmpty(a) + _assertNoEmpty2d(a) _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) |