summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-05-26 16:27:57 -0500
committerGitHub <noreply@github.com>2020-05-26 16:27:57 -0500
commit171ea2a856e7b40d0806501613f7361e09f0a2f8 (patch)
treea6fdd57a2cd60398c66bf98c439f29f94b1d68e6 /numpy/core
parentc02feb9c70befed4cc078902d28f9df37d9f984f (diff)
parent7ef6b65fa960fcedcef8a27b1c3ef8c67ebb2078 (diff)
downloadnumpy-171ea2a856e7b40d0806501613f7361e09f0a2f8.tar.gz
Merge pull request #15037 from tirthasheshpatel/patch-1
BUG: `np.resize` negative shape and subclasses edge case fixes
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/fromnumeric.py30
-rw-r--r--numpy/core/tests/test_numeric.py28
2 files changed, 46 insertions, 12 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 7193af839..0c63bcf73 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -1374,10 +1374,17 @@ def resize(a, new_shape):
See Also
--------
+ np.reshape : Reshape an array without changing the total size.
+ np.pad : Enlarge and pad an array.
+ np.repeat: Repeat elements of an array.
ndarray.resize : resize an array in-place.
Notes
-----
+ When the total size of the array does not change `~numpy.reshape` should
+ be used. In most other cases either indexing (to reduce the size)
+ or padding (to increase the size) may be a more appropriate solution.
+
Warning: This functionality does **not** consider axes separately,
i.e. it does not apply interpolation/extrapolation.
It fills the return array with the required number of elements, taken
@@ -1401,22 +1408,21 @@ def resize(a, new_shape):
"""
if isinstance(new_shape, (int, nt.integer)):
new_shape = (new_shape,)
+
a = ravel(a)
- Na = len(a)
- total_size = um.multiply.reduce(new_shape)
- if Na == 0 or total_size == 0:
- return mu.zeros(new_shape, a.dtype)
- n_copies = int(total_size / Na)
- extra = total_size % Na
+ new_size = 1
+ for dim_length in new_shape:
+ new_size *= dim_length
+ if dim_length < 0:
+ raise ValueError('all elements of `new_shape` must be non-negative')
- if extra != 0:
- n_copies = n_copies + 1
- extra = Na - extra
+ if a.size == 0 or new_size == 0:
+ # First case must zero fill. The second would have repeats == 0.
+ return np.zeros_like(a, shape=new_shape)
- a = concatenate((a,) * n_copies)
- if extra > 0:
- a = a[:-extra]
+ repeats = -(-new_size // a.size) # ceil division
+ a = concatenate((a,) * repeats)[:new_size]
return reshape(a, new_shape)
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index acd442e2f..2a87ffaf8 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -31,6 +31,17 @@ class TestResize:
Ar3 = np.array([[1, 2, 3], [4, 1, 2], [3, 4, 1], [2, 3, 4]])
assert_equal(np.resize(A, (4, 3)), Ar3)
+ def test_repeats(self):
+ A = np.array([1, 2, 3])
+ Ar1 = np.array([[1, 2, 3, 1], [2, 3, 1, 2]])
+ assert_equal(np.resize(A, (2, 4)), Ar1)
+
+ Ar2 = np.array([[1, 2], [3, 1], [2, 3], [1, 2]])
+ assert_equal(np.resize(A, (4, 2)), Ar2)
+
+ Ar3 = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
+ assert_equal(np.resize(A, (4, 3)), Ar3)
+
def test_zeroresize(self):
A = np.array([[1, 2], [3, 4]])
Ar = np.resize(A, (0,))
@@ -50,6 +61,23 @@ class TestResize:
assert_array_equal(Ar, np.zeros((2, 1), Ar.dtype))
assert_equal(A.dtype, Ar.dtype)
+ def test_negative_resize(self):
+ A = np.arange(0, 10, dtype=np.float32)
+ new_shape = (-10, -1)
+ with pytest.raises(ValueError, match=r"negative"):
+ np.resize(A, new_shape=new_shape)
+
+ def test_subclass(self):
+ class MyArray(np.ndarray):
+ __array_priority__ = 1.
+
+ my_arr = np.array([1]).view(MyArray)
+ assert type(np.resize(my_arr, 5)) is MyArray
+ assert type(np.resize(my_arr, 0)) is MyArray
+
+ my_arr = np.array([]).view(MyArray)
+ assert type(np.resize(my_arr, 5)) is MyArray
+
class TestNonarrayArgs:
# check that non-array arguments to functions wrap them in arrays