diff options
| author | Jordan Cook <jordan.cook@pioneer.com> | 2021-08-02 15:58:34 -0500 |
|---|---|---|
| committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-08-02 16:33:09 -0500 |
| commit | 05c9f33b34026c50b800c6017f5a53670acdad7d (patch) | |
| tree | 3e1cf579e9872984d8829bab0198750eed5bdc6a /requests_cache/backends/sqlite.py | |
| parent | 097a6e098c2375ff14f74a29a682fda67cdb7ad1 (diff) | |
| download | requests-cache-05c9f33b34026c50b800c6017f5a53670acdad7d.tar.gz | |
Update `DbCache.clear()` to succeed even if the database is corrupted
Diffstat (limited to 'requests_cache/backends/sqlite.py')
| -rw-r--r-- | requests_cache/backends/sqlite.py | 44 |
1 files changed, 31 insertions, 13 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 10cfc38..48f178b 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -2,8 +2,8 @@ import sqlite3 import threading from contextlib import contextmanager from logging import getLogger -from os import makedirs -from os.path import abspath, basename, dirname, expanduser, isabs, join +from os import makedirs, unlink +from os.path import abspath, basename, dirname, expanduser, isabs, isfile, join from pathlib import Path from tempfile import gettempdir from typing import Collection, Iterable, Iterator, List, Tuple, Type, Union @@ -47,6 +47,15 @@ class DbCache(BaseCache): self.redirects.bulk_delete(values=keys) self.redirects.vacuum() + def clear(self): + """Clear the cache by deleting the cache file and re-initializing. This is done to allow + clear() to succeed even if the file is corrupted. + """ + if isfile(self.responses.db_path): + unlink(self.responses.db_path) + self.responses.init_db() + self.redirects.init_db() + class DbDict(BaseStorage): """A dictionary-like interface for SQLite. @@ -77,19 +86,24 @@ class DbDict(BaseStorage): self.fast_save = fast_save self.table_name = table_name + self._lock = threading.RLock() self._can_commit = True self._local_context = threading.local() - with sqlite3.connect(self.db_path, **self.connection_kwargs) as con: - self._create_table(con) + self.init_db() - # Initial CREATE TABLE must happen in shared connection; subsequent queries will use thread-local connections - def _create_table(self, connection): - connection.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)') + def init_db(self): + """Initialize the database, if it hasn't already been. + This must be done in shared connection, but all subsequent queries can use thread-local connections. + """ + self.close() + with self._lock: + with sqlite3.connect(self.db_path, **self.connection_kwargs) as con: + con.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)') @contextmanager def connection(self, commit=False) -> Iterator[sqlite3.Connection]: """Get a thread-local database connection""" - if not hasattr(self._local_context, 'con'): + if not getattr(self._local_context, 'con', None): logger.debug(f'Opening connection to {self.db_path}:{self.table_name}') self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs) if self.fast_save: @@ -98,6 +112,12 @@ class DbDict(BaseStorage): if commit and self._can_commit: self._local_context.con.commit() + def close(self): + """Close any active connections""" + if getattr(self._local_context, 'con', None): + self._local_context.con.close() + self._local_context.con = None + @contextmanager def bulk_commit(self): """Context manager used to speed up insertion of a large number of records @@ -119,9 +139,7 @@ class DbDict(BaseStorage): self._can_commit = True def __del__(self): - """Close any active connections""" - if hasattr(self._local_context, 'con'): - self._local_context.con.close() + self.close() def __delitem__(self, key): with self.connection(commit=True) as con: @@ -170,8 +188,8 @@ class DbDict(BaseStorage): def clear(self): with self.connection(commit=True) as con: con.execute(f'DROP TABLE IF EXISTS {self.table_name}') - self._create_table(con) - con.execute('VACUUM') + self.init_db() + self.vacuum() def vacuum(self): with self.connection(commit=True) as con: |
