diff options
-rw-r--r-- | numpy/linalg/linalg.py | 10 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 5 |
2 files changed, 7 insertions, 8 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 6002c63b9..3b83ac4a6 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1516,8 +1516,8 @@ def matrix_rank(M, tol=None): Parameters ---------- - M : {(M,), (M, N)} array_like - array of <=2 dimensions + M : {(M,), (..., M, N)} array_like + input vector or stack of matrices tol : {None, float}, optional threshold below which SVD values are considered zero. If `tol` is None, and ``S`` is an array with singular values for `M`, and @@ -1584,14 +1584,12 @@ def matrix_rank(M, tol=None): 0 """ M = asarray(M) - if M.ndim > 2: - raise TypeError('array should have 2 or fewer dimensions') if M.ndim < 2: return int(not all(M==0)) S = svd(M, compute_uv=False) if tol is None: - tol = S.max() * max(M.shape) * finfo(S.dtype).eps - return sum(S > tol) + tol = S.max(axis=-1, keepdims=True) * max(M.shape[-2:]) * finfo(S.dtype).eps + return (S > tol).sum(axis=-1) # Generalized inverse diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index fc4f98ed7..9e29b343e 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1261,8 +1261,9 @@ class TestMatrixRank(object): yield assert_equal, matrix_rank(np.zeros((4,))), 0 # accepts array-like yield assert_equal, matrix_rank([1]), 1 - # greater than 2 dimensions raises error - yield assert_raises, TypeError, matrix_rank, np.zeros((2, 2, 2)) + # greater than 2 dimensions treated as stacked matrices + ms = np.array([I, np.eye(4), np.zeros((4,4))]) + yield assert_equal, matrix_rank(ms), np.array([3, 4, 0]) # works on scalar yield assert_equal, matrix_rank(1), 1 |