diff options
Diffstat (limited to 'numpy/random/tests/test_randomstate.py')
-rw-r--r-- | numpy/random/tests/test_randomstate.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py index 1e22880ea..fdd194fb9 100644 --- a/numpy/random/tests/test_randomstate.py +++ b/numpy/random/tests/test_randomstate.py @@ -53,6 +53,14 @@ def int_func(request): INT_FUNC_HASHES[request.param]) +@pytest.fixture +def restore_singleton_bitgen(): + """Ensures that the singleton bitgen is restored after a test""" + orig_bitgen = np.random.get_bit_generator() + yield + np.random.set_bit_generator(orig_bitgen) + + def assert_mt19937_state_equal(a, b): assert_equal(a['bit_generator'], b['bit_generator']) assert_array_equal(a['state']['key'], b['state']['key']) @@ -2038,3 +2046,62 @@ def test_randomstate_ctor_old_style_pickle(): assert_array_equal(state_a['state']['pos'], state_b['state']['pos']) assert_equal(state_a['has_gauss'], state_b['has_gauss']) assert_equal(state_a['gauss'], state_b['gauss']) + +def test_hot_swap(restore_singleton_bitgen): + # GH 21808 + def_bg = np.random.default_rng(0) + bg = def_bg.bit_generator + np.random.set_bit_generator(bg) + assert isinstance(np.random.mtrand._rand._bit_generator, type(bg)) + + second_bg = np.random.get_bit_generator() + assert bg is second_bg + + +def test_seed_alt_bit_gen(restore_singleton_bitgen): + # GH 21808 + bg = PCG64(0) + np.random.set_bit_generator(bg) + state = np.random.get_state(legacy=False) + np.random.seed(1) + new_state = np.random.get_state(legacy=False) + print(state) + print(new_state) + assert state["bit_generator"] == "PCG64" + assert state["state"]["state"] != new_state["state"]["state"] + assert state["state"]["inc"] != new_state["state"]["inc"] + + +def test_state_error_alt_bit_gen(restore_singleton_bitgen): + # GH 21808 + state = np.random.get_state() + bg = PCG64(0) + np.random.set_bit_generator(bg) + with pytest.raises(ValueError, match="state must be for a PCG64"): + np.random.set_state(state) + + +def test_swap_worked(restore_singleton_bitgen): + # GH 21808 + np.random.seed(98765) + vals = np.random.randint(0, 2 ** 30, 10) + bg = PCG64(0) + state = bg.state + np.random.set_bit_generator(bg) + state_direct = np.random.get_state(legacy=False) + for field in state: + assert state[field] == state_direct[field] + np.random.seed(98765) + pcg_vals = np.random.randint(0, 2 ** 30, 10) + assert not np.all(vals == pcg_vals) + new_state = bg.state + assert new_state["state"]["state"] != state["state"]["state"] + assert new_state["state"]["inc"] == new_state["state"]["inc"] + + +def test_swapped_singleton_against_direct(restore_singleton_bitgen): + np.random.set_bit_generator(PCG64(98765)) + singleton_vals = np.random.randint(0, 2 ** 30, 10) + rg = np.random.RandomState(PCG64(98765)) + non_singleton_vals = rg.randint(0, 2 ** 30, 10) + assert_equal(non_singleton_vals, singleton_vals) |