summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests/test_linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r--numpy/linalg/tests/test_linalg.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 8edf36aa6..169a455ec 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -885,6 +885,48 @@ class _TestNorm(object):
expected = [norm(B[:,:, k], ord=order) for k in range(B.shape[2])]
assert_almost_equal(n, expected)
+ def test_keepdims(self):
+ A = np.arange(1,25, dtype=self.dt).reshape(2,3,4)
+
+ allclose_err = 'order {0}, axis = {1}'
+ shape_err = 'Shape mismatch found {0}, expected {1}, order={2}, axis={3}'
+
+ # check the order=None, axis=None case
+ expected = norm(A, ord=None, axis=None)
+ found = norm(A, ord=None, axis=None, keepdims=True)
+ assert_allclose(np.squeeze(found), expected,
+ err_msg=allclose_err.format(None,None))
+ expected_shape = (1,1,1)
+ assert_(found.shape == expected_shape,
+ shape_err.format(found.shape, expected_shape, None, None))
+
+ # Vector norms.
+ for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]:
+ for k in range(A.ndim):
+ expected = norm(A, ord=order, axis=k)
+ found = norm(A, ord=order, axis=k, keepdims=True)
+ assert_allclose(np.squeeze(found), expected,
+ err_msg=allclose_err.format(order,k))
+ expected_shape = list(A.shape)
+ expected_shape[k] = 1
+ expected_shape = tuple(expected_shape)
+ assert_(found.shape == expected_shape,
+ shape_err.format(found.shape, expected_shape, order, k))
+
+ # Matrix norms.
+ for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']:
+ for k in itertools.permutations(range(A.ndim), 2):
+ expected = norm(A, ord=order, axis=k)
+ found = norm(A, ord=order, axis=k, keepdims=True)
+ assert_allclose(np.squeeze(found), expected,
+ err_msg=allclose_err.format(order,k))
+ expected_shape = list(A.shape)
+ expected_shape[k[0]] = 1
+ expected_shape[k[1]] = 1
+ expected_shape = tuple(expected_shape)
+ assert_(found.shape == expected_shape,
+ shape_err.format(found.shape, expected_shape, order, k))
+
def test_bad_args(self):
# Check that bad arguments raise the appropriate exceptions.