summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2019-06-03 20:48:53 +0300
committerGitHub <noreply@github.com>2019-06-03 20:48:53 +0300
commita9d09163b0e65e5f0de584d03585a4d9f5b47a64 (patch)
treec4b270c1e13fac450c0bc623ca8132375d7e87e0 /numpy
parent52d173d9087f6fb09994b433a127f58d19875c98 (diff)
parent2dd6f048fb0985e76018d87ccbce70445dcfa2b7 (diff)
downloadnumpy-a9d09163b0e65e5f0de584d03585a4d9f5b47a64.tar.gz
Merge pull request #13695 from bashtage/choice-windows-int32-dtype
BUG: Ensure Windows choice returns int32
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/mtrand.pyx6
-rw-r--r--numpy/random/tests/test_randomstate_regression.py11
2 files changed, 15 insertions, 2 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index e97af4596..4184d1ebe 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -807,7 +807,9 @@ cdef class RandomState:
cdf /= cdf[-1]
uniform_samples = self.random_sample(shape)
idx = cdf.searchsorted(uniform_samples, side='right')
- idx = np.array(idx, copy=False) # searchsorted returns a scalar
+ # searchsorted returns a scalar
+ # force cast to int for LLP64
+ idx = np.array(idx, copy=False).astype(int, casting='unsafe')
else:
idx = self.randint(0, pop_size, size=shape)
else:
@@ -822,7 +824,7 @@ cdef class RandomState:
raise ValueError("Fewer non-zero entries in p than size")
n_uniq = 0
p = p.copy()
- found = np.zeros(shape, dtype=np.int64)
+ found = np.zeros(shape, dtype=int)
flat_found = found.ravel()
while n_uniq < size:
x = self.rand(size - n_uniq)
diff --git a/numpy/random/tests/test_randomstate_regression.py b/numpy/random/tests/test_randomstate_regression.py
index cdd905929..29870534a 100644
--- a/numpy/random/tests/test_randomstate_regression.py
+++ b/numpy/random/tests/test_randomstate_regression.py
@@ -170,3 +170,14 @@ class TestRegression(object):
rs1 = np.random.RandomState(123456789)
rs2 = np.random.RandomState(seed=123456789)
assert rs1.randint(0, 100) == rs2.randint(0, 100)
+
+ def test_choice_retun_dtype(self):
+ # GH 9867
+ c = np.random.choice(10, p=[.1]*10, size=2)
+ assert c.dtype == np.dtype(int)
+ c = np.random.choice(10, p=[.1]*10, replace=False, size=2)
+ assert c.dtype == np.dtype(int)
+ c = np.random.choice(10, size=2)
+ assert c.dtype == np.dtype(int)
+ c = np.random.choice(10, replace=False, size=2)
+ assert c.dtype == np.dtype(int)