summaryrefslogtreecommitdiff
path: root/requests_cache/backends/sqlite.py
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-08-02 15:58:34 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-08-02 16:33:09 -0500
commit05c9f33b34026c50b800c6017f5a53670acdad7d (patch)
tree3e1cf579e9872984d8829bab0198750eed5bdc6a /requests_cache/backends/sqlite.py
parent097a6e098c2375ff14f74a29a682fda67cdb7ad1 (diff)
downloadrequests-cache-05c9f33b34026c50b800c6017f5a53670acdad7d.tar.gz
Update `DbCache.clear()` to succeed even if the database is corrupted
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r--requests_cache/backends/sqlite.py44
1 files changed, 31 insertions, 13 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 10cfc38..48f178b 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -2,8 +2,8 @@ import sqlite3
import threading
from contextlib import contextmanager
from logging import getLogger
-from os import makedirs
-from os.path import abspath, basename, dirname, expanduser, isabs, join
+from os import makedirs, unlink
+from os.path import abspath, basename, dirname, expanduser, isabs, isfile, join
from pathlib import Path
from tempfile import gettempdir
from typing import Collection, Iterable, Iterator, List, Tuple, Type, Union
@@ -47,6 +47,15 @@ class DbCache(BaseCache):
self.redirects.bulk_delete(values=keys)
self.redirects.vacuum()
+ def clear(self):
+ """Clear the cache by deleting the cache file and re-initializing. This is done to allow
+ clear() to succeed even if the file is corrupted.
+ """
+ if isfile(self.responses.db_path):
+ unlink(self.responses.db_path)
+ self.responses.init_db()
+ self.redirects.init_db()
+
class DbDict(BaseStorage):
"""A dictionary-like interface for SQLite.
@@ -77,19 +86,24 @@ class DbDict(BaseStorage):
self.fast_save = fast_save
self.table_name = table_name
+ self._lock = threading.RLock()
self._can_commit = True
self._local_context = threading.local()
- with sqlite3.connect(self.db_path, **self.connection_kwargs) as con:
- self._create_table(con)
+ self.init_db()
- # Initial CREATE TABLE must happen in shared connection; subsequent queries will use thread-local connections
- def _create_table(self, connection):
- connection.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)')
+ def init_db(self):
+ """Initialize the database, if it hasn't already been.
+ This must be done in shared connection, but all subsequent queries can use thread-local connections.
+ """
+ self.close()
+ with self._lock:
+ with sqlite3.connect(self.db_path, **self.connection_kwargs) as con:
+ con.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)')
@contextmanager
def connection(self, commit=False) -> Iterator[sqlite3.Connection]:
"""Get a thread-local database connection"""
- if not hasattr(self._local_context, 'con'):
+ if not getattr(self._local_context, 'con', None):
logger.debug(f'Opening connection to {self.db_path}:{self.table_name}')
self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs)
if self.fast_save:
@@ -98,6 +112,12 @@ class DbDict(BaseStorage):
if commit and self._can_commit:
self._local_context.con.commit()
+ def close(self):
+ """Close any active connections"""
+ if getattr(self._local_context, 'con', None):
+ self._local_context.con.close()
+ self._local_context.con = None
+
@contextmanager
def bulk_commit(self):
"""Context manager used to speed up insertion of a large number of records
@@ -119,9 +139,7 @@ class DbDict(BaseStorage):
self._can_commit = True
def __del__(self):
- """Close any active connections"""
- if hasattr(self._local_context, 'con'):
- self._local_context.con.close()
+ self.close()
def __delitem__(self, key):
with self.connection(commit=True) as con:
@@ -170,8 +188,8 @@ class DbDict(BaseStorage):
def clear(self):
with self.connection(commit=True) as con:
con.execute(f'DROP TABLE IF EXISTS {self.table_name}')
- self._create_table(con)
- con.execute('VACUUM')
+ self.init_db()
+ self.vacuum()
def vacuum(self):
with self.connection(commit=True) as con: