diff options
Diffstat (limited to 'coverage/sqldata.py')
-rw-r--r-- | coverage/sqldata.py | 37 |
1 files changed, 19 insertions, 18 deletions
diff --git a/coverage/sqldata.py b/coverage/sqldata.py index af1c837a..d35bb36a 100644 --- a/coverage/sqldata.py +++ b/coverage/sqldata.py @@ -17,6 +17,7 @@ import itertools import os import sqlite3 import sys +import threading from coverage.backward import get_thread_id, iitems from coverage.data import filename_suffix @@ -95,6 +96,7 @@ class CoverageSqliteData(SimpleReprMixin): # Maps thread ids to SqliteDb objects. self._dbs = {} self._pid = os.getpid() + self._lock = threading.Lock() # Are we in sync with the data file? self._have_used = False @@ -114,9 +116,8 @@ class CoverageSqliteData(SimpleReprMixin): self._filename += "." + suffix def _reset(self): - if self._dbs: - for db in self._dbs.values(): - db.close() + for db in self._dbs.values(): + db.close() self._dbs = {} self._file_map = {} self._have_used = False @@ -166,11 +167,12 @@ class CoverageSqliteData(SimpleReprMixin): def _connect(self): """Get the SqliteDb object to use.""" - if get_thread_id() not in self._dbs: - if os.path.exists(self._filename): - self._open_db() - else: - self._create_db() + with self._lock: + if get_thread_id() not in self._dbs: + if os.path.exists(self._filename): + self._open_db() + else: + self._create_db() return self._dbs[get_thread_id()] def __nonzero__(self): @@ -523,7 +525,8 @@ class CoverageSqliteData(SimpleReprMixin): def write(self): """Write the collected coverage data to a file.""" - pass + for db in self._dbs.values(): + db.close() def _start_using(self): if self._pid != os.getpid(): @@ -684,7 +687,7 @@ class SqliteDb(SimpleReprMixin): def __init__(self, filename, debug): self.debug = debug if debug.should('sql') else None self.filename = filename - self.nest = 0 + self.con = None def connect(self): # SQLite on Windows on py2 won't open a file if the filename argument @@ -708,20 +711,18 @@ class SqliteDb(SimpleReprMixin): self.execute("pragma synchronous=off").close() def close(self): - self.con.close() + if self.con is not None: + self.con.close() + self.con = None def __enter__(self): - if self.nest == 0: + if self.con is None: self.connect() - self.con.__enter__() - self.nest += 1 + self.con.__enter__() return self def __exit__(self, exc_type, exc_value, traceback): - self.nest -= 1 - if self.nest == 0: - self.con.__exit__(exc_type, exc_value, traceback) - self.close() + self.con.__exit__(exc_type, exc_value, traceback) def execute(self, sql, parameters=()): if self.debug: |