summaryrefslogtreecommitdiff
path: root/numpy/lib/arraypad.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/arraypad.py')
-rw-r--r--numpy/lib/arraypad.py27
1 files changed, 15 insertions, 12 deletions
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py
index 9a4c1350a..9f6b75560 100644
--- a/numpy/lib/arraypad.py
+++ b/numpy/lib/arraypad.py
@@ -74,16 +74,23 @@ def _round_ifneeded(arr, dtype):
arr.round(out=arr)
+def _slice_at_axis(shape, sl, axis):
+ """
+ Construct a slice tuple the length of shape, with sl at the specified axis
+ """
+ slice_tup = (slice(None),)
+ return slice_tup * axis + (sl,) + slice_tup * (len(shape) - axis - 1)
+
+
def _slice_first(shape, n, axis):
""" Construct a slice tuple to take the first n elements along axis """
- return tuple(slice(None) if i != axis else slice(0, n)
- for (i, x) in enumerate(shape))
+ return _slice_at_axis(shape, slice(0, n), axis=axis)
def _slice_last(shape, n, axis):
""" Construct a slice tuple to take the last n elements along axis """
- return tuple(slice(None) if i != axis else slice(x - n, x)
- for (i, x) in enumerate(shape))
+ dim = shape[axis] # doing this explicitly makes n=0 work
+ return _slice_at_axis(shape, slice(dim - n, dim), axis=axis)
def _prepend_const(arr, pad_amt, val, axis=-1):
@@ -729,8 +736,7 @@ def _pad_ref(arr, pad_amt, method, axis=-1):
# Prepended region
# Slice off a reverse indexed chunk from near edge to pad `arr` before
- ref_slice = tuple(slice(None) if i != axis else slice(pad_amt[0], 0, -1)
- for (i, x) in enumerate(arr.shape))
+ ref_slice = _slice_at_axis(arr.shape, slice(pad_amt[0], 0, -1), axis=axis)
ref_chunk1 = arr[ref_slice]
@@ -747,10 +753,8 @@ def _pad_ref(arr, pad_amt, method, axis=-1):
# Slice off a reverse indexed chunk from far edge to pad `arr` after
start = arr.shape[axis] - pad_amt[1] - 1
end = arr.shape[axis] - 1
- ref_slice = tuple(slice(None) if i != axis else slice(start, end)
- for (i, x) in enumerate(arr.shape))
- rev_idx = tuple(slice(None) if i != axis else slice(None, None, -1)
- for (i, x) in enumerate(arr.shape))
+ ref_slice = _slice_at_axis(arr.shape, slice(start, end), axis=axis)
+ rev_idx = _slice_at_axis(arr.shape, slice(None, None, -1), axis=axis)
ref_chunk2 = arr[ref_slice][rev_idx]
if 'odd' in method:
@@ -804,8 +808,7 @@ def _pad_sym(arr, pad_amt, method, axis=-1):
# Slice off a reverse indexed chunk from near edge to pad `arr` before
sym_slice = _slice_first(arr.shape, pad_amt[0], axis=axis)
- rev_idx = tuple(slice(None) if i != axis else slice(None, None, -1)
- for (i, x) in enumerate(arr.shape))
+ rev_idx = _slice_at_axis(arr.shape, slice(None, None, -1), axis=axis)
sym_chunk1 = arr[sym_slice][rev_idx]
# Memory/computationally more expensive, only do this if `method='odd'`