diff options
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index e1352ee6c..afa10747c 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None): return roll(a.ravel(), shift, 0).reshape(a.shape) else: + axis = _validate_axis(axis, a.ndim, allow_duplicate=True) broadcasted = broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError( "'shift' and 'axis' should be scalars or 1D sequences") shifts = {ax: 0 for ax in range(a.ndim)} for sh, ax in broadcasted: - if -a.ndim <= ax < a.ndim: - shifts[ax % a.ndim] += sh - else: - raise ValueError("'axis' entry is out of bounds") + shifts[ax] += sh rolls = [((slice(None), slice(None)),)] * a.ndim for ax, offset in shifts.items(): @@ -1544,13 +1542,13 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname=None): +def _validate_axis(axis, ndim, argname=None, allow_duplicate=False): try: axis = [operator.index(axis)] except TypeError: axis = list(axis) axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] - if len(set(axis)) != len(axis): + if not allow_duplicate and len(set(axis)) != len(axis): if argname: raise ValueError('repeated axis in `%s` argument' % argname) else: |