summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py50
-rw-r--r--numpy/lib/tests/test_function_base.py36
2 files changed, 70 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 8440be52e..a4516c276 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -145,7 +145,7 @@ def rot90(m, k=1, axes=(0,1)):
return flip(transpose(m, axes_list), axes[1])
-def flip(m, axis):
+def flip(m, axis=None):
"""
Reverse the order of elements in an array along the given axis.
@@ -157,9 +157,16 @@ def flip(m, axis):
----------
m : array_like
Input array.
- axis : integer
- Axis in array, which entries are reversed.
+ axis : None or int or tuple of ints, optional
+ Axis or axes along which to flip over. The default,
+ axis=None, will flip over all of the axes of the input array.
+ If axis is negative it counts from the last to the first axis.
+
+ If axis is a tuple of ints, flipping is performed on all of the axes
+ specified in the tuple.
+ .. versionchanged:: 1.15.0
+ None and tuples of axes are supported
Returns
-------
@@ -175,9 +182,17 @@ def flip(m, axis):
Notes
-----
flip(m, 0) is equivalent to flipud(m).
+
flip(m, 1) is equivalent to fliplr(m).
+
flip(m, n) corresponds to ``m[...,::-1,...]`` with ``::-1`` at position n.
+ flip(m) corresponds to ``m[::-1,::-1,...,::-1]`` with ``::-1`` at all
+ positions.
+
+ flip(m, (0, 1)) corresponds to ``m[::-1,::-1,...]`` with ``::-1`` at
+ position 0 and position 1.
+
Examples
--------
>>> A = np.arange(8).reshape((2,2,2))
@@ -186,32 +201,41 @@ def flip(m, axis):
[2, 3]],
[[4, 5],
[6, 7]]])
-
>>> flip(A, 0)
array([[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]]])
-
>>> flip(A, 1)
array([[[2, 3],
[0, 1]],
[[6, 7],
[4, 5]]])
-
+ >>> np.flip(A)
+ array([[[7, 6],
+ [5, 4]],
+ [[3, 2],
+ [1, 0]]])
+ >>> np.flip(A, (0, 2))
+ array([[[5, 4],
+ [7, 6]],
+ [[1, 0],
+ [3, 2]]])
>>> A = np.random.randn(3,4,5)
>>> np.all(flip(A,2) == A[:,:,::-1,...])
True
"""
if not hasattr(m, 'ndim'):
m = asarray(m)
- indexer = [slice(None)] * m.ndim
- try:
- indexer[axis] = slice(None, None, -1)
- except IndexError:
- raise ValueError("axis=%i is invalid for the %i-dimensional input array"
- % (axis, m.ndim))
- return m[tuple(indexer)]
+ if axis is None:
+ indexer = (np.s_[::-1],) * m.ndim
+ else:
+ axis = _nx.normalize_axis_tuple(axis, m.ndim)
+ indexer = [np.s_[:]] * m.ndim
+ for ax in axis:
+ indexer[ax] = np.s_[::-1]
+ indexer = tuple(indexer)
+ return m[indexer]
def iterable(y):
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 0a4c7c370..6653b5ba1 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -104,9 +104,10 @@ class TestRot90(object):
class TestFlip(object):
def test_axes(self):
- assert_raises(ValueError, np.flip, np.ones(4), axis=1)
- assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=2)
- assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=-3)
+ assert_raises(np.AxisError, np.flip, np.ones(4), axis=1)
+ assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=2)
+ assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=-3)
+ assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=(0, 3))
def test_basic_lr(self):
a = get_mat(4)
@@ -173,6 +174,35 @@ class TestFlip(object):
assert_equal(np.flip(a, i),
np.flipud(a.swapaxes(0, i)).swapaxes(i, 0))
+ def test_default_axis(self):
+ a = np.array([[1, 2, 3],
+ [4, 5, 6]])
+ b = np.array([[6, 5, 4],
+ [3, 2, 1]])
+ assert_equal(np.flip(a), b)
+
+ def test_multiple_axes(self):
+ a = np.array([[[0, 1],
+ [2, 3]],
+ [[4, 5],
+ [6, 7]]])
+
+ assert_equal(np.flip(a, axis=()), a)
+
+ b = np.array([[[5, 4],
+ [7, 6]],
+ [[1, 0],
+ [3, 2]]])
+
+ assert_equal(np.flip(a, axis=(0, 2)), b)
+
+ c = np.array([[[3, 2],
+ [1, 0]],
+ [[7, 6],
+ [5, 4]]])
+
+ assert_equal(np.flip(a, axis=(1, 2)), c)
+
class TestAny(object):