diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2014-10-17 09:52:27 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2014-10-17 09:52:27 -0600 |
commit | c9d61533f0a5a40ec0a38c69e19615d375af166e (patch) | |
tree | 28675cba68919cfdd36a067f612b1e1f4631edd0 | |
parent | 5505dc9cdfa54fa5a45c8e69f3f82874c5aee9a6 (diff) | |
parent | 6f88eea024448b913cea881efad405e5d4ece195 (diff) | |
download | numpy-c9d61533f0a5a40ec0a38c69e19615d375af166e.tar.gz |
Merge pull request #5196 from ewmoore/norm_keepdims
Add keepdims to linalg.norm
-rw-r--r-- | numpy/linalg/linalg.py | 45 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 43 |
2 files changed, 73 insertions, 15 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 6b2299fe7..e8f7a8ab1 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1921,7 +1921,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op): return result -def norm(x, ord=None, axis=None): +def norm(x, ord=None, axis=None, keepdims=False): """ Matrix or vector norm. @@ -1942,6 +1942,11 @@ def norm(x, ord=None, axis=None): axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If `axis` is None then either a vector norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D) is returned. + keepdims : bool, optional + .. versionadded:: 1.10.0 + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. Returns ------- @@ -2053,12 +2058,16 @@ def norm(x, ord=None, axis=None): # Check the default case first and handle it immediately. if ord is None and axis is None: + ndim = x.ndim x = x.ravel(order='K') if isComplexType(x.dtype.type): sqnorm = dot(x.real, x.real) + dot(x.imag, x.imag) else: sqnorm = dot(x, x) - return sqrt(sqnorm) + ret = sqrt(sqnorm) + if keepdims: + ret = ret.reshape(ndim*[1]) + return ret # Normalize the `axis` argument to a tuple. nd = x.ndim @@ -2069,19 +2078,19 @@ def norm(x, ord=None, axis=None): if len(axis) == 1: if ord == Inf: - return abs(x).max(axis=axis) + return abs(x).max(axis=axis, keepdims=keepdims) elif ord == -Inf: - return abs(x).min(axis=axis) + return abs(x).min(axis=axis, keepdims=keepdims) elif ord == 0: # Zero norm - return (x != 0).sum(axis=axis) + return (x != 0).sum(axis=axis, keepdims=keepdims) elif ord == 1: # special case for speedup - return add.reduce(abs(x), axis=axis) + return add.reduce(abs(x), axis=axis, keepdims=keepdims) elif ord is None or ord == 2: # special case for speedup s = (x.conj() * x).real - return sqrt(add.reduce(s, axis=axis)) + return sqrt(add.reduce(s, axis=axis, keepdims=keepdims)) else: try: ord + 1 @@ -2100,7 +2109,7 @@ def norm(x, ord=None, axis=None): # if the type changed, we can safely overwrite absx abs(absx, out=absx) absx **= ord - return add.reduce(absx, axis=axis) ** (1.0 / ord) + return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord) elif len(axis) == 2: row_axis, col_axis = axis if not (-nd <= row_axis < nd and -nd <= col_axis < nd): @@ -2109,28 +2118,34 @@ def norm(x, ord=None, axis=None): if row_axis % nd == col_axis % nd: raise ValueError('Duplicate axes given.') if ord == 2: - return _multi_svd_norm(x, row_axis, col_axis, amax) + ret = _multi_svd_norm(x, row_axis, col_axis, amax) elif ord == -2: - return _multi_svd_norm(x, row_axis, col_axis, amin) + ret = _multi_svd_norm(x, row_axis, col_axis, amin) elif ord == 1: if col_axis > row_axis: col_axis -= 1 - return add.reduce(abs(x), axis=row_axis).max(axis=col_axis) + ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis) elif ord == Inf: if row_axis > col_axis: row_axis -= 1 - return add.reduce(abs(x), axis=col_axis).max(axis=row_axis) + ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis) elif ord == -1: if col_axis > row_axis: col_axis -= 1 - return add.reduce(abs(x), axis=row_axis).min(axis=col_axis) + ret = add.reduce(abs(x), axis=row_axis).min(axis=col_axis) elif ord == -Inf: if row_axis > col_axis: row_axis -= 1 - return add.reduce(abs(x), axis=col_axis).min(axis=row_axis) + ret = add.reduce(abs(x), axis=col_axis).min(axis=row_axis) elif ord in [None, 'fro', 'f']: - return sqrt(add.reduce((x.conj() * x).real, axis=axis)) + ret = sqrt(add.reduce((x.conj() * x).real, axis=axis)) else: raise ValueError("Invalid norm order for matrices.") + if keepdims: + ret_shape = list(x.shape) + ret_shape[axis[0]] = 1 + ret_shape[axis[1]] = 1 + ret = ret.reshape(ret_shape) + return ret else: raise ValueError("Improper number of dimensions to norm.") diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 8edf36aa6..34079bb87 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -885,6 +885,48 @@ class _TestNorm(object): expected = [norm(B[:,:, k], ord=order) for k in range(B.shape[2])] assert_almost_equal(n, expected) + def test_keepdims(self): + A = np.arange(1,25, dtype=self.dt).reshape(2,3,4) + + allclose_err = 'order {0}, axis = {1}' + shape_err = 'Shape mismatch found {0}, expected {1}, order={2}, axis={3}' + + # check the order=None, axis=None case + expected = norm(A, ord=None, axis=None) + found = norm(A, ord=None, axis=None, keepdims=True) + assert_allclose(np.squeeze(found), expected, + err_msg=allclose_err.format(None,None)) + expected_shape = (1,1,1) + assert_(found.shape == expected_shape, + shape_err.format(found.shape, expected_shape, None, None)) + + # Vector norms. + for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]: + for k in range(A.ndim): + expected = norm(A, ord=order, axis=k) + found = norm(A, ord=order, axis=k, keepdims=True) + assert_allclose(np.squeeze(found), expected, + err_msg=allclose_err.format(order,k)) + expected_shape = list(A.shape) + expected_shape[k] = 1 + expected_shape = tuple(expected_shape) + assert_(found.shape == expected_shape, + shape_err.format(found.shape, expected_shape, order, k)) + + # Matrix norms. + for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']: + for k in itertools.permutations(range(A.ndim), 2): + expected = norm(A, ord=order, axis=k) + found = norm(A, ord=order, axis=k, keepdims=True) + assert_allclose(np.squeeze(found), expected, + err_msg=allclose_err.format(order,k)) + expected_shape = list(A.shape) + expected_shape[k[0]] = 1 + expected_shape[k[1]] = 1 + expected_shape = tuple(expected_shape) + assert_(found.shape == expected_shape, + shape_err.format(found.shape, expected_shape, order, k)) + def test_bad_args(self): # Check that bad arguments raise the appropriate exceptions. @@ -909,6 +951,7 @@ class _TestNorm(object): assert_raises(ValueError, norm, B, None, (2, 3)) assert_raises(ValueError, norm, B, None, (0, 1, 2)) +class TestNorm_NonSystematic(object): def test_longdouble_norm(self): # Non-regression test: p-norm of longdouble would previously raise # UnboundLocalError. |