summaryrefslogtreecommitdiff
path: root/numpy/lib/arraypad.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/arraypad.py')
-rw-r--r--numpy/lib/arraypad.py35
1 files changed, 6 insertions, 29 deletions
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py
index 3cfb1052a..15e3ed957 100644
--- a/numpy/lib/arraypad.py
+++ b/numpy/lib/arraypad.py
@@ -1030,43 +1030,20 @@ def _normalize_shape(ndarray, shape, cast_to_int=True):
return ((None, None), ) * ndims
# Convert any input `info` to a NumPy array
- arr = np.asarray(shape)
-
- # Switch based on what input looks like
- if arr.ndim <= 1:
- if arr.shape == () or arr.shape == (1,):
- # Single scalar input
- # Create new array of ones, multiply by the scalar
- arr = np.ones((ndims, 2), dtype=ndarray.dtype) * arr
- elif arr.shape == (2,):
- # Apply padding (before, after) each axis
- # Create new axis 0, repeat along it for every axis
- arr = arr[np.newaxis, :].repeat(ndims, axis=0)
- else:
- fmt = "Unable to create correctly shaped tuple from %s"
- raise ValueError(fmt % (shape,))
-
- elif arr.ndim == 2:
- if arr.shape[1] == 1 and arr.shape[0] == ndims:
- # Padded before and after by the same amount
- arr = arr.repeat(2, axis=1)
- elif arr.shape[0] == ndims:
- # Input correctly formatted, pass it on as `arr`
- pass
- else:
- fmt = "Unable to create correctly shaped tuple from %s"
- raise ValueError(fmt % (shape,))
+ shape_arr = np.asarray(shape)
- else:
+ try:
+ shape_arr = np.broadcast_to(shape_arr, (ndims, 2))
+ except ValueError:
fmt = "Unable to create correctly shaped tuple from %s"
raise ValueError(fmt % (shape,))
# Cast if necessary
if cast_to_int is True:
- arr = np.round(arr).astype(int)
+ shape_arr = np.round(shape_arr).astype(int)
# Convert list of lists to tuple of tuples
- return tuple(tuple(axis) for axis in arr.tolist())
+ return tuple(tuple(axis) for axis in shape_arr.tolist())
def _validate_lengths(narray, number_elements):