diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-05-22 00:45:27 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-05-22 00:58:15 -0700 |
commit | 5608636c360050e865abe4fcbe3511f84db9591b (patch) | |
tree | a26e2191f8a3bcdcb9c27f65e9595b4be030b4d3 /numpy/lib/arraypad.py | |
parent | 651e9cf480f043bb4a591a1d74c88721f01b3216 (diff) | |
download | numpy-5608636c360050e865abe4fcbe3511f84db9591b.tar.gz |
MAINT: np.pad: Generalize the helper function to be used in more places
This makes `_slice_first` almost a factor of two faster
Diffstat (limited to 'numpy/lib/arraypad.py')
-rw-r--r-- | numpy/lib/arraypad.py | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py index 9a4c1350a..9f6b75560 100644 --- a/numpy/lib/arraypad.py +++ b/numpy/lib/arraypad.py @@ -74,16 +74,23 @@ def _round_ifneeded(arr, dtype): arr.round(out=arr) +def _slice_at_axis(shape, sl, axis): + """ + Construct a slice tuple the length of shape, with sl at the specified axis + """ + slice_tup = (slice(None),) + return slice_tup * axis + (sl,) + slice_tup * (len(shape) - axis - 1) + + def _slice_first(shape, n, axis): """ Construct a slice tuple to take the first n elements along axis """ - return tuple(slice(None) if i != axis else slice(0, n) - for (i, x) in enumerate(shape)) + return _slice_at_axis(shape, slice(0, n), axis=axis) def _slice_last(shape, n, axis): """ Construct a slice tuple to take the last n elements along axis """ - return tuple(slice(None) if i != axis else slice(x - n, x) - for (i, x) in enumerate(shape)) + dim = shape[axis] # doing this explicitly makes n=0 work + return _slice_at_axis(shape, slice(dim - n, dim), axis=axis) def _prepend_const(arr, pad_amt, val, axis=-1): @@ -729,8 +736,7 @@ def _pad_ref(arr, pad_amt, method, axis=-1): # Prepended region # Slice off a reverse indexed chunk from near edge to pad `arr` before - ref_slice = tuple(slice(None) if i != axis else slice(pad_amt[0], 0, -1) - for (i, x) in enumerate(arr.shape)) + ref_slice = _slice_at_axis(arr.shape, slice(pad_amt[0], 0, -1), axis=axis) ref_chunk1 = arr[ref_slice] @@ -747,10 +753,8 @@ def _pad_ref(arr, pad_amt, method, axis=-1): # Slice off a reverse indexed chunk from far edge to pad `arr` after start = arr.shape[axis] - pad_amt[1] - 1 end = arr.shape[axis] - 1 - ref_slice = tuple(slice(None) if i != axis else slice(start, end) - for (i, x) in enumerate(arr.shape)) - rev_idx = tuple(slice(None) if i != axis else slice(None, None, -1) - for (i, x) in enumerate(arr.shape)) + ref_slice = _slice_at_axis(arr.shape, slice(start, end), axis=axis) + rev_idx = _slice_at_axis(arr.shape, slice(None, None, -1), axis=axis) ref_chunk2 = arr[ref_slice][rev_idx] if 'odd' in method: @@ -804,8 +808,7 @@ def _pad_sym(arr, pad_amt, method, axis=-1): # Slice off a reverse indexed chunk from near edge to pad `arr` before sym_slice = _slice_first(arr.shape, pad_amt[0], axis=axis) - rev_idx = tuple(slice(None) if i != axis else slice(None, None, -1) - for (i, x) in enumerate(arr.shape)) + rev_idx = _slice_at_axis(arr.shape, slice(None, None, -1), axis=axis) sym_chunk1 = arr[sym_slice][rev_idx] # Memory/computationally more expensive, only do this if `method='odd'` |