summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorJunjie Bai <bai@in.tum.de>2018-04-10 01:24:35 -0700
committerJunjie Bai <bai@in.tum.de>2018-04-17 13:51:12 -0700
commit9d1d44fc120c451a4925720027e85c70e13c26f2 (patch)
tree225480480d981b854c089aace86f6ca0cd282b77 /numpy/lib/function_base.py
parent8323be1bc44c2811fc36f5b99c1a30ebcee8edbd (diff)
downloadnumpy-9d1d44fc120c451a4925720027e85c70e13c26f2.tar.gz
ENH: Extend np.flip to work over multiple axes
Closes #10847
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py50
1 files changed, 37 insertions, 13 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 8440be52e..a4516c276 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -145,7 +145,7 @@ def rot90(m, k=1, axes=(0,1)):
return flip(transpose(m, axes_list), axes[1])
-def flip(m, axis):
+def flip(m, axis=None):
"""
Reverse the order of elements in an array along the given axis.
@@ -157,9 +157,16 @@ def flip(m, axis):
----------
m : array_like
Input array.
- axis : integer
- Axis in array, which entries are reversed.
+ axis : None or int or tuple of ints, optional
+ Axis or axes along which to flip over. The default,
+ axis=None, will flip over all of the axes of the input array.
+ If axis is negative it counts from the last to the first axis.
+
+ If axis is a tuple of ints, flipping is performed on all of the axes
+ specified in the tuple.
+ .. versionchanged:: 1.15.0
+ None and tuples of axes are supported
Returns
-------
@@ -175,9 +182,17 @@ def flip(m, axis):
Notes
-----
flip(m, 0) is equivalent to flipud(m).
+
flip(m, 1) is equivalent to fliplr(m).
+
flip(m, n) corresponds to ``m[...,::-1,...]`` with ``::-1`` at position n.
+ flip(m) corresponds to ``m[::-1,::-1,...,::-1]`` with ``::-1`` at all
+ positions.
+
+ flip(m, (0, 1)) corresponds to ``m[::-1,::-1,...]`` with ``::-1`` at
+ position 0 and position 1.
+
Examples
--------
>>> A = np.arange(8).reshape((2,2,2))
@@ -186,32 +201,41 @@ def flip(m, axis):
[2, 3]],
[[4, 5],
[6, 7]]])
-
>>> flip(A, 0)
array([[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]]])
-
>>> flip(A, 1)
array([[[2, 3],
[0, 1]],
[[6, 7],
[4, 5]]])
-
+ >>> np.flip(A)
+ array([[[7, 6],
+ [5, 4]],
+ [[3, 2],
+ [1, 0]]])
+ >>> np.flip(A, (0, 2))
+ array([[[5, 4],
+ [7, 6]],
+ [[1, 0],
+ [3, 2]]])
>>> A = np.random.randn(3,4,5)
>>> np.all(flip(A,2) == A[:,:,::-1,...])
True
"""
if not hasattr(m, 'ndim'):
m = asarray(m)
- indexer = [slice(None)] * m.ndim
- try:
- indexer[axis] = slice(None, None, -1)
- except IndexError:
- raise ValueError("axis=%i is invalid for the %i-dimensional input array"
- % (axis, m.ndim))
- return m[tuple(indexer)]
+ if axis is None:
+ indexer = (np.s_[::-1],) * m.ndim
+ else:
+ axis = _nx.normalize_axis_tuple(axis, m.ndim)
+ indexer = [np.s_[:]] * m.ndim
+ for ax in axis:
+ indexer[ax] = np.s_[::-1]
+ indexer = tuple(indexer)
+ return m[indexer]
def iterable(y):