diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-07-09 00:13:21 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-07-09 00:13:21 -0400 |
commit | a52634d7ea38f74ba286580f96c35190b2b3b549 (patch) | |
tree | 46017cbba9f2b68434f5b5e0025aac054f6e5db7 /numpy/core/numeric.py | |
parent | 0da547de39fd1d335d7ad34c6f1e7f153ea72a9b (diff) | |
parent | 4bed228adba88f9c5b85eaad308582db79a95b7b (diff) | |
download | numpy-a52634d7ea38f74ba286580f96c35190b2b3b549.tar.gz |
Merge pull request #11518 from mhvk/normalize-axis-tuple-speedup
MAINT: Speed up normalize_axis_tuple by about 30%
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b49a7f551..e5570791a 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1509,11 +1509,14 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): -------- normalize_axis_index : normalizing a single scalar axis """ - try: - axis = [operator.index(axis)] - except TypeError: - axis = tuple(axis) - axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) + # Optimization to speed-up the most common cases. + if type(axis) not in (tuple, list): + try: + axis = [operator.index(axis)] + except TypeError: + pass + # Going via an iterator directly is slower than via list comprehension. + axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) if not allow_duplicate and len(set(axis)) != len(axis): if argname: raise ValueError('repeated axis in `{}` argument'.format(argname)) |