summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index e1352ee6c..afa10747c 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None):
return roll(a.ravel(), shift, 0).reshape(a.shape)
else:
+ axis = _validate_axis(axis, a.ndim, allow_duplicate=True)
broadcasted = broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError(
"'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(a.ndim)}
for sh, ax in broadcasted:
- if -a.ndim <= ax < a.ndim:
- shifts[ax % a.ndim] += sh
- else:
- raise ValueError("'axis' entry is out of bounds")
+ shifts[ax] += sh
rolls = [((slice(None), slice(None)),)] * a.ndim
for ax, offset in shifts.items():
@@ -1544,13 +1542,13 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
-def _validate_axis(axis, ndim, argname=None):
+def _validate_axis(axis, ndim, argname=None, allow_duplicate=False):
try:
axis = [operator.index(axis)]
except TypeError:
axis = list(axis)
axis = [normalize_axis_index(ax, ndim, argname) for ax in axis]
- if len(set(axis)) != len(axis):
+ if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
raise ValueError('repeated axis in `%s` argument' % argname)
else: