summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_randomstate.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/random/tests/test_randomstate.py')
-rw-r--r--numpy/random/tests/test_randomstate.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index 1e22880ea..fdd194fb9 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -53,6 +53,14 @@ def int_func(request):
INT_FUNC_HASHES[request.param])
+@pytest.fixture
+def restore_singleton_bitgen():
+ """Ensures that the singleton bitgen is restored after a test"""
+ orig_bitgen = np.random.get_bit_generator()
+ yield
+ np.random.set_bit_generator(orig_bitgen)
+
+
def assert_mt19937_state_equal(a, b):
assert_equal(a['bit_generator'], b['bit_generator'])
assert_array_equal(a['state']['key'], b['state']['key'])
@@ -2038,3 +2046,62 @@ def test_randomstate_ctor_old_style_pickle():
assert_array_equal(state_a['state']['pos'], state_b['state']['pos'])
assert_equal(state_a['has_gauss'], state_b['has_gauss'])
assert_equal(state_a['gauss'], state_b['gauss'])
+
+def test_hot_swap(restore_singleton_bitgen):
+ # GH 21808
+ def_bg = np.random.default_rng(0)
+ bg = def_bg.bit_generator
+ np.random.set_bit_generator(bg)
+ assert isinstance(np.random.mtrand._rand._bit_generator, type(bg))
+
+ second_bg = np.random.get_bit_generator()
+ assert bg is second_bg
+
+
+def test_seed_alt_bit_gen(restore_singleton_bitgen):
+ # GH 21808
+ bg = PCG64(0)
+ np.random.set_bit_generator(bg)
+ state = np.random.get_state(legacy=False)
+ np.random.seed(1)
+ new_state = np.random.get_state(legacy=False)
+ print(state)
+ print(new_state)
+ assert state["bit_generator"] == "PCG64"
+ assert state["state"]["state"] != new_state["state"]["state"]
+ assert state["state"]["inc"] != new_state["state"]["inc"]
+
+
+def test_state_error_alt_bit_gen(restore_singleton_bitgen):
+ # GH 21808
+ state = np.random.get_state()
+ bg = PCG64(0)
+ np.random.set_bit_generator(bg)
+ with pytest.raises(ValueError, match="state must be for a PCG64"):
+ np.random.set_state(state)
+
+
+def test_swap_worked(restore_singleton_bitgen):
+ # GH 21808
+ np.random.seed(98765)
+ vals = np.random.randint(0, 2 ** 30, 10)
+ bg = PCG64(0)
+ state = bg.state
+ np.random.set_bit_generator(bg)
+ state_direct = np.random.get_state(legacy=False)
+ for field in state:
+ assert state[field] == state_direct[field]
+ np.random.seed(98765)
+ pcg_vals = np.random.randint(0, 2 ** 30, 10)
+ assert not np.all(vals == pcg_vals)
+ new_state = bg.state
+ assert new_state["state"]["state"] != state["state"]["state"]
+ assert new_state["state"]["inc"] == new_state["state"]["inc"]
+
+
+def test_swapped_singleton_against_direct(restore_singleton_bitgen):
+ np.random.set_bit_generator(PCG64(98765))
+ singleton_vals = np.random.randint(0, 2 ** 30, 10)
+ rg = np.random.RandomState(PCG64(98765))
+ non_singleton_vals = rg.randint(0, 2 ** 30, 10)
+ assert_equal(non_singleton_vals, singleton_vals)