diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2013-04-01 15:25:27 -0700 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-04-01 15:25:27 -0700 |
commit | a6385b3a99e508e0d2ed5a6397da29d05da27ceb (patch) | |
tree | 61f036ade95db2dd8ffcd761674c6f1fcde74a58 /numpy | |
parent | 5ad97ea70b126a1a13abc4d7500f0281ba4ddd50 (diff) | |
parent | 8d83ae93706d3486447a9b40908995b406178111 (diff) | |
download | numpy-a6385b3a99e508e0d2ed5a6397da29d05da27ceb.tar.gz |
Merge pull request #273 from dlax/fix/roll
somes fixes for the roll function
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 14 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 22 |
2 files changed, 31 insertions, 5 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 57e366efb..fd7f06ca6 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1153,15 +1153,19 @@ def roll(a, shift, axis=None): n = a.size reshape = True else: - n = a.shape[axis] + try: + n = a.shape[axis] + except IndexError: + raise ValueError('axis must be >= 0 and < %d' % a.ndim) reshape = False + if n == 0: + return a shift %= n - indexes = concatenate((arange(n-shift,n),arange(n-shift))) + indexes = concatenate((arange(n - shift, n), arange(n - shift))) res = a.take(indexes, axis) if reshape: - return res.reshape(a.shape) - else: - return res + res = res.reshape(a.shape) + return res def rollaxis(a, axis, start=0): """ diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 6d3cbe923..ae9c26a36 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1504,5 +1504,27 @@ class TestStringFunction(object): np.set_string_function(None, repr=False) assert_equal(str(a), "[1]") +class TestRoll(TestCase): + def test_roll1d(self): + x = np.arange(10) + xr = np.roll(x, 2) + assert_equal(xr, np.array([8, 9, 0, 1, 2, 3, 4, 5, 6, 7])) + + def test_roll2d(self): + x2 = np.reshape(np.arange(10), (2,5)) + x2r = np.roll(x2, 1) + assert_equal(x2r, np.array([[9, 0, 1, 2, 3], [4, 5, 6, 7, 8]])) + + x2r = np.roll(x2, 1, axis=0) + assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]])) + + x2r = np.roll(x2, 1, axis=1) + assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]])) + + def test_roll_empty(self): + x = np.array([]) + assert_equal(np.roll(x, 1), np.array([])) + + if __name__ == "__main__": run_module_suite() |