summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
authornjsmith <njs@pobox.com>2013-06-05 08:36:40 -0700
committernjsmith <njs@pobox.com>2013-06-05 08:36:40 -0700
commitd307ff623dc58cd7afd02e99805e409edd5ba179 (patch)
tree4e0fd365dd610bb9ea479b8ec9ad376695b2fabe /numpy/linalg/linalg.py
parent7e92df5c483c4a466315be1df7ee077c956a32b0 (diff)
parent1d9607ac4a1fc7270ee0686b7e3f20b2ce709bf2 (diff)
downloadnumpy-d307ff623dc58cd7afd02e99805e409edd5ba179.tar.gz
Merge pull request #3387 from WarrenWeckesser/norm-axis
ENH: linalg: Add an `axis` argument to linalg.norm
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py143
1 files changed, 116 insertions, 27 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index ae0da3685..f11f905f7 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -20,10 +20,10 @@ import warnings
from numpy.core import array, asarray, zeros, empty, transpose, \
intc, single, double, csingle, cdouble, inexact, complexfloating, \
- newaxis, ravel, all, Inf, dot, add, multiply, identity, sqrt, \
- maximum, flatnonzero, diagonal, arange, fastCopyAndTranspose, sum, \
- isfinite, size, finfo, absolute, log, exp, errstate, geterrobj
-from numpy.lib import triu
+ newaxis, ravel, all, Inf, dot, add, multiply, sqrt, maximum, \
+ fastCopyAndTranspose, sum, isfinite, size, finfo, errstate, \
+ geterrobj, float128, rollaxis, amin, amax
+from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
from numpy.matrixlib.defmatrix import matrix_power
from numpy.compat import asbytes
@@ -1865,7 +1865,38 @@ def lstsq(a, b, rcond=-1):
st = s[:min(n, m)].copy().astype(result_real_t)
return wrap(x), wrap(resids), results['rank'], st
-def norm(x, ord=None):
+
+def _multi_svd_norm(x, row_axis, col_axis, op):
+ """Compute the exteme singular values of the 2-D matrices in `x`.
+
+ This is a private utility function used by numpy.linalg.norm().
+
+ Parameters
+ ----------
+ x : ndarray
+ row_axis, col_axis : int
+ The axes of `x` that hold the 2-D matrices.
+ op : callable
+ This should be either numpy.amin or numpy.amax.
+
+ Returns
+ -------
+ result : float or ndarray
+ If `x` is 2-D, the return values is a float.
+ Otherwise, it is an array with ``x.ndim - 2`` dimensions.
+ The return values are either the minimum or maximum of the
+ singular values of the matrices, depending on whether `op`
+ is `numpy.amin` or `numpy.amax`.
+
+ """
+ if row_axis > col_axis:
+ row_axis -= 1
+ y = rollaxis(rollaxis(x, col_axis, x.ndim), row_axis, -1)
+ result = op(svd(y, compute_uv=0), axis=-1)
+ return result
+
+
+def norm(x, ord=None, axis=None):
"""
Matrix or vector norm.
@@ -1875,16 +1906,22 @@ def norm(x, ord=None):
Parameters
----------
- x : {(M,), (M, N)} array_like
- Input array.
+ x : array_like
+ Input array. If `axis` is None, `x` must be 1-D or 2-D.
ord : {non-zero int, inf, -inf, 'fro'}, optional
Order of the norm (see table under ``Notes``). inf means numpy's
`inf` object.
+ axis : {int, 2-tuple of ints, None}, optional
+ If `axis` is an integer, it specifies the axis of `x` along which to
+ compute the vector norms. If `axis` is a 2-tuple, it specifies the
+ axes that hold 2-D matrices, and the matrix norms of these matrices
+ are computed. If `axis` is None then either a vector norm (when `x`
+ is 1-D) or a matrix norm (when `x` is 2-D) is returned.
Returns
-------
- n : float
- Norm of the matrix or vector.
+ n : float or ndarray
+ Norm of the matrix or vector(s).
Notes
-----
@@ -1967,44 +2004,96 @@ def norm(x, ord=None):
>>> LA.norm(a, -3)
nan
+ Using the `axis` argument to compute vector norms:
+
+ >>> c = np.array([[ 1, 2, 3],
+ ... [-1, 1, 4]])
+ >>> LA.norm(c, axis=0)
+ array([ 1.41421356, 2.23606798, 5. ])
+ >>> LA.norm(c, axis=1)
+ array([ 3.74165739, 4.24264069])
+ >>> LA.norm(c, ord=1, axis=1)
+ array([6, 6])
+
+ Using the `axis` argument to compute matrix norms:
+
+ >>> m = np.arange(8).reshape(2,2,2)
+ >>> LA.norm(m, axis=(1,2))
+ array([ 3.74165739, 11.22497216])
+ >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
+ (3.7416573867739413, 11.224972160321824)
+
"""
x = asarray(x)
- if ord is None: # check the default case first and handle it immediately
+
+ # Check the default case first and handle it immediately.
+ if ord is None and axis is None:
+ s = (x.conj() * x).real
return sqrt(add.reduce((x.conj() * x).ravel().real))
+ # Normalize the `axis` argument to a tuple.
+ if axis is None:
+ axis = tuple(range(x.ndim))
+ elif not isinstance(axis, tuple):
+ axis = (axis,)
+
nd = x.ndim
- if nd == 1:
+ if len(axis) == 1:
if ord == Inf:
- return abs(x).max()
+ return abs(x).max(axis=axis)
elif ord == -Inf:
- return abs(x).min()
+ return abs(x).min(axis=axis)
elif ord == 0:
- return (x != 0).sum() # Zero norm
+ # Zero norm
+ return (x != 0).sum(axis=axis)
elif ord == 1:
- return abs(x).sum() # special case for speedup
- elif ord == 2:
- return sqrt(((x.conj()*x).real).sum()) # special case for speedup
+ # special case for speedup
+ return add.reduce(abs(x), axis=axis)
+ elif ord is None or ord == 2:
+ # special case for speedup
+ s = (x.conj() * x).real
+ return sqrt(add.reduce(s, axis=axis))
else:
try:
ord + 1
except TypeError:
raise ValueError("Invalid norm order for vectors.")
- return ((abs(x)**ord).sum())**(1.0/ord)
- elif nd == 2:
+ if x.dtype != float128:
+ # Convert to a float type, so integer arrays give
+ # float results. Don't apply asfarray to float128 arrays,
+ # because it will downcast to float64.
+ absx = asfarray(abs(x))
+ return add.reduce(absx**ord, axis=axis)**(1.0/ord)
+ elif len(axis) == 2:
+ row_axis, col_axis = axis
+ if not (-x.ndim <= row_axis < x.ndim and
+ -x.ndim <= col_axis < x.ndim):
+ raise ValueError('Invalid axis %r for an array with shape %r' %
+ (axis, x.shape))
+ if row_axis % x.ndim == col_axis % x.ndim:
+ raise ValueError('Duplicate axes given.')
if ord == 2:
- return svd(x, compute_uv=0).max()
+ return _multi_svd_norm(x, row_axis, col_axis, amax)
elif ord == -2:
- return svd(x, compute_uv=0).min()
+ return _multi_svd_norm(x, row_axis, col_axis, amin)
elif ord == 1:
- return abs(x).sum(axis=0).max()
+ if col_axis > row_axis:
+ col_axis -= 1
+ return add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
elif ord == Inf:
- return abs(x).sum(axis=1).max()
+ if row_axis > col_axis:
+ row_axis -= 1
+ return add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
elif ord == -1:
- return abs(x).sum(axis=0).min()
+ if col_axis > row_axis:
+ col_axis -= 1
+ return add.reduce(abs(x), axis=row_axis).min(axis=col_axis)
elif ord == -Inf:
- return abs(x).sum(axis=1).min()
- elif ord in ['fro','f']:
- return sqrt(add.reduce((x.conj() * x).real.ravel()))
+ if row_axis > col_axis:
+ row_axis -= 1
+ return add.reduce(abs(x), axis=col_axis).min(axis=row_axis)
+ elif ord in [None, 'fro', 'f']:
+ return sqrt(add.reduce((x.conj() * x).real, axis=axis))
else:
raise ValueError("Invalid norm order for matrices.")
else: