diff options
author | Warren Weckesser <warren.weckesser@gmail.com> | 2016-09-17 18:22:37 -0400 |
---|---|---|
committer | Warren Weckesser <warren.weckesser@gmail.com> | 2016-09-17 18:25:09 -0400 |
commit | f28957d2a9411958a6952687be9c4dcd1be2a23a (patch) | |
tree | 6ce535d9d2a34737054af6a2cbb8e160a7c512b1 | |
parent | c90d7c94fd2077d0beca48fa89a423da2b0bb663 (diff) | |
download | numpy-f28957d2a9411958a6952687be9c4dcd1be2a23a.tar.gz |
BUG: lib: Simplify (and fix) pad's handling of the pad_width
Simplify the expansion of the pad_width argument by using `broadcast_to()`.
This fixes the problem reported in gh-7808, where, for example,
`pad_width=((1, 2),)` resulted in an error.
Closes gh-7808.
-rw-r--r-- | numpy/lib/arraypad.py | 35 | ||||
-rw-r--r-- | numpy/lib/tests/test_arraypad.py | 18 |
2 files changed, 24 insertions, 29 deletions
diff --git a/numpy/lib/arraypad.py b/numpy/lib/arraypad.py index 3cfb1052a..15e3ed957 100644 --- a/numpy/lib/arraypad.py +++ b/numpy/lib/arraypad.py @@ -1030,43 +1030,20 @@ def _normalize_shape(ndarray, shape, cast_to_int=True): return ((None, None), ) * ndims # Convert any input `info` to a NumPy array - arr = np.asarray(shape) - - # Switch based on what input looks like - if arr.ndim <= 1: - if arr.shape == () or arr.shape == (1,): - # Single scalar input - # Create new array of ones, multiply by the scalar - arr = np.ones((ndims, 2), dtype=ndarray.dtype) * arr - elif arr.shape == (2,): - # Apply padding (before, after) each axis - # Create new axis 0, repeat along it for every axis - arr = arr[np.newaxis, :].repeat(ndims, axis=0) - else: - fmt = "Unable to create correctly shaped tuple from %s" - raise ValueError(fmt % (shape,)) - - elif arr.ndim == 2: - if arr.shape[1] == 1 and arr.shape[0] == ndims: - # Padded before and after by the same amount - arr = arr.repeat(2, axis=1) - elif arr.shape[0] == ndims: - # Input correctly formatted, pass it on as `arr` - pass - else: - fmt = "Unable to create correctly shaped tuple from %s" - raise ValueError(fmt % (shape,)) + shape_arr = np.asarray(shape) - else: + try: + shape_arr = np.broadcast_to(shape_arr, (ndims, 2)) + except ValueError: fmt = "Unable to create correctly shaped tuple from %s" raise ValueError(fmt % (shape,)) # Cast if necessary if cast_to_int is True: - arr = np.round(arr).astype(int) + shape_arr = np.round(shape_arr).astype(int) # Convert list of lists to tuple of tuples - return tuple(tuple(axis) for axis in arr.tolist()) + return tuple(tuple(axis) for axis in shape_arr.tolist()) def _validate_lengths(narray, number_elements): diff --git a/numpy/lib/tests/test_arraypad.py b/numpy/lib/tests/test_arraypad.py index 9ad05906d..d037962e6 100644 --- a/numpy/lib/tests/test_arraypad.py +++ b/numpy/lib/tests/test_arraypad.py @@ -914,6 +914,24 @@ class TestEdge(TestCase): ) assert_array_equal(a, b) + def test_check_width_shape_1_2(self): + # Check a pad_width of the form ((1, 2),). + # Regression test for issue gh-7808. + a = np.array([1, 2, 3]) + padded = pad(a, ((1, 2),), 'edge') + expected = np.array([1, 1, 2, 3, 3, 3]) + assert_array_equal(padded, expected) + + a = np.array([[1, 2, 3], [4, 5, 6]]) + padded = pad(a, ((1, 2),), 'edge') + expected = pad(a, ((1, 2), (1, 2)), 'edge') + assert_array_equal(padded, expected) + + a = np.arange(24).reshape(2, 3, 4) + padded = pad(a, ((1, 2),), 'edge') + expected = pad(a, ((1, 2), (1, 2), (1, 2)), 'edge') + assert_array_equal(padded, expected) + class TestZeroPadWidth(TestCase): def test_zero_pad_width(self): |