diff options
-rw-r--r-- | numpy/random/_generator.pyx | 5 | ||||
-rw-r--r-- | numpy/random/mtrand.pyx | 3 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 13 | ||||
-rw-r--r-- | numpy/random/tests/test_random.py | 6 |
4 files changed, 27 insertions, 0 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index 391987a1e..5dc761af9 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -32,6 +32,8 @@ from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE, cdef extern from "numpy/arrayobject.h": int PyArray_ResolveWritebackIfCopy(np.ndarray) + int PyArray_FailUnlessWriteable(np.PyArrayObject *obj, + const char *name) except -1 object PyArray_FromArray(np.PyArrayObject *, np.PyArray_Descr *, int) enum: @@ -4416,6 +4418,7 @@ cdef class Generator: else: if type(out) is not np.ndarray: raise TypeError('out must be a numpy array') + PyArray_FailUnlessWriteable(<np.PyArrayObject *>out, "out") if out.shape != x.shape: raise ValueError('out must have the same shape as x') np.copyto(out, x, casting='safe') @@ -4524,6 +4527,8 @@ cdef class Generator: char* buf_ptr if isinstance(x, np.ndarray): + if not x.flags.writeable: + raise ValueError('array is read-only') # Only call ndim on ndarrays, see GH 18142 axis = normalize_axis_index(axis, np.ndim(x)) diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index ce09a041c..d8f7c500f 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -4483,6 +4483,9 @@ cdef class RandomState: char* x_ptr char* buf_ptr + if isinstance(x, np.ndarray) and not x.flags.writeable: + raise ValueError('array is read-only') + if type(x) is np.ndarray and x.ndim == 1 and x.size: # Fast, statically typed path: shuffle the underlying buffer. # Only for non-empty, 1d objects of class ndarray (subclasses such diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index e5411b8ef..e16a82973 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1020,6 +1020,13 @@ class TestRandomDist: arr = np.ones((3, 2)) assert_raises(np.AxisError, random.shuffle, arr, 2) + def test_shuffle_not_writeable(self): + random = Generator(MT19937(self.seed)) + a = np.zeros(5) + a.flags.writeable = False + with pytest.raises(ValueError, match='read-only'): + random.shuffle(a) + def test_permutation(self): random = Generator(MT19937(self.seed)) alist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] @@ -1116,6 +1123,12 @@ class TestRandomDist: with pytest.raises(TypeError, match='Cannot cast'): random.permuted(x, axis=1, out=out) + def test_permuted_not_writeable(self): + x = np.zeros((2, 5)) + x.flags.writeable = False + with pytest.raises(ValueError, match='read-only'): + random.permuted(x, axis=1, out=x) + def test_beta(self): random = Generator(MT19937(self.seed)) actual = random.beta(.1, .9, size=(3, 2)) diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 6a584a511..773b63653 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -564,6 +564,12 @@ class TestRandomDist: rng.shuffle(a) assert_equal(np.asarray(a), [4, 1, 0, 3, 2]) + def test_shuffle_not_writeable(self): + a = np.zeros(3) + a.flags.writeable = False + with pytest.raises(ValueError, match='read-only'): + np.random.shuffle(a) + def test_beta(self): np.random.seed(self.seed) actual = np.random.beta(.1, .9, size=(3, 2)) |