diff options
-rw-r--r-- | numpy/random/src/mt19937/mt19937-jump.c | 4 | ||||
-rw-r--r-- | numpy/random/src/mt19937/mt19937-poly.h | 8 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 46 |
3 files changed, 52 insertions, 6 deletions
diff --git a/numpy/random/src/mt19937/mt19937-jump.c b/numpy/random/src/mt19937/mt19937-jump.c index 46b28cf96..de8969bf9 100644 --- a/numpy/random/src/mt19937/mt19937-jump.c +++ b/numpy/random/src/mt19937/mt19937-jump.c @@ -187,11 +187,9 @@ void mt19937_jump_state(mt19937_state *state, const char *jump_str) { pf = (unsigned long *)calloc(P_SIZE, sizeof(unsigned long)); for (i = MEXP - 1; i > -1; i--) { - if (jump_str[i] == '1') + if (jump_str[(MEXP - 1) - i] == '1') set_coef(pf, i, 1); } - /* TODO: Should generate the next set and start from 0, but doesn't matter ?? - */ if (state->pos >= N) { state->pos = 0; } diff --git a/numpy/random/src/mt19937/mt19937-poly.h b/numpy/random/src/mt19937/mt19937-poly.h index b03747881..016b4f626 100644 --- a/numpy/random/src/mt19937/mt19937-poly.h +++ b/numpy/random/src/mt19937/mt19937-poly.h @@ -1,3 +1,11 @@ +/* + * 2**128 step polynomial produced using the file mt19937-generate-jump-poly.c + * (randomgen) which is a modified version of minipoly_mt19937.c as distributed + * in http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/JUMP/jump_ahead_1.02.tar.gz + * + * These files are not part of NumPy. + */ + static const char * poly = "0001000111110111011100100010101111000000010100100101000001110111100010101000110100101001011001010" "1110101101100101011100101101011001110011100011110100001000001011100101100010100000010011101110011" diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index a28b7ca11..a5cfa8a65 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1,4 +1,5 @@ import sys +import hashlib import pytest @@ -13,6 +14,26 @@ from numpy.random import Generator, MT19937, SeedSequence random = Generator(MT19937()) +JUMP_TEST_DATA = [ + { + "seed": 0, + "steps": 10, + "initial": {"key_md5": "64eaf265d2203179fb5ffb73380cd589", "pos": 9}, + "jumped": {"key_md5": "8cb7b061136efceef5217a9ce2cc9a5a", "pos": 598}, + }, + { + "seed":384908324, + "steps":312, + "initial": {"key_md5": "e99708a47b82ff51a2c7b0625b81afb5", "pos": 311}, + "jumped": {"key_md5": "2ecdbfc47a895b253e6e19ccb2e74b90", "pos": 276}, + }, + { + "seed": [839438204, 980239840, 859048019, 821], + "steps": 511, + "initial": {"key_md5": "9fcd6280df9199785e17e93162ce283c", "pos": 510}, + "jumped": {"key_md5": "433b85229f2ed853cde06cd872818305", "pos": 475}, + }, +] @pytest.fixture(scope='module', params=[True, False]) def endpoint(request): @@ -462,7 +483,6 @@ class TestIntegers: assert_array_equal(scalar, array) def test_repeatability(self, endpoint): - import hashlib # We use a md5 hash of generated sequences of 1000 samples # in the range [0, 6) for all but bool, where the range # is [0, 2). Hashes are for little endian numbers. @@ -885,8 +905,6 @@ class TestRandomDist: assert actual.dtype == np.int64 def test_choice_large_sample(self): - import hashlib - choice_hash = 'd44962a0b1e92f4a3373c23222244e21' random = Generator(MT19937(self.seed)) actual = random.choice(10000, 5000, replace=False) @@ -2351,3 +2369,25 @@ class TestSingleEltArrayInput: out = func(self.argOne, self.argTwo[0], self.argThree) assert_equal(out.shape, self.tgtShape) + + +@pytest.mark.parametrize("config", JUMP_TEST_DATA) +def test_jumped(config): + # Each config contains the initial seed, a number of raw steps + # the md5 hashes of the initial and the final states' keys and + # the position of of the initial and the final state. + # These were produced using the original C implementation. + seed = config["seed"] + steps = config["steps"] + + mt19937 = MT19937(seed) + # Burn step + mt19937.random_raw(steps) + md5 = hashlib.md5(mt19937.state["state"]["key"]) + assert md5.hexdigest() == config["initial"]["key_md5"] + assert mt19937.state["state"]["pos"] == config["initial"]["pos"] + + jumped = mt19937.jumped() + md5 = hashlib.md5(jumped.state["state"]["key"]) + assert md5.hexdigest() == config["jumped"]["key_md5"] + assert jumped.state["state"]["pos"] == config["jumped"]["pos"] |