diff options
author | Junjie Bai <bai@in.tum.de> | 2018-04-10 01:24:35 -0700 |
---|---|---|
committer | Junjie Bai <bai@in.tum.de> | 2018-04-17 13:51:12 -0700 |
commit | 9d1d44fc120c451a4925720027e85c70e13c26f2 (patch) | |
tree | 225480480d981b854c089aace86f6ca0cd282b77 /numpy/lib/tests/test_function_base.py | |
parent | 8323be1bc44c2811fc36f5b99c1a30ebcee8edbd (diff) | |
download | numpy-9d1d44fc120c451a4925720027e85c70e13c26f2.tar.gz |
ENH: Extend np.flip to work over multiple axes
Closes #10847
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 0a4c7c370..6653b5ba1 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -104,9 +104,10 @@ class TestRot90(object): class TestFlip(object): def test_axes(self): - assert_raises(ValueError, np.flip, np.ones(4), axis=1) - assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=2) - assert_raises(ValueError, np.flip, np.ones((4, 4)), axis=-3) + assert_raises(np.AxisError, np.flip, np.ones(4), axis=1) + assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=2) + assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=-3) + assert_raises(np.AxisError, np.flip, np.ones((4, 4)), axis=(0, 3)) def test_basic_lr(self): a = get_mat(4) @@ -173,6 +174,35 @@ class TestFlip(object): assert_equal(np.flip(a, i), np.flipud(a.swapaxes(0, i)).swapaxes(i, 0)) + def test_default_axis(self): + a = np.array([[1, 2, 3], + [4, 5, 6]]) + b = np.array([[6, 5, 4], + [3, 2, 1]]) + assert_equal(np.flip(a), b) + + def test_multiple_axes(self): + a = np.array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + + assert_equal(np.flip(a, axis=()), a) + + b = np.array([[[5, 4], + [7, 6]], + [[1, 0], + [3, 2]]]) + + assert_equal(np.flip(a, axis=(0, 2)), b) + + c = np.array([[[3, 2], + [1, 0]], + [[7, 6], + [5, 4]]]) + + assert_equal(np.flip(a, axis=(1, 2)), c) + class TestAny(object): |