summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/random/_generator.pyx5
-rw-r--r--numpy/random/_pickle.py49
-rw-r--r--numpy/random/mtrand.pyx5
-rw-r--r--numpy/random/tests/test_generator_mt19937.py13
-rw-r--r--numpy/random/tests/test_randomstate.py18
5 files changed, 61 insertions, 29 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index 77f70e9fe..da8ab64e2 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -220,8 +220,11 @@ cdef class Generator:
self.bit_generator.state = state
def __reduce__(self):
+ ctor, name_tpl, state = self._bit_generator.__reduce__()
+
from ._pickle import __generator_ctor
- return __generator_ctor, (self.bit_generator.state['bit_generator'],), self.bit_generator.state
+ # Requirements of __generator_ctor are (name, ctor)
+ return __generator_ctor, (name_tpl[0], ctor), state
@property
def bit_generator(self):
diff --git a/numpy/random/_pickle.py b/numpy/random/_pickle.py
index 5e89071e8..073993726 100644
--- a/numpy/random/_pickle.py
+++ b/numpy/random/_pickle.py
@@ -14,19 +14,19 @@ BitGenerators = {'MT19937': MT19937,
}
-def __generator_ctor(bit_generator_name='MT19937'):
+def __bit_generator_ctor(bit_generator_name='MT19937'):
"""
- Pickling helper function that returns a Generator object
+ Pickling helper function that returns a bit generator object
Parameters
----------
bit_generator_name : str
- String containing the core BitGenerator
+ String containing the name of the BitGenerator
Returns
-------
- rg : Generator
- Generator using the named core BitGenerator
+ bit_generator : BitGenerator
+ BitGenerator instance
"""
if bit_generator_name in BitGenerators:
bit_generator = BitGenerators[bit_generator_name]
@@ -34,50 +34,47 @@ def __generator_ctor(bit_generator_name='MT19937'):
raise ValueError(str(bit_generator_name) + ' is not a known '
'BitGenerator module.')
- return Generator(bit_generator())
+ return bit_generator()
-def __bit_generator_ctor(bit_generator_name='MT19937'):
+def __generator_ctor(bit_generator_name="MT19937",
+ bit_generator_ctor=__bit_generator_ctor):
"""
- Pickling helper function that returns a bit generator object
+ Pickling helper function that returns a Generator object
Parameters
----------
bit_generator_name : str
- String containing the name of the BitGenerator
+ String containing the core BitGenerator's name
+ bit_generator_ctor : callable, optional
+ Callable function that takes bit_generator_name as its only argument
+ and returns an instantized bit generator.
Returns
-------
- bit_generator : BitGenerator
- BitGenerator instance
+ rg : Generator
+ Generator using the named core BitGenerator
"""
- if bit_generator_name in BitGenerators:
- bit_generator = BitGenerators[bit_generator_name]
- else:
- raise ValueError(str(bit_generator_name) + ' is not a known '
- 'BitGenerator module.')
-
- return bit_generator()
+ return Generator(bit_generator_ctor(bit_generator_name))
-def __randomstate_ctor(bit_generator_name='MT19937'):
+def __randomstate_ctor(bit_generator_name="MT19937",
+ bit_generator_ctor=__bit_generator_ctor):
"""
Pickling helper function that returns a legacy RandomState-like object
Parameters
----------
bit_generator_name : str
- String containing the core BitGenerator
+ String containing the core BitGenerator's name
+ bit_generator_ctor : callable, optional
+ Callable function that takes bit_generator_name as its only argument
+ and returns an instantized bit generator.
Returns
-------
rs : RandomState
Legacy RandomState using the named core BitGenerator
"""
- if bit_generator_name in BitGenerators:
- bit_generator = BitGenerators[bit_generator_name]
- else:
- raise ValueError(str(bit_generator_name) + ' is not a known '
- 'BitGenerator module.')
- return RandomState(bit_generator())
+ return RandomState(bit_generator_ctor(bit_generator_name))
diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx
index 19d23f6a8..c9cdb5839 100644
--- a/numpy/random/mtrand.pyx
+++ b/numpy/random/mtrand.pyx
@@ -213,9 +213,10 @@ cdef class RandomState:
self.set_state(state)
def __reduce__(self):
- state = self.get_state(legacy=False)
+ ctor, name_tpl, _ = self._bit_generator.__reduce__()
+
from ._pickle import __randomstate_ctor
- return __randomstate_ctor, (state['bit_generator'],), state
+ return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False)
cdef _reset_gauss(self):
self._aug_state.has_gauss = 0
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index fa55ac0ee..0f4437bcb 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -2695,3 +2695,16 @@ def test_contig_req_out(dist, order, dtype):
assert variates is out
variates = dist(out=out, dtype=dtype, size=out.shape)
assert variates is out
+
+
+def test_generator_ctor_old_style_pickle():
+ rg = np.random.Generator(np.random.PCG64DXSM(0))
+ rg.standard_normal(1)
+ # Directly call reduce which is used in pickline
+ ctor, args, state_a = rg.__reduce__()
+ # Simulate unpickling an old pickle that only has the name
+ assert args[:1] == ("PCG64DXSM",)
+ b = ctor(*args[:1])
+ b.bit_generator.state = state_a
+ state_b = b.bit_generator.state
+ assert state_a == state_b
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index 861813a95..aadfe3d78 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -2020,3 +2020,21 @@ def test_broadcast_size_error():
random.binomial([1, 2], 0.3, size=(2, 1))
with pytest.raises(ValueError):
random.binomial([1, 2], [0.3, 0.7], size=(2, 1))
+
+
+def test_randomstate_ctor_old_style_pickle():
+ rs = np.random.RandomState(MT19937(0))
+ rs.standard_normal(1)
+ # Directly call reduce which is used in pickline
+ ctor, args, state_a = rs.__reduce__()
+ # Simulate unpickling an old pickle that only has the name
+ assert args[:1] == ("MT19937",)
+ b = ctor(*args[:1])
+ b.set_state(state_a)
+ state_b = b.get_state(legacy=False)
+
+ assert_equal(state_a['bit_generator'], state_b['bit_generator'])
+ assert_array_equal(state_a['state']['key'], state_b['state']['key'])
+ assert_array_equal(state_a['state']['pos'], state_b['state']['pos'])
+ assert_equal(state_a['has_gauss'], state_b['has_gauss'])
+ assert_equal(state_a['gauss'], state_b['gauss'])