summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-07-09 00:13:21 -0400
committerGitHub <noreply@github.com>2018-07-09 00:13:21 -0400
commita52634d7ea38f74ba286580f96c35190b2b3b549 (patch)
tree46017cbba9f2b68434f5b5e0025aac054f6e5db7 /numpy/core/numeric.py
parent0da547de39fd1d335d7ad34c6f1e7f153ea72a9b (diff)
parent4bed228adba88f9c5b85eaad308582db79a95b7b (diff)
downloadnumpy-a52634d7ea38f74ba286580f96c35190b2b3b549.tar.gz
Merge pull request #11518 from mhvk/normalize-axis-tuple-speedup
MAINT: Speed up normalize_axis_tuple by about 30%
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index b49a7f551..e5570791a 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1509,11 +1509,14 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
--------
normalize_axis_index : normalizing a single scalar axis
"""
- try:
- axis = [operator.index(axis)]
- except TypeError:
- axis = tuple(axis)
- axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
+ # Optimization to speed-up the most common cases.
+ if type(axis) not in (tuple, list):
+ try:
+ axis = [operator.index(axis)]
+ except TypeError:
+ pass
+ # Going via an iterator directly is slower than via list comprehension.
+ axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
raise ValueError('repeated axis in `{}` argument'.format(argname))