diff options
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r-- | requests_cache/backends/sqlite.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 72b62da..569b433 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -386,6 +386,92 @@ class SQLiteDict(BaseStorage): con.execute('VACUUM') +class SQLiteLRUDict(SQLiteDict): + """SQLite backend that stays within a certain size limit""" + + def __init__(self, *args, max_size: int = True, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = max_size + + # Maximim number of rows to remove in a single prune call + self.max_remove_per_prune = SQLITE_MAX_VARIABLE_NUMBER + + # Minimum interval for pruning cache + self.prune_interval = 10 + self.last_prune = time() + + # % of max size to target with each prune call (higher % results in more frequent pruning) + self.prune_target_size_percent = 0.9 + + def init_db(self): + """Add last_used column and index, if it doesn't already exist""" + super().init_db() + with self._lock, self.connection() as con: + try: + con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN last_used INTEGER') + except sqlite3.OperationalError: + pass + con.execute(f'CREATE INDEX IF NOT EXISTS last_used_idx ON {self.table_name}(last_used)') + + def __getitem__(self, key): + """Set last_used when retrieving an item""" + value = super().__getitem__(key) + with self.connection(commit=True) as con: + con.execute( + f'UPDATE {self.table_name} SET last_used=? WHERE key=?', + (round(time()), key), + ) + return value + + def __setitem__(self, key, value): + """Set an item and remove any expired responses""" + # If available, set expiration as a timestamp in unix format + expires = getattr(value, 'expires_unix', None) + now = round(time()) + value = self.serialize(value) + with self.connection(commit=True) as con: + con.execute( + f'INSERT OR REPLACE INTO {self.table_name} ' + '(key,value,expires,last_used) VALUES (?,?,?,?)', + (key, value, expires, now), + ) + con.execute(f'DELETE FROM {self.table_name} WHERE expires <= ?', (now,)) + + self._prune() + + def _check_prune(self) -> bool: + return time() - self.last_prune > self.prune_interval and self.size() > self.max_size + + # WIP; just one of multiple approaches + def _prune(self): + if not self._check_prune(): + return + + with self.connection(commit=True) as con: + # Get a batch of least recently used rows (up to a max query size) + rows = con.execute( + f'SELECT key, LENGTH(value) from {self.table_name} ' + 'ORDER BY last_used ASC LIMIT ?', + (self.max_remove_per_prune,), + ) + + # Figure out how many rows we need to delete to reach a target size of 90% of capacity + delete_keys = [] + delete_size = self.size() - int(self.max_size * self.prune_target_size_percent) + pending_delete_size = 0 + for key, row_size in rows: + delete_keys.append(key) + pending_delete_size += row_size + if pending_delete_size >= delete_size: + break + + # Delete selected rows + marks, args = _format_sequence(delete_keys) + con.execute(f'DELETE FROM {self.table_name} WHERE key IN ({marks})', args) + + self.last_prune = time() + + def _format_sequence(values: Collection) -> Tuple[str, List]: """Get SQL parameter marks for a sequence-based query""" return ','.join(['?'] * len(values)), list(values) |