diff options
author | Matti Picus <matti.picus@gmail.com> | 2019-06-03 20:48:53 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-03 20:48:53 +0300 |
commit | a9d09163b0e65e5f0de584d03585a4d9f5b47a64 (patch) | |
tree | c4b270c1e13fac450c0bc623ca8132375d7e87e0 /numpy | |
parent | 52d173d9087f6fb09994b433a127f58d19875c98 (diff) | |
parent | 2dd6f048fb0985e76018d87ccbce70445dcfa2b7 (diff) | |
download | numpy-a9d09163b0e65e5f0de584d03585a4d9f5b47a64.tar.gz |
Merge pull request #13695 from bashtage/choice-windows-int32-dtype
BUG: Ensure Windows choice returns int32
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/random/mtrand.pyx | 6 | ||||
-rw-r--r-- | numpy/random/tests/test_randomstate_regression.py | 11 |
2 files changed, 15 insertions, 2 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index e97af4596..4184d1ebe 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -807,7 +807,9 @@ cdef class RandomState: cdf /= cdf[-1] uniform_samples = self.random_sample(shape) idx = cdf.searchsorted(uniform_samples, side='right') - idx = np.array(idx, copy=False) # searchsorted returns a scalar + # searchsorted returns a scalar + # force cast to int for LLP64 + idx = np.array(idx, copy=False).astype(int, casting='unsafe') else: idx = self.randint(0, pop_size, size=shape) else: @@ -822,7 +824,7 @@ cdef class RandomState: raise ValueError("Fewer non-zero entries in p than size") n_uniq = 0 p = p.copy() - found = np.zeros(shape, dtype=np.int64) + found = np.zeros(shape, dtype=int) flat_found = found.ravel() while n_uniq < size: x = self.rand(size - n_uniq) diff --git a/numpy/random/tests/test_randomstate_regression.py b/numpy/random/tests/test_randomstate_regression.py index cdd905929..29870534a 100644 --- a/numpy/random/tests/test_randomstate_regression.py +++ b/numpy/random/tests/test_randomstate_regression.py @@ -170,3 +170,14 @@ class TestRegression(object): rs1 = np.random.RandomState(123456789) rs2 = np.random.RandomState(seed=123456789) assert rs1.randint(0, 100) == rs2.randint(0, 100) + + def test_choice_retun_dtype(self): + # GH 9867 + c = np.random.choice(10, p=[.1]*10, size=2) + assert c.dtype == np.dtype(int) + c = np.random.choice(10, p=[.1]*10, replace=False, size=2) + assert c.dtype == np.dtype(int) + c = np.random.choice(10, size=2) + assert c.dtype == np.dtype(int) + c = np.random.choice(10, replace=False, size=2) + assert c.dtype == np.dtype(int) |