From 5c42cc31fa1ac6bfaed50d7069d7bb8e8fd8bb53 Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Fri, 6 Jul 2018 11:10:22 -0400 Subject: MAINT: Speed up normalize_axis_tuple by about 30% MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is used in `np.moveaxis`, which is relatively slow. Timings change as follows: ``` import numpy as np from numpy.core.numeric import normalize_axis_tuple %timeit normalize_axis_tuple([0, 1, 2], 4, 'parrot') a = np.zeros((10,20,30)) %timeit np.moveaxis(a, [0,1,2], [2,1,0]) 100000 loops, best of 3: 5.56 µs per loop -> 4.17 µs ``` --- numpy/core/numeric.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b49a7f551..a9fa3efe0 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1509,11 +1509,14 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): -------- normalize_axis_index : normalizing a single scalar axis """ - try: - axis = [operator.index(axis)] - except TypeError: - axis = tuple(axis) - axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) + # Speed-up most common cases. + if not isinstance(axis, (list, tuple)): + try: + axis = [operator.index(axis)] + except TypeError: + pass + # Going via an iterator directly is slower than via list comprehension. + axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) if not allow_duplicate and len(set(axis)) != len(axis): if argname: raise ValueError('repeated axis in `{}` argument'.format(argname)) -- cgit v1.2.1 From 4bed228adba88f9c5b85eaad308582db79a95b7b Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Sun, 8 Jul 2018 17:38:33 +0100 Subject: MAINT: Restore previous behavior on subclasses --- numpy/core/numeric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'numpy/core/numeric.py') diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index a9fa3efe0..e5570791a 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1509,8 +1509,8 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): -------- normalize_axis_index : normalizing a single scalar axis """ - # Speed-up most common cases. - if not isinstance(axis, (list, tuple)): + # Optimization to speed-up the most common cases. + if type(axis) not in (tuple, list): try: axis = [operator.index(axis)] except TypeError: -- cgit v1.2.1