summaryrefslogtreecommitdiff
path: root/numpy/lib/arraypad.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-05-22 00:45:27 -0700
committerEric Wieser <wieser.eric@gmail.com>2018-05-22 00:58:15 -0700
commit5608636c360050e865abe4fcbe3511f84db9591b (patch)
treea26e2191f8a3bcdcb9c27f65e9595b4be030b4d3 /numpy/lib/arraypad.py
parent651e9cf480f043bb4a591a1d74c88721f01b3216 (diff)
downloadnumpy-5608636c360050e865abe4fcbe3511f84db9591b.tar.gz
MAINT: np.pad: Generalize the helper function to be used in more places
This makes `_slice_first` almost a factor of two faster
Diffstat (limited to 'numpy/lib/arraypad.py')
-rw-r--r--numpy/lib/arraypad.py27
1 files changed, 15 insertions, 12 deletions
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py
index 9a4c1350a..9f6b75560 100644
--- a/numpy/lib/arraypad.py
+++ b/numpy/lib/arraypad.py
@@ -74,16 +74,23 @@ def _round_ifneeded(arr, dtype):
arr.round(out=arr)
+def _slice_at_axis(shape, sl, axis):
+ """
+ Construct a slice tuple the length of shape, with sl at the specified axis
+ """
+ slice_tup = (slice(None),)
+ return slice_tup * axis + (sl,) + slice_tup * (len(shape) - axis - 1)
+
+
def _slice_first(shape, n, axis):
""" Construct a slice tuple to take the first n elements along axis """
- return tuple(slice(None) if i != axis else slice(0, n)
- for (i, x) in enumerate(shape))
+ return _slice_at_axis(shape, slice(0, n), axis=axis)
def _slice_last(shape, n, axis):
""" Construct a slice tuple to take the last n elements along axis """
- return tuple(slice(None) if i != axis else slice(x - n, x)
- for (i, x) in enumerate(shape))
+ dim = shape[axis] # doing this explicitly makes n=0 work
+ return _slice_at_axis(shape, slice(dim - n, dim), axis=axis)
def _prepend_const(arr, pad_amt, val, axis=-1):
@@ -729,8 +736,7 @@ def _pad_ref(arr, pad_amt, method, axis=-1):
# Prepended region
# Slice off a reverse indexed chunk from near edge to pad `arr` before
- ref_slice = tuple(slice(None) if i != axis else slice(pad_amt[0], 0, -1)
- for (i, x) in enumerate(arr.shape))
+ ref_slice = _slice_at_axis(arr.shape, slice(pad_amt[0], 0, -1), axis=axis)
ref_chunk1 = arr[ref_slice]
@@ -747,10 +753,8 @@ def _pad_ref(arr, pad_amt, method, axis=-1):
# Slice off a reverse indexed chunk from far edge to pad `arr` after
start = arr.shape[axis] - pad_amt[1] - 1
end = arr.shape[axis] - 1
- ref_slice = tuple(slice(None) if i != axis else slice(start, end)
- for (i, x) in enumerate(arr.shape))
- rev_idx = tuple(slice(None) if i != axis else slice(None, None, -1)
- for (i, x) in enumerate(arr.shape))
+ ref_slice = _slice_at_axis(arr.shape, slice(start, end), axis=axis)
+ rev_idx = _slice_at_axis(arr.shape, slice(None, None, -1), axis=axis)
ref_chunk2 = arr[ref_slice][rev_idx]
if 'odd' in method:
@@ -804,8 +808,7 @@ def _pad_sym(arr, pad_amt, method, axis=-1):
# Slice off a reverse indexed chunk from near edge to pad `arr` before
sym_slice = _slice_first(arr.shape, pad_amt[0], axis=axis)
- rev_idx = tuple(slice(None) if i != axis else slice(None, None, -1)
- for (i, x) in enumerate(arr.shape))
+ rev_idx = _slice_at_axis(arr.shape, slice(None, None, -1), axis=axis)
sym_chunk1 = arr[sym_slice][rev_idx]
# Memory/computationally more expensive, only do this if `method='odd'`