diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index d7d67a91f..de25d25e9 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -24,7 +24,7 @@ from numpy.core import ( add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size, finfo, errstate, geterrobj, longdouble, moveaxis, amin, amax, product, abs, broadcast, atleast_2d, intp, asanyarray, object_, ones, matmul, - swapaxes, divide, count_nonzero + swapaxes, divide, count_nonzero, ndarray, isnan ) from numpy.core.multiarray import normalize_axis_index from numpy.lib import triu, asfarray @@ -1538,11 +1538,40 @@ def cond(x, p=None): """ x = asarray(x) # in case we have a matrix - if p is None: + if p is None or p == 2 or p == -2: s = svd(x, compute_uv=False) - return s[..., 0]/s[..., -1] + with errstate(all='ignore'): + if p == -2: + r = s[..., -1] / s[..., 0] + else: + r = s[..., 0] / s[..., -1] else: - return norm(x, p, axis=(-2, -1)) * norm(inv(x), p, axis=(-2, -1)) + # Call inv(x) ignoring errors. The result array will + # contain nans in the entries where inversion failed. + _assertRankAtLeast2(x) + _assertNdSquareness(x) + t, result_t = _commonType(x) + signature = 'D->D' if isComplexType(t) else 'd->d' + with errstate(all='ignore'): + invx = _umath_linalg.inv(x, signature=signature) + r = norm(x, p, axis=(-2, -1)) * norm(invx, p, axis=(-2, -1)) + r = r.astype(result_t, copy=False) + + # Convert nans to infs unless the original array had nan entries + r = asarray(r) + nan_mask = isnan(r) + if nan_mask.any(): + nan_mask &= ~isnan(x).any(axis=(-2, -1)) + if r.ndim > 0: + r[nan_mask] = Inf + elif nan_mask: + r[()] = Inf + + # Convention is to return scalars instead of 0d arrays + if r.ndim == 0: + r = r[()] + + return r def matrix_rank(M, tol=None, hermitian=False): |