summaryrefslogtreecommitdiff
path: root/requests_cache/models
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2022-04-01 16:29:13 -0500
committerJordan Cook <jordan.cook@pioneer.com>2022-04-01 17:29:22 -0500
commit0d2d9c690a787f8894bb81fec25d65a4b774ad43 (patch)
tree4154088435cdbe152e271974215479cf50fff02c /requests_cache/models
parent026b627c63124c885ff734d3b30a15464f9b0c93 (diff)
downloadrequests-cache-0d2d9c690a787f8894bb81fec25d65a4b774ad43.tar.gz
Add an intermediate wrapper class, OriginalResponse, to provide type hints for extra attributes set on requests.Response objects
Diffstat (limited to 'requests_cache/models')
-rw-r--r--requests_cache/models/__init__.py4
-rwxr-xr-xrequests_cache/models/response.py59
2 files changed, 42 insertions, 21 deletions
diff --git a/requests_cache/models/__init__.py b/requests_cache/models/__init__.py
index 6ffc7ad..1824a6c 100644
--- a/requests_cache/models/__init__.py
+++ b/requests_cache/models/__init__.py
@@ -6,8 +6,8 @@ from requests import PreparedRequest, Request, Response
from .raw_response import CachedHTTPResponse
from .request import CachedRequest
-from .response import CachedResponse, set_response_defaults
+from .response import CachedResponse, OriginalResponse
-AnyResponse = Union[Response, CachedResponse]
+AnyResponse = Union[OriginalResponse, CachedResponse]
AnyRequest = Union[Request, PreparedRequest, CachedRequest]
AnyPreparedRequest = Union[PreparedRequest, CachedRequest]
diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py
index 4ac24ce..b6ba3ae 100755
--- a/requests_cache/models/response.py
+++ b/requests_cache/models/response.py
@@ -1,6 +1,6 @@
from datetime import datetime, timedelta, timezone
from logging import getLogger
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
import attr
from attr import define, field
@@ -12,17 +12,53 @@ from urllib3._collections import HTTPHeaderDict
from ..expiration import ExpirationTime, get_expiration_datetime
from . import CachedHTTPResponse, CachedRequest
+if TYPE_CHECKING:
+ from ..cache_control import CacheActions
+
DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S %Z' # Format used for __str__ only
HeaderList = List[Tuple[str, str]]
logger = getLogger(__name__)
@define(auto_attribs=False, slots=False)
-class CachedResponse(Response):
- """A class that emulates :py:class:`requests.Response`, with some additional optimizations
- for serialization.
+class BaseResponse(Response):
+ """Wrapper class for responses returned by :py:class:`.CachedSession`. This mainly exists to
+ provide type hints for extra cache-related attributes that are added to non-cached responses.
"""
+ cache_key: Optional[str] = None
+ created_at: datetime = field(factory=datetime.utcnow)
+ expires: Optional[datetime] = field(default=None)
+
+ @property
+ def from_cache(self) -> bool:
+ return False
+
+ @property
+ def is_expired(self) -> bool:
+ return False
+
+
+@define(auto_attribs=False, repr=False, slots=False)
+class OriginalResponse(BaseResponse):
+ """Wrapper class for non-cached responses returned by :py:class:`.CachedSession`"""
+
+ @classmethod
+ def wrap_response(cls, response: Response, actions: 'CacheActions'):
+ """Modify a response object in-place and add extra cache-related attributes"""
+ if not isinstance(response, cls):
+ response.__class__ = cls
+ # Add expires and cache_key only if the response was written to the cache
+ response.expires = None if actions.skip_write else actions.expires # type: ignore
+ response.cache_key = None if actions.skip_write else actions.cache_key # type: ignore
+ response.created_at = datetime.utcnow() # type: ignore
+ return response
+
+
+@define(auto_attribs=False, slots=False)
+class CachedResponse(BaseResponse):
+ """A class that emulates :py:class:`requests.Response`, optimized for serialization"""
+
_content: bytes = field(default=None)
_next: Optional[CachedRequest] = field(default=None)
cache_key: Optional[str] = None # Not serialized; set by BaseCache.get_response()
@@ -155,18 +191,3 @@ def format_file_size(n_bytes: int) -> str:
if TYPE_CHECKING:
return _format(unit)
-
-
-def set_response_defaults(
- response: Union[Response, CachedResponse], cache_key: str = None
-) -> Union[Response, CachedResponse]:
- """Set some default CachedResponse values on a requests.Response object, so they can be
- expected to always be present
- """
- if not isinstance(response, CachedResponse):
- response.cache_key = cache_key # type: ignore
- response.created_at = None # type: ignore
- response.expires = None # type: ignore
- response.from_cache = False # type: ignore
- response.is_expired = False # type: ignore
- return response