summaryrefslogtreecommitdiff
path: root/requests_cache
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-04-27 16:26:46 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-04-27 23:19:45 -0500
commitcfd5c391f546e5e9cfaeb5a0339750f9a790325a (patch)
tree59e9a7356ae3dc27bf68e2ab29bdbf81e9a1f7ab /requests_cache
parent51cbd04e8f3af869fe57c7a57d15f40c04cc4dea (diff)
downloadrequests-cache-cfd5c391f546e5e9cfaeb5a0339750f9a790325a.tar.gz
Add BaseCache.keys() and values() methods
* Update `BaseCache.urls` to return a generator instead of the whole list at once (which may use too much memory for particularly large caches) * Update `BaseCache.urls` to just skip invalid responses instead of deleting them
Diffstat (limited to 'requests_cache')
-rw-r--r--requests_cache/backends/base.py57
-rw-r--r--requests_cache/backends/sqlite.py6
2 files changed, 45 insertions, 18 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index def4dc1..0ac4164 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -3,7 +3,7 @@ import warnings
from abc import ABC
from collections.abc import MutableMapping
from logging import DEBUG, WARNING, getLogger
-from typing import Iterable, List, Tuple, Union
+from typing import Iterable, Iterator, Tuple, Union
import requests
from requests.models import PreparedRequest
@@ -11,6 +11,9 @@ from requests.models import PreparedRequest
from ..cache_keys import create_key, url_to_key
from ..response import AnyResponse, CachedResponse, ExpirationTime
+# Specific exceptions that may be raised during deserialization
+DESERIALIZE_ERRORS = (AttributeError, TypeError, ValueError, pickle.PickleError)
+
ResponseOrKey = Union[CachedResponse, str]
logger = getLogger(__name__)
@@ -34,9 +37,10 @@ class BaseCache:
self.ignored_parameters = ignored_parameters
@property
- def urls(self) -> List[str]:
+ def urls(self) -> Iterator[str]:
"""Get all URLs currently in the cache (excluding redirects)"""
- return [r.url for _, r in self._get_valid_responses()]
+ for response in self.values():
+ yield response.url
def save_response(self, response: AnyResponse, key: str = None, expire_after: ExpirationTime = None):
"""Save response to cache
@@ -75,7 +79,7 @@ class BaseCache:
return response
except KeyError:
return default
- except (AttributeError, TypeError, ValueError, pickle.PickleError) as e:
+ except DESERIALIZE_ERRORS as e:
logger.error(f'Unable to deserialize response with key {key}: {str(e)}')
return default
@@ -117,7 +121,8 @@ class BaseCache:
expire_after: A new expiration time used to revalidate the cache
"""
logger.info('Removing expired responses.' + (f'Revalidating with: {expire_after}' if expire_after else ''))
- for key, response in self._get_valid_responses():
+ # _get_valid_responses must be consumed before making any additional writes
+ for key, response in list(self._get_valid_responses(delete_invalid=True)):
# If we're revalidating and it's not yet expired, update the cached item's expiration
if expire_after is not None and not response.revalidate(expire_after):
self.responses[key] = response
@@ -125,21 +130,10 @@ class BaseCache:
self.delete(key)
def remove_old_entries(self, *args, **kwargs):
- msg = 'BaseCache.remove_old_entries() is deprecated; ' 'please use CachedSession.remove_expired_responses()'
+ msg = 'BaseCache.remove_old_entries() is deprecated; please use CachedSession.remove_expired_responses()'
warnings.warn(DeprecationWarning(msg))
self.remove_expired_responses(*args, **kwargs)
- def _get_valid_responses(self) -> Iterable[Tuple[str, CachedResponse]]:
- """Get all responses from the cache, and delete any invalid ones"""
- for key in list(self.responses.keys()):
- # If a response is invalid, delete it
- try:
- yield key, self.responses[key]
- except Exception as e:
- logger.debug(f'Unable to deserialize response with key {key}: {str(e)}')
- self.delete(key)
- continue
-
def create_key(self, request: requests.PreparedRequest, **kwargs) -> str:
"""Create a normalized cache key from a request object"""
return create_key(request, self.ignored_parameters, self.include_get_headers, **kwargs)
@@ -152,6 +146,35 @@ class BaseCache:
"""Returns `True` if cache has `url`, `False` otherwise. Works only for GET request urls"""
return self.has_key(url_to_key(url, self.ignored_parameters)) # noqa: W601
+ def keys(self) -> Iterator[str]:
+ """Get all cache keys for redirects and (valid) responses combined"""
+ yield from self.redirects.keys()
+ for key, _ in self._get_valid_responses():
+ yield key
+
+ def values(self) -> Iterator[CachedResponse]:
+ """Get all valid response objects from the cache"""
+ for _, response in self._get_valid_responses():
+ yield response
+
+ def _get_valid_responses(self, delete_invalid=False) -> Iterator[Tuple[str, CachedResponse]]:
+ """Get all responses from the cache, and skip (+ optionally delete) any invalid ones that
+ can't be deserialized"""
+ keys_to_delete = []
+
+ for key in self.responses.keys():
+ try:
+ yield key, self.responses[key]
+ except DESERIALIZE_ERRORS as e:
+ logger.debug(f'Unable to deserialize response with key {key}: {e}')
+ keys_to_delete.append(key)
+
+ # Delay deletion until the end, to slightly improve responsiveness when used as a generator
+ if delete_invalid:
+ logger.debug(f'Deleting {len(keys_to_delete)} invalid responses')
+ for key in keys_to_delete:
+ self.delete(key)
+
def __str__(self):
return f'redirects: {len(self.redirects)}\nresponses: {len(self.responses)}'
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 29ac083..5ad100a 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -27,7 +27,11 @@ class DbCache(BaseCache):
"""
def __init__(
- self, db_path: Union[Path, str] = 'http_cache', use_temp: bool = False, fast_save: bool = False, **kwargs
+ self,
+ db_path: Union[Path, str] = 'http_cache',
+ use_temp: bool = False,
+ fast_save: bool = False,
+ **kwargs,
):
super().__init__(**kwargs)
self.responses = DbPickleDict(db_path, table_name='responses', use_temp=use_temp, fast_save=fast_save, **kwargs)