summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/mtrand.pyx119
-rw-r--r--numpy/random/tests/test_randomstate.py67
2 files changed, 171 insertions, 15 deletions
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index c9cdb5839..cbf37e51d 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -186,16 +186,7 @@ cdef class RandomState:
else:
bit_generator = seed
- self._bit_generator = bit_generator
- capsule = bit_generator.capsule
- cdef const char *name = "BitGenerator"
- if not PyCapsule_IsValid(capsule, name):
- raise ValueError("Invalid bit generator. The bit generator must "
- "be instantized.")
- self._bitgen = (<bitgen_t *> PyCapsule_GetPointer(capsule, name))[0]
- self._aug_state.bit_generator = &self._bitgen
- self._reset_gauss()
- self.lock = bit_generator.lock
+ self._initialize_bit_generator(bit_generator)
def __repr__(self):
return self.__str__() + ' at 0x{:X}'.format(id(self))
@@ -218,13 +209,25 @@ cdef class RandomState:
from ._pickle import __randomstate_ctor
return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False)
+ cdef _initialize_bit_generator(self, bit_generator):
+ self._bit_generator = bit_generator
+ capsule = bit_generator.capsule
+ cdef const char *name = "BitGenerator"
+ if not PyCapsule_IsValid(capsule, name):
+ raise ValueError("Invalid bit generator. The bit generator must "
+ "be instantized.")
+ self._bitgen = (<bitgen_t *> PyCapsule_GetPointer(capsule, name))[0]
+ self._aug_state.bit_generator = &self._bitgen
+ self._reset_gauss()
+ self.lock = bit_generator.lock
+
cdef _reset_gauss(self):
self._aug_state.has_gauss = 0
self._aug_state.gauss = 0.0
def seed(self, seed=None):
"""
- seed(self, seed=None)
+ seed(seed=None)
Reseed a legacy MT19937 BitGenerator
@@ -249,7 +252,7 @@ cdef class RandomState:
def get_state(self, legacy=True):
"""
- get_state()
+ get_state(legacy=True)
Return a tuple representing the internal state of the generator.
@@ -259,12 +262,13 @@ cdef class RandomState:
----------
legacy : bool, optional
Flag indicating to return a legacy tuple state when the BitGenerator
- is MT19937, instead of a dict.
+ is MT19937, instead of a dict. Raises ValueError if the underlying
+ bit generator is not an instance of MT19937.
Returns
-------
out : {tuple(str, ndarray of 624 uints, int, int, float), dict}
- The returned tuple has the following items:
+ If legacy is True, the returned tuple has the following items:
1. the string 'MT19937'.
2. a 1-D array of 624 unsigned integer keys.
@@ -294,6 +298,11 @@ cdef class RandomState:
legacy = False
st['has_gauss'] = self._aug_state.has_gauss
st['gauss'] = self._aug_state.gauss
+ if legacy and not isinstance(self._bit_generator, _MT19937):
+ raise ValueError(
+ "legacy can only be True when the underlyign bitgenerator is "
+ "an instance of MT19937."
+ )
if legacy:
return (st['bit_generator'], st['state']['key'], st['state']['pos'],
st['has_gauss'], st['gauss'])
@@ -4690,7 +4699,6 @@ random = _rand.random
random_integers = _rand.random_integers
random_sample = _rand.random_sample
rayleigh = _rand.rayleigh
-seed = _rand.seed
set_state = _rand.set_state
shuffle = _rand.shuffle
standard_cauchy = _rand.standard_cauchy
@@ -4705,6 +4713,85 @@ wald = _rand.wald
weibull = _rand.weibull
zipf = _rand.zipf
+def seed(seed=None):
+ """
+ seed(seed=None)
+
+ Reseed the singleton RandomState instance.
+
+ Notes
+ -----
+ This is a convenience, legacy function that exists to support
+ older code that uses the singleton RandomState. Best practice
+ is to use a dedicated ``Generator`` instance rather than
+ the random variate generation methods exposed directly in
+ the random module.
+
+ See Also
+ --------
+ numpy.random.Generator
+ """
+ if isinstance(_rand._bit_generator, _MT19937):
+ return _rand.seed(seed)
+ else:
+ bg_type = type(_rand._bit_generator)
+ _rand._bit_generator.state = bg_type(seed).state
+
+def get_bit_generator():
+ """
+ Returns the singleton RandomState's bit generator
+
+ Returns
+ -------
+ BitGenerator
+ The bit generator that underlies the singleton RandomState instance
+
+ Notes
+ -----
+ The singleton RandomState provides the random variate generators in the
+ NumPy random namespace. This function, and its counterpart set method,
+ provides a path to hot-swap the default MT19937 bit generator with a
+ user provided alternative. These function are intended to provide
+ a continuous path where a single underlying bit generator can be
+ used both with an instance of ``Generator`` and with the singleton
+ instance of RandomState.
+
+ See Also
+ --------
+ set_bit_generator
+ numpy.random.Generator
+ """
+ return _rand._bit_generator
+
+def set_bit_generator(bitgen):
+ """
+ Sets the singleton RandomState's bit generator
+
+ Parameters
+ ----------
+ bitgen
+ A bit generator instance
+
+ Notes
+ -----
+ The singleton RandomState provides the random variate generators in the
+ NumPy random namespace. This function, and its counterpart get method,
+ provides a path to hot-swap the default MT19937 bit generator with a
+ user provided alternative. These function are intended to provide
+ a continuous path where a single underlying bit generator can be
+ used both with an instance of ``Generator`` and with the singleton
+ instance of RandomState.
+
+ See Also
+ --------
+ get_bit_generator
+ numpy.random.Generator
+ """
+ cdef RandomState singleton
+ singleton = _rand
+ singleton._initialize_bit_generator(bitgen)
+
+
# Old aliases that should not be removed
def sample(*args, **kwargs):
"""
@@ -4731,6 +4818,7 @@ __all__ = [
'f',
'gamma',
'geometric',
+ 'get_bit_generator',
'get_state',
'gumbel',
'hypergeometric',
@@ -4758,6 +4846,7 @@ __all__ = [
'rayleigh',
'sample',
'seed',
+ 'set_bit_generator',
'set_state',
'shuffle',
'standard_cauchy',
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index 1e22880ea..fdd194fb9 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -53,6 +53,14 @@ def int_func(request):
INT_FUNC_HASHES[request.param])
+@pytest.fixture
+def restore_singleton_bitgen():
+ """Ensures that the singleton bitgen is restored after a test"""
+ orig_bitgen = np.random.get_bit_generator()
+ yield
+ np.random.set_bit_generator(orig_bitgen)
+
+
def assert_mt19937_state_equal(a, b):
assert_equal(a['bit_generator'], b['bit_generator'])
assert_array_equal(a['state']['key'], b['state']['key'])
@@ -2038,3 +2046,62 @@ def test_randomstate_ctor_old_style_pickle():
assert_array_equal(state_a['state']['pos'], state_b['state']['pos'])
assert_equal(state_a['has_gauss'], state_b['has_gauss'])
assert_equal(state_a['gauss'], state_b['gauss'])
+
+def test_hot_swap(restore_singleton_bitgen):
+ # GH 21808
+ def_bg = np.random.default_rng(0)
+ bg = def_bg.bit_generator
+ np.random.set_bit_generator(bg)
+ assert isinstance(np.random.mtrand._rand._bit_generator, type(bg))
+
+ second_bg = np.random.get_bit_generator()
+ assert bg is second_bg
+
+
+def test_seed_alt_bit_gen(restore_singleton_bitgen):
+ # GH 21808
+ bg = PCG64(0)
+ np.random.set_bit_generator(bg)
+ state = np.random.get_state(legacy=False)
+ np.random.seed(1)
+ new_state = np.random.get_state(legacy=False)
+ print(state)
+ print(new_state)
+ assert state["bit_generator"] == "PCG64"
+ assert state["state"]["state"] != new_state["state"]["state"]
+ assert state["state"]["inc"] != new_state["state"]["inc"]
+
+
+def test_state_error_alt_bit_gen(restore_singleton_bitgen):
+ # GH 21808
+ state = np.random.get_state()
+ bg = PCG64(0)
+ np.random.set_bit_generator(bg)
+ with pytest.raises(ValueError, match="state must be for a PCG64"):
+ np.random.set_state(state)
+
+
+def test_swap_worked(restore_singleton_bitgen):
+ # GH 21808
+ np.random.seed(98765)
+ vals = np.random.randint(0, 2 ** 30, 10)
+ bg = PCG64(0)
+ state = bg.state
+ np.random.set_bit_generator(bg)
+ state_direct = np.random.get_state(legacy=False)
+ for field in state:
+ assert state[field] == state_direct[field]
+ np.random.seed(98765)
+ pcg_vals = np.random.randint(0, 2 ** 30, 10)
+ assert not np.all(vals == pcg_vals)
+ new_state = bg.state
+ assert new_state["state"]["state"] != state["state"]["state"]
+ assert new_state["state"]["inc"] == new_state["state"]["inc"]
+
+
+def test_swapped_singleton_against_direct(restore_singleton_bitgen):
+ np.random.set_bit_generator(PCG64(98765))
+ singleton_vals = np.random.randint(0, 2 ** 30, 10)
+ rg = np.random.RandomState(PCG64(98765))
+ non_singleton_vals = rg.randint(0, 2 ** 30, 10)
+ assert_equal(non_singleton_vals, singleton_vals)