summaryrefslogtreecommitdiff
path: root/tests/test_coroutine.py
blob: 2b4d298f03014dfff13b51f543cfea75fa191caa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Tests for coroutining."""

import os, os.path, sys, threading

import coverage

from tests.coveragetest import CoverageTest


# These libraries aren't always available, we'll skip tests if they aren't.

try:
    import eventlet         # pylint: disable=import-error
except ImportError:
    eventlet = None

try:
    import gevent           # pylint: disable=import-error
except ImportError:
    gevent = None

# Are we running with the C tracer or not?
C_TRACER = os.getenv('COVERAGE_TEST_TRACER', 'c') == 'c'


def line_count(s):
    """How many non-blank non-comment lines are in `s`?"""
    def code_line(l):
        """Is this a code line? Not blank, and not a full-line comment."""
        return l.strip() and not l.strip().startswith('#')
    return sum(1 for l in s.splitlines() if code_line(l))


class CoroutineTest(CoverageTest):
    """Tests of the coroutine support in coverage.py."""

    LIMIT = 1000

    # The code common to all the concurrency models.
    COMMON = """
        class Producer(threading.Thread):
            def __init__(self, q):
                threading.Thread.__init__(self)
                self.q = q

            def run(self):
                for i in range({LIMIT}):
                    self.q.put(i)
                self.q.put(None)

        class Consumer(threading.Thread):
            def __init__(self, q):
                threading.Thread.__init__(self)
                self.q = q

            def run(self):
                sum = 0
                while True:
                    i = self.q.get()
                    if i is None:
                        print(sum)
                        break
                    sum += i

        q = queue.Queue()
        c = Consumer(q)
        p = Producer(q)
        c.start()
        p.start()

        p.join()
        c.join()
        """.format(LIMIT=LIMIT)

    # Import the things to use threads.
    if sys.version_info < (3, 0):
        THREAD = """\
        import threading
        import Queue as queue
        """ + COMMON
    else:
        THREAD = """\
        import threading
        import queue
        """ + COMMON

    # Import the things to use eventlet.
    EVENTLET = """\
        import eventlet.green.threading as threading
        import eventlet.queue as queue
        """ + COMMON

    # Import the things to use gevent.
    GEVENT = """\
        from gevent import monkey
        monkey.patch_thread()
        import threading
        import gevent.queue as queue
        """ + COMMON

    # Uncomplicated code that doesn't use any of the coroutining stuff, to test
    # the simple case under each of the regimes.
    SIMPLE = """\
        total = 0
        for i in range({LIMIT}):
            total += i
        print(total)
        """.format(LIMIT=LIMIT)

    def try_some_code(self, code, coroutine, the_module):
        """Run some coroutine testing code and see that it was all covered.

        `code` is the Python code to execute.  `coroutine` is the name of the
        coroutine regime to test it under.  `the_module` is the imported module
        that must be available for this to work at all.

        """

        self.make_file("try_it.py", code)

        cmd = "coverage run --coroutine=%s try_it.py" % coroutine
        out = self.run_command(cmd)

        if not the_module:
            # We don't even have the underlying module installed, we expect
            # coverage to alert us to this fact.
            expected_out = (
                "Couldn't trace with coroutine=%s, "
                "the module isn't installed.\n" % coroutine
            )
            self.assertEqual(out, expected_out)
        elif C_TRACER or coroutine == "thread":
            # We can fully measure the code if we are using the C tracer, which
            # can support all the coroutining, or if we are using threads.
            expected_out = "%d\n" % (sum(range(self.LIMIT)))
            self.assertEqual(out, expected_out)

            # Read the coverage file and see that try_it.py has all its lines
            # executed.
            data = coverage.CoverageData()
            data.read_file(".coverage")

            # If the test fails, it's helpful to see this info:
            fname = os.path.abspath("try_it.py")
            linenos = data.executed_lines(fname).keys()
            print("{0}: {1}".format(len(linenos), linenos))
            print_simple_annotation(code, linenos)

            lines = line_count(code)
            self.assertEqual(data.summary()['try_it.py'], lines)
        else:
            expected_out = (
                "Can't support coroutine=%s with PyTracer, "
                "only threads are supported\n" % coroutine
            )
            self.assertEqual(out, expected_out)

    def test_threads(self):
        self.try_some_code(self.THREAD, "thread", threading)

    def test_threads_simple_code(self):
        self.try_some_code(self.SIMPLE, "thread", threading)

    def test_eventlet(self):
        self.try_some_code(self.EVENTLET, "eventlet", eventlet)

    def test_eventlet_simple_code(self):
        self.try_some_code(self.SIMPLE, "eventlet", eventlet)

    def test_gevent(self):
        self.try_some_code(self.GEVENT, "gevent", gevent)

    def test_gevent_simple_code(self):
        self.try_some_code(self.SIMPLE, "gevent", gevent)


def print_simple_annotation(code, linenos):
    """Print the lines in `code` with X for each line number in `linenos`."""
    for lineno, line in enumerate(code.splitlines(), start=1):
        print(" {0} {1}".format("X" if lineno in linenos else " ", line))