summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTyler Reddy <tyler.je.reddy@gmail.com>2021-02-01 09:01:17 -0700
committerTyler Reddy <tyler.je.reddy@gmail.com>2021-02-01 09:01:17 -0700
commit044550d603b0d7bbe201d703354dd4857bb3606d (patch)
tree20b75aee1d7ced4353a072f15d662afb64b2abd9 /numpy
parent2d6c55528956e2f9233364657bc865eb57603514 (diff)
downloadnumpy-044550d603b0d7bbe201d703354dd4857bb3606d.tar.gz
MAINT: gracefully shuffle memoryviews
* allow graceful shuffling of memoryviews, with same behavior as arrays, instead of producing a warning on `memoryview` shuffle
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/_generator.pyx3
-rw-r--r--numpy/random/mtrand.pyx3
-rw-r--r--numpy/random/tests/test_random.py15
3 files changed, 21 insertions, 0 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index 0a41f13b6..297642940 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -4398,6 +4398,9 @@ cdef class Generator:
char* x_ptr
char* buf_ptr
+ if isinstance(x, memoryview):
+ x = np.asarray(x)
+
axis = normalize_axis_index(axis, np.ndim(x))
if type(x) is np.ndarray and x.ndim == 1 and x.size:
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 814630c03..daab2c6f1 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -4436,6 +4436,9 @@ cdef class RandomState:
char* x_ptr
char* buf_ptr
+ if isinstance(x, memoryview):
+ x = np.asarray(x)
+
if type(x) is np.ndarray and x.ndim == 1 and x.size:
# Fast, statically typed path: shuffle the underlying buffer.
# Only for non-empty, 1d objects of class ndarray (subclasses such
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index c13fc39e3..5f8b39ef9 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -510,6 +510,21 @@ class TestRandomDist:
assert_equal(
sorted(b.data[~b.mask]), sorted(b_orig.data[~b_orig.mask]))
+ def test_shuffle_memoryview(self):
+ # gh-18273
+ # allow graceful handling of memoryviews
+ # (treat the same as arrays)
+ np.random.seed(self.seed)
+ a = np.arange(5).data
+ np.random.shuffle(a)
+ assert_equal(np.asarray(a), [0, 1, 4, 3, 2])
+ rng = np.random.RandomState(self.seed)
+ rng.shuffle(a)
+ assert_equal(np.asarray(a), [0, 1, 2, 3, 4])
+ rng = np.random.default_rng(self.seed)
+ rng.shuffle(a)
+ assert_equal(np.asarray(a), [4, 1, 0, 3, 2])
+
def test_beta(self):
np.random.seed(self.seed)
actual = np.random.beta(.1, .9, size=(3, 2))