summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py68
1 files changed, 46 insertions, 22 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 7f1206855..11a95fa7b 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1,9 +1,11 @@
from __future__ import division, absolute_import, print_function
-import sys
+import collections
+import itertools
import operator
+import sys
import warnings
-import collections
+
from numpy.core import multiarray
from . import umath
from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
@@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None):
----------
a : array_like
Input array.
- shift : int
- The number of places by which elements are shifted.
- axis : int, optional
- The axis along which elements are shifted. By default, the array
- is flattened before shifting, after which the original
+ shift : int or tuple of ints
+ The number of places by which elements are shifted. If a tuple,
+ then `axis` must be a tuple of the same size, and each of the
+ given axes is shifted by the corresponding number. If an int
+ while `axis` is a tuple of ints, then the same value is used for
+ all given axes.
+ axis : int or tuple of ints, optional
+ Axis or axes along which elements are shifted. By default, the
+ array is flattened before shifting, after which the original
shape is restored.
Returns
@@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None):
rollaxis : Roll the specified axis backwards, until it lies in a
given position.
+ Notes
+ -----
+ .. versionadded:: 1.12.0
+
+ Supports rolling over multiple dimensions simultaneously.
+
Examples
--------
>>> x = np.arange(10)
@@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None):
"""
a = asanyarray(a)
if axis is None:
- n = a.size
- reshape = True
+ return roll(a.ravel(), shift, 0).reshape(a.shape)
+
else:
- try:
- n = a.shape[axis]
- except IndexError:
- raise ValueError('axis must be >= 0 and < %d' % a.ndim)
- reshape = False
- if n == 0:
- return a
- shift %= n
- indexes = concatenate((arange(n - shift, n), arange(n - shift)))
- res = a.take(indexes, axis)
- if reshape:
- res = res.reshape(a.shape)
- return res
+ broadcasted = broadcast(shift, axis)
+ if len(broadcasted.shape) > 1:
+ raise ValueError(
+ "'shift' and 'axis' should be scalars or 1D sequences")
+ shifts = {ax: 0 for ax in range(a.ndim)}
+ for sh, ax in broadcasted:
+ if -a.ndim <= ax < a.ndim:
+ shifts[ax % a.ndim] += sh
+ else:
+ raise ValueError("'axis' entry is out of bounds")
+
+ rolls = [((slice(None), slice(None)),)] * a.ndim
+ for ax, offset in shifts.items():
+ offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters.
+ if offset:
+ # (original, result), (original, result)
+ rolls[ax] = ((slice(None, -offset), slice(offset, None)),
+ (slice(-offset, None), slice(None, offset)))
+
+ result = empty_like(a)
+ for indices in itertools.product(*rolls):
+ arr_index, res_index = zip(*indices)
+ result[res_index] = a[arr_index]
+
+ return result
def rollaxis(a, axis, start=0):