diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 4 | ||||
-rw-r--r-- | numpy/random/tests/test_regression.py | 22 |
2 files changed, 24 insertions, 2 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index 5097ad88f..ab5f64336 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -4907,8 +4907,8 @@ cdef class RandomState: # shuffle has fast-path for 1-d if arr.ndim == 1: - # must return a copy - if arr is x: + # Return a copy if same memory + if np.may_share_memory(arr, x): arr = np.array(arr) self.shuffle(arr) return arr diff --git a/numpy/random/tests/test_regression.py b/numpy/random/tests/test_regression.py index 3b4b4ed40..ca9bbbc71 100644 --- a/numpy/random/tests/test_regression.py +++ b/numpy/random/tests/test_regression.py @@ -133,3 +133,25 @@ class TestRegression(object): # Force Garbage Collection - should not segfault. import gc gc.collect() + + def test_permutation_subclass(self): + class N(np.ndarray): + pass + + np.random.seed(1) + orig = np.arange(3).view(N) + perm = np.random.permutation(orig) + assert_array_equal(perm, np.array([0, 2, 1])) + assert_array_equal(orig, np.arange(3).view(N)) + + class M(object): + a = np.arange(5) + + def __array__(self): + return self.a + + np.random.seed(1) + m = M() + perm = np.random.permutation(m) + assert_array_equal(perm, np.array([2, 1, 4, 0, 3])) + assert_array_equal(m.__array__(), np.arange(5)) |