summaryrefslogtreecommitdiff
path: root/coverage/sqldata.py
diff options
context:
space:
mode:
Diffstat (limited to 'coverage/sqldata.py')
-rw-r--r--coverage/sqldata.py44
1 files changed, 34 insertions, 10 deletions
diff --git a/coverage/sqldata.py b/coverage/sqldata.py
index db2b0b17..05680043 100644
--- a/coverage/sqldata.py
+++ b/coverage/sqldata.py
@@ -44,12 +44,6 @@ create table arc (
);
"""
-def _create_db(filename, schema):
- con = sqlite3.connect(filename)
- with con:
- for stmt in schema.split(';'):
- con.execute(stmt.strip())
- con.close()
class CoverageDataSqlite(object):
@@ -73,8 +67,14 @@ class CoverageDataSqlite(object):
if not os.path.exists(self.filename):
if self._debug and self._debug.should('dataio'):
self._debug.write("Creating data file %r" % (self.filename,))
- _create_db(self.filename, SCHEMA)
- self._db = sqlite3.connect(self.filename)
+ self._db = Sqlite(self.filename, self._debug)
+ with self._db:
+ for stmt in SCHEMA.split(';'):
+ stmt = stmt.strip()
+ if stmt:
+ self._db.execute(stmt)
+ else:
+ self._db = Sqlite(self.filename, self._debug)
for path, id in self._db.execute("select path, id from file"):
self._file_map[path] = id
return self._db
@@ -83,8 +83,7 @@ class CoverageDataSqlite(object):
self._start_writing()
if filename not in self._file_map:
with self._connect() as con:
- cur = con.cursor()
- cur.execute("insert into file (path) values (?)", (filename,))
+ cur = con.execute("insert into file (path) values (?)", (filename,))
self._file_map[filename] = cur.lastrowid
return self._file_map[filename]
@@ -174,3 +173,28 @@ class CoverageDataSqlite(object):
with self._connect() as con:
file_id = self._file_id(filename)
return [lineno for lineno, in con.execute("select lineno from line where file_id = ?", (file_id,))]
+
+
+class Sqlite(object):
+ def __init__(self, filename, debug):
+ self.debug = debug if debug.should('sql') else None
+ if self.debug:
+ self.debug.write("Connecting to {!r}".format(filename))
+ self.con = sqlite3.connect(filename)
+
+ def close(self):
+ self.con.close()
+
+ def __enter__(self):
+ self.con.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ return self.con.__exit__(exc_type, exc_value, traceback)
+
+ def execute(self, sql, parameters=()):
+ if self.debug:
+ tail = " with {!r}".format(parameters) if parameters else ""
+ self.debug.write("Executing {!r}{}".format(sql, tail))
+ cur = self.con.execute(sql, parameters)
+ return cur