diff options
Diffstat (limited to 'numpy/random/_mt19937.pyx')
-rw-r--r-- | numpy/random/_mt19937.pyx | 273 |
1 files changed, 273 insertions, 0 deletions
diff --git a/numpy/random/_mt19937.pyx b/numpy/random/_mt19937.pyx new file mode 100644 index 000000000..e99652b73 --- /dev/null +++ b/numpy/random/_mt19937.pyx @@ -0,0 +1,273 @@ +import operator + +import numpy as np +cimport numpy as np + +from libc.stdint cimport uint32_t, uint64_t +from ._bit_generator cimport BitGenerator, SeedSequence + +__all__ = ['MT19937'] + +np.import_array() + +cdef extern from "src/mt19937/mt19937.h": + + struct s_mt19937_state: + uint32_t key[624] + int pos + + ctypedef s_mt19937_state mt19937_state + + uint64_t mt19937_next64(mt19937_state *state) nogil + uint32_t mt19937_next32(mt19937_state *state) nogil + double mt19937_next_double(mt19937_state *state) nogil + void mt19937_init_by_array(mt19937_state *state, uint32_t *init_key, int key_length) + void mt19937_seed(mt19937_state *state, uint32_t seed) + void mt19937_jump(mt19937_state *state) + + enum: + RK_STATE_LEN + +cdef uint64_t mt19937_uint64(void *st) nogil: + return mt19937_next64(<mt19937_state *> st) + +cdef uint32_t mt19937_uint32(void *st) nogil: + return mt19937_next32(<mt19937_state *> st) + +cdef double mt19937_double(void *st) nogil: + return mt19937_next_double(<mt19937_state *> st) + +cdef uint64_t mt19937_raw(void *st) nogil: + return <uint64_t>mt19937_next32(<mt19937_state *> st) + +cdef class MT19937(BitGenerator): + """ + MT19937(seed=None) + + Container for the Mersenne Twister pseudo-random number generator. + + Parameters + ---------- + seed : {None, int, array_like[ints], SeedSequence}, optional + A seed to initialize the `BitGenerator`. If None, then fresh, + unpredictable entropy will be pulled from the OS. If an ``int`` or + ``array_like[ints]`` is passed, then it will be passed to + `SeedSequence` to derive the initial `BitGenerator` state. One may also + pass in a `SeedSequence` instance. + + Attributes + ---------- + lock: threading.Lock + Lock instance that is shared so that the same bit git generator can + be used in multiple Generators without corrupting the state. Code that + generates values from a bit generator should hold the bit generator's + lock. + + Notes + ----- + ``MT19937`` provides a capsule containing function pointers that produce + doubles, and unsigned 32 and 64- bit integers [1]_. These are not + directly consumable in Python and must be consumed by a ``Generator`` + or similar object that supports low-level access. + + The Python stdlib module "random" also contains a Mersenne Twister + pseudo-random number generator. + + **State and Seeding** + + The ``MT19937`` state vector consists of a 624-element array of + 32-bit unsigned integers plus a single integer value between 0 and 624 + that indexes the current position within the main array. + + The input seed is processed by `SeedSequence` to fill the whole state. The + first element is reset such that only its most significant bit is set. + + **Parallel Features** + + The preferred way to use a BitGenerator in parallel applications is to use + the `SeedSequence.spawn` method to obtain entropy values, and to use these + to generate new BitGenerators: + + >>> from numpy.random import Generator, MT19937, SeedSequence + >>> sg = SeedSequence(1234) + >>> rg = [Generator(MT19937(s)) for s in sg.spawn(10)] + + Another method is to use `MT19937.jumped` which advances the state as-if + :math:`2^{128}` random numbers have been generated ([1]_, [2]_). This + allows the original sequence to be split so that distinct segments can be + used in each worker process. All generators should be chained to ensure + that the segments come from the same sequence. + + >>> from numpy.random import Generator, MT19937, SeedSequence + >>> sg = SeedSequence(1234) + >>> bit_generator = MT19937(sg) + >>> rg = [] + >>> for _ in range(10): + ... rg.append(Generator(bit_generator)) + ... # Chain the BitGenerators + ... bit_generator = bit_generator.jumped() + + **Compatibility Guarantee** + + ``MT19937`` makes a guarantee that a fixed seed and will always produce + the same random integer stream. + + References + ---------- + .. [1] Hiroshi Haramoto, Makoto Matsumoto, and Pierre L\'Ecuyer, "A Fast + Jump Ahead Algorithm for Linear Recurrences in a Polynomial Space", + Sequences and Their Applications - SETA, 290--298, 2008. + .. [2] Hiroshi Haramoto, Makoto Matsumoto, Takuji Nishimura, François + Panneton, Pierre L\'Ecuyer, "Efficient Jump Ahead for F2-Linear + Random Number Generators", INFORMS JOURNAL ON COMPUTING, Vol. 20, + No. 3, Summer 2008, pp. 385-390. + + """ + cdef mt19937_state rng_state + + def __init__(self, seed=None): + BitGenerator.__init__(self, seed) + val = self._seed_seq.generate_state(RK_STATE_LEN, np.uint32) + # MSB is 1; assuring non-zero initial array + self.rng_state.key[0] = 0x80000000UL + for i in range(1, RK_STATE_LEN): + self.rng_state.key[i] = val[i] + self.rng_state.pos = i + + self._bitgen.state = &self.rng_state + self._bitgen.next_uint64 = &mt19937_uint64 + self._bitgen.next_uint32 = &mt19937_uint32 + self._bitgen.next_double = &mt19937_double + self._bitgen.next_raw = &mt19937_raw + + def _legacy_seeding(self, seed): + """ + _legacy_seeding(seed) + + Seed the generator in a backward compatible way. For modern + applications, creating a new instance is preferable. Calling this + overrides self._seed_seq + + Parameters + ---------- + seed : {None, int, array_like} + Random seed initializing the pseudo-random number generator. + Can be an integer in [0, 2**32-1], array of integers in + [0, 2**32-1], a `SeedSequence, or ``None``. If `seed` + is ``None``, then fresh, unpredictable entropy will be pulled from + the OS. + + Raises + ------ + ValueError + If seed values are out of range for the PRNG. + """ + cdef np.ndarray obj + with self.lock: + try: + if seed is None: + seed = SeedSequence() + val = seed.generate_state(RK_STATE_LEN) + # MSB is 1; assuring non-zero initial array + self.rng_state.key[0] = 0x80000000UL + for i in range(1, RK_STATE_LEN): + self.rng_state.key[i] = val[i] + else: + if hasattr(seed, 'squeeze'): + seed = seed.squeeze() + idx = operator.index(seed) + if idx > int(2**32 - 1) or idx < 0: + raise ValueError("Seed must be between 0 and 2**32 - 1") + mt19937_seed(&self.rng_state, seed) + except TypeError: + obj = np.asarray(seed) + if obj.size == 0: + raise ValueError("Seed must be non-empty") + obj = obj.astype(np.int64, casting='safe') + if obj.ndim != 1: + raise ValueError("Seed array must be 1-d") + if ((obj > int(2**32 - 1)) | (obj < 0)).any(): + raise ValueError("Seed must be between 0 and 2**32 - 1") + obj = obj.astype(np.uint32, casting='unsafe', order='C') + mt19937_init_by_array(&self.rng_state, <uint32_t*> obj.data, np.PyArray_DIM(obj, 0)) + self._seed_seq = None + + cdef jump_inplace(self, iter): + """ + Jump state in-place + + Not part of public API + + Parameters + ---------- + iter : integer, positive + Number of times to jump the state of the rng. + """ + cdef np.npy_intp i + for i in range(iter): + mt19937_jump(&self.rng_state) + + + def jumped(self, np.npy_intp jumps=1): + """ + jumped(jumps=1) + + Returns a new bit generator with the state jumped + + The state of the returned big generator is jumped as-if + 2**(128 * jumps) random numbers have been generated. + + Parameters + ---------- + jumps : integer, positive + Number of times to jump the state of the bit generator returned + + Returns + ------- + bit_generator : MT19937 + New instance of generator jumped iter times + """ + cdef MT19937 bit_generator + + bit_generator = self.__class__() + bit_generator.state = self.state + bit_generator.jump_inplace(jumps) + + return bit_generator + + @property + def state(self): + """ + Get or set the PRNG state + + Returns + ------- + state : dict + Dictionary containing the information required to describe the + state of the PRNG + """ + key = np.zeros(624, dtype=np.uint32) + for i in range(624): + key[i] = self.rng_state.key[i] + + return {'bit_generator': self.__class__.__name__, + 'state': {'key': key, 'pos': self.rng_state.pos}} + + @state.setter + def state(self, value): + if isinstance(value, tuple): + if value[0] != 'MT19937' or len(value) not in (3, 5): + raise ValueError('state is not a legacy MT19937 state') + value ={'bit_generator': 'MT19937', + 'state': {'key': value[1], 'pos': value[2]}} + + if not isinstance(value, dict): + raise TypeError('state must be a dict') + bitgen = value.get('bit_generator', '') + if bitgen != self.__class__.__name__: + raise ValueError('state must be for a {0} ' + 'PRNG'.format(self.__class__.__name__)) + key = value['state']['key'] + for i in range(624): + self.rng_state.key[i] = key[i] + self.rng_state.pos = value['state']['pos'] |