summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@climate.com>2015-11-04 20:08:57 -0800
committerStephan Hoyer <shoyer@climate.com>2016-01-09 15:56:13 -0800
commit8ffde7f488eb583ed2a200702e85a6518c4f94ec (patch)
treefc7a3eed899e6677fbbf6f40f54b03a0257bed3f /numpy/core/numeric.py
parent8a76291c76aa9b33b18ceb06f6e8f37f990c9e27 (diff)
downloadnumpy-8ffde7f488eb583ed2a200702e85a6518c4f94ec.tar.gz
ENH: moveaxis function
Fixes GH2039 This function provides a much more intuitive interface than `np.rollaxis`, which has a confusing behavior with the position of the `start` argument: http://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing It was independently suggested several times over the years after discussions on the mailing list and GitHub (GH2039), but never made it into a pull request: https://mail.scipy.org/pipermail/numpy-discussion/2010-September/052882.html My version adds support for a sequence of axis arguments. I find this behavior to be very useful. It is often more intuitive than supplying a list of arguments to `transpose` and also nicely generalizes NumPy's existing axis manipulation routines, e.g., def transpose(a, order=None): if order is None: order = reversed(range(a.ndim)) return moveaxes(a, order, range(a.ndim)) def swapaxes(a, axis1, axis2): return moveaxes(a, [axis1, axis2], [axis2, axis1]) def rollaxis(a, axis, start=0): if axis < start: start -= 1 return moveaxes(a, axis, start)
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py107
1 files changed, 98 insertions, 9 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 4f3d418e6..a18b38072 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function
import sys
+import operator
import warnings
import collections
from numpy.core import multiarray
@@ -15,8 +16,10 @@ from ._internal import TooHardError
if sys.version_info[0] >= 3:
import pickle
basestring = str
+ import builtins
else:
import cPickle as pickle
+ import __builtin__ as builtins
loads = pickle.loads
@@ -31,15 +34,15 @@ __all__ = [
'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like',
'zeros_like', 'ones_like', 'correlate', 'convolve', 'inner', 'dot',
'einsum', 'outer', 'vdot', 'alterdot', 'restoredot', 'roll',
- 'rollaxis', 'cross', 'tensordot', 'array2string', 'get_printoptions',
- 'set_printoptions', 'array_repr', 'array_str', 'set_string_function',
- 'little_endian', 'require', 'fromiter', 'array_equal', 'array_equiv',
- 'indices', 'fromfunction', 'isclose', 'load', 'loads', 'isscalar',
- 'binary_repr', 'base_repr', 'ones', 'identity', 'allclose',
- 'compare_chararrays', 'putmask', 'seterr', 'geterr', 'setbufsize',
- 'getbufsize', 'seterrcall', 'geterrcall', 'errstate', 'flatnonzero',
- 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_',
- 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
+ 'rollaxis', 'moveaxis', 'cross', 'tensordot', 'array2string',
+ 'get_printoptions', 'set_printoptions', 'array_repr', 'array_str',
+ 'set_string_function', 'little_endian', 'require', 'fromiter',
+ 'array_equal', 'array_equiv', 'indices', 'fromfunction', 'isclose', 'load',
+ 'loads', 'isscalar', 'binary_repr', 'base_repr', 'ones', 'identity',
+ 'allclose', 'compare_chararrays', 'putmask', 'seterr', 'geterr',
+ 'setbufsize', 'getbufsize', 'seterrcall', 'geterrcall', 'errstate',
+ 'flatnonzero', 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_',
+ 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
'TooHardError',
@@ -1422,6 +1425,7 @@ def rollaxis(a, axis, start=0):
See Also
--------
+ moveaxis : Move array axes to new positions.
roll : Roll the elements of an array by a number of positions along a
given axis.
@@ -1457,6 +1461,91 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
+def _validate_axis(axis, ndim, argname):
+ try:
+ axis = [operator.index(axis)]
+ except TypeError:
+ axis = list(axis)
+ axis = [a + ndim if a < 0 else a for a in axis]
+ if not builtins.all(0 <= a < ndim for a in axis):
+ raise ValueError('invalid axis for this array in `%s` argument' %
+ argname)
+ if len(set(axis)) != len(axis):
+ raise ValueError('repeated axis in `%s` argument' % argname)
+ return axis
+
+
+def moveaxis(a, source, destination):
+ """
+ Move axes of an array to new positions.
+
+ Other axes remain in their original order.
+
+ .. versionadded::1.11.0
+
+ Parameters
+ ----------
+ a : np.ndarray
+ The array whose axes should be reordered.
+ source : int or sequence of int
+ Original positions of the axes to move. These must be unique.
+ destination : int or sequence of int
+ Destination positions for each of the original axes. These must also be
+ unique.
+
+ Returns
+ -------
+ result : np.ndarray
+ Array with moved axes. This array is a view of the input array.
+
+ See Also
+ --------
+ transpose: Permute the dimensions of an array.
+ swapaxes: Interchange two axes of an array.
+
+ Examples
+ --------
+
+ >>> x = np.zeros((3, 4, 5))
+ >>> np.moveaxis(x, 0, -1).shape
+ (4, 5, 3)
+ >>> np.moveaxis(x, -1, 0).shape
+ (5, 3, 4)
+
+ These all achieve the same result:
+
+ >>> np.transpose(x).shape
+ (5, 4, 3)
+ >>> np.swapaxis(x, 0, -1).shape
+ (5, 4, 3)
+ >>> np.moveaxis(x, [0, 1], [-1, -2]).shape
+ (5, 4, 3)
+ >>> np.moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape
+ (5, 4, 3)
+
+ """
+ try:
+ # allow duck-array types if they define transpose
+ transpose = a.transpose
+ except AttributeError:
+ a = asarray(a)
+ transpose = a.transpose
+
+ source = _validate_axis(source, a.ndim, 'source')
+ destination = _validate_axis(destination, a.ndim, 'destination')
+ if len(source) != len(destination):
+ raise ValueError('`source` and `destination` arguments must have '
+ 'the same number of elements')
+
+ order = [n for n in range(a.ndim) if n not in source]
+
+ for dest, src in sorted(zip(destination, source)):
+ order.insert(dest, src)
+
+ result = transpose(order)
+ return result
+
+
# fix hack in scipy which imports this function
def _move_axis_to_0(a, axis):
return rollaxis(a, axis, 0)