summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-26 10:50:54 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-03-28 20:44:45 +0100
commitb6850e9c00055666fe9bd5f984750743ffcb85d2 (patch)
treec51de944d0c6f8926140dc5564bd6eac66746671 /numpy/core/numeric.py
parentefa1bd2c0de6cd9c07b410002f2e1c0bae8cf385 (diff)
downloadnumpy-b6850e9c00055666fe9bd5f984750743ffcb85d2.tar.gz
MAINT: use normalize_axis_index in np.roll
Interestingly np.roll allows axes to be repeated.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index e1352ee6c..afa10747c 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None):
return roll(a.ravel(), shift, 0).reshape(a.shape)
else:
+ axis = _validate_axis(axis, a.ndim, allow_duplicate=True)
broadcasted = broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError(
"'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(a.ndim)}
for sh, ax in broadcasted:
- if -a.ndim <= ax < a.ndim:
- shifts[ax % a.ndim] += sh
- else:
- raise ValueError("'axis' entry is out of bounds")
+ shifts[ax] += sh
rolls = [((slice(None), slice(None)),)] * a.ndim
for ax, offset in shifts.items():
@@ -1544,13 +1542,13 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
-def _validate_axis(axis, ndim, argname=None):
+def _validate_axis(axis, ndim, argname=None, allow_duplicate=False):
try:
axis = [operator.index(axis)]
except TypeError:
axis = list(axis)
axis = [normalize_axis_index(ax, ndim, argname) for ax in axis]
- if len(set(axis)) != len(axis):
+ if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
raise ValueError('repeated axis in `%s` argument' % argname)
else: