diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-26 10:50:54 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-03-28 20:44:45 +0100 |
commit | b6850e9c00055666fe9bd5f984750743ffcb85d2 (patch) | |
tree | c51de944d0c6f8926140dc5564bd6eac66746671 /numpy/core/numeric.py | |
parent | efa1bd2c0de6cd9c07b410002f2e1c0bae8cf385 (diff) | |
download | numpy-b6850e9c00055666fe9bd5f984750743ffcb85d2.tar.gz |
MAINT: use normalize_axis_index in np.roll
Interestingly np.roll allows axes to be repeated.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index e1352ee6c..afa10747c 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None): return roll(a.ravel(), shift, 0).reshape(a.shape) else: + axis = _validate_axis(axis, a.ndim, allow_duplicate=True) broadcasted = broadcast(shift, axis) if broadcasted.ndim > 1: raise ValueError( "'shift' and 'axis' should be scalars or 1D sequences") shifts = {ax: 0 for ax in range(a.ndim)} for sh, ax in broadcasted: - if -a.ndim <= ax < a.ndim: - shifts[ax % a.ndim] += sh - else: - raise ValueError("'axis' entry is out of bounds") + shifts[ax] += sh rolls = [((slice(None), slice(None)),)] * a.ndim for ax, offset in shifts.items(): @@ -1544,13 +1542,13 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) -def _validate_axis(axis, ndim, argname=None): +def _validate_axis(axis, ndim, argname=None, allow_duplicate=False): try: axis = [operator.index(axis)] except TypeError: axis = list(axis) axis = [normalize_axis_index(ax, ndim, argname) for ax in axis] - if len(set(axis)) != len(axis): + if not allow_duplicate and len(set(axis)) != len(axis): if argname: raise ValueError('repeated axis in `%s` argument' % argname) else: |