diff options
author | Denis Alevi <mail@denisalevi.de> | 2016-02-27 03:59:26 +0100 |
---|---|---|
committer | Eren Sezener <erensezener@gmail.com> | 2016-03-20 14:41:40 +0100 |
commit | 7c76ef7ea3adbd9c70a08ee39fdf1b56fb940b3b (patch) | |
tree | 0ba0a9824993f010493d4c6549edbda72fc1ef83 /numpy/lib | |
parent | e7de401f5c9634a2a63eb8f44f2193be1a946191 (diff) | |
download | numpy-7c76ef7ea3adbd9c70a08ee39fdf1b56fb940b3b.tar.gz |
ENH: generalize rot90 with axes kwarg, move to function_base.py, and add tests
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 92 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 70 | ||||
-rw-r--r-- | numpy/lib/tests/test_twodim_base.py | 33 | ||||
-rw-r--r-- | numpy/lib/twodim_base.py | 61 |
4 files changed, 163 insertions, 93 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 84690e4e3..c88228354 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -7,11 +7,11 @@ import operator import numpy as np import numpy.core.numeric as _nx -from numpy.core import linspace, atleast_1d, atleast_2d +from numpy.core import linspace, atleast_1d, atleast_2d, transpose from numpy.core.numeric import ( ones, zeros, arange, concatenate, array, asarray, asanyarray, empty, empty_like, ndarray, around, floor, ceil, take, dot, where, intp, - integer, isscalar + integer, isscalar, absolute ) from numpy.core.umath import ( pi, multiply, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin, @@ -37,7 +37,7 @@ if sys.version_info[0] < 3: __all__ = [ 'select', 'piecewise', 'trim_zeros', 'copy', 'iterable', 'percentile', 'diff', 'gradient', 'angle', 'unwrap', 'sort_complex', 'disp', 'flip', - 'extract', 'place', 'vectorize', 'asarray_chkfinite', 'average', + 'rot90', 'extract', 'place', 'vectorize', 'asarray_chkfinite', 'average', 'histogram', 'histogramdd', 'bincount', 'digitize', 'cov', 'corrcoef', 'msort', 'median', 'sinc', 'hamming', 'hanning', 'bartlett', 'blackman', 'kaiser', 'trapz', 'i0', 'add_newdoc', 'add_docstring', @@ -45,6 +45,92 @@ __all__ = [ ] +def rot90(m, k=1, axes=(0,1)): + """ + Rotate an array by 90 degrees in the plane specified by axes. + + Rotation direction is from the first towards the second axis. + + .. versionadded:: 1.12.0 + + Parameters + ---------- + m : array_like + Array of two or more dimensions. + k : integer + Number of times the array is rotated by 90 degrees. + axes: (2,) array_like + The array is rotated in the plane defined by the axes. + Axes must be different. + + Returns + ------- + y : ndarray + A rotated view of `m`. + + See Also + -------- + flip : Reverse the order of elements in an array along the given axis. + fliplr : Flip an array horizontally. + flipud : Flip an array vertically. + + Notes + ----- + rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1)) + rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1)) + + Examples + -------- + >>> m = np.array([[1,2],[3,4]], int) + >>> m + array([[1, 2], + [3, 4]]) + >>> np.rot90(m) + array([[2, 4], + [1, 3]]) + >>> np.rot90(m, 2) + array([[4, 3], + [2, 1]]) + >>> m = np.arange(8).reshape((2,2,2)) + >>> np.rot90(m, 1, (1,2)) + array([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) + + """ + axes = tuple(axes) + if len(axes) != 2: + raise ValueError("len(axes) must be 2.") + + m = asanyarray(m) + + if axes[0] == axes[1] or absolute(axes[0] - axes[1]) == m.ndim: + raise ValueError("Axes must be different.") + + if (axes[0] >= m.ndim or axes[0] < -m.ndim + or axes[1] >= m.ndim or axes[1] < -m.ndim): + raise ValueError("Axes={} out of range for array of ndim={}." + .format(axes, m.ndim)) + + k %= 4 + + if k == 0: + return m[:] + if k == 2: + return flip(flip(m, axes[0]), axes[1]) + + axes_list = arange(0, m.ndim) + axes_list[axes[0]], axes_list[axes[1]] = axes_list[axes[1]], axes_list[axes[0]] + + if k == 1: + return transpose(flip(m,axes[1]), axes_list) + else: + # k == 3 + return flip(transpose(m, axes_list), axes[1]) + + def flip(m, axis): """ Reverse the order of elements in an array along the given axis. diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 061ba8742..a26979f87 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -16,7 +16,7 @@ from numpy.lib import ( add_newdoc_ufunc, angle, average, bartlett, blackman, corrcoef, cov, delete, diff, digitize, extract, flipud, gradient, hamming, hanning, histogram, histogramdd, i0, insert, interp, kaiser, meshgrid, msort, - piecewise, place, select, setxor1d, sinc, split, trapz, trim_zeros, + piecewise, place, rot90, select, setxor1d, sinc, split, trapz, trim_zeros, unwrap, unique, vectorize, ) @@ -29,6 +29,74 @@ def get_mat(n): return data +class TestRot90(TestCase): + def test_basic(self): + self.assertRaises(ValueError, rot90, np.ones(4)) + assert_raises(ValueError, rot90, np.ones((2,2,2)), axes=(0,1,2)) + assert_raises(ValueError, rot90, np.ones((2,2)), axes=(0,2)) + assert_raises(ValueError, rot90, np.ones((2,2)), axes=(1,1)) + assert_raises(ValueError, rot90, np.ones((2,2,2)), axes=(-2,1)) + + a = [[0, 1, 2], + [3, 4, 5]] + b1 = [[2, 5], + [1, 4], + [0, 3]] + b2 = [[5, 4, 3], + [2, 1, 0]] + b3 = [[3, 0], + [4, 1], + [5, 2]] + b4 = [[0, 1, 2], + [3, 4, 5]] + + for k in range(-3, 13, 4): + assert_equal(rot90(a, k=k), b1) + for k in range(-2, 13, 4): + assert_equal(rot90(a, k=k), b2) + for k in range(-1, 13, 4): + assert_equal(rot90(a, k=k), b3) + for k in range(0, 13, 4): + assert_equal(rot90(a, k=k), b4) + + assert_equal(rot90(rot90(a, axes=(0,1)), axes=(1,0)), a) + assert_equal(rot90(a, k=1, axes=(1,0)), rot90(a, k=-1, axes=(0,1))) + + def test_axes(self): + a = np.ones((50, 40, 3)) + assert_equal(rot90(a).shape, (40, 50, 3)) + assert_equal(rot90(a, axes=(0,2)), rot90(a, axes=(0,-1))) + assert_equal(rot90(a, axes=(1,2)), rot90(a, axes=(-2,-1))) + + def test_rotation_axes(self): + a = np.arange(8).reshape((2,2,2)) + + a_rot90_01 = [[[2, 3], + [6, 7]], + [[0, 1], + [4, 5]]] + a_rot90_12 = [[[1, 3], + [0, 2]], + [[5, 7], + [4, 6]]] + a_rot90_20 = [[[4, 0], + [6, 2]], + [[5, 1], + [7, 3]]] + a_rot90_10 = [[[4, 5], + [0, 1]], + [[6, 7], + [2, 3]]] + + assert_equal(rot90(a, axes=(0, 1)), a_rot90_01) + assert_equal(rot90(a, axes=(1, 0)), a_rot90_10) + assert_equal(rot90(a, axes=(1, 2)), a_rot90_12) + + for k in range(1,5): + assert_equal(rot90(a, k=k, axes=(2, 0)), + rot90(a_rot90_20, k=k-1, axes=(2, 0))) + + class TestFlip(TestCase): def test_axes(self): self.assertRaises(ValueError, np.flip, np.ones(4), axis=1) diff --git a/numpy/lib/tests/test_twodim_base.py b/numpy/lib/tests/test_twodim_base.py index b65a8df97..738a02b37 100644 --- a/numpy/lib/tests/test_twodim_base.py +++ b/numpy/lib/tests/test_twodim_base.py @@ -9,7 +9,7 @@ from numpy.testing import ( ) from numpy import ( - arange, rot90, add, fliplr, flipud, zeros, ones, eye, array, diag, + arange, add, fliplr, flipud, zeros, ones, eye, array, diag, histogram2d, tri, mask_indices, triu_indices, triu_indices_from, tril_indices, tril_indices_from, vander, ) @@ -169,37 +169,6 @@ class TestFlipud(TestCase): assert_equal(flipud(a), b) -class TestRot90(TestCase): - def test_basic(self): - self.assertRaises(ValueError, rot90, ones(4)) - - a = [[0, 1, 2], - [3, 4, 5]] - b1 = [[2, 5], - [1, 4], - [0, 3]] - b2 = [[5, 4, 3], - [2, 1, 0]] - b3 = [[3, 0], - [4, 1], - [5, 2]] - b4 = [[0, 1, 2], - [3, 4, 5]] - - for k in range(-3, 13, 4): - assert_equal(rot90(a, k=k), b1) - for k in range(-2, 13, 4): - assert_equal(rot90(a, k=k), b2) - for k in range(-1, 13, 4): - assert_equal(rot90(a, k=k), b3) - for k in range(0, 13, 4): - assert_equal(rot90(a, k=k), b4) - - def test_axes(self): - a = ones((50, 40, 3)) - assert_equal(rot90(a).shape, (40, 50, 3)) - - class TestHistogram2d(TestCase): def test_simple(self): x = array( diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index aefe8d64b..8858f5bad 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -4,14 +4,14 @@ from __future__ import division, absolute_import, print_function from numpy.core.numeric import ( - asanyarray, arange, zeros, greater_equal, multiply, ones, asarray, - where, int8, int16, int32, int64, empty, promote_types, diagonal, + absolute, asanyarray, arange, zeros, greater_equal, multiply, ones, + asarray, where, int8, int16, int32, int64, empty, promote_types, diagonal, ) -from numpy.core import iinfo +from numpy.core import iinfo, transpose __all__ = [ - 'diag', 'diagflat', 'eye', 'fliplr', 'flipud', 'rot90', 'tri', 'triu', + 'diag', 'diagflat', 'eye', 'fliplr', 'flipud', 'tri', 'triu', 'tril', 'vander', 'histogram2d', 'mask_indices', 'tril_indices', 'tril_indices_from', 'triu_indices', 'triu_indices_from', ] @@ -136,59 +136,6 @@ def flipud(m): return m[::-1, ...] -def rot90(m, k=1): - """ - Rotate an array by 90 degrees in the counter-clockwise direction. - - The first two dimensions are rotated; therefore, the array must be at - least 2-D. - - Parameters - ---------- - m : array_like - Array of two or more dimensions. - k : integer - Number of times the array is rotated by 90 degrees. - - Returns - ------- - y : ndarray - Rotated array. - - See Also - -------- - fliplr : Flip an array horizontally. - flipud : Flip an array vertically. - - Examples - -------- - >>> m = np.array([[1,2],[3,4]], int) - >>> m - array([[1, 2], - [3, 4]]) - >>> np.rot90(m) - array([[2, 4], - [1, 3]]) - >>> np.rot90(m, 2) - array([[4, 3], - [2, 1]]) - - """ - m = asanyarray(m) - if m.ndim < 2: - raise ValueError("Input must >= 2-d.") - k = k % 4 - if k == 0: - return m - elif k == 1: - return fliplr(m).swapaxes(0, 1) - elif k == 2: - return fliplr(flipud(m)) - else: - # k == 3 - return fliplr(m.swapaxes(0, 1)) - - def eye(N, M=None, k=0, dtype=float): """ Return a 2-D array with ones on the diagonal and zeros elsewhere. |