diff options
author | Warren Weckesser <warren.weckesser@gmail.com> | 2013-06-04 09:21:13 -0400 |
---|---|---|
committer | Warren Weckesser <warren.weckesser@gmail.com> | 2013-06-04 09:21:13 -0400 |
commit | a2926149891027fff151d50933405a696384d47c (patch) | |
tree | 522aaf1f234eca41a036cfdc09abf6e90504a91c /numpy | |
parent | 6eb57a767713a7a2b74e1cd3365c34638ec55eac (diff) | |
download | numpy-a2926149891027fff151d50933405a696384d47c.tar.gz |
ENH: linalg: allow the 'axis' argument of linalg.norm to be a 2-tuple, in which case matrix norms of the collection of 2-D matrices are computed.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 101 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 52 |
2 files changed, 134 insertions, 19 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 2bc56f978..49d645e80 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -22,7 +22,7 @@ from numpy.core import array, asarray, zeros, empty, transpose, \ intc, single, double, csingle, cdouble, inexact, complexfloating, \ newaxis, ravel, all, Inf, dot, add, multiply, sqrt, maximum, \ fastCopyAndTranspose, sum, isfinite, size, finfo, errstate, \ - geterrobj, float128 + geterrobj, float128, rollaxis, amin, amax from numpy.lib import triu, asfarray from numpy.linalg import lapack_lite, _umath_linalg from numpy.matrixlib.defmatrix import matrix_power @@ -1866,6 +1866,44 @@ def lstsq(a, b, rcond=-1): return wrap(x), wrap(resids), results['rank'], st +def _multi_svd_norm(x, row_axis, col_axis, op): + """Compute the exteme singular values of the 2-D matrices in `x`. + + This is a private utility function used by numpy.linalg.norm(). + + Parameters + ---------- + x : ndarray + row_axis, col_axis : int + The axes of `x` that hold the 2-D matrices. + op : callable + This should be either numpy.amin or numpy.amax. + + Returns + ------- + result : float or ndarray + If `x` is 2-D, the return values is a float. + Otherwise, it is an array with ``x.ndim - 2`` dimensions. + The return values are either the minimum or maximum of the + singular values of the matrices, depending on whether `op` + is `numpy.amin` or `numpy.amax`. + + """ + if row_axis > col_axis: + row_axis -= 1 + y = rollaxis(rollaxis(x, col_axis, x.ndim), row_axis, -1) + if x.ndim > 3: + z = y.reshape((-1,) + y.shape[-2:]) + else: + z = y + if x.ndim == 2: + result = op(svd(z, compute_uv=0)) + else: + result = array([op(svd(m, compute_uv=0)) for m in z]) + result.shape = y.shape[:-2] + return result + + def norm(x, ord=None, axis=None): """ Matrix or vector norm. @@ -1881,10 +1919,12 @@ def norm(x, ord=None, axis=None): ord : {non-zero int, inf, -inf, 'fro'}, optional Order of the norm (see table under ``Notes``). inf means numpy's `inf` object. - axis : int or None, optional - If `axis` is not None, it specifies the axis of `x` along which to - compute the vector norms. If `axis` is None, then either a vector - norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D) is returned. + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `x` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `x` + is 1-D) or a matrix norm (when `x` is 2-D) is returned. Returns ------- @@ -1972,7 +2012,7 @@ def norm(x, ord=None, axis=None): >>> LA.norm(a, -3) nan - Using the `axis` argument: + Using the `axis` argument to compute vector norms: >>> c = np.array([[ 1, 2, 3], ... [-1, 1, 4]]) @@ -1983,6 +2023,14 @@ def norm(x, ord=None, axis=None): >>> LA.norm(c, ord=1, axis=1) array([6, 6]) + Using the `axis` argument to compute matrix norms: + + >>> m = np.arange(8).reshape(2,2,2) + >>> norm(m, axis=(1,2)) + array([ 3.74165739, 11.22497216]) + >>> norm(m[0]), norm(m[1]) + (3.7416573867739413, 11.224972160321824) + """ x = asarray(x) @@ -1991,8 +2039,14 @@ def norm(x, ord=None, axis=None): s = (x.conj() * x).real return sqrt(add.reduce((x.conj() * x).ravel().real)) + # Normalize the `axis` argument to a tuple. + if axis is None: + axis = tuple(range(x.ndim)) + elif not isinstance(axis, tuple): + axis = (axis,) + nd = x.ndim - if nd == 1 or axis is not None: + if len(axis) == 1: if ord == Inf: return abs(x).max(axis=axis) elif ord == -Inf: @@ -2018,21 +2072,36 @@ def norm(x, ord=None, axis=None): # because it will downcast to float64. absx = asfarray(abs(x)) return add.reduce(absx**ord, axis=axis)**(1.0/ord) - elif nd == 2: + elif len(axis) == 2: + row_axis, col_axis = axis + if not (-x.ndim <= row_axis < x.ndim and + -x.ndim <= col_axis < x.ndim): + raise ValueError('Invalid axis %r for an array with shape %r' % + (axis, x.shape)) + if row_axis % x.ndim == col_axis % x.ndim: + raise ValueError('Duplicate axes given.') if ord == 2: - return svd(x, compute_uv=0).max() + return _multi_svd_norm(x, row_axis, col_axis, amax) elif ord == -2: - return svd(x, compute_uv=0).min() + return _multi_svd_norm(x, row_axis, col_axis, amin) elif ord == 1: - return abs(x).sum(axis=0).max() + if col_axis > row_axis: + col_axis -= 1 + return add.reduce(abs(x), axis=row_axis).max(axis=col_axis) elif ord == Inf: - return abs(x).sum(axis=1).max() + if row_axis > col_axis: + row_axis -= 1 + return add.reduce(abs(x), axis=col_axis).max(axis=row_axis) elif ord == -1: - return abs(x).sum(axis=0).min() + if col_axis > row_axis: + col_axis -= 1 + return add.reduce(abs(x), axis=row_axis).min(axis=col_axis) elif ord == -Inf: - return abs(x).sum(axis=1).min() - elif ord in ['fro','f']: - return sqrt(add.reduce((x.conj() * x).real.ravel())) + if row_axis > col_axis: + row_axis -= 1 + return add.reduce(abs(x), axis=col_axis).min(axis=row_axis) + elif ord in [None, 'fro', 'f']: + return sqrt(add.reduce((x.conj() * x).real, axis=axis)) else: raise ValueError("Invalid norm order for matrices.") else: diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index fdb243271..7ee7b8317 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -563,6 +563,7 @@ class _TestNorm(TestCase): self.assertRaises(ValueError, norm, A, 0) def test_axis(self): + # Vector norms. # Compare the use of `axis` with computing the norm of each row # or column separately. A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt) @@ -572,10 +573,55 @@ class _TestNorm(TestCase): expected1 = [norm(A[k,:], ord=order) for k in range(A.shape[0])] assert_almost_equal(norm(A, ord=order, axis=1), expected1) - # Check bad case. Using `axis` implies vector norms are being - # computed, so also using `ord='fro'` raises a ValueError - # (just like `norm([1,2,3], ord='fro')` does). + # Matrix norms. + B = np.arange(1, 25, dtype=self.dt).reshape(2,3,4) + + for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']: + assert_almost_equal(norm(A, ord=order), norm(A, ord=order, axis=(0,1))) + + n = norm(B, ord=order, axis=(1,2)) + expected = [norm(B[k], ord=order) for k in range(B.shape[0])] + print("shape is %r, axis=(1,2)" % (B.shape,)) + assert_almost_equal(n, expected) + + n = norm(B, ord=order, axis=(2,1)) + expected = [norm(B[k].T, ord=order) for k in range(B.shape[0])] + print("shape is %r, axis=(2,1)" % (B.shape,)) + assert_almost_equal(n, expected) + + n = norm(B, ord=order, axis=(0,2)) + expected = [norm(B[:,k,:], ord=order) for k in range(B.shape[1])] + print("shape is %r, axis=(0,2)" % (B.shape,)) + assert_almost_equal(n, expected) + + n = norm(B, ord=order, axis=(0,1)) + expected = [norm(B[:,:,k], ord=order) for k in range(B.shape[2])] + print("shape is %r, axis=(0,1)" % (B.shape,)) + assert_almost_equal(n, expected) + + def test_bad_args(self): + # Check that bad arguments raise the appropriate exceptions. + + A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt) + B = np.arange(1, 25, dtype=self.dt).reshape(2,3,4) + + # Using `axis=<integer>` or passing in a 1-D array implies vector + # norms are being computed, so also using `ord='fro'` raises a + # ValueError. self.assertRaises(ValueError, norm, A, 'fro', 0) + self.assertRaises(ValueError, norm, [3, 4], 'fro', None) + + # Similarly, norm should raise an exception when ord is any finite + # number other than 1, 2, -1 or -2 when computing matrix norms. + for order in [0, 3]: + self.assertRaises(ValueError, norm, A, order, None) + self.assertRaises(ValueError, norm, A, order, (0,1)) + self.assertRaises(ValueError, norm, B, order, (1,2)) + + # Invalid axis + self.assertRaises(ValueError, norm, B, None, 3) + self.assertRaises(ValueError, norm, B, None, (2,3)) + self.assertRaises(ValueError, norm, B, None, (0,1,2)) class TestNormDouble(_TestNorm): |