summaryrefslogtreecommitdiff
path: root/coverage/sqldata.py
diff options
context:
space:
mode:
Diffstat (limited to 'coverage/sqldata.py')
-rw-r--r--coverage/sqldata.py37
1 files changed, 19 insertions, 18 deletions
diff --git a/coverage/sqldata.py b/coverage/sqldata.py
index af1c837a..d35bb36a 100644
--- a/coverage/sqldata.py
+++ b/coverage/sqldata.py
@@ -17,6 +17,7 @@ import itertools
import os
import sqlite3
import sys
+import threading
from coverage.backward import get_thread_id, iitems
from coverage.data import filename_suffix
@@ -95,6 +96,7 @@ class CoverageSqliteData(SimpleReprMixin):
# Maps thread ids to SqliteDb objects.
self._dbs = {}
self._pid = os.getpid()
+ self._lock = threading.Lock()
# Are we in sync with the data file?
self._have_used = False
@@ -114,9 +116,8 @@ class CoverageSqliteData(SimpleReprMixin):
self._filename += "." + suffix
def _reset(self):
- if self._dbs:
- for db in self._dbs.values():
- db.close()
+ for db in self._dbs.values():
+ db.close()
self._dbs = {}
self._file_map = {}
self._have_used = False
@@ -166,11 +167,12 @@ class CoverageSqliteData(SimpleReprMixin):
def _connect(self):
"""Get the SqliteDb object to use."""
- if get_thread_id() not in self._dbs:
- if os.path.exists(self._filename):
- self._open_db()
- else:
- self._create_db()
+ with self._lock:
+ if get_thread_id() not in self._dbs:
+ if os.path.exists(self._filename):
+ self._open_db()
+ else:
+ self._create_db()
return self._dbs[get_thread_id()]
def __nonzero__(self):
@@ -523,7 +525,8 @@ class CoverageSqliteData(SimpleReprMixin):
def write(self):
"""Write the collected coverage data to a file."""
- pass
+ for db in self._dbs.values():
+ db.close()
def _start_using(self):
if self._pid != os.getpid():
@@ -684,7 +687,7 @@ class SqliteDb(SimpleReprMixin):
def __init__(self, filename, debug):
self.debug = debug if debug.should('sql') else None
self.filename = filename
- self.nest = 0
+ self.con = None
def connect(self):
# SQLite on Windows on py2 won't open a file if the filename argument
@@ -708,20 +711,18 @@ class SqliteDb(SimpleReprMixin):
self.execute("pragma synchronous=off").close()
def close(self):
- self.con.close()
+ if self.con is not None:
+ self.con.close()
+ self.con = None
def __enter__(self):
- if self.nest == 0:
+ if self.con is None:
self.connect()
- self.con.__enter__()
- self.nest += 1
+ self.con.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
- self.nest -= 1
- if self.nest == 0:
- self.con.__exit__(exc_type, exc_value, traceback)
- self.close()
+ self.con.__exit__(exc_type, exc_value, traceback)
def execute(self, sql, parameters=()):
if self.debug: