diff options
Diffstat (limited to 'coverage/sqldata.py')
-rw-r--r-- | coverage/sqldata.py | 44 |
1 files changed, 34 insertions, 10 deletions
diff --git a/coverage/sqldata.py b/coverage/sqldata.py index db2b0b17..05680043 100644 --- a/coverage/sqldata.py +++ b/coverage/sqldata.py @@ -44,12 +44,6 @@ create table arc ( ); """ -def _create_db(filename, schema): - con = sqlite3.connect(filename) - with con: - for stmt in schema.split(';'): - con.execute(stmt.strip()) - con.close() class CoverageDataSqlite(object): @@ -73,8 +67,14 @@ class CoverageDataSqlite(object): if not os.path.exists(self.filename): if self._debug and self._debug.should('dataio'): self._debug.write("Creating data file %r" % (self.filename,)) - _create_db(self.filename, SCHEMA) - self._db = sqlite3.connect(self.filename) + self._db = Sqlite(self.filename, self._debug) + with self._db: + for stmt in SCHEMA.split(';'): + stmt = stmt.strip() + if stmt: + self._db.execute(stmt) + else: + self._db = Sqlite(self.filename, self._debug) for path, id in self._db.execute("select path, id from file"): self._file_map[path] = id return self._db @@ -83,8 +83,7 @@ class CoverageDataSqlite(object): self._start_writing() if filename not in self._file_map: with self._connect() as con: - cur = con.cursor() - cur.execute("insert into file (path) values (?)", (filename,)) + cur = con.execute("insert into file (path) values (?)", (filename,)) self._file_map[filename] = cur.lastrowid return self._file_map[filename] @@ -174,3 +173,28 @@ class CoverageDataSqlite(object): with self._connect() as con: file_id = self._file_id(filename) return [lineno for lineno, in con.execute("select lineno from line where file_id = ?", (file_id,))] + + +class Sqlite(object): + def __init__(self, filename, debug): + self.debug = debug if debug.should('sql') else None + if self.debug: + self.debug.write("Connecting to {!r}".format(filename)) + self.con = sqlite3.connect(filename) + + def close(self): + self.con.close() + + def __enter__(self): + self.con.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + return self.con.__exit__(exc_type, exc_value, traceback) + + def execute(self, sql, parameters=()): + if self.debug: + tail = " with {!r}".format(parameters) if parameters else "" + self.debug.write("Executing {!r}{}".format(sql, tail)) + cur = self.con.execute(sql, parameters) + return cur |