summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_direct.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/random/tests/test_direct.py')
-rw-r--r--numpy/random/tests/test_direct.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/numpy/random/tests/test_direct.py b/numpy/random/tests/test_direct.py
index 58d966adf..fa2ae866b 100644
--- a/numpy/random/tests/test_direct.py
+++ b/numpy/random/tests/test_direct.py
@@ -148,6 +148,46 @@ def test_seedsequence():
assert len(dummy.spawn(10)) == 10
+def test_generator_spawning():
+ """ Test spawning new generators and bit_generators directly.
+ """
+ rng = np.random.default_rng()
+ seq = rng.bit_generator.seed_seq
+ new_ss = seq.spawn(5)
+ expected_keys = [seq.spawn_key + (i,) for i in range(5)]
+ assert [c.spawn_key for c in new_ss] == expected_keys
+
+ new_bgs = rng.bit_generator.spawn(5)
+ expected_keys = [seq.spawn_key + (i,) for i in range(5, 10)]
+ assert [bg.seed_seq.spawn_key for bg in new_bgs] == expected_keys
+
+ new_rngs = rng.spawn(5)
+ expected_keys = [seq.spawn_key + (i,) for i in range(10, 15)]
+ found_keys = [rng.bit_generator.seed_seq.spawn_key for rng in new_rngs]
+ assert found_keys == expected_keys
+
+ # Sanity check that streams are actually different:
+ assert new_rngs[0].uniform() != new_rngs[1].uniform()
+
+
+def test_non_spawnable():
+ from numpy.random.bit_generator import ISeedSequence
+
+ class FakeSeedSequence:
+ def generate_state(self, n_words, dtype=np.uint32):
+ return np.zeros(n_words, dtype=dtype)
+
+ ISeedSequence.register(FakeSeedSequence)
+
+ rng = np.random.default_rng(FakeSeedSequence())
+
+ with pytest.raises(TypeError, match="The underlying SeedSequence"):
+ rng.spawn(5)
+
+ with pytest.raises(TypeError, match="The underlying SeedSequence"):
+ rng.bit_generator.spawn(5)
+
+
class Base:
dtype = np.uint64
data2 = data1 = {}