diff options
author | Kevin Sheppard <kevin.k.sheppard@gmail.com> | 2022-07-13 09:19:43 +0100 |
---|---|---|
committer | Kevin Sheppard <kevin.k.sheppard@gmail.com> | 2022-08-11 15:28:24 +0100 |
commit | dd79030184e7aaa072c12d9182d79b3d9eb74757 (patch) | |
tree | 8f100b0feccc8f981fddffc6a5c2f3a6441117c3 | |
parent | 86a368f53ba15a610be4c6cb3a173446aef3aa80 (diff) | |
download | numpy-dd79030184e7aaa072c12d9182d79b3d9eb74757.tar.gz |
ENH: Add the capability to swap the singleton bit generator
Add a new version or seed that supports seeding any bit gen
Add set/get_bit_generator as explicity methodds to support swapping
closes #21808
-rw-r--r-- | doc/release/upcoming_changes/21976.new_feature.rst | 25 | ||||
-rw-r--r-- | numpy/random/mtrand.pyx | 119 | ||||
-rw-r--r-- | numpy/random/tests/test_randomstate.py | 67 |
3 files changed, 196 insertions, 15 deletions
diff --git a/doc/release/upcoming_changes/21976.new_feature.rst b/doc/release/upcoming_changes/21976.new_feature.rst new file mode 100644 index 000000000..fd2417bb9 --- /dev/null +++ b/doc/release/upcoming_changes/21976.new_feature.rst @@ -0,0 +1,25 @@ +The bit generator underlying the singleton RandomState can be changed +--------------------------------------------------------------------- +The singleton ``RandomState`` instance exposed in the ``numpy.random`` module +id initialized using system-provided entropy with the ``MT19937` bit generator. +The function ``set_bit_generator`` allows the default bit generator to be +replaced with a user-provided bit generator. This function has been introduced to +provide a method allowing seemless integration of a high-quality, modern bit generator +in new code with existing code that makes use of the singleton-provided random +variate generating functions. + +The preferred method to generate reproducible random numbers is to use a modern +bit generator in an instance of ``Generator``. The function ``default_rng`` +simplifies instantization. + + >>> rg = np.random.default_rng(3728973198) + >>> rg.random() + +The same bit generator can then shared with the singleton instance so that +calling functions in the ``random`` module will use the same bit generator. + + >>> np.random.set_bit_generator(rg.bit_generator) + >>> np.random.normal() + +The swap is permanent (until reversed) and so any call to functions +in the ``random`` module will use the new bit generator. diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index c9cdb5839..cbf37e51d 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -186,16 +186,7 @@ cdef class RandomState: else: bit_generator = seed - self._bit_generator = bit_generator - capsule = bit_generator.capsule - cdef const char *name = "BitGenerator" - if not PyCapsule_IsValid(capsule, name): - raise ValueError("Invalid bit generator. The bit generator must " - "be instantized.") - self._bitgen = (<bitgen_t *> PyCapsule_GetPointer(capsule, name))[0] - self._aug_state.bit_generator = &self._bitgen - self._reset_gauss() - self.lock = bit_generator.lock + self._initialize_bit_generator(bit_generator) def __repr__(self): return self.__str__() + ' at 0x{:X}'.format(id(self)) @@ -218,13 +209,25 @@ cdef class RandomState: from ._pickle import __randomstate_ctor return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False) + cdef _initialize_bit_generator(self, bit_generator): + self._bit_generator = bit_generator + capsule = bit_generator.capsule + cdef const char *name = "BitGenerator" + if not PyCapsule_IsValid(capsule, name): + raise ValueError("Invalid bit generator. The bit generator must " + "be instantized.") + self._bitgen = (<bitgen_t *> PyCapsule_GetPointer(capsule, name))[0] + self._aug_state.bit_generator = &self._bitgen + self._reset_gauss() + self.lock = bit_generator.lock + cdef _reset_gauss(self): self._aug_state.has_gauss = 0 self._aug_state.gauss = 0.0 def seed(self, seed=None): """ - seed(self, seed=None) + seed(seed=None) Reseed a legacy MT19937 BitGenerator @@ -249,7 +252,7 @@ cdef class RandomState: def get_state(self, legacy=True): """ - get_state() + get_state(legacy=True) Return a tuple representing the internal state of the generator. @@ -259,12 +262,13 @@ cdef class RandomState: ---------- legacy : bool, optional Flag indicating to return a legacy tuple state when the BitGenerator - is MT19937, instead of a dict. + is MT19937, instead of a dict. Raises ValueError if the underlying + bit generator is not an instance of MT19937. Returns ------- out : {tuple(str, ndarray of 624 uints, int, int, float), dict} - The returned tuple has the following items: + If legacy is True, the returned tuple has the following items: 1. the string 'MT19937'. 2. a 1-D array of 624 unsigned integer keys. @@ -294,6 +298,11 @@ cdef class RandomState: legacy = False st['has_gauss'] = self._aug_state.has_gauss st['gauss'] = self._aug_state.gauss + if legacy and not isinstance(self._bit_generator, _MT19937): + raise ValueError( + "legacy can only be True when the underlyign bitgenerator is " + "an instance of MT19937." + ) if legacy: return (st['bit_generator'], st['state']['key'], st['state']['pos'], st['has_gauss'], st['gauss']) @@ -4690,7 +4699,6 @@ random = _rand.random random_integers = _rand.random_integers random_sample = _rand.random_sample rayleigh = _rand.rayleigh -seed = _rand.seed set_state = _rand.set_state shuffle = _rand.shuffle standard_cauchy = _rand.standard_cauchy @@ -4705,6 +4713,85 @@ wald = _rand.wald weibull = _rand.weibull zipf = _rand.zipf +def seed(seed=None): + """ + seed(seed=None) + + Reseed the singleton RandomState instance. + + Notes + ----- + This is a convenience, legacy function that exists to support + older code that uses the singleton RandomState. Best practice + is to use a dedicated ``Generator`` instance rather than + the random variate generation methods exposed directly in + the random module. + + See Also + -------- + numpy.random.Generator + """ + if isinstance(_rand._bit_generator, _MT19937): + return _rand.seed(seed) + else: + bg_type = type(_rand._bit_generator) + _rand._bit_generator.state = bg_type(seed).state + +def get_bit_generator(): + """ + Returns the singleton RandomState's bit generator + + Returns + ------- + BitGenerator + The bit generator that underlies the singleton RandomState instance + + Notes + ----- + The singleton RandomState provides the random variate generators in the + NumPy random namespace. This function, and its counterpart set method, + provides a path to hot-swap the default MT19937 bit generator with a + user provided alternative. These function are intended to provide + a continuous path where a single underlying bit generator can be + used both with an instance of ``Generator`` and with the singleton + instance of RandomState. + + See Also + -------- + set_bit_generator + numpy.random.Generator + """ + return _rand._bit_generator + +def set_bit_generator(bitgen): + """ + Sets the singleton RandomState's bit generator + + Parameters + ---------- + bitgen + A bit generator instance + + Notes + ----- + The singleton RandomState provides the random variate generators in the + NumPy random namespace. This function, and its counterpart get method, + provides a path to hot-swap the default MT19937 bit generator with a + user provided alternative. These function are intended to provide + a continuous path where a single underlying bit generator can be + used both with an instance of ``Generator`` and with the singleton + instance of RandomState. + + See Also + -------- + get_bit_generator + numpy.random.Generator + """ + cdef RandomState singleton + singleton = _rand + singleton._initialize_bit_generator(bitgen) + + # Old aliases that should not be removed def sample(*args, **kwargs): """ @@ -4731,6 +4818,7 @@ __all__ = [ 'f', 'gamma', 'geometric', + 'get_bit_generator', 'get_state', 'gumbel', 'hypergeometric', @@ -4758,6 +4846,7 @@ __all__ = [ 'rayleigh', 'sample', 'seed', + 'set_bit_generator', 'set_state', 'shuffle', 'standard_cauchy', diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py index 1e22880ea..fdd194fb9 100644 --- a/numpy/random/tests/test_randomstate.py +++ b/numpy/random/tests/test_randomstate.py @@ -53,6 +53,14 @@ def int_func(request): INT_FUNC_HASHES[request.param]) +@pytest.fixture +def restore_singleton_bitgen(): + """Ensures that the singleton bitgen is restored after a test""" + orig_bitgen = np.random.get_bit_generator() + yield + np.random.set_bit_generator(orig_bitgen) + + def assert_mt19937_state_equal(a, b): assert_equal(a['bit_generator'], b['bit_generator']) assert_array_equal(a['state']['key'], b['state']['key']) @@ -2038,3 +2046,62 @@ def test_randomstate_ctor_old_style_pickle(): assert_array_equal(state_a['state']['pos'], state_b['state']['pos']) assert_equal(state_a['has_gauss'], state_b['has_gauss']) assert_equal(state_a['gauss'], state_b['gauss']) + +def test_hot_swap(restore_singleton_bitgen): + # GH 21808 + def_bg = np.random.default_rng(0) + bg = def_bg.bit_generator + np.random.set_bit_generator(bg) + assert isinstance(np.random.mtrand._rand._bit_generator, type(bg)) + + second_bg = np.random.get_bit_generator() + assert bg is second_bg + + +def test_seed_alt_bit_gen(restore_singleton_bitgen): + # GH 21808 + bg = PCG64(0) + np.random.set_bit_generator(bg) + state = np.random.get_state(legacy=False) + np.random.seed(1) + new_state = np.random.get_state(legacy=False) + print(state) + print(new_state) + assert state["bit_generator"] == "PCG64" + assert state["state"]["state"] != new_state["state"]["state"] + assert state["state"]["inc"] != new_state["state"]["inc"] + + +def test_state_error_alt_bit_gen(restore_singleton_bitgen): + # GH 21808 + state = np.random.get_state() + bg = PCG64(0) + np.random.set_bit_generator(bg) + with pytest.raises(ValueError, match="state must be for a PCG64"): + np.random.set_state(state) + + +def test_swap_worked(restore_singleton_bitgen): + # GH 21808 + np.random.seed(98765) + vals = np.random.randint(0, 2 ** 30, 10) + bg = PCG64(0) + state = bg.state + np.random.set_bit_generator(bg) + state_direct = np.random.get_state(legacy=False) + for field in state: + assert state[field] == state_direct[field] + np.random.seed(98765) + pcg_vals = np.random.randint(0, 2 ** 30, 10) + assert not np.all(vals == pcg_vals) + new_state = bg.state + assert new_state["state"]["state"] != state["state"]["state"] + assert new_state["state"]["inc"] == new_state["state"]["inc"] + + +def test_swapped_singleton_against_direct(restore_singleton_bitgen): + np.random.set_bit_generator(PCG64(98765)) + singleton_vals = np.random.randint(0, 2 ** 30, 10) + rg = np.random.RandomState(PCG64(98765)) + non_singleton_vals = rg.randint(0, 2 ** 30, 10) + assert_equal(non_singleton_vals, singleton_vals) |