diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 50 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 36 |
2 files changed, 70 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 8440be52e..a4516c276 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -145,7 +145,7 @@ def rot90(m, k=1, axes=(0,1)): return flip(transpose(m, axes_list), axes[1]) -def flip(m, axis): +def flip(m, axis=None): """ Reverse the order of elements in an array along the given axis. @@ -157,9 +157,16 @@ def flip(m, axis): ---------- m : array_like Input array. - axis : integer - Axis in array, which entries are reversed. + axis : None or int or tuple of ints, optional + Axis or axes along which to flip over. The default, + axis=None, will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + .. versionchanged:: 1.15.0 + None and tuples of axes are supported Returns ------- @@ -175,9 +182,17 @@ def flip(m, axis): Notes ----- flip(m, 0) is equivalent to flipud(m). + flip(m, 1) is equivalent to fliplr(m). + flip(m, n) corresponds to ``m[...,::-1,...]`` with ``::-1`` at position n. + flip(m) corresponds to ``m[::-1,::-1,...,::-1]`` with ``::-1`` at all + positions. + + flip(m, (0, 1)) corresponds to ``m[::-1,::-1,...]`` with ``::-1`` at + position 0 and position 1. + Examples -------- >>> A = np.arange(8).reshape((2,2,2)) @@ -186,32 +201,41 @@ def flip(m, axis): [2, 3]], [[4, 5], [6, 7]]]) - >>> flip(A, 0) array([[[4, 5], [6, 7]], [[0, 1], [2, 3]]]) - >>> flip(A, 1) array([[[2, 3], [0, 1]], [[6, 7], [4, 5]]]) - + >>> np.flip(A) + array([[[7, 6], + [5, 4]], + [[3, 2], + [1, 0]]]) + >>> np.flip(A, (0, 2)) + array([[[5, 4], + [7, 6]], + [[1, 0], + [3, 2]]]) >>> A = np.random.randn(3,4,5) >>> np.all(flip(A,2) == A[:,:,::-1,...]) True """ if not hasattr(m, 'ndim'): m = asarray(m) - indexer = [slice(None)] * m.ndim - try: - indexer[axis] = slice(None, None, -1) - except IndexError: - raise ValueError("axis=%i is invalid for the %i-dimensional input array" - % (axis, m.ndim)) - return m[tuple(indexer)] + if axis is None: + indexer = (np.s_[::-1],) * m.ndim + else: + axis = _nx.normalize_axis_tuple(axis, m.ndim) + indexer = [np.s_[:]] * m.ndim + for ax in axis: + indexer[ax] = np.s_[::-1] + indexer = tuple(indexer) + return m[indexer] def iterable(y): diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 0a4c7c370..6653b5ba1 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -104,9 +104,10 @@ class TestRot90(object): class TestFlip(object): def test_axes(self): - assert_raises(ValueError, np.flip, np.ones(4), axis=1) - assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=2) - assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=-3) + assert_raises(np.AxisError, np.flip, np.ones(4), axis=1) + assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=2) + assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=-3) + assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=(0, 3)) def test_basic_lr(self): a = get_mat(4) @@ -173,6 +174,35 @@ class TestFlip(object): assert_equal(np.flip(a, i), np.flipud(a.swapaxes(0, i)).swapaxes(i, 0)) + def test_default_axis(self): + a = np.array([[1, 2, 3], + [4, 5, 6]]) + b = np.array([[6, 5, 4], + [3, 2, 1]]) + assert_equal(np.flip(a), b) + + def test_multiple_axes(self): + a = np.array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + + assert_equal(np.flip(a, axis=()), a) + + b = np.array([[[5, 4], + [7, 6]], + [[1, 0], + [3, 2]]]) + + assert_equal(np.flip(a, axis=(0, 2)), b) + + c = np.array([[[3, 2], + [1, 0]], + [[7, 6], + [5, 4]]]) + + assert_equal(np.flip(a, axis=(1, 2)), c) + class TestAny(object): |