summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-04-01 15:25:27 -0700
committerCharles Harris <charlesr.harris@gmail.com>2013-04-01 15:25:27 -0700
commita6385b3a99e508e0d2ed5a6397da29d05da27ceb (patch)
tree61f036ade95db2dd8ffcd761674c6f1fcde74a58 /numpy
parent5ad97ea70b126a1a13abc4d7500f0281ba4ddd50 (diff)
parent8d83ae93706d3486447a9b40908995b406178111 (diff)
downloadnumpy-a6385b3a99e508e0d2ed5a6397da29d05da27ceb.tar.gz
Merge pull request #273 from dlax/fix/roll
somes fixes for the roll function
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/numeric.py14
-rw-r--r--numpy/core/tests/test_numeric.py22
2 files changed, 31 insertions, 5 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 57e366efb..fd7f06ca6 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1153,15 +1153,19 @@ def roll(a, shift, axis=None):
n = a.size
reshape = True
else:
- n = a.shape[axis]
+ try:
+ n = a.shape[axis]
+ except IndexError:
+ raise ValueError('axis must be >= 0 and < %d' % a.ndim)
reshape = False
+ if n == 0:
+ return a
shift %= n
- indexes = concatenate((arange(n-shift,n),arange(n-shift)))
+ indexes = concatenate((arange(n - shift, n), arange(n - shift)))
res = a.take(indexes, axis)
if reshape:
- return res.reshape(a.shape)
- else:
- return res
+ res = res.reshape(a.shape)
+ return res
def rollaxis(a, axis, start=0):
"""
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 6d3cbe923..ae9c26a36 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1504,5 +1504,27 @@ class TestStringFunction(object):
np.set_string_function(None, repr=False)
assert_equal(str(a), "[1]")
+class TestRoll(TestCase):
+ def test_roll1d(self):
+ x = np.arange(10)
+ xr = np.roll(x, 2)
+ assert_equal(xr, np.array([8, 9, 0, 1, 2, 3, 4, 5, 6, 7]))
+
+ def test_roll2d(self):
+ x2 = np.reshape(np.arange(10), (2,5))
+ x2r = np.roll(x2, 1)
+ assert_equal(x2r, np.array([[9, 0, 1, 2, 3], [4, 5, 6, 7, 8]]))
+
+ x2r = np.roll(x2, 1, axis=0)
+ assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]))
+
+ x2r = np.roll(x2, 1, axis=1)
+ assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
+
+ def test_roll_empty(self):
+ x = np.array([])
+ assert_equal(np.roll(x, 1), np.array([]))
+
+
if __name__ == "__main__":
run_module_suite()