summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorWarren Weckesser <warren.weckesser@gmail.com>2013-06-04 09:21:13 -0400
committerWarren Weckesser <warren.weckesser@gmail.com>2013-06-04 09:21:13 -0400
commita2926149891027fff151d50933405a696384d47c (patch)
tree522aaf1f234eca41a036cfdc09abf6e90504a91c /numpy
parent6eb57a767713a7a2b74e1cd3365c34638ec55eac (diff)
downloadnumpy-a2926149891027fff151d50933405a696384d47c.tar.gz
ENH: linalg: allow the 'axis' argument of linalg.norm to be a 2-tuple, in which case matrix norms of the collection of 2-D matrices are computed.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py101
-rw-r--r--numpy/linalg/tests/test_linalg.py52
2 files changed, 134 insertions, 19 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 2bc56f978..49d645e80 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -22,7 +22,7 @@ from numpy.core import array, asarray, zeros, empty, transpose, \
intc, single, double, csingle, cdouble, inexact, complexfloating, \
newaxis, ravel, all, Inf, dot, add, multiply, sqrt, maximum, \
fastCopyAndTranspose, sum, isfinite, size, finfo, errstate, \
- geterrobj, float128
+ geterrobj, float128, rollaxis, amin, amax
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
from numpy.matrixlib.defmatrix import matrix_power
@@ -1866,6 +1866,44 @@ def lstsq(a, b, rcond=-1):
return wrap(x), wrap(resids), results['rank'], st
+def _multi_svd_norm(x, row_axis, col_axis, op):
+ """Compute the exteme singular values of the 2-D matrices in `x`.
+
+ This is a private utility function used by numpy.linalg.norm().
+
+ Parameters
+ ----------
+ x : ndarray
+ row_axis, col_axis : int
+ The axes of `x` that hold the 2-D matrices.
+ op : callable
+ This should be either numpy.amin or numpy.amax.
+
+ Returns
+ -------
+ result : float or ndarray
+ If `x` is 2-D, the return values is a float.
+ Otherwise, it is an array with ``x.ndim - 2`` dimensions.
+ The return values are either the minimum or maximum of the
+ singular values of the matrices, depending on whether `op`
+ is `numpy.amin` or `numpy.amax`.
+
+ """
+ if row_axis > col_axis:
+ row_axis -= 1
+ y = rollaxis(rollaxis(x, col_axis, x.ndim), row_axis, -1)
+ if x.ndim > 3:
+ z = y.reshape((-1,) + y.shape[-2:])
+ else:
+ z = y
+ if x.ndim == 2:
+ result = op(svd(z, compute_uv=0))
+ else:
+ result = array([op(svd(m, compute_uv=0)) for m in z])
+ result.shape = y.shape[:-2]
+ return result
+
+
def norm(x, ord=None, axis=None):
"""
Matrix or vector norm.
@@ -1881,10 +1919,12 @@ def norm(x, ord=None, axis=None):
ord : {non-zero int, inf, -inf, 'fro'}, optional
Order of the norm (see table under ``Notes``). inf means numpy's
`inf` object.
- axis : int or None, optional
- If `axis` is not None, it specifies the axis of `x` along which to
- compute the vector norms. If `axis` is None, then either a vector
- norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D) is returned.
+ axis : {int, 2-tuple of ints, None}, optional
+ If `axis` is an integer, it specifies the axis of `x` along which to
+ compute the vector norms. If `axis` is a 2-tuple, it specifies the
+ axes that hold 2-D matrices, and the matrix norms of these matrices
+ are computed. If `axis` is None then either a vector norm (when `x`
+ is 1-D) or a matrix norm (when `x` is 2-D) is returned.
Returns
-------
@@ -1972,7 +2012,7 @@ def norm(x, ord=None, axis=None):
>>> LA.norm(a, -3)
nan
- Using the `axis` argument:
+ Using the `axis` argument to compute vector norms:
>>> c = np.array([[ 1, 2, 3],
... [-1, 1, 4]])
@@ -1983,6 +2023,14 @@ def norm(x, ord=None, axis=None):
>>> LA.norm(c, ord=1, axis=1)
array([6, 6])
+ Using the `axis` argument to compute matrix norms:
+
+ >>> m = np.arange(8).reshape(2,2,2)
+ >>> norm(m, axis=(1,2))
+ array([ 3.74165739, 11.22497216])
+ >>> norm(m[0]), norm(m[1])
+ (3.7416573867739413, 11.224972160321824)
+
"""
x = asarray(x)
@@ -1991,8 +2039,14 @@ def norm(x, ord=None, axis=None):
s = (x.conj() * x).real
return sqrt(add.reduce((x.conj() * x).ravel().real))
+ # Normalize the `axis` argument to a tuple.
+ if axis is None:
+ axis = tuple(range(x.ndim))
+ elif not isinstance(axis, tuple):
+ axis = (axis,)
+
nd = x.ndim
- if nd == 1 or axis is not None:
+ if len(axis) == 1:
if ord == Inf:
return abs(x).max(axis=axis)
elif ord == -Inf:
@@ -2018,21 +2072,36 @@ def norm(x, ord=None, axis=None):
# because it will downcast to float64.
absx = asfarray(abs(x))
return add.reduce(absx**ord, axis=axis)**(1.0/ord)
- elif nd == 2:
+ elif len(axis) == 2:
+ row_axis, col_axis = axis
+ if not (-x.ndim <= row_axis < x.ndim and
+ -x.ndim <= col_axis < x.ndim):
+ raise ValueError('Invalid axis %r for an array with shape %r' %
+ (axis, x.shape))
+ if row_axis % x.ndim == col_axis % x.ndim:
+ raise ValueError('Duplicate axes given.')
if ord == 2:
- return svd(x, compute_uv=0).max()
+ return _multi_svd_norm(x, row_axis, col_axis, amax)
elif ord == -2:
- return svd(x, compute_uv=0).min()
+ return _multi_svd_norm(x, row_axis, col_axis, amin)
elif ord == 1:
- return abs(x).sum(axis=0).max()
+ if col_axis > row_axis:
+ col_axis -= 1
+ return add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
elif ord == Inf:
- return abs(x).sum(axis=1).max()
+ if row_axis > col_axis:
+ row_axis -= 1
+ return add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
elif ord == -1:
- return abs(x).sum(axis=0).min()
+ if col_axis > row_axis:
+ col_axis -= 1
+ return add.reduce(abs(x), axis=row_axis).min(axis=col_axis)
elif ord == -Inf:
- return abs(x).sum(axis=1).min()
- elif ord in ['fro','f']:
- return sqrt(add.reduce((x.conj() * x).real.ravel()))
+ if row_axis > col_axis:
+ row_axis -= 1
+ return add.reduce(abs(x), axis=col_axis).min(axis=row_axis)
+ elif ord in [None, 'fro', 'f']:
+ return sqrt(add.reduce((x.conj() * x).real, axis=axis))
else:
raise ValueError("Invalid norm order for matrices.")
else:
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index fdb243271..7ee7b8317 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -563,6 +563,7 @@ class _TestNorm(TestCase):
self.assertRaises(ValueError, norm, A, 0)
def test_axis(self):
+ # Vector norms.
# Compare the use of `axis` with computing the norm of each row
# or column separately.
A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
@@ -572,10 +573,55 @@ class _TestNorm(TestCase):
expected1 = [norm(A[k,:], ord=order) for k in range(A.shape[0])]
assert_almost_equal(norm(A, ord=order, axis=1), expected1)
- # Check bad case. Using `axis` implies vector norms are being
- # computed, so also using `ord='fro'` raises a ValueError
- # (just like `norm([1,2,3], ord='fro')` does).
+ # Matrix norms.
+ B = np.arange(1, 25, dtype=self.dt).reshape(2,3,4)
+
+ for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']:
+ assert_almost_equal(norm(A, ord=order), norm(A, ord=order, axis=(0,1)))
+
+ n = norm(B, ord=order, axis=(1,2))
+ expected = [norm(B[k], ord=order) for k in range(B.shape[0])]
+ print("shape is %r, axis=(1,2)" % (B.shape,))
+ assert_almost_equal(n, expected)
+
+ n = norm(B, ord=order, axis=(2,1))
+ expected = [norm(B[k].T, ord=order) for k in range(B.shape[0])]
+ print("shape is %r, axis=(2,1)" % (B.shape,))
+ assert_almost_equal(n, expected)
+
+ n = norm(B, ord=order, axis=(0,2))
+ expected = [norm(B[:,k,:], ord=order) for k in range(B.shape[1])]
+ print("shape is %r, axis=(0,2)" % (B.shape,))
+ assert_almost_equal(n, expected)
+
+ n = norm(B, ord=order, axis=(0,1))
+ expected = [norm(B[:,:,k], ord=order) for k in range(B.shape[2])]
+ print("shape is %r, axis=(0,1)" % (B.shape,))
+ assert_almost_equal(n, expected)
+
+ def test_bad_args(self):
+ # Check that bad arguments raise the appropriate exceptions.
+
+ A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
+ B = np.arange(1, 25, dtype=self.dt).reshape(2,3,4)
+
+ # Using `axis=<integer>` or passing in a 1-D array implies vector
+ # norms are being computed, so also using `ord='fro'` raises a
+ # ValueError.
self.assertRaises(ValueError, norm, A, 'fro', 0)
+ self.assertRaises(ValueError, norm, [3, 4], 'fro', None)
+
+ # Similarly, norm should raise an exception when ord is any finite
+ # number other than 1, 2, -1 or -2 when computing matrix norms.
+ for order in [0, 3]:
+ self.assertRaises(ValueError, norm, A, order, None)
+ self.assertRaises(ValueError, norm, A, order, (0,1))
+ self.assertRaises(ValueError, norm, B, order, (1,2))
+
+ # Invalid axis
+ self.assertRaises(ValueError, norm, B, None, 3)
+ self.assertRaises(ValueError, norm, B, None, (2,3))
+ self.assertRaises(ValueError, norm, B, None, (0,1,2))
class TestNormDouble(_TestNorm):