summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py50
1 files changed, 34 insertions, 16 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 760ca31b0..4b0d3d86d 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -19,10 +19,11 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv',
import warnings
from numpy.core import (
- array, asarray, zeros, empty, transpose, intc, single, double, csingle,
- cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot, add,
- multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
- finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax
+ array, asarray, zeros, empty, empty_like, transpose, intc, single, double,
+ csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
+ add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
+ finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product,
+ broadcast
)
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
@@ -215,9 +216,9 @@ def _assertFinite(*arrays):
if not (isfinite(a).all()):
raise LinAlgError("Array must not contain infs or NaNs")
-def _assertNonEmpty(*arrays):
+def _assertNoEmpty2d(*arrays):
for a in arrays:
- if size(a) == 0:
+ if a.size == 0 and product(a.shape[-2:]) == 0:
raise LinAlgError("Arrays cannot be empty")
@@ -350,15 +351,28 @@ def solve(a, b):
"""
a, _ = _makearray(a)
- _assertNonEmpty(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
b, wrap = _makearray(b)
t, result_t = _commonType(a, b)
- if len(b.shape) == len(a.shape) - 1:
+ # We use the b = (..., M,) logic, only if the number of extra dimensions
+ # match exactly
+ if b.ndim == a.ndim - 1:
+ if a.shape[-1] == 0 and b.shape[-1] == 0:
+ # Legal, but the ufunc cannot handle the 0-sized inner dims
+ # let the ufunc handle all wrong cases.
+ a = a.reshape(a.shape[:-1])
+ bc = broadcast(a, b)
+ return wrap(empty(bc.shape, dtype=result_t))
+
gufunc = _umath_linalg.solve1
else:
+ if a.shape[-1] == 0 and b.shape[-2] == 0:
+ a = a.reshape(a.shape[:-1] + (1,))
+ bc = broadcast(a, b)
+ return wrap(empty(bc.shape, dtype=result_t))
+
gufunc = _umath_linalg.solve
signature = 'DD->D' if isComplexType(t) else 'dd->d'
@@ -492,10 +506,14 @@ def inv(a):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
+
+ if a.shape[-1] == 0:
+ # The inner array is 0x0, the ufunc cannot handle this case
+ return wrap(empty_like(a, dtype=result_t))
+
signature = 'D->D' if isComplexType(t) else 'd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
@@ -718,7 +736,7 @@ def qr(a, mode='reduced'):
a, wrap = _makearray(a)
_assertRank2(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
m, n = a.shape
t, result_t = _commonType(a)
a = _fastCopyAndTranspose(t, a)
@@ -863,7 +881,7 @@ def eigvals(a):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
_assertFinite(a)
@@ -940,7 +958,7 @@ def eigvalsh(a, UPLO='L'):
gufunc = _umath_linalg.eigvalsh_up
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
@@ -1279,7 +1297,7 @@ def svd(a, full_matrices=1, compute_uv=1):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
t, result_t = _commonType(a)
@@ -1556,7 +1574,7 @@ def pinv(a, rcond=1e-15 ):
"""
a, wrap = _makearray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
a = a.conjugate()
u, s, vt = svd(a, 0)
m = u.shape[0]
@@ -1643,7 +1661,7 @@ def slogdet(a):
"""
a = asarray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
@@ -1697,7 +1715,7 @@ def det(a):
"""
a = asarray(a)
- _assertNonEmpty(a)
+ _assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)