From 264e5dc48f2ff1f750423e127eb63b0906ea62ff Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Mon, 2 Aug 2021 16:29:00 -0500 Subject: Update `DbDict.bulk_delete()` to support deleting more items than SQLite's variable limit (999) --- requests_cache/backends/sqlite.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) (limited to 'requests_cache/backends/sqlite.py') diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 48f178b..d6fef95 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -10,6 +10,7 @@ from typing import Collection, Iterable, Iterator, List, Tuple, Type, Union from . import BaseCache, BaseStorage, get_valid_kwargs +SQLITE_MAX_VARIABLE_NUMBER = 999 logger = getLogger(__name__) @@ -96,9 +97,8 @@ class DbDict(BaseStorage): This must be done in shared connection, but all subsequent queries can use thread-local connections. """ self.close() - with self._lock: - with sqlite3.connect(self.db_path, **self.connection_kwargs) as con: - con.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)') + with self._lock, sqlite3.connect(self.db_path, **self.connection_kwargs) as con: + con.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)') @contextmanager def connection(self, commit=False) -> Iterator[sqlite3.Connection]: @@ -172,18 +172,19 @@ class DbDict(BaseStorage): return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0] def bulk_delete(self, keys=None, values=None): - """Delete multiple keys from the cache. Does not raise errors for missing keys. + """Delete multiple keys from the cache, without raising errors for any missing keys. Also supports deleting by value. """ if not keys and not values: return column = 'key' if keys else 'value' - marks, args = _format_sequence(keys or values) - statement = f'DELETE FROM {self.table_name} WHERE {column} IN ({marks})' - with self.connection(commit=True) as con: - con.execute(statement, args) + # Split into small enough chunks for SQLite to handle + for chunk in chunkify(keys or values): + marks, args = _format_sequence(chunk) + statement = f'DELETE FROM {self.table_name} WHERE {column} IN ({marks})' + con.execute(statement, args) def clear(self): with self.connection(commit=True) as con: @@ -209,6 +210,13 @@ class DbPickleDict(DbDict): return self.serializer.loads(super().__getitem__(key)) +def chunkify(iterable: Iterable, max_size=SQLITE_MAX_VARIABLE_NUMBER) -> Iterator[List]: + """Split an iterable into chunks of a max size""" + iterable = list(iterable) + for index in range(0, len(iterable), max_size): + yield iterable[index : index + max_size] + + def _format_sequence(values: Collection) -> Tuple[str, List]: """Get SQL parameter marks for a sequence-based query, and ensure value is a sequence""" if not isinstance(values, Iterable): -- cgit v1.2.1