summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-07-06 11:10:22 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-07-06 12:40:25 -0400
commit5c42cc31fa1ac6bfaed50d7069d7bb8e8fd8bb53 (patch)
treeea54a46186621588509b8645a559ba9aa7115cdb /numpy/core/numeric.py
parent739443679b50b43c34808b8fb767bac643fcd91d (diff)
downloadnumpy-5c42cc31fa1ac6bfaed50d7069d7bb8e8fd8bb53.tar.gz
MAINT: Speed up normalize_axis_tuple by about 30%
This is used in `np.moveaxis`, which is relatively slow. Timings change as follows: ``` import numpy as np from numpy.core.numeric import normalize_axis_tuple %timeit normalize_axis_tuple([0, 1, 2], 4, 'parrot') a = np.zeros((10,20,30)) %timeit np.moveaxis(a, [0,1,2], [2,1,0]) 100000 loops, best of 3: 5.56 µs per loop -> 4.17 µs ```
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index b49a7f551..a9fa3efe0 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1509,11 +1509,14 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
--------
normalize_axis_index : normalizing a single scalar axis
"""
- try:
- axis = [operator.index(axis)]
- except TypeError:
- axis = tuple(axis)
- axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
+ # Speed-up most common cases.
+ if not isinstance(axis, (list, tuple)):
+ try:
+ axis = [operator.index(axis)]
+ except TypeError:
+ pass
+ # Going via an iterator directly is slower than via list comprehension.
+ axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
raise ValueError('repeated axis in `{}` argument'.format(argname))