diff options
author | Stephan Hoyer <shoyer@climate.com> | 2015-11-04 20:08:57 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@climate.com> | 2016-01-09 15:56:13 -0800 |
commit | 8ffde7f488eb583ed2a200702e85a6518c4f94ec (patch) | |
tree | fc7a3eed899e6677fbbf6f40f54b03a0257bed3f /numpy/core/numeric.py | |
parent | 8a76291c76aa9b33b18ceb06f6e8f37f990c9e27 (diff) | |
download | numpy-8ffde7f488eb583ed2a200702e85a6518c4f94ec.tar.gz |
ENH: moveaxis function
Fixes GH2039
This function provides a much more intuitive interface than `np.rollaxis`,
which has a confusing behavior with the position of the `start` argument:
http://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing
It was independently suggested several times over the years after discussions
on the mailing list and GitHub (GH2039), but never made it into a pull request:
https://mail.scipy.org/pipermail/numpy-discussion/2010-September/052882.html
My version adds support for a sequence of axis arguments. I find this behavior
to be very useful. It is often more intuitive than supplying a list of
arguments to `transpose` and also nicely generalizes NumPy's existing axis
manipulation routines, e.g.,
def transpose(a, order=None):
if order is None:
order = reversed(range(a.ndim))
return moveaxes(a, order, range(a.ndim))
def swapaxes(a, axis1, axis2):
return moveaxes(a, [axis1, axis2], [axis2, axis1])
def rollaxis(a, axis, start=0):
if axis < start:
start -= 1
return moveaxes(a, axis, start)
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 107 |
1 files changed, 98 insertions, 9 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 4f3d418e6..a18b38072 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1,6 +1,7 @@ from __future__ import division, absolute_import, print_function import sys +import operator import warnings import collections from numpy.core import multiarray @@ -15,8 +16,10 @@ from ._internal import TooHardError if sys.version_info[0] >= 3: import pickle basestring = str + import builtins else: import cPickle as pickle + import __builtin__ as builtins loads = pickle.loads @@ -31,15 +34,15 @@ __all__ = [ 'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like', 'zeros_like', 'ones_like', 'correlate', 'convolve', 'inner', 'dot', 'einsum', 'outer', 'vdot', 'alterdot', 'restoredot', 'roll', - 'rollaxis', 'cross', 'tensordot', 'array2string', 'get_printoptions', - 'set_printoptions', 'array_repr', 'array_str', 'set_string_function', - 'little_endian', 'require', 'fromiter', 'array_equal', 'array_equiv', - 'indices', 'fromfunction', 'isclose', 'load', 'loads', 'isscalar', - 'binary_repr', 'base_repr', 'ones', 'identity', 'allclose', - 'compare_chararrays', 'putmask', 'seterr', 'geterr', 'setbufsize', - 'getbufsize', 'seterrcall', 'geterrcall', 'errstate', 'flatnonzero', - 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_', - 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', + 'rollaxis', 'moveaxis', 'cross', 'tensordot', 'array2string', + 'get_printoptions', 'set_printoptions', 'array_repr', 'array_str', + 'set_string_function', 'little_endian', 'require', 'fromiter', + 'array_equal', 'array_equiv', 'indices', 'fromfunction', 'isclose', 'load', + 'loads', 'isscalar', 'binary_repr', 'base_repr', 'ones', 'identity', + 'allclose', 'compare_chararrays', 'putmask', 'seterr', 'geterr', + 'setbufsize', 'getbufsize', 'seterrcall', 'geterrcall', 'errstate', + 'flatnonzero', 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', + 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'TooHardError', @@ -1422,6 +1425,7 @@ def rollaxis(a, axis, start=0): See Also -------- + moveaxis : Move array axes to new positions. roll : Roll the elements of an array by a number of positions along a given axis. @@ -1457,6 +1461,91 @@ def rollaxis(a, axis, start=0): return a.transpose(axes) +def _validate_axis(axis, ndim, argname): + try: + axis = [operator.index(axis)] + except TypeError: + axis = list(axis) + axis = [a + ndim if a < 0 else a for a in axis] + if not builtins.all(0 <= a < ndim for a in axis): + raise ValueError('invalid axis for this array in `%s` argument' % + argname) + if len(set(axis)) != len(axis): + raise ValueError('repeated axis in `%s` argument' % argname) + return axis + + +def moveaxis(a, source, destination): + """ + Move axes of an array to new positions. + + Other axes remain in their original order. + + .. versionadded::1.11.0 + + Parameters + ---------- + a : np.ndarray + The array whose axes should be reordered. + source : int or sequence of int + Original positions of the axes to move. These must be unique. + destination : int or sequence of int + Destination positions for each of the original axes. These must also be + unique. + + Returns + ------- + result : np.ndarray + Array with moved axes. This array is a view of the input array. + + See Also + -------- + transpose: Permute the dimensions of an array. + swapaxes: Interchange two axes of an array. + + Examples + -------- + + >>> x = np.zeros((3, 4, 5)) + >>> np.moveaxis(x, 0, -1).shape + (4, 5, 3) + >>> np.moveaxis(x, -1, 0).shape + (5, 3, 4) + + These all achieve the same result: + + >>> np.transpose(x).shape + (5, 4, 3) + >>> np.swapaxis(x, 0, -1).shape + (5, 4, 3) + >>> np.moveaxis(x, [0, 1], [-1, -2]).shape + (5, 4, 3) + >>> np.moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape + (5, 4, 3) + + """ + try: + # allow duck-array types if they define transpose + transpose = a.transpose + except AttributeError: + a = asarray(a) + transpose = a.transpose + + source = _validate_axis(source, a.ndim, 'source') + destination = _validate_axis(destination, a.ndim, 'destination') + if len(source) != len(destination): + raise ValueError('`source` and `destination` arguments must have ' + 'the same number of elements') + + order = [n for n in range(a.ndim) if n not in source] + + for dest, src in sorted(zip(destination, source)): + order.insert(dest, src) + + result = transpose(order) + return result + + # fix hack in scipy which imports this function def _move_axis_to_0(a, axis): return rollaxis(a, axis, 0) |