summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/1.14.0-notes.rst4
-rw-r--r--numpy/linalg/linalg.py75
-rw-r--r--numpy/linalg/tests/test_linalg.py8
3 files changed, 59 insertions, 28 deletions
diff --git a/doc/release/1.14.0-notes.rst b/doc/release/1.14.0-notes.rst
index 0bcb3e4a2..0fa4672b2 100644
--- a/doc/release/1.14.0-notes.rst
+++ b/doc/release/1.14.0-notes.rst
@@ -217,6 +217,10 @@ selected via the ``--fcompiler`` and ``--compiler`` options to
supported; by default a gfortran-compatible static archive
``openblas.a`` is looked for.
+``np.linalg.pinv`` now works on stacked matrices
+------------------------------------------------
+Previously it was limited to a single 2d array.
+
Changes
=======
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index fd4e9baf9..bd89a90a1 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -19,12 +19,13 @@ __all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv',
import warnings
from numpy.core import (
- array, asarray, zeros, empty, empty_like, transpose, intc, single, double,
+ array, asarray, zeros, empty, empty_like, intc, single, double,
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
finfo, errstate, geterrobj, longdouble, moveaxis, amin, amax, product, abs,
- broadcast, atleast_2d, intp, asanyarray, isscalar, object_, ones
- )
+ broadcast, atleast_2d, intp, asanyarray, isscalar, object_, ones, matmul,
+ swapaxes, divide)
+
from numpy.core.multiarray import normalize_axis_index
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
@@ -223,6 +224,22 @@ def _assertNoEmpty2d(*arrays):
if _isEmpty2d(a):
raise LinAlgError("Arrays cannot be empty")
+def transpose(a):
+ """
+ Transpose each matrix in a stack of matrices.
+
+ Unlike np.transpose, this only swaps the last two axes, rather than all of
+ them
+
+ Parameters
+ ----------
+ a : (...,M,N) array_like
+
+ Returns
+ -------
+ aT : (...,N,M) ndarray
+ """
+ return swapaxes(a, -1, -2)
# Linear equations
@@ -1494,6 +1511,9 @@ def matrix_rank(M, tol=None):
Rank of the array is the number of SVD singular values of the array that are
greater than `tol`.
+ .. versionchanged:: 1.14
+ Can now operate on stacks of matrices
+
Parameters
----------
M : {(M,), (..., M, N)} array_like
@@ -1582,26 +1602,29 @@ def pinv(a, rcond=1e-15 ):
singular-value decomposition (SVD) and including all
*large* singular values.
+ .. versionchanged:: 1.14
+ Can now operate on stacks of matrices
+
Parameters
----------
- a : (M, N) array_like
- Matrix to be pseudo-inverted.
- rcond : float
- Cutoff for small singular values.
- Singular values smaller (in modulus) than
- `rcond` * largest_singular_value (again, in modulus)
- are set to zero.
+ a : (..., M, N) array_like
+ Matrix or stack of matrices to be pseudo-inverted.
+ rcond : (...) array_like of float
+ Cutoff for small singular values.
+ Singular values smaller (in modulus) than
+ `rcond` * largest_singular_value (again, in modulus)
+ are set to zero. Broadcasts against the stack of matrices
Returns
-------
- B : (N, M) ndarray
- The pseudo-inverse of `a`. If `a` is a `matrix` instance, then so
- is `B`.
+ B : (..., N, M) ndarray
+ The pseudo-inverse of `a`. If `a` is a `matrix` instance, then so
+ is `B`.
Raises
------
LinAlgError
- If the SVD computation does not converge.
+ If the SVD computation does not converge.
Notes
-----
@@ -1638,20 +1661,20 @@ def pinv(a, rcond=1e-15 ):
"""
a, wrap = _makearray(a)
+ rcond = asarray(rcond)
if _isEmpty2d(a):
res = empty(a.shape[:-2] + (a.shape[-1], a.shape[-2]), dtype=a.dtype)
return wrap(res)
a = a.conjugate()
u, s, vt = svd(a, full_matrices=False)
- m = u.shape[0]
- n = vt.shape[1]
- cutoff = rcond*maximum.reduce(s)
- for i in range(min(n, m)):
- if s[i] > cutoff:
- s[i] = 1./s[i]
- else:
- s[i] = 0.
- res = dot(transpose(vt), multiply(s[:, newaxis], transpose(u)))
+
+ # discard small singular values
+ cutoff = rcond[..., newaxis] * amax(s, axis=-1, keepdims=True)
+ large = s > cutoff
+ s = divide(1, s, where=large, out=s)
+ s[~large] = 0
+
+ res = matmul(transpose(vt), multiply(s[..., newaxis], transpose(u)))
return wrap(res)
# Determinant
@@ -1987,13 +2010,13 @@ def lstsq(a, b, rcond="warn"):
resids = array([sum((ravel(bstar)[n:])**2)],
dtype=result_real_t)
else:
- x = array(transpose(bstar)[:n,:], dtype=result_t, copy=True)
+ x = array(bstar.T[:n,:], dtype=result_t, copy=True)
if results['rank'] == n and m > n:
if isComplexType(t):
- resids = sum(abs(transpose(bstar)[n:,:])**2, axis=0).astype(
+ resids = sum(abs(bstar.T[n:,:])**2, axis=0).astype(
result_real_t, copy=False)
else:
- resids = sum((transpose(bstar)[n:,:])**2, axis=0).astype(
+ resids = sum((bstar.T[n:,:])**2, axis=0).astype(
result_real_t, copy=False)
st = s[:min(n, m)].astype(result_real_t, copy=True)
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index ab81fc485..fa20cc5ea 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -712,12 +712,16 @@ class TestCondInf(object):
assert_almost_equal(linalg.cond(A, inf), 3.)
-class TestPinv(LinalgSquareTestCase, LinalgNonsquareTestCase):
+class TestPinv(LinalgSquareTestCase,
+ LinalgNonsquareTestCase,
+ LinalgGeneralizedSquareTestCase,
+ LinalgGeneralizedNonsquareTestCase):
def do(self, a, b, tags):
a_ginv = linalg.pinv(a)
# `a @ a_ginv == I` does not hold if a is singular
- assert_almost_equal(dot(a, a_ginv).dot(a), a, single_decimal=5, double_decimal=11)
+ dot = dot_generalized
+ assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
assert_(imply(isinstance(a, matrix), isinstance(a_ginv, matrix)))