diff options
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 68 |
1 files changed, 46 insertions, 22 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 7f1206855..11a95fa7b 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1,9 +1,11 @@ from __future__ import division, absolute_import, print_function -import sys +import collections +import itertools import operator +import sys import warnings -import collections + from numpy.core import multiarray from . import umath from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, @@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None): ---------- a : array_like Input array. - shift : int - The number of places by which elements are shifted. - axis : int, optional - The axis along which elements are shifted. By default, the array - is flattened before shifting, after which the original + shift : int or tuple of ints + The number of places by which elements are shifted. If a tuple, + then `axis` must be a tuple of the same size, and each of the + given axes is shifted by the corresponding number. If an int + while `axis` is a tuple of ints, then the same value is used for + all given axes. + axis : int or tuple of ints, optional + Axis or axes along which elements are shifted. By default, the + array is flattened before shifting, after which the original shape is restored. Returns @@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None): rollaxis : Roll the specified axis backwards, until it lies in a given position. + Notes + ----- + .. versionadded:: 1.12.0 + + Supports rolling over multiple dimensions simultaneously. + Examples -------- >>> x = np.arange(10) @@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None): """ a = asanyarray(a) if axis is None: - n = a.size - reshape = True + return roll(a.ravel(), shift, 0).reshape(a.shape) + else: - try: - n = a.shape[axis] - except IndexError: - raise ValueError('axis must be >= 0 and < %d' % a.ndim) - reshape = False - if n == 0: - return a - shift %= n - indexes = concatenate((arange(n - shift, n), arange(n - shift))) - res = a.take(indexes, axis) - if reshape: - res = res.reshape(a.shape) - return res + broadcasted = broadcast(shift, axis) + if len(broadcasted.shape) > 1: + raise ValueError( + "'shift' and 'axis' should be scalars or 1D sequences") + shifts = {ax: 0 for ax in range(a.ndim)} + for sh, ax in broadcasted: + if -a.ndim <= ax < a.ndim: + shifts[ax % a.ndim] += sh + else: + raise ValueError("'axis' entry is out of bounds") + + rolls = [((slice(None), slice(None)),)] * a.ndim + for ax, offset in shifts.items(): + offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters. + if offset: + # (original, result), (original, result) + rolls[ax] = ((slice(None, -offset), slice(offset, None)), + (slice(-offset, None), slice(None, offset))) + + result = empty_like(a) + for indices in itertools.product(*rolls): + arr_index, res_index = zip(*indices) + result[res_index] = a[arr_index] + + return result def rollaxis(a, axis, start=0): |