diff options
author | Eric Moore <ewm@redtetrahedron.org> | 2014-10-16 11:59:03 -0400 |
---|---|---|
committer | Eric Moore <ewm@redtetrahedron.org> | 2014-10-17 10:50:43 -0400 |
commit | 9b152aaaf3abd0f98b7a88ed66df7518a6a6c85b (patch) | |
tree | 7d6cec3fd0573068497a478ec0522dc4ccb602b4 /numpy/linalg/tests | |
parent | 51f0976c1ca101a01d09e26ee5dfea5360f73c63 (diff) | |
download | numpy-9b152aaaf3abd0f98b7a88ed66df7518a6a6c85b.tar.gz |
ENH: Add keepdims to linalg.norm
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 8edf36aa6..169a455ec 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -885,6 +885,48 @@ class _TestNorm(object): expected = [norm(B[:,:, k], ord=order) for k in range(B.shape[2])] assert_almost_equal(n, expected) + def test_keepdims(self): + A = np.arange(1,25, dtype=self.dt).reshape(2,3,4) + + allclose_err = 'order {0}, axis = {1}' + shape_err = 'Shape mismatch found {0}, expected {1}, order={2}, axis={3}' + + # check the order=None, axis=None case + expected = norm(A, ord=None, axis=None) + found = norm(A, ord=None, axis=None, keepdims=True) + assert_allclose(np.squeeze(found), expected, + err_msg=allclose_err.format(None,None)) + expected_shape = (1,1,1) + assert_(found.shape == expected_shape, + shape_err.format(found.shape, expected_shape, None, None)) + + # Vector norms. + for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]: + for k in range(A.ndim): + expected = norm(A, ord=order, axis=k) + found = norm(A, ord=order, axis=k, keepdims=True) + assert_allclose(np.squeeze(found), expected, + err_msg=allclose_err.format(order,k)) + expected_shape = list(A.shape) + expected_shape[k] = 1 + expected_shape = tuple(expected_shape) + assert_(found.shape == expected_shape, + shape_err.format(found.shape, expected_shape, order, k)) + + # Matrix norms. + for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']: + for k in itertools.permutations(range(A.ndim), 2): + expected = norm(A, ord=order, axis=k) + found = norm(A, ord=order, axis=k, keepdims=True) + assert_allclose(np.squeeze(found), expected, + err_msg=allclose_err.format(order,k)) + expected_shape = list(A.shape) + expected_shape[k[0]] = 1 + expected_shape[k[1]] = 1 + expected_shape = tuple(expected_shape) + assert_(found.shape == expected_shape, + shape_err.format(found.shape, expected_shape, order, k)) + def test_bad_args(self): # Check that bad arguments raise the appropriate exceptions. |