From e94500bb40827729332db7aafd9dec0e0205f077 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Thu, 30 Nov 2017 00:16:39 -0800 Subject: MAINT: Simplify _leading_trailing Removes the list comprehensions in favor of numpy primitives. Now recurses over the indices, not the values. --- numpy/core/arrayprint.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) (limited to 'numpy/core/arrayprint.py') diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index e4be810b9..c4d761059 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -272,22 +272,27 @@ def get_printoptions(): """ return _format_options.copy() -def _leading_trailing(a): + +def _leading_trailing(a, index=()): + """ + Keep only the N-D corners (leading and trailing edges) of an array. + + Should be passed a base-class ndarray, since it makes no guarantees about + preserving subclasses. + """ edgeitems = _format_options['edgeitems'] - if a.ndim == 1: - if len(a) > 2*edgeitems: - b = concatenate((a[:edgeitems], a[-edgeitems:])) - else: - b = a + axis = len(index) + if axis == a.ndim: + return a[index] + + if a.shape[axis] > 2*edgeitems: + return concatenate(( + _leading_trailing(a, index + np.index_exp[ :edgeitems]), + _leading_trailing(a, index + np.index_exp[-edgeitems:]) + ), axis=axis) else: - if len(a) > 2*edgeitems: - l = [_leading_trailing(a[i]) for i in range(min(len(a), edgeitems))] - l.extend([_leading_trailing(a[-i]) for i in range( - min(len(a), edgeitems), 0, -1)]) - else: - l = [_leading_trailing(a[i]) for i in range(0, len(a))] - b = concatenate(tuple(l)) - return b + return _leading_trailing(a, index + np.index_exp[:]) + def _object_format(o): """ Object arrays containing lists should be printed unambiguously """ -- cgit v1.2.1