summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_generator_mt19937.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/random/tests/test_generator_mt19937.py')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 895e7fc6c..d76291d1a 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -639,6 +639,18 @@ class TestRandomDist(object):
p = [None, None, None]
assert_raises(ValueError, random.choice, a, p=p)
+ def test_choice_return_type(self):
+ # gh 9867
+ p = np.ones(4) / 4.
+ actual = random.choice(4, 2)
+ assert actual.dtype == np.int64
+ actual = random.choice(4, 2, replace=False)
+ assert actual.dtype == np.int64
+ actual = random.choice(4, 2, p=p)
+ assert actual.dtype == np.int64
+ actual = random.choice(4, 2, p=p, replace=False)
+ assert actual.dtype == np.int64
+
def test_bytes(self):
random.brng.seed(self.seed)
actual = random.bytes(10)