diff options
-rw-r--r-- | coverage/monkey.py | 30 | ||||
-rw-r--r-- | igor.py | 2 | ||||
-rw-r--r-- | tests/test_concurrency.py | 30 |
3 files changed, 51 insertions, 11 deletions
diff --git a/coverage/monkey.py b/coverage/monkey.py index b896dbf5..3f78d7dc 100644 --- a/coverage/monkey.py +++ b/coverage/monkey.py @@ -12,7 +12,6 @@ import sys PATCHED_MARKER = "_coverage$patched" if sys.version_info >= (3, 4): - klass = multiprocessing.process.BaseProcess else: klass = multiprocessing.Process @@ -49,4 +48,33 @@ def patch_multiprocessing(): else: multiprocessing.Process = ProcessWithCoverage + # When spawning processes rather than forking them, we have no state in the + # new process. We sneak in there with a Stowaway: we stuff one of our own + # objects into the data that gets pickled and sent to the sub-process. When + # the Stowaway is unpickled, it's __setstate__ method is called, which + # re-applies the monkey-patch. + # Windows only spawns, so this is needed to keep Windows working. + try: + from multiprocessing import spawn + original_get_preparation_data = spawn.get_preparation_data + except (ImportError, AttributeError): + pass + else: + def get_preparation_data_with_stowaway(name): + """Get the original preparation data, and also insert our stowaway.""" + d = original_get_preparation_data(name) + d['stowaway'] = Stowaway() + return d + + spawn.get_preparation_data = get_preparation_data_with_stowaway + setattr(multiprocessing, PATCHED_MARKER, True) + + +class Stowaway(object): + """An object to pickle, so when it is unpickled, it can apply the monkey-patch.""" + def __getstate__(self): + return {} + + def __setstate__(self, state): + patch_multiprocessing() @@ -332,7 +332,7 @@ def print_banner(label): which_python = os.path.relpath(sys.executable) except ValueError: # On Windows having a python executable on a different drives - # than the sources cannot be relative + # than the sources cannot be relative. which_python = sys.executable print('=== %s %s %s (%s) ===' % (impl, version, label, which_python)) sys.stdout.flush() diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 0f5ffe95..04eb9853 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -3,6 +3,7 @@ """Tests for concurrency libraries.""" +import multiprocessing import threading import coverage @@ -227,6 +228,7 @@ class MultiprocessingTest(CoverageTest): import multiprocessing import os import time + import sys def func(x): # Need to pause, or the tasks go too quick, and some processes @@ -240,6 +242,7 @@ class MultiprocessingTest(CoverageTest): return os.getpid(), y if __name__ == "__main__": + if len(sys.argv) > 1: multiprocessing.set_start_method(sys.argv[1]) pool = multiprocessing.Pool(3) inputs = range(30) outputs = pool.imap_unordered(func, inputs) @@ -253,16 +256,25 @@ class MultiprocessingTest(CoverageTest): pool.join() """) - out = self.run_command( - "coverage run --concurrency=multiprocessing multi.py" - ) - total = sum(x*x if x%2 else x*x*x for x in range(30)) - self.assertEqual(out.rstrip(), "3 pids, total = %d" % total) + if env.PYVERSION >= (3, 4): + start_methods = ['fork', 'spawn'] + else: + start_methods = [''] + + for start_method in start_methods: + if start_method and start_method not in multiprocessing.get_all_start_methods(): + continue + + out = self.run_command( + "coverage run --concurrency=multiprocessing multi.py %s" % start_method + ) + total = sum(x*x if x%2 else x*x*x for x in range(30)) + self.assertEqual(out.rstrip(), "3 pids, total = %d" % total) - self.run_command("coverage combine") - out = self.run_command("coverage report -m") - last_line = self.squeezed_lines(out)[-1] - self.assertEqual(last_line, "multi.py 21 0 100%") + self.run_command("coverage combine") + out = self.run_command("coverage report -m") + last_line = self.squeezed_lines(out)[-1] + self.assertEqual(last_line, "multi.py 23 0 100%") def print_simple_annotation(code, linenos): |