diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-07-06 11:10:22 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-07-06 12:40:25 -0400 |
commit | 5c42cc31fa1ac6bfaed50d7069d7bb8e8fd8bb53 (patch) | |
tree | ea54a46186621588509b8645a559ba9aa7115cdb /numpy/core/numeric.py | |
parent | 739443679b50b43c34808b8fb767bac643fcd91d (diff) | |
download | numpy-5c42cc31fa1ac6bfaed50d7069d7bb8e8fd8bb53.tar.gz |
MAINT: Speed up normalize_axis_tuple by about 30%
This is used in `np.moveaxis`, which is relatively slow.
Timings change as follows:
```
import numpy as np
from numpy.core.numeric import normalize_axis_tuple
%timeit normalize_axis_tuple([0, 1, 2], 4, 'parrot')
a = np.zeros((10,20,30))
%timeit np.moveaxis(a, [0,1,2], [2,1,0])
100000 loops, best of 3: 5.56 µs per loop -> 4.17 µs
```
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b49a7f551..a9fa3efe0 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1509,11 +1509,14 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): -------- normalize_axis_index : normalizing a single scalar axis """ - try: - axis = [operator.index(axis)] - except TypeError: - axis = tuple(axis) - axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) + # Speed-up most common cases. + if not isinstance(axis, (list, tuple)): + try: + axis = [operator.index(axis)] + except TypeError: + pass + # Going via an iterator directly is slower than via list comprehension. + axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) if not allow_duplicate and len(set(axis)) != len(axis): if argname: raise ValueError('repeated axis in `{}` argument'.format(argname)) |