diff options
author | Antony Lee <anntzer.lee@gmail.com> | 2016-03-19 18:50:55 -0700 |
---|---|---|
committer | Antony Lee <anntzer.lee@gmail.com> | 2016-04-02 12:19:09 -0700 |
commit | 612caca61f82e3e20117ec5917d71fd0f48c42ea (patch) | |
tree | 216258cdc9d3a91d2947165a858340d93689856b /numpy/core/numeric.py | |
parent | 2af06c804931aae4b30bb3349bc60271b0b65381 (diff) | |
download | numpy-612caca61f82e3e20117ec5917d71fd0f48c42ea.tar.gz |
ENH: Allow rolling multiple axes at the same time.
A quick test suggests that this implementation from @seberg, relying on
slices rather than index arrays, is 1.5~3x faster than the previous (1D)
roll (depending on the axis).
Also switched the error message for invalid inputs to match the one of
ufuncs, because the axis can actually also be negative.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 68 |
1 files changed, 46 insertions, 22 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 7f1206855..11a95fa7b 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1,9 +1,11 @@ from __future__ import division, absolute_import, print_function -import sys +import collections +import itertools import operator +import sys import warnings -import collections + from numpy.core import multiarray from . import umath from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, @@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None): ---------- a : array_like Input array. - shift : int - The number of places by which elements are shifted. - axis : int, optional - The axis along which elements are shifted. By default, the array - is flattened before shifting, after which the original + shift : int or tuple of ints + The number of places by which elements are shifted. If a tuple, + then `axis` must be a tuple of the same size, and each of the + given axes is shifted by the corresponding number. If an int + while `axis` is a tuple of ints, then the same value is used for + all given axes. + axis : int or tuple of ints, optional + Axis or axes along which elements are shifted. By default, the + array is flattened before shifting, after which the original shape is restored. Returns @@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None): rollaxis : Roll the specified axis backwards, until it lies in a given position. + Notes + ----- + .. versionadded:: 1.12.0 + + Supports rolling over multiple dimensions simultaneously. + Examples -------- >>> x = np.arange(10) @@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None): """ a = asanyarray(a) if axis is None: - n = a.size - reshape = True + return roll(a.ravel(), shift, 0).reshape(a.shape) + else: - try: - n = a.shape[axis] - except IndexError: - raise ValueError('axis must be >= 0 and < %d' % a.ndim) - reshape = False - if n == 0: - return a - shift %= n - indexes = concatenate((arange(n - shift, n), arange(n - shift))) - res = a.take(indexes, axis) - if reshape: - res = res.reshape(a.shape) - return res + broadcasted = broadcast(shift, axis) + if len(broadcasted.shape) > 1: + raise ValueError( + "'shift' and 'axis' should be scalars or 1D sequences") + shifts = {ax: 0 for ax in range(a.ndim)} + for sh, ax in broadcasted: + if -a.ndim <= ax < a.ndim: + shifts[ax % a.ndim] += sh + else: + raise ValueError("'axis' entry is out of bounds") + + rolls = [((slice(None), slice(None)),)] * a.ndim + for ax, offset in shifts.items(): + offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters. + if offset: + # (original, result), (original, result) + rolls[ax] = ((slice(None, -offset), slice(offset, None)), + (slice(-offset, None), slice(None, offset))) + + result = empty_like(a) + for indices in itertools.product(*rolls): + arr_index, res_index = zip(*indices) + result[res_index] = a[arr_index] + + return result def rollaxis(a, axis, start=0): |