summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorJunjie Bai <bai@in.tum.de>2018-04-10 01:24:35 -0700
committerJunjie Bai <bai@in.tum.de>2018-04-17 13:51:12 -0700
commit9d1d44fc120c451a4925720027e85c70e13c26f2 (patch)
tree225480480d981b854c089aace86f6ca0cd282b77 /numpy/lib/tests/test_function_base.py
parent8323be1bc44c2811fc36f5b99c1a30ebcee8edbd (diff)
downloadnumpy-9d1d44fc120c451a4925720027e85c70e13c26f2.tar.gz
ENH: Extend np.flip to work over multiple axes
Closes #10847
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py36
1 files changed, 33 insertions, 3 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 0a4c7c370..6653b5ba1 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -104,9 +104,10 @@ class TestRot90(object):
class TestFlip(object):
def test_axes(self):
- assert_raises(ValueError, np.flip, np.ones(4), axis=1)
- assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=2)
- assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=-3)
+ assert_raises(np.AxisError, np.flip, np.ones(4), axis=1)
+ assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=2)
+ assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=-3)
+ assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=(0, 3))
def test_basic_lr(self):
a = get_mat(4)
@@ -173,6 +174,35 @@ class TestFlip(object):
assert_equal(np.flip(a, i),
np.flipud(a.swapaxes(0, i)).swapaxes(i, 0))
+ def test_default_axis(self):
+ a = np.array([[1, 2, 3],
+ [4, 5, 6]])
+ b = np.array([[6, 5, 4],
+ [3, 2, 1]])
+ assert_equal(np.flip(a), b)
+
+ def test_multiple_axes(self):
+ a = np.array([[[0, 1],
+ [2, 3]],
+ [[4, 5],
+ [6, 7]]])
+
+ assert_equal(np.flip(a, axis=()), a)
+
+ b = np.array([[[5, 4],
+ [7, 6]],
+ [[1, 0],
+ [3, 2]]])
+
+ assert_equal(np.flip(a, axis=(0, 2)), b)
+
+ c = np.array([[[3, 2],
+ [1, 0]],
+ [[7, 6],
+ [5, 4]]])
+
+ assert_equal(np.flip(a, axis=(1, 2)), c)
+
class TestAny(object):