summaryrefslogtreecommitdiff
path: root/requests_cache/backends/sqlite.py
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2022-04-10 15:39:14 -0500
committerJordan Cook <jordan.cook@pioneer.com>2022-04-10 20:05:27 -0500
commitb199db190c1ae6dfbd37f79fa910f45c3045dc02 (patch)
tree8c937a6e85fa5ec573b99cfa659240c0b609ca45 /requests_cache/backends/sqlite.py
parent3120a87395887831c95efcf3a218dcac033058ee (diff)
downloadrequests-cache-b199db190c1ae6dfbd37f79fa910f45c3045dc02.tar.gz
Add SQLiteDict.sorted() method with sorting and other query options
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r--requests_cache/backends/sqlite.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 54989ed..ce62749 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -157,6 +157,23 @@ class SQLiteCache(BaseCache):
')'
)
+ def sorted(
+ self,
+ key: str = 'expires',
+ reversed: bool = False,
+ limit: int = None,
+ exclude_expired=False,
+ ):
+ """Get cached responses, with sorting and other query options.
+
+ Args:
+ key: Key to sort by; either 'expires', 'size', or 'key'
+ reversed: Sort in descending order
+ limit: Maximum number of responses to return
+ exclude_expired: Only return unexpired responses
+ """
+ return self.responses.sorted(key, reversed, limit, exclude_expired)
+
class SQLiteDict(BaseStorage):
"""A dictionary-like interface for SQLite"""
@@ -308,6 +325,34 @@ class SQLiteDict(BaseStorage):
con.execute(f"DELETE FROM {self.table_name} WHERE expires <= ?", (posix_now,))
self.vacuum()
+ def sorted(
+ self, key: str = 'expires', reversed: bool = False, limit: int = None, exclude_expired=False
+ ):
+ """Get cache values in sorted order; see :py:meth:`.SQLiteCache.sorted` for usage details"""
+ # Get sort key, direction, and limit
+ if key not in ['expires', 'size', 'key']:
+ raise ValueError(f'Invalid sort key: {key}')
+ if key == 'size':
+ key = 'LENGTH(value)'
+ direction = 'DESC' if reversed else 'ASC'
+ limit_expr = f'LIMIT {limit}' if limit else ''
+
+ # Filter out expired items, if specified
+ filter_expr = ''
+ params: Tuple = ()
+ if exclude_expired:
+ posix_now = round(datetime.utcnow().timestamp())
+ filter_expr = 'WHERE expires is null or expires > ?'
+ params = (posix_now,)
+
+ with self.connection(commit=True) as con:
+ for row in con.execute(
+ f'SELECT value FROM {self.table_name} {filter_expr}'
+ f' ORDER BY {key} {direction} {limit_expr}',
+ params,
+ ):
+ yield row[0]
+
def vacuum(self):
with self.connection(commit=True) as con:
con.execute('VACUUM')
@@ -325,6 +370,16 @@ class SQLitePickleDict(SQLiteDict):
def __getitem__(self, key):
return self.serializer.loads(super().__getitem__(key))
+ def sorted(
+ self,
+ key: str = 'expires',
+ reversed: bool = False,
+ limit: int = None,
+ exclude_expired: bool = False,
+ ):
+ for value in super().sorted(key, reversed, limit, exclude_expired):
+ yield self.serializer.loads(value)
+
def _format_sequence(values: Collection) -> Tuple[str, List]:
"""Get SQL parameter marks for a sequence-based query, and ensure value is a sequence"""