diff options
author | Ned Batchelder <nedbat@gmail.com> | 2017-10-21 20:29:09 +0000 |
---|---|---|
committer | Ned Batchelder <nedbat@gmail.com> | 2017-10-21 20:29:09 +0000 |
commit | 5e142b6c029f3a7672ff89a9883895622e8b57dc (patch) | |
tree | ce0da676d10be369c077b7111787259af2f6f063 | |
parent | 9e22cdaf8d1e1b39814d04c9b8badda4c0d4769a (diff) | |
parent | 6f79cf1d20452bfd6222410165e13a1b75e1b5bb (diff) | |
download | python-coveragepy-git-5e142b6c029f3a7672ff89a9883895622e8b57dc.tar.gz |
Merged in ogrisel/coverage.py/fix-thread-safety (pull request #127)
FIX thread-safe Collector.save_data()
-rw-r--r-- | coverage/collector.py | 5 | ||||
-rw-r--r-- | coverage/ctracer/tracer.c | 13 | ||||
-rw-r--r-- | coverage/pytracer.py | 16 | ||||
-rw-r--r-- | tests/test_concurrency.py | 95 |
4 files changed, 123 insertions, 6 deletions
diff --git a/coverage/collector.py b/coverage/collector.py index 90c32756..cea341a1 100644 --- a/coverage/collector.py +++ b/coverage/collector.py @@ -375,7 +375,10 @@ class Collector(object): def abs_file_dict(d): """Return a dict like d, but with keys modified by `abs_file`.""" - return dict((abs_file(k), v) for k, v in iitems(d)) + # The call to list() ensures that the GIL protects the dictionary + # iterator against concurrent modifications by tracers running + # in other threads. + return dict((abs_file(k), v) for k, v in list(iitems(d))) if self.branch: covdata.add_arcs(abs_file_dict(self.data)) diff --git a/coverage/ctracer/tracer.c b/coverage/ctracer/tracer.c index 625a45a6..0ade7412 100644 --- a/coverage/ctracer/tracer.c +++ b/coverage/ctracer/tracer.c @@ -833,6 +833,14 @@ CTracer_trace(CTracer *self, PyFrameObject *frame, int what, PyObject *arg_unuse return RET_OK; #endif + if (!self->started) { + /* If CTracer.stop() has been called from another thread, the tracer + is still active in the current thread. Let's deactivate ourselves + now. */ + PyEval_SetTrace(NULL, NULL); + return RET_OK; + } + #if WHAT_LOG || TRACE_LOG PyObject * ascii = NULL; #endif @@ -1027,7 +1035,10 @@ static PyObject * CTracer_stop(CTracer *self, PyObject *args_unused) { if (self->started) { - PyEval_SetTrace(NULL, NULL); + /* Set the started flag only. The actual call to + PyEval_SetTrace(NULL, NULL) is delegated to the callback + itself to ensure that it called from the right thread. + */ self->started = FALSE; } diff --git a/coverage/pytracer.py b/coverage/pytracer.py index b41f4059..8f1b4b98 100644 --- a/coverage/pytracer.py +++ b/coverage/pytracer.py @@ -68,7 +68,10 @@ class PyTracer(object): def _trace(self, frame, event, arg_unused): """The trace function passed to sys.settrace.""" - if self.stopped: + if (self.stopped and sys.gettrace() == self._trace): + # The PyTrace.stop() method has been called, possibly by another + # thread, let's deactivate ourselves now. + sys.settrace(None) return if self.last_exc_back: @@ -152,7 +155,15 @@ class PyTracer(object): def stop(self): """Stop this Tracer.""" + # Get the activate tracer callback before setting the stop flag to be + # able to detect if the tracer was changed prior to stopping it. + tf = sys.gettrace() + + # Set the stop flag. The actual call to sys.settrace(None) will happen + # in the self._trace callback itself to make sure to call it from the + # right thread. self.stopped = True + if self.threading and self.thread.ident != self.threading.currentThread().ident: # Called on a different thread than started us: we can't unhook # ourselves, but we've set the flag that we should stop, so we @@ -163,7 +174,6 @@ class PyTracer(object): # PyPy clears the trace function before running atexit functions, # so don't warn if we are in atexit on PyPy and the trace function # has changed to None. - tf = sys.gettrace() dont_warn = (env.PYPY and env.PYPYVERSION >= (5, 4) and self.in_atexit and tf is None) if (not dont_warn) and tf != self._trace: self.warn( @@ -171,8 +181,6 @@ class PyTracer(object): slug="trace-changed", ) - sys.settrace(None) - def activity(self): """Has there been any activity?""" return self._activity diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index a57ee93b..5b25e4c2 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -3,7 +3,11 @@ """Tests for concurrency libraries.""" +import importlib +import random +import sys import threading +import time from flaky import flaky @@ -453,3 +457,94 @@ class MultiprocessingTest(CoverageTest): total = sum(x*x if x%2 else x*x*x for x in range(upto)) expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total) self.try_multiprocessing_code_with_branching(code, expected_out) + + +def test_coverage_stop_in_threads(): + has_started_coverage = [] + has_stopped_coverage = [] + + def run_thread(): + deadline = time.time() + 5 + ident = threading.currentThread().ident + if sys.gettrace() is not None: + has_started_coverage.append(ident) + while sys.gettrace() is not None: + # Wait for coverage to stop + time.sleep(0.01) + if time.time() > deadline: + return + has_stopped_coverage.append(ident) + + cov = coverage.coverage() + cov.start() + + t = threading.Thread(target=run_thread) + t.start() + + time.sleep(0.1) + cov.stop() + time.sleep(0.1) + + assert has_started_coverage == [t.ident] + assert has_stopped_coverage == [t.ident] + t.join() + + +def test_thread_safe_save_data(tmpdir): + # Non-regression test for: + # https://bitbucket.org/ned/coveragepy/issues/581 + + # Create some Python modules and put them in the path + modules_dir = tmpdir.mkdir('test_modules') + module_names = ["m{:03d}".format(i) for i in range(1000)] + for module_name in module_names: + modules_dir.join(module_name + ".py").write("def f(): pass\n") + + # Shared variables for threads + should_run = [True] + imported = [] + + sys.path.insert(0, modules_dir.strpath) + try: + # Make sure that all dummy modules can be imported. + for module_name in module_names: + importlib.import_module(module_name) + + def random_load(): + while should_run[0]: + module_name = random.choice(module_names) + mod = importlib.import_module(module_name) + mod.f() + imported.append(mod) + + # Spawn some threads with coverage enabled and attempt to read the + # results right after stopping coverage collection with the threads + # still running. + duration = 0.01 + for i in range(3): + cov = coverage.coverage() + cov.start() + + threads = [threading.Thread(target=random_load) for i in range(10)] + should_run[0] = True + for t in threads: + t.start() + + time.sleep(duration) + + cov.stop() + + # The following call use to crash with running background threads. + cov.get_data() + + # Stop the threads + should_run[0] = False + for t in threads: + t.join() + + if len(imported) == 0 and duration < 10: + duration *= 2 + + finally: + sys.path.remove(modules_dir.strpath) + should_run[0] = False |