summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_random.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/random/tests/test_random.py')
-rw-r--r--numpy/random/tests/test_random.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index db4093ab4..b6c4fe3af 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -624,5 +624,45 @@ class TestRandomDist(TestCase):
[ 3, 13]])
np.testing.assert_array_equal(actual, desired)
+
+class TestThread:
+ """ make sure each state produces the same sequence even in threads """
+ def setUp(self):
+ self.seeds = range(4)
+
+ def check_function(self, function, sz):
+ from threading import Thread
+
+ out1 = np.empty((len(self.seeds),) + sz)
+ out2 = np.empty((len(self.seeds),) + sz)
+
+ # threaded generation
+ t = [Thread(target=function, args=(np.random.RandomState(s), o))
+ for s, o in zip(self.seeds, out1)]
+ [x.start() for x in t]
+ [x.join() for x in t]
+
+ # the same serial
+ for s, o in zip(self.seeds, out2):
+ function(np.random.RandomState(s), o)
+
+ np.testing.assert_array_equal(out1, out2)
+
+ def test_normal(self):
+ def gen_random(state, out):
+ out[...] = state.normal(size=10000)
+ self.check_function(gen_random, sz=(10000,))
+
+ def test_exp(self):
+ def gen_random(state, out):
+ out[...] = state.exponential(scale=np.ones((100, 1000)))
+ self.check_function(gen_random, sz=(100, 1000))
+
+ def test_multinomial(self):
+ def gen_random(state, out):
+ out[...] = state.multinomial(10, [1/6.]*6, size=10000)
+ self.check_function(gen_random, sz=(10000,6))
+
+
if __name__ == "__main__":
run_module_suite()