diff options
author | Junjie Bai <bai@in.tum.de> | 2018-04-10 01:24:35 -0700 |
---|---|---|
committer | Junjie Bai <bai@in.tum.de> | 2018-04-17 13:51:12 -0700 |
commit | 9d1d44fc120c451a4925720027e85c70e13c26f2 (patch) | |
tree | 225480480d981b854c089aace86f6ca0cd282b77 /numpy/lib/function_base.py | |
parent | 8323be1bc44c2811fc36f5b99c1a30ebcee8edbd (diff) | |
download | numpy-9d1d44fc120c451a4925720027e85c70e13c26f2.tar.gz |
ENH: Extend np.flip to work over multiple axes
Closes #10847
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 50 |
1 files changed, 37 insertions, 13 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 8440be52e..a4516c276 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -145,7 +145,7 @@ def rot90(m, k=1, axes=(0,1)): return flip(transpose(m, axes_list), axes[1]) -def flip(m, axis): +def flip(m, axis=None): """ Reverse the order of elements in an array along the given axis. @@ -157,9 +157,16 @@ def flip(m, axis): ---------- m : array_like Input array. - axis : integer - Axis in array, which entries are reversed. + axis : None or int or tuple of ints, optional + Axis or axes along which to flip over. The default, + axis=None, will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + .. versionchanged:: 1.15.0 + None and tuples of axes are supported Returns ------- @@ -175,9 +182,17 @@ def flip(m, axis): Notes ----- flip(m, 0) is equivalent to flipud(m). + flip(m, 1) is equivalent to fliplr(m). + flip(m, n) corresponds to ``m[...,::-1,...]`` with ``::-1`` at position n. + flip(m) corresponds to ``m[::-1,::-1,...,::-1]`` with ``::-1`` at all + positions. + + flip(m, (0, 1)) corresponds to ``m[::-1,::-1,...]`` with ``::-1`` at + position 0 and position 1. + Examples -------- >>> A = np.arange(8).reshape((2,2,2)) @@ -186,32 +201,41 @@ def flip(m, axis): [2, 3]], [[4, 5], [6, 7]]]) - >>> flip(A, 0) array([[[4, 5], [6, 7]], [[0, 1], [2, 3]]]) - >>> flip(A, 1) array([[[2, 3], [0, 1]], [[6, 7], [4, 5]]]) - + >>> np.flip(A) + array([[[7, 6], + [5, 4]], + [[3, 2], + [1, 0]]]) + >>> np.flip(A, (0, 2)) + array([[[5, 4], + [7, 6]], + [[1, 0], + [3, 2]]]) >>> A = np.random.randn(3,4,5) >>> np.all(flip(A,2) == A[:,:,::-1,...]) True """ if not hasattr(m, 'ndim'): m = asarray(m) - indexer = [slice(None)] * m.ndim - try: - indexer[axis] = slice(None, None, -1) - except IndexError: - raise ValueError("axis=%i is invalid for the %i-dimensional input array" - % (axis, m.ndim)) - return m[tuple(indexer)] + if axis is None: + indexer = (np.s_[::-1],) * m.ndim + else: + axis = _nx.normalize_axis_tuple(axis, m.ndim) + indexer = [np.s_[:]] * m.ndim + for ax in axis: + indexer[ax] = np.s_[::-1] + indexer = tuple(indexer) + return m[indexer] def iterable(y): |