diff options
| author | Jordan Cook <jordan.cook@pioneer.com> | 2021-03-03 19:47:40 -0600 |
|---|---|---|
| committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-03-04 22:17:21 -0600 |
| commit | 0f03e68b9840caf4bf40321cd3110347034fddd5 (patch) | |
| tree | 28b72bf7a4756e1f426ff9ab76bdb9e719573286 /requests_cache | |
| parent | bfe214eedeb728b1e8899460a5bdde5590fc34ae (diff) | |
| download | requests-cache-0f03e68b9840caf4bf40321cd3110347034fddd5.tar.gz | |
Refactor CachedSession to be usable as a mixin class
Diffstat (limited to 'requests_cache')
| -rw-r--r-- | requests_cache/backends/__init__.py | 2 | ||||
| -rw-r--r-- | requests_cache/backends/base.py | 20 | ||||
| -rw-r--r-- | requests_cache/backends/dynamodb.py | 2 | ||||
| -rw-r--r-- | requests_cache/backends/gridfs.py | 2 | ||||
| -rw-r--r-- | requests_cache/backends/mongo.py | 2 | ||||
| -rw-r--r-- | requests_cache/backends/redis.py | 2 | ||||
| -rw-r--r-- | requests_cache/backends/sqlite.py | 2 | ||||
| -rw-r--r-- | requests_cache/backends/storage/dbdict.py | 4 | ||||
| -rw-r--r-- | requests_cache/backends/storage/mongodict.py | 4 | ||||
| -rw-r--r-- | requests_cache/core.py | 58 |
10 files changed, 62 insertions, 36 deletions
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py index 82fb5e1..66df84f 100644 --- a/requests_cache/backends/__init__.py +++ b/requests_cache/backends/__init__.py @@ -7,7 +7,7 @@ """ -from .base import BaseCache +from .base import BACKEND_KWARGS, BaseCache registry = { 'memory': BaseCache, diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index a14b884..7210a32 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -14,7 +14,23 @@ from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import requests -_DEFAULT_HEADERS = requests.utils.default_headers() +# All backend-specific keyword arguments combined +BACKEND_KWARGS = [ + 'connection', + 'db_name', + 'endpont_url', + 'extension', + 'fast_save', + 'ignored_parameters', + 'include_get_headers', + 'location', + 'name', + 'namespace', + 'read_capacity_units', + 'region_name', + 'write_capacity_units', +] +DEFAULT_HEADERS = requests.utils.default_headers() class BaseCache(object): @@ -230,7 +246,7 @@ class BaseCache(object): if request.body: key.update(_to_bytes(body)) else: - if self._include_get_headers and request.headers != _DEFAULT_HEADERS: + if self._include_get_headers and request.headers != DEFAULT_HEADERS: for name, value in sorted(request.headers.items()): key.update(_to_bytes(name)) key.update(_to_bytes(value)) diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py index 55eaf1b..d3813c1 100644 --- a/requests_cache/backends/dynamodb.py +++ b/requests_cache/backends/dynamodb.py @@ -17,7 +17,7 @@ class DynamoDbCache(BaseCache): :param namespace: dynamodb table name (default: ``'requests-cache'``) :param connection: (optional) ``boto3.resource('dynamodb')`` """ - super(DynamoDbCache, self).__init__(**options) + super().__init__(**options) self.responses = DynamoDbDict( table_name, 'responses', diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index e999043..7e8d69d 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py @@ -27,6 +27,6 @@ class GridFSCache(BaseCache): :param db_name: database name :param connection: (optional) ``pymongo.Connection`` """ - super(GridFSCache, self).__init__(**options) + super().__init__(**options) self.responses = GridFSPickleDict(db_name, options.get('connection')) self.keys_map = MongoDict(db_name, 'http_redirects', self.responses.connection) diff --git a/requests_cache/backends/mongo.py b/requests_cache/backends/mongo.py index 6f7b4ca..13ed774 100644 --- a/requests_cache/backends/mongo.py +++ b/requests_cache/backends/mongo.py @@ -17,6 +17,6 @@ class MongoCache(BaseCache): :param db_name: database name (default: ``'requests-cache'``) :param connection: (optional) ``pymongo.Connection`` """ - super(MongoCache, self).__init__(**options) + super().__init__(**options) self.responses = MongoPickleDict(db_name, 'responses', options.get('connection')) self.keys_map = MongoDict(db_name, 'urls', self.responses.connection) diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py index 14bb170..daad608 100644 --- a/requests_cache/backends/redis.py +++ b/requests_cache/backends/redis.py @@ -17,6 +17,6 @@ class RedisCache(BaseCache): :param namespace: redis namespace (default: ``'requests-cache'``) :param connection: (optional) ``redis.StrictRedis`` """ - super(RedisCache, self).__init__(**options) + super().__init__(**options) self.responses = RedisDict(namespace, 'responses', options.get('connection')) self.keys_map = RedisDict(namespace, 'urls', self.responses.connection) diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 5fef40d..543915d 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -23,6 +23,6 @@ class DbCache(BaseCache): See :ref:`backends.DbDict <backends_dbdict>` for more info :param extension: extension for filename (default: ``'.sqlite'``) """ - super(DbCache, self).__init__(**options) + super().__init__(**options) self.responses = DbPickleDict(str(location) + extension, 'responses', fast_save=fast_save) self.keys_map = DbDict(location + extension, 'urls') diff --git a/requests_cache/backends/storage/dbdict.py b/requests_cache/backends/storage/dbdict.py index 07251a3..cd0958f 100644 --- a/requests_cache/backends/storage/dbdict.py +++ b/requests_cache/backends/storage/dbdict.py @@ -144,7 +144,7 @@ class DbPickleDict(DbDict): """Same as :class:`DbDict`, but pickles values before saving""" def __setitem__(self, key, item): - super(DbPickleDict, self).__setitem__(key, sqlite3.Binary(pickle.dumps(item))) + super().__setitem__(key, sqlite3.Binary(pickle.dumps(item))) def __getitem__(self, key): - return pickle.loads(bytes(super(DbPickleDict, self).__getitem__(key))) + return pickle.loads(bytes(super().__getitem__(key))) diff --git a/requests_cache/backends/storage/mongodict.py b/requests_cache/backends/storage/mongodict.py index 2d85c0c..1511228 100644 --- a/requests_cache/backends/storage/mongodict.py +++ b/requests_cache/backends/storage/mongodict.py @@ -67,7 +67,7 @@ class MongoPickleDict(MongoDict): """Same as :class:`MongoDict`, but pickles values before saving""" def __setitem__(self, key, item): - super(MongoPickleDict, self).__setitem__(key, pickle.dumps(item)) + super().__setitem__(key, pickle.dumps(item)) def __getitem__(self, key): - return pickle.loads(bytes(super(MongoPickleDict, self).__getitem__(key))) + return pickle.loads(bytes(super().__getitem__(key))) diff --git a/requests_cache/core.py b/requests_cache/core.py index cf14539..d76c74e 100644 --- a/requests_cache/core.py +++ b/requests_cache/core.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from datetime import datetime, timedelta from operator import itemgetter +from requests_cache.backends.base import BACKEND_KWARGS from typing import Callable, Iterable, Union import requests @@ -16,8 +17,8 @@ from requests.hooks import dispatch_hook from . import backends -class CachedSession(OriginalSession): - """Requests ``Sessions`` with caching support +class CacheMixin: + """Mixin class that extends ``requests.Session`` with caching features. Args: cache_name: Cache prefix or namespace, depending on backend; see notes below @@ -27,13 +28,15 @@ class CachedSession(OriginalSession): never expire allowable_codes: Only cache responses with one of these codes allowable_methods: Cache only responses for one of these HTTP methods - include_headers: Make request headers part of the cache key + include_get_headers: Make request headers part of the cache key ignored_parameters: List of request parameters to be excluded from the cache key. filter_fn: function that takes a :py:class:`aiohttp.ClientResponse` object and returns a boolean indicating whether or not that response should be cached. Will be applied to both new and previously cached responses old_data_on_error: Return expired cached responses if new request fails + See individual backend classes for additional backend-specific arguments. + The ``cache_name`` parameter will be used as follows depending on the backend: * ``sqlite``: Cache filename prefix, e.g ``my_cache.sqlite`` @@ -55,9 +58,9 @@ class CachedSession(OriginalSession): allowable_methods: Iterable['str'] = ('GET',), filter_fn: Callable = None, old_data_on_error: bool = False, - **backend_options + **kwargs ): - self.cache = backends.create_backend(backend, cache_name, backend_options) + self.cache = backends.create_backend(backend, cache_name, kwargs) self._cache_name = cache_name if expire_after is not None and not isinstance(expire_after, timedelta): @@ -69,41 +72,36 @@ class CachedSession(OriginalSession): self._filter_fn = filter_fn or (lambda r: True) self._return_old_data_on_error = old_data_on_error self._is_cache_disabled = False - super(CachedSession, self).__init__() + + # Remove any requests-cache-specific kwargs before passing along to superclass + session_kwargs = {k: v for k, v in kwargs.items() if k not in BACKEND_KWARGS} + super().__init__(**session_kwargs) def send(self, request, **kwargs): if self._is_cache_disabled or request.method not in self._cache_allowable_methods: - response = super(CachedSession, self).send(request, **kwargs) + response = super().send(request, **kwargs) response.from_cache = False response.cache_date = None return response cache_key = self.cache.create_key(request) - def send_request_and_cache_response(): - response = super(CachedSession, self).send(request, **kwargs) - if response.status_code in self._cache_allowable_codes: - self.cache.save_response(cache_key, response) - response.from_cache = False - response.cache_date = None - return response - try: response, timestamp = self.cache.get_response_and_time(cache_key) except (ImportError, TypeError): - return send_request_and_cache_response() + response, timestamp = None, None if response is None: - return send_request_and_cache_response() + return self.send_request_and_cache_response(request, cache_key, **kwargs) if self._cache_expire_after is not None: is_expired = datetime.utcnow() - timestamp > self._cache_expire_after if is_expired: if not self._return_old_data_on_error: self.cache.delete(cache_key) - return send_request_and_cache_response() + return self.send_request_and_cache_response(request, cache_key, **kwargs) try: - new_response = send_request_and_cache_response() + new_response = self.send_request_and_cache_response(request, cache_key, **kwargs) except Exception: return response else: @@ -117,8 +115,16 @@ class CachedSession(OriginalSession): response = dispatch_hook('response', request.hooks, response, **kwargs) return response + def send_request_and_cache_response(self, request, cache_key, **kwargs): + response = super().send(request, **kwargs) + if response.status_code in self._cache_allowable_codes: + self.cache.save_response(cache_key, response) + response.from_cache = False + response.cache_date = None + return response + def request(self, method, url, params=None, data=None, **kwargs): - response = super(CachedSession, self).request( + response = super().request( method, url, _normalize_parameters(params), _normalize_parameters(data), **kwargs ) if self._is_cache_disabled: @@ -167,6 +173,10 @@ class CachedSession(OriginalSession): ) +class CachedSession(CacheMixin, OriginalSession): + pass + + def install_cache( cache_name: str = 'cache', backend: str = None, @@ -176,7 +186,7 @@ def install_cache( filter_fn: Callable = None, old_data_on_error: bool = False, session_factory=CachedSession, - **backend_options + **kwargs ): """ Installs cache for all ``Requests`` requests by monkey-patching ``Session`` @@ -186,11 +196,11 @@ def install_cache( :param session_factory: Session factory. It must be class which inherits :class:`CachedSession` (default) """ if backend: - backend = backends.create_backend(backend, cache_name, backend_options) + backend = backends.create_backend(backend, cache_name, kwargs) class _ConfiguredCachedSession(session_factory): def __init__(self): - super(_ConfiguredCachedSession, self).__init__( + super().__init__( cache_name=cache_name, backend=backend, expire_after=expire_after, @@ -198,7 +208,7 @@ def install_cache( allowable_methods=allowable_methods, filter_fn=filter_fn, old_data_on_error=old_data_on_error, - **backend_options + **kwargs ) _patch_session_factory(_ConfiguredCachedSession) |
