summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 6139e40c4..8ecb90dc9 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -210,7 +210,8 @@ def _assertSquareness(*arrays):
def _assertNdSquareness(*arrays):
for a in arrays:
- if max(a.shape[-2:]) != min(a.shape[-2:]):
+ m, n = a.shape[-2:]
+ if m != n:
raise LinAlgError('Last 2 dimensions of the array must be square')
def _assertFinite(*arrays):
@@ -1429,8 +1430,7 @@ def svd(a, full_matrices=True, compute_uv=True):
extobj = get_linalg_error_extobj(_raise_linalgerror_svd_nonconvergence)
- m = a.shape[-2]
- n = a.shape[-1]
+ m, n = a.shape[-2:]
if compute_uv:
if full_matrices:
if m < n:
@@ -1750,7 +1750,8 @@ def pinv(a, rcond=1e-15 ):
a, wrap = _makearray(a)
rcond = asarray(rcond)
if _isEmpty2d(a):
- res = empty(a.shape[:-2] + (a.shape[-1], a.shape[-2]), dtype=a.dtype)
+ m, n = a.shape[-2:]
+ res = empty(a.shape[:-2] + (n, m), dtype=a.dtype)
return wrap(res)
a = a.conjugate()
u, s, vt = svd(a, full_matrices=False)