summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-10-17 09:52:27 -0600
committerCharles Harris <charlesr.harris@gmail.com>2014-10-17 09:52:27 -0600
commitc9d61533f0a5a40ec0a38c69e19615d375af166e (patch)
tree28675cba68919cfdd36a067f612b1e1f4631edd0
parent5505dc9cdfa54fa5a45c8e69f3f82874c5aee9a6 (diff)
parent6f88eea024448b913cea881efad405e5d4ece195 (diff)
downloadnumpy-c9d61533f0a5a40ec0a38c69e19615d375af166e.tar.gz
Merge pull request #5196 from ewmoore/norm_keepdims
Add keepdims to linalg.norm
-rw-r--r--numpy/linalg/linalg.py45
-rw-r--r--numpy/linalg/tests/test_linalg.py43
2 files changed, 73 insertions, 15 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 6b2299fe7..e8f7a8ab1 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -1921,7 +1921,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op):
return result
-def norm(x, ord=None, axis=None):
+def norm(x, ord=None, axis=None, keepdims=False):
"""
Matrix or vector norm.
@@ -1942,6 +1942,11 @@ def norm(x, ord=None, axis=None):
axes that hold 2-D matrices, and the matrix norms of these matrices
are computed. If `axis` is None then either a vector norm (when `x`
is 1-D) or a matrix norm (when `x` is 2-D) is returned.
+ keepdims : bool, optional
+ .. versionadded:: 1.10.0
+ If this is set to True, the axes which are normed over are left in the
+ result as dimensions with size one. With this option the result will
+ broadcast correctly against the original `x`.
Returns
-------
@@ -2053,12 +2058,16 @@ def norm(x, ord=None, axis=None):
# Check the default case first and handle it immediately.
if ord is None and axis is None:
+ ndim = x.ndim
x = x.ravel(order='K')
if isComplexType(x.dtype.type):
sqnorm = dot(x.real, x.real) + dot(x.imag, x.imag)
else:
sqnorm = dot(x, x)
- return sqrt(sqnorm)
+ ret = sqrt(sqnorm)
+ if keepdims:
+ ret = ret.reshape(ndim*[1])
+ return ret
# Normalize the `axis` argument to a tuple.
nd = x.ndim
@@ -2069,19 +2078,19 @@ def norm(x, ord=None, axis=None):
if len(axis) == 1:
if ord == Inf:
- return abs(x).max(axis=axis)
+ return abs(x).max(axis=axis, keepdims=keepdims)
elif ord == -Inf:
- return abs(x).min(axis=axis)
+ return abs(x).min(axis=axis, keepdims=keepdims)
elif ord == 0:
# Zero norm
- return (x != 0).sum(axis=axis)
+ return (x != 0).sum(axis=axis, keepdims=keepdims)
elif ord == 1:
# special case for speedup
- return add.reduce(abs(x), axis=axis)
+ return add.reduce(abs(x), axis=axis, keepdims=keepdims)
elif ord is None or ord == 2:
# special case for speedup
s = (x.conj() * x).real
- return sqrt(add.reduce(s, axis=axis))
+ return sqrt(add.reduce(s, axis=axis, keepdims=keepdims))
else:
try:
ord + 1
@@ -2100,7 +2109,7 @@ def norm(x, ord=None, axis=None):
# if the type changed, we can safely overwrite absx
abs(absx, out=absx)
absx **= ord
- return add.reduce(absx, axis=axis) ** (1.0 / ord)
+ return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord)
elif len(axis) == 2:
row_axis, col_axis = axis
if not (-nd <= row_axis < nd and -nd <= col_axis < nd):
@@ -2109,28 +2118,34 @@ def norm(x, ord=None, axis=None):
if row_axis % nd == col_axis % nd:
raise ValueError('Duplicate axes given.')
if ord == 2:
- return _multi_svd_norm(x, row_axis, col_axis, amax)
+ ret = _multi_svd_norm(x, row_axis, col_axis, amax)
elif ord == -2:
- return _multi_svd_norm(x, row_axis, col_axis, amin)
+ ret = _multi_svd_norm(x, row_axis, col_axis, amin)
elif ord == 1:
if col_axis > row_axis:
col_axis -= 1
- return add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
+ ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
elif ord == Inf:
if row_axis > col_axis:
row_axis -= 1
- return add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
+ ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
elif ord == -1:
if col_axis > row_axis:
col_axis -= 1
- return add.reduce(abs(x), axis=row_axis).min(axis=col_axis)
+ ret = add.reduce(abs(x), axis=row_axis).min(axis=col_axis)
elif ord == -Inf:
if row_axis > col_axis:
row_axis -= 1
- return add.reduce(abs(x), axis=col_axis).min(axis=row_axis)
+ ret = add.reduce(abs(x), axis=col_axis).min(axis=row_axis)
elif ord in [None, 'fro', 'f']:
- return sqrt(add.reduce((x.conj() * x).real, axis=axis))
+ ret = sqrt(add.reduce((x.conj() * x).real, axis=axis))
else:
raise ValueError("Invalid norm order for matrices.")
+ if keepdims:
+ ret_shape = list(x.shape)
+ ret_shape[axis[0]] = 1
+ ret_shape[axis[1]] = 1
+ ret = ret.reshape(ret_shape)
+ return ret
else:
raise ValueError("Improper number of dimensions to norm.")
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 8edf36aa6..34079bb87 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -885,6 +885,48 @@ class _TestNorm(object):
expected = [norm(B[:,:, k], ord=order) for k in range(B.shape[2])]
assert_almost_equal(n, expected)
+ def test_keepdims(self):
+ A = np.arange(1,25, dtype=self.dt).reshape(2,3,4)
+
+ allclose_err = 'order {0}, axis = {1}'
+ shape_err = 'Shape mismatch found {0}, expected {1}, order={2}, axis={3}'
+
+ # check the order=None, axis=None case
+ expected = norm(A, ord=None, axis=None)
+ found = norm(A, ord=None, axis=None, keepdims=True)
+ assert_allclose(np.squeeze(found), expected,
+ err_msg=allclose_err.format(None,None))
+ expected_shape = (1,1,1)
+ assert_(found.shape == expected_shape,
+ shape_err.format(found.shape, expected_shape, None, None))
+
+ # Vector norms.
+ for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]:
+ for k in range(A.ndim):
+ expected = norm(A, ord=order, axis=k)
+ found = norm(A, ord=order, axis=k, keepdims=True)
+ assert_allclose(np.squeeze(found), expected,
+ err_msg=allclose_err.format(order,k))
+ expected_shape = list(A.shape)
+ expected_shape[k] = 1
+ expected_shape = tuple(expected_shape)
+ assert_(found.shape == expected_shape,
+ shape_err.format(found.shape, expected_shape, order, k))
+
+ # Matrix norms.
+ for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']:
+ for k in itertools.permutations(range(A.ndim), 2):
+ expected = norm(A, ord=order, axis=k)
+ found = norm(A, ord=order, axis=k, keepdims=True)
+ assert_allclose(np.squeeze(found), expected,
+ err_msg=allclose_err.format(order,k))
+ expected_shape = list(A.shape)
+ expected_shape[k[0]] = 1
+ expected_shape[k[1]] = 1
+ expected_shape = tuple(expected_shape)
+ assert_(found.shape == expected_shape,
+ shape_err.format(found.shape, expected_shape, order, k))
+
def test_bad_args(self):
# Check that bad arguments raise the appropriate exceptions.
@@ -909,6 +951,7 @@ class _TestNorm(object):
assert_raises(ValueError, norm, B, None, (2, 3))
assert_raises(ValueError, norm, B, None, (0, 1, 2))
+class TestNorm_NonSystematic(object):
def test_longdouble_norm(self):
# Non-regression test: p-norm of longdouble would previously raise
# UnboundLocalError.