summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorAntony Lee <anntzer.lee@gmail.com>2016-03-19 18:50:55 -0700
committerAntony Lee <anntzer.lee@gmail.com>2016-04-02 12:19:09 -0700
commit612caca61f82e3e20117ec5917d71fd0f48c42ea (patch)
tree216258cdc9d3a91d2947165a858340d93689856b /numpy/core/numeric.py
parent2af06c804931aae4b30bb3349bc60271b0b65381 (diff)
downloadnumpy-612caca61f82e3e20117ec5917d71fd0f48c42ea.tar.gz
ENH: Allow rolling multiple axes at the same time.
A quick test suggests that this implementation from @seberg, relying on slices rather than index arrays, is 1.5~3x faster than the previous (1D) roll (depending on the axis). Also switched the error message for invalid inputs to match the one of ufuncs, because the axis can actually also be negative.
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py68
1 files changed, 46 insertions, 22 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 7f1206855..11a95fa7b 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1,9 +1,11 @@
from __future__ import division, absolute_import, print_function
-import sys
+import collections
+import itertools
import operator
+import sys
import warnings
-import collections
+
from numpy.core import multiarray
from . import umath
from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
@@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None):
----------
a : array_like
Input array.
- shift : int
- The number of places by which elements are shifted.
- axis : int, optional
- The axis along which elements are shifted. By default, the array
- is flattened before shifting, after which the original
+ shift : int or tuple of ints
+ The number of places by which elements are shifted. If a tuple,
+ then `axis` must be a tuple of the same size, and each of the
+ given axes is shifted by the corresponding number. If an int
+ while `axis` is a tuple of ints, then the same value is used for
+ all given axes.
+ axis : int or tuple of ints, optional
+ Axis or axes along which elements are shifted. By default, the
+ array is flattened before shifting, after which the original
shape is restored.
Returns
@@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None):
rollaxis : Roll the specified axis backwards, until it lies in a
given position.
+ Notes
+ -----
+ .. versionadded:: 1.12.0
+
+ Supports rolling over multiple dimensions simultaneously.
+
Examples
--------
>>> x = np.arange(10)
@@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None):
"""
a = asanyarray(a)
if axis is None:
- n = a.size
- reshape = True
+ return roll(a.ravel(), shift, 0).reshape(a.shape)
+
else:
- try:
- n = a.shape[axis]
- except IndexError:
- raise ValueError('axis must be >= 0 and < %d' % a.ndim)
- reshape = False
- if n == 0:
- return a
- shift %= n
- indexes = concatenate((arange(n - shift, n), arange(n - shift)))
- res = a.take(indexes, axis)
- if reshape:
- res = res.reshape(a.shape)
- return res
+ broadcasted = broadcast(shift, axis)
+ if len(broadcasted.shape) > 1:
+ raise ValueError(
+ "'shift' and 'axis' should be scalars or 1D sequences")
+ shifts = {ax: 0 for ax in range(a.ndim)}
+ for sh, ax in broadcasted:
+ if -a.ndim <= ax < a.ndim:
+ shifts[ax % a.ndim] += sh
+ else:
+ raise ValueError("'axis' entry is out of bounds")
+
+ rolls = [((slice(None), slice(None)),)] * a.ndim
+ for ax, offset in shifts.items():
+ offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters.
+ if offset:
+ # (original, result), (original, result)
+ rolls[ax] = ((slice(None, -offset), slice(offset, None)),
+ (slice(-offset, None), slice(None, offset)))
+
+ result = empty_like(a)
+ for indices in itertools.product(*rolls):
+ arr_index, res_index = zip(*indices)
+ result[res_index] = a[arr_index]
+
+ return result
def rollaxis(a, axis, start=0):