diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-04-20 20:32:33 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-04-21 11:13:53 -0700 |
commit | 2a43ccbba5cbeec40dee708fc8f56a593d84272b (patch) | |
tree | bcc043ec63d6866be7e24cc5ab99fd69f12d9c11 /numpy/linalg/linalg.py | |
parent | 3a166395409d92211c56cfeff44adc7f485db9ec (diff) | |
download | numpy-2a43ccbba5cbeec40dee708fc8f56a593d84272b.tar.gz |
MAINT: Always spell "get the last two dims" the same way
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 6139e40c4..8ecb90dc9 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -210,7 +210,8 @@ def _assertSquareness(*arrays): def _assertNdSquareness(*arrays): for a in arrays: - if max(a.shape[-2:]) != min(a.shape[-2:]): + m, n = a.shape[-2:] + if m != n: raise LinAlgError('Last 2 dimensions of the array must be square') def _assertFinite(*arrays): @@ -1429,8 +1430,7 @@ def svd(a, full_matrices=True, compute_uv=True): extobj = get_linalg_error_extobj(_raise_linalgerror_svd_nonconvergence) - m = a.shape[-2] - n = a.shape[-1] + m, n = a.shape[-2:] if compute_uv: if full_matrices: if m < n: @@ -1750,7 +1750,8 @@ def pinv(a, rcond=1e-15 ): a, wrap = _makearray(a) rcond = asarray(rcond) if _isEmpty2d(a): - res = empty(a.shape[:-2] + (a.shape[-1], a.shape[-2]), dtype=a.dtype) + m, n = a.shape[-2:] + res = empty(a.shape[:-2] + (n, m), dtype=a.dtype) return wrap(res) a = a.conjugate() u, s, vt = svd(a, full_matrices=False) |