summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--coverage/collector.py5
-rw-r--r--coverage/ctracer/tracer.c8
-rw-r--r--coverage/pytracer.py3
-rw-r--r--tests/test_concurrency.py95
4 files changed, 110 insertions, 1 deletions
diff --git a/coverage/collector.py b/coverage/collector.py
index cfdcf402..a318c106 100644
--- a/coverage/collector.py
+++ b/coverage/collector.py
@@ -374,7 +374,10 @@ class Collector(object):
def abs_file_dict(d):
"""Return a dict like d, but with keys modified by `abs_file`."""
- return dict((abs_file(k), v) for k, v in iitems(d))
+ # The call to list() ensures that the GIL protects the dictionary
+ # iterator against concurrent modifications by tracers running
+ # in other threads.
+ return dict((abs_file(k), v) for k, v in list(iitems(d)))
if self.branch:
covdata.add_arcs(abs_file_dict(self.data))
diff --git a/coverage/ctracer/tracer.c b/coverage/ctracer/tracer.c
index 625a45a6..236fbbfa 100644
--- a/coverage/ctracer/tracer.c
+++ b/coverage/ctracer/tracer.c
@@ -833,6 +833,14 @@ CTracer_trace(CTracer *self, PyFrameObject *frame, int what, PyObject *arg_unuse
return RET_OK;
#endif
+ if (!self->started) {
+ /* If CTracer.stop() has been called from another thread, the tracer
+ is still active in the current thread. Let's deactivate ourselves
+ now. */
+ PyEval_SetTrace(NULL, NULL);
+ return RET_OK;
+ }
+
#if WHAT_LOG || TRACE_LOG
PyObject * ascii = NULL;
#endif
diff --git a/coverage/pytracer.py b/coverage/pytracer.py
index b41f4059..6ebb60c7 100644
--- a/coverage/pytracer.py
+++ b/coverage/pytracer.py
@@ -69,6 +69,9 @@ class PyTracer(object):
"""The trace function passed to sys.settrace."""
if self.stopped:
+ # The PyTrace.stop() method has been called by another thread,
+ # let's deactivate ourselves now.
+ self.stop()
return
if self.last_exc_back:
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
index 841b5df4..131fa72d 100644
--- a/tests/test_concurrency.py
+++ b/tests/test_concurrency.py
@@ -3,7 +3,11 @@
"""Tests for concurrency libraries."""
+import importlib
+import random
+import sys
import threading
+import time
from flaky import flaky
@@ -453,3 +457,94 @@ class MultiprocessingTest(CoverageTest):
total = sum(x*x if x%2 else x*x*x for x in range(upto))
expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total)
self.try_multiprocessing_code_with_branching(code, expected_out)
+
+
+def test_coverage_stop_in_threads():
+ has_started_coverage = []
+ has_stopped_coverage = []
+
+ def run_thread():
+ deadline = time.time() + 5
+ ident = threading.currentThread().ident
+ if sys.gettrace() is not None:
+ has_started_coverage.append(ident)
+ while sys.gettrace() is not None:
+ # Wait for coverage to stop
+ time.sleep(0.01)
+ if time.time() > deadline:
+ return
+ has_stopped_coverage.append(ident)
+
+ cov = coverage.coverage()
+ cov.start()
+
+ t = threading.Thread(target=run_thread)
+ t.start()
+
+ time.sleep(0.1)
+ cov.stop()
+ time.sleep(0.1)
+
+ assert has_started_coverage == [t.ident]
+ assert has_started_coverage == [t.ident]
+ t.join()
+
+
+def test_thread_safe_save_data(tmpdir):
+ # Non-regression test for:
+ # https://bitbucket.org/ned/coveragepy/issues/581
+
+ # Create some Python modules and put them in the path
+ modules_dir = tmpdir.mkdir('test_modules')
+ module_names = ["m{:03d}".format(i) for i in range(1000)]
+ for module_name in module_names:
+ modules_dir.join(module_name + ".py").write("def f(): pass\n")
+
+ # Shared variables for threads
+ should_run = [True]
+ imported = []
+
+ sys.path.insert(0, modules_dir.strpath)
+ try:
+ # Make sure that all dummy modules can be imported.
+ for module_name in module_names:
+ importlib.import_module(module_name)
+
+ def random_load():
+ while should_run[0]:
+ module_name = random.choice(module_names)
+ mod = importlib.import_module(module_name)
+ mod.f()
+ imported.append(mod)
+
+ # Spawn some threads with coverage enabled and attempt to read the
+ # results right after stopping coverage collection with the threads
+ # still running.
+ duration = 0.01
+ for i in range(3):
+ cov = coverage.coverage()
+ cov.start()
+
+ threads = [threading.Thread(target=random_load) for i in range(10)]
+ should_run[0] = True
+ for t in threads:
+ t.start()
+
+ time.sleep(duration)
+
+ cov.stop()
+
+ # The following call use to crash with running background threads.
+ cov.get_data()
+
+ # Stop the threads
+ should_run[0] = False
+ for t in threads:
+ t.join()
+
+ if len(imported) == 0 and duration < 10:
+ duration *= 2
+
+ finally:
+ sys.path.remove(modules_dir.strpath)
+ should_run[0] = False