summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorDenis Alevi <mail@denisalevi.de>2016-02-27 03:59:26 +0100
committerEren Sezener <erensezener@gmail.com>2016-03-20 14:41:40 +0100
commit7c76ef7ea3adbd9c70a08ee39fdf1b56fb940b3b (patch)
tree0ba0a9824993f010493d4c6549edbda72fc1ef83 /numpy/lib/tests/test_function_base.py
parente7de401f5c9634a2a63eb8f44f2193be1a946191 (diff)
downloadnumpy-7c76ef7ea3adbd9c70a08ee39fdf1b56fb940b3b.tar.gz
ENH: generalize rot90 with axes kwarg, move to function_base.py, and add tests
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py70
1 files changed, 69 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 061ba8742..a26979f87 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -16,7 +16,7 @@ from numpy.lib import (
add_newdoc_ufunc, angle, average, bartlett, blackman, corrcoef, cov,
delete, diff, digitize, extract, flipud, gradient, hamming, hanning,
histogram, histogramdd, i0, insert, interp, kaiser, meshgrid, msort,
- piecewise, place, select, setxor1d, sinc, split, trapz, trim_zeros,
+ piecewise, place, rot90, select, setxor1d, sinc, split, trapz, trim_zeros,
unwrap, unique, vectorize,
)
@@ -29,6 +29,74 @@ def get_mat(n):
return data
+class TestRot90(TestCase):
+ def test_basic(self):
+ self.assertRaises(ValueError, rot90, np.ones(4))
+ assert_raises(ValueError, rot90, np.ones((2,2,2)), axes=(0,1,2))
+ assert_raises(ValueError, rot90, np.ones((2,2)), axes=(0,2))
+ assert_raises(ValueError, rot90, np.ones((2,2)), axes=(1,1))
+ assert_raises(ValueError, rot90, np.ones((2,2,2)), axes=(-2,1))
+
+ a = [[0, 1, 2],
+ [3, 4, 5]]
+ b1 = [[2, 5],
+ [1, 4],
+ [0, 3]]
+ b2 = [[5, 4, 3],
+ [2, 1, 0]]
+ b3 = [[3, 0],
+ [4, 1],
+ [5, 2]]
+ b4 = [[0, 1, 2],
+ [3, 4, 5]]
+
+ for k in range(-3, 13, 4):
+ assert_equal(rot90(a, k=k), b1)
+ for k in range(-2, 13, 4):
+ assert_equal(rot90(a, k=k), b2)
+ for k in range(-1, 13, 4):
+ assert_equal(rot90(a, k=k), b3)
+ for k in range(0, 13, 4):
+ assert_equal(rot90(a, k=k), b4)
+
+ assert_equal(rot90(rot90(a, axes=(0,1)), axes=(1,0)), a)
+ assert_equal(rot90(a, k=1, axes=(1,0)), rot90(a, k=-1, axes=(0,1)))
+
+ def test_axes(self):
+ a = np.ones((50, 40, 3))
+ assert_equal(rot90(a).shape, (40, 50, 3))
+ assert_equal(rot90(a, axes=(0,2)), rot90(a, axes=(0,-1)))
+ assert_equal(rot90(a, axes=(1,2)), rot90(a, axes=(-2,-1)))
+
+ def test_rotation_axes(self):
+ a = np.arange(8).reshape((2,2,2))
+
+ a_rot90_01 = [[[2, 3],
+ [6, 7]],
+ [[0, 1],
+ [4, 5]]]
+ a_rot90_12 = [[[1, 3],
+ [0, 2]],
+ [[5, 7],
+ [4, 6]]]
+ a_rot90_20 = [[[4, 0],
+ [6, 2]],
+ [[5, 1],
+ [7, 3]]]
+ a_rot90_10 = [[[4, 5],
+ [0, 1]],
+ [[6, 7],
+ [2, 3]]]
+
+ assert_equal(rot90(a, axes=(0, 1)), a_rot90_01)
+ assert_equal(rot90(a, axes=(1, 0)), a_rot90_10)
+ assert_equal(rot90(a, axes=(1, 2)), a_rot90_12)
+
+ for k in range(1,5):
+ assert_equal(rot90(a, k=k, axes=(2, 0)),
+ rot90(a_rot90_20, k=k-1, axes=(2, 0)))
+
+
class TestFlip(TestCase):
def test_axes(self):
self.assertRaises(ValueError, np.flip, np.ones(4), axis=1)