summaryrefslogtreecommitdiff
path: root/coverage/sqldata.py
blob: ee0798e36d925ecc58b3aa438bde9847f8030377 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt

"""Sqlite coverage data."""

import os
import sqlite3

from coverage.backward import iitems
from coverage.misc import CoverageException, file_be_gone


SCHEMA = """
create table schema (
    version integer
);

insert into schema (version) values (1);

create table meta (
    name text,
    value text,
    unique(name)
);

create table file (
    id integer primary key,
    path text,
    tracer text,
    unique(path)
);

create table line (
    file_id integer,
    lineno integer,
    unique(file_id, lineno)
);

create table arc (
    file_id integer,
    fromno integer,
    tono integer,
    unique(file_id, fromno, tono)
);
"""

def _create_db(filename, schema):
    con = sqlite3.connect(filename)
    with con:
        for stmt in schema.split(';'):
            con.execute(stmt.strip())
    con.close()


class CoverageDataSqlite(object):
    def __init__(self, basename=None, warn=None, debug=None):
        self.filename = os.path.abspath(basename or ".coverage")
        self._warn = warn
        self._debug = debug

        self._file_map = {}
        self._db = None
        self._have_read = False

    def _reset(self):
        self._file_map = {}
        if self._db is not None:
            self._db.close()
        self._db = None

    def _connect(self):
        if self._db is None:
            if not os.path.exists(self.filename):
                if self._debug and self._debug.should('dataio'):
                    self._debug.write("Creating data file %r" % (self.filename,))
                _create_db(self.filename, SCHEMA)
            self._db = sqlite3.connect(self.filename)
            for path, id in self._db.execute("select path, id from file"):
                self._file_map[path] = id
        return self._db

    def _file_id(self, filename):
        self._start_writing()
        if filename not in self._file_map:
            with self._connect() as con:
                cur = con.cursor()
                cur.execute("insert into file (path) values (?)", (filename,))
                self._file_map[filename] = cur.lastrowid
        return self._file_map[filename]

    def add_lines(self, line_data):
        """Add measured line data.

        `line_data` is a dictionary mapping file names to dictionaries::

            { filename: { lineno: None, ... }, ...}

        """
        self._start_writing()
        with self._connect() as con:
            for filename, linenos in iitems(line_data):
                file_id = self._file_id(filename)
                for lineno in linenos:
                    con.execute(
                        "insert or ignore into line (file_id, lineno) values (?, ?)",
                        (file_id, lineno),
                    )

    def add_file_tracers(self, file_tracers):
        """Add per-file plugin information.

        `file_tracers` is { filename: plugin_name, ... }

        """
        self._start_writing()
        with self._connect() as con:
            for filename, tracer in iitems(file_tracers):
                con.execute(
                    "insert into file (path, tracer) values (?, ?) on duplicate key update",
                    (filename, tracer),
                )

    def erase(self, parallel=False):
        """Erase the data in this object.

        If `parallel` is true, then also deletes data files created from the
        basename by parallel-mode.

        """
        self._reset()
        if self._debug and self._debug.should('dataio'):
            self._debug.write("Erasing data file %r" % (self.filename,))
        file_be_gone(self.filename)
        if parallel:
            data_dir, local = os.path.split(self.filename)
            localdot = local + '.*'
            pattern = os.path.join(os.path.abspath(data_dir), localdot)
            for filename in glob.glob(pattern):
                if self._debug and self._debug.should('dataio'):
                    self._debug.write("Erasing parallel data file %r" % (filename,))
                file_be_gone(filename)

    def read(self):
        self._have_read = True

    def write(self, suffix=None):
        """Write the collected coverage data to a file."""
        pass

    def _start_writing(self):
        if not self._have_read:
            self.erase()
        self._have_read = True

    def has_arcs(self):
        return False    # TODO!

    def measured_files(self):
        """A list of all files that had been measured."""
        self._connect()
        return list(self._file_map)

    def file_tracer(self, filename):
        """Get the plugin name of the file tracer for a file.

        Returns the name of the plugin that handles this file.  If the file was
        measured, but didn't use a plugin, then "" is returned.  If the file
        was not measured, then None is returned.

        """
        with self._connect() as con:
            pass