diff options
Diffstat (limited to 'coverage/sqldata.py')
-rw-r--r-- | coverage/sqldata.py | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/coverage/sqldata.py b/coverage/sqldata.py index ea4922ca..00a72b9f 100644 --- a/coverage/sqldata.py +++ b/coverage/sqldata.py @@ -15,9 +15,10 @@ import itertools import os import sqlite3 import sys +import zlib from coverage.backward import get_thread_id, iitems -from coverage.backward import bytes_to_ints, binary_bytes, zip_longest +from coverage.backward import bytes_to_ints, binary_bytes, zip_longest, to_bytes, to_string from coverage.debug import NoDebugging, SimpleReprMixin from coverage import env from coverage.files import PathAliases @@ -219,8 +220,11 @@ class CoverageData(SimpleReprMixin): def _open_db(self): if self._debug.should('dataio'): self._debug.write("Opening data file {!r}".format(self._filename)) - self._dbs[get_thread_id()] = db = SqliteDb(self._filename, self._debug) - with db: + self._dbs[get_thread_id()] = SqliteDb(self._filename, self._debug) + self._read_db() + + def _read_db(self): + with self._dbs[get_thread_id()] as db: try: schema_version, = db.execute("select version from coverage_schema").fetchone() except Exception as exc: @@ -264,22 +268,23 @@ class CoverageData(SimpleReprMixin): __bool__ = __nonzero__ - def dump(self): # pragma: debugging - """Write a dump of the database.""" - if self._debug: - with self._connect() as con: - self._debug.write(con.dump()) - + @contract(returns='bytes') def dumps(self): with self._connect() as con: - return con.dump() + return b'z' + zlib.compress(to_bytes(con.dump())) + @contract(data='bytes') def loads(self, data): if self._debug.should('dataio'): self._debug.write("Loading data into data file {!r}".format(self._filename)) + if data[:1] != b'z': + raise CoverageException("Unrecognized serialization: {!r} (head of {} bytes)".format(data[:40], len(data))) + script = to_string(zlib.decompress(data[1:])) self._dbs[get_thread_id()] = db = SqliteDb(self._filename, self._debug) with db: - db.executescript(data) + db.executescript(script) + self._read_db() + self._have_used = True def _file_id(self, filename, add=False): """Get the file id for `filename`. @@ -858,7 +863,7 @@ class SqliteDb(SimpleReprMixin): self.debug.write("Executing script with {} chars".format(len(script))) self.con.executescript(script) - def dump(self): # pragma: debugging + def dump(self): """Return a multi-line string, the dump of the database.""" return "\n".join(self.con.iterdump()) |