summaryrefslogtreecommitdiff
path: root/requests_cache
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-03-03 19:47:40 -0600
committerJordan Cook <jordan.cook@pioneer.com>2021-03-04 22:17:21 -0600
commit0f03e68b9840caf4bf40321cd3110347034fddd5 (patch)
tree28b72bf7a4756e1f426ff9ab76bdb9e719573286 /requests_cache
parentbfe214eedeb728b1e8899460a5bdde5590fc34ae (diff)
downloadrequests-cache-0f03e68b9840caf4bf40321cd3110347034fddd5.tar.gz
Refactor CachedSession to be usable as a mixin class
Diffstat (limited to 'requests_cache')
-rw-r--r--requests_cache/backends/__init__.py2
-rw-r--r--requests_cache/backends/base.py20
-rw-r--r--requests_cache/backends/dynamodb.py2
-rw-r--r--requests_cache/backends/gridfs.py2
-rw-r--r--requests_cache/backends/mongo.py2
-rw-r--r--requests_cache/backends/redis.py2
-rw-r--r--requests_cache/backends/sqlite.py2
-rw-r--r--requests_cache/backends/storage/dbdict.py4
-rw-r--r--requests_cache/backends/storage/mongodict.py4
-rw-r--r--requests_cache/core.py58
10 files changed, 62 insertions, 36 deletions
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py
index 82fb5e1..66df84f 100644
--- a/requests_cache/backends/__init__.py
+++ b/requests_cache/backends/__init__.py
@@ -7,7 +7,7 @@
"""
-from .base import BaseCache
+from .base import BACKEND_KWARGS, BaseCache
registry = {
'memory': BaseCache,
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index a14b884..7210a32 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -14,7 +14,23 @@ from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
import requests
-_DEFAULT_HEADERS = requests.utils.default_headers()
+# All backend-specific keyword arguments combined
+BACKEND_KWARGS = [
+ 'connection',
+ 'db_name',
+ 'endpont_url',
+ 'extension',
+ 'fast_save',
+ 'ignored_parameters',
+ 'include_get_headers',
+ 'location',
+ 'name',
+ 'namespace',
+ 'read_capacity_units',
+ 'region_name',
+ 'write_capacity_units',
+]
+DEFAULT_HEADERS = requests.utils.default_headers()
class BaseCache(object):
@@ -230,7 +246,7 @@ class BaseCache(object):
if request.body:
key.update(_to_bytes(body))
else:
- if self._include_get_headers and request.headers != _DEFAULT_HEADERS:
+ if self._include_get_headers and request.headers != DEFAULT_HEADERS:
for name, value in sorted(request.headers.items()):
key.update(_to_bytes(name))
key.update(_to_bytes(value))
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index 55eaf1b..d3813c1 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -17,7 +17,7 @@ class DynamoDbCache(BaseCache):
:param namespace: dynamodb table name (default: ``'requests-cache'``)
:param connection: (optional) ``boto3.resource('dynamodb')``
"""
- super(DynamoDbCache, self).__init__(**options)
+ super().__init__(**options)
self.responses = DynamoDbDict(
table_name,
'responses',
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index e999043..7e8d69d 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -27,6 +27,6 @@ class GridFSCache(BaseCache):
:param db_name: database name
:param connection: (optional) ``pymongo.Connection``
"""
- super(GridFSCache, self).__init__(**options)
+ super().__init__(**options)
self.responses = GridFSPickleDict(db_name, options.get('connection'))
self.keys_map = MongoDict(db_name, 'http_redirects', self.responses.connection)
diff --git a/requests_cache/backends/mongo.py b/requests_cache/backends/mongo.py
index 6f7b4ca..13ed774 100644
--- a/requests_cache/backends/mongo.py
+++ b/requests_cache/backends/mongo.py
@@ -17,6 +17,6 @@ class MongoCache(BaseCache):
:param db_name: database name (default: ``'requests-cache'``)
:param connection: (optional) ``pymongo.Connection``
"""
- super(MongoCache, self).__init__(**options)
+ super().__init__(**options)
self.responses = MongoPickleDict(db_name, 'responses', options.get('connection'))
self.keys_map = MongoDict(db_name, 'urls', self.responses.connection)
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index 14bb170..daad608 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -17,6 +17,6 @@ class RedisCache(BaseCache):
:param namespace: redis namespace (default: ``'requests-cache'``)
:param connection: (optional) ``redis.StrictRedis``
"""
- super(RedisCache, self).__init__(**options)
+ super().__init__(**options)
self.responses = RedisDict(namespace, 'responses', options.get('connection'))
self.keys_map = RedisDict(namespace, 'urls', self.responses.connection)
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 5fef40d..543915d 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -23,6 +23,6 @@ class DbCache(BaseCache):
See :ref:`backends.DbDict <backends_dbdict>` for more info
:param extension: extension for filename (default: ``'.sqlite'``)
"""
- super(DbCache, self).__init__(**options)
+ super().__init__(**options)
self.responses = DbPickleDict(str(location) + extension, 'responses', fast_save=fast_save)
self.keys_map = DbDict(location + extension, 'urls')
diff --git a/requests_cache/backends/storage/dbdict.py b/requests_cache/backends/storage/dbdict.py
index 07251a3..cd0958f 100644
--- a/requests_cache/backends/storage/dbdict.py
+++ b/requests_cache/backends/storage/dbdict.py
@@ -144,7 +144,7 @@ class DbPickleDict(DbDict):
"""Same as :class:`DbDict`, but pickles values before saving"""
def __setitem__(self, key, item):
- super(DbPickleDict, self).__setitem__(key, sqlite3.Binary(pickle.dumps(item)))
+ super().__setitem__(key, sqlite3.Binary(pickle.dumps(item)))
def __getitem__(self, key):
- return pickle.loads(bytes(super(DbPickleDict, self).__getitem__(key)))
+ return pickle.loads(bytes(super().__getitem__(key)))
diff --git a/requests_cache/backends/storage/mongodict.py b/requests_cache/backends/storage/mongodict.py
index 2d85c0c..1511228 100644
--- a/requests_cache/backends/storage/mongodict.py
+++ b/requests_cache/backends/storage/mongodict.py
@@ -67,7 +67,7 @@ class MongoPickleDict(MongoDict):
"""Same as :class:`MongoDict`, but pickles values before saving"""
def __setitem__(self, key, item):
- super(MongoPickleDict, self).__setitem__(key, pickle.dumps(item))
+ super().__setitem__(key, pickle.dumps(item))
def __getitem__(self, key):
- return pickle.loads(bytes(super(MongoPickleDict, self).__getitem__(key)))
+ return pickle.loads(bytes(super().__getitem__(key)))
diff --git a/requests_cache/core.py b/requests_cache/core.py
index cf14539..d76c74e 100644
--- a/requests_cache/core.py
+++ b/requests_cache/core.py
@@ -7,6 +7,7 @@
from contextlib import contextmanager
from datetime import datetime, timedelta
from operator import itemgetter
+from requests_cache.backends.base import BACKEND_KWARGS
from typing import Callable, Iterable, Union
import requests
@@ -16,8 +17,8 @@ from requests.hooks import dispatch_hook
from . import backends
-class CachedSession(OriginalSession):
- """Requests ``Sessions`` with caching support
+class CacheMixin:
+ """Mixin class that extends ``requests.Session`` with caching features.
Args:
cache_name: Cache prefix or namespace, depending on backend; see notes below
@@ -27,13 +28,15 @@ class CachedSession(OriginalSession):
never expire
allowable_codes: Only cache responses with one of these codes
allowable_methods: Cache only responses for one of these HTTP methods
- include_headers: Make request headers part of the cache key
+ include_get_headers: Make request headers part of the cache key
ignored_parameters: List of request parameters to be excluded from the cache key.
filter_fn: function that takes a :py:class:`aiohttp.ClientResponse` object and
returns a boolean indicating whether or not that response should be cached. Will be
applied to both new and previously cached responses
old_data_on_error: Return expired cached responses if new request fails
+ See individual backend classes for additional backend-specific arguments.
+
The ``cache_name`` parameter will be used as follows depending on the backend:
* ``sqlite``: Cache filename prefix, e.g ``my_cache.sqlite``
@@ -55,9 +58,9 @@ class CachedSession(OriginalSession):
allowable_methods: Iterable['str'] = ('GET',),
filter_fn: Callable = None,
old_data_on_error: bool = False,
- **backend_options
+ **kwargs
):
- self.cache = backends.create_backend(backend, cache_name, backend_options)
+ self.cache = backends.create_backend(backend, cache_name, kwargs)
self._cache_name = cache_name
if expire_after is not None and not isinstance(expire_after, timedelta):
@@ -69,41 +72,36 @@ class CachedSession(OriginalSession):
self._filter_fn = filter_fn or (lambda r: True)
self._return_old_data_on_error = old_data_on_error
self._is_cache_disabled = False
- super(CachedSession, self).__init__()
+
+ # Remove any requests-cache-specific kwargs before passing along to superclass
+ session_kwargs = {k: v for k, v in kwargs.items() if k not in BACKEND_KWARGS}
+ super().__init__(**session_kwargs)
def send(self, request, **kwargs):
if self._is_cache_disabled or request.method not in self._cache_allowable_methods:
- response = super(CachedSession, self).send(request, **kwargs)
+ response = super().send(request, **kwargs)
response.from_cache = False
response.cache_date = None
return response
cache_key = self.cache.create_key(request)
- def send_request_and_cache_response():
- response = super(CachedSession, self).send(request, **kwargs)
- if response.status_code in self._cache_allowable_codes:
- self.cache.save_response(cache_key, response)
- response.from_cache = False
- response.cache_date = None
- return response
-
try:
response, timestamp = self.cache.get_response_and_time(cache_key)
except (ImportError, TypeError):
- return send_request_and_cache_response()
+ response, timestamp = None, None
if response is None:
- return send_request_and_cache_response()
+ return self.send_request_and_cache_response(request, cache_key, **kwargs)
if self._cache_expire_after is not None:
is_expired = datetime.utcnow() - timestamp > self._cache_expire_after
if is_expired:
if not self._return_old_data_on_error:
self.cache.delete(cache_key)
- return send_request_and_cache_response()
+ return self.send_request_and_cache_response(request, cache_key, **kwargs)
try:
- new_response = send_request_and_cache_response()
+ new_response = self.send_request_and_cache_response(request, cache_key, **kwargs)
except Exception:
return response
else:
@@ -117,8 +115,16 @@ class CachedSession(OriginalSession):
response = dispatch_hook('response', request.hooks, response, **kwargs)
return response
+ def send_request_and_cache_response(self, request, cache_key, **kwargs):
+ response = super().send(request, **kwargs)
+ if response.status_code in self._cache_allowable_codes:
+ self.cache.save_response(cache_key, response)
+ response.from_cache = False
+ response.cache_date = None
+ return response
+
def request(self, method, url, params=None, data=None, **kwargs):
- response = super(CachedSession, self).request(
+ response = super().request(
method, url, _normalize_parameters(params), _normalize_parameters(data), **kwargs
)
if self._is_cache_disabled:
@@ -167,6 +173,10 @@ class CachedSession(OriginalSession):
)
+class CachedSession(CacheMixin, OriginalSession):
+ pass
+
+
def install_cache(
cache_name: str = 'cache',
backend: str = None,
@@ -176,7 +186,7 @@ def install_cache(
filter_fn: Callable = None,
old_data_on_error: bool = False,
session_factory=CachedSession,
- **backend_options
+ **kwargs
):
"""
Installs cache for all ``Requests`` requests by monkey-patching ``Session``
@@ -186,11 +196,11 @@ def install_cache(
:param session_factory: Session factory. It must be class which inherits :class:`CachedSession` (default)
"""
if backend:
- backend = backends.create_backend(backend, cache_name, backend_options)
+ backend = backends.create_backend(backend, cache_name, kwargs)
class _ConfiguredCachedSession(session_factory):
def __init__(self):
- super(_ConfiguredCachedSession, self).__init__(
+ super().__init__(
cache_name=cache_name,
backend=backend,
expire_after=expire_after,
@@ -198,7 +208,7 @@ def install_cache(
allowable_methods=allowable_methods,
filter_fn=filter_fn,
old_data_on_error=old_data_on_error,
- **backend_options
+ **kwargs
)
_patch_session_factory(_ConfiguredCachedSession)