summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/random/mtrand/mtrand.pyx18
1 files changed, 11 insertions, 7 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 3f6af86b1..8f79780d3 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -691,10 +691,13 @@ cdef class RandomState:
"""
cdef ndarray state "arrayObject_state"
state = <ndarray>np.empty(624, np.uint)
- memcpy(<void*>PyArray_DATA(state), <void*>(self.internal_state.key), 624*sizeof(long))
+ with self.lock:
+ memcpy(<void*>PyArray_DATA(state), <void*>(self.internal_state.key), 624*sizeof(long))
+ has_gauss = self.internal_state.has_gauss
+ gauss = self.internal_state.gauss
+ pos = self.internal_state.pos
state = <ndarray>np.asarray(state, np.uint32)
- return ('MT19937', state, self.internal_state.pos,
- self.internal_state.has_gauss, self.internal_state.gauss)
+ return ('MT19937', state, pos, has_gauss, gauss)
def set_state(self, state):
"""
@@ -761,10 +764,11 @@ cdef class RandomState:
obj = <ndarray>PyArray_ContiguousFromObject(key, NPY_LONG, 1, 1)
if PyArray_DIM(obj, 0) != 624:
raise ValueError("state must be 624 longs")
- memcpy(<void*>(self.internal_state.key), <void*>PyArray_DATA(obj), 624*sizeof(long))
- self.internal_state.pos = pos
- self.internal_state.has_gauss = has_gauss
- self.internal_state.gauss = cached_gaussian
+ with self.lock:
+ memcpy(<void*>(self.internal_state.key), <void*>PyArray_DATA(obj), 624*sizeof(long))
+ self.internal_state.pos = pos
+ self.internal_state.has_gauss = has_gauss
+ self.internal_state.gauss = cached_gaussian
# Pickling support:
def __getstate__(self):