summaryrefslogtreecommitdiff
path: root/requests_cache/response.py
blob: 0f2edaf09c1677e3f88c2d0566c136614a11fd5a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Classes to wrap cached response objects"""
from copy import copy
from datetime import datetime, timedelta
from io import BytesIO
from logging import getLogger
from typing import Any, Dict, Optional, Union

from requests import Response
from urllib3.response import HTTPResponse

# Reponse attributes to copy
RESPONSE_ATTRS = Response.__attrs__
RAW_RESPONSE_ATTRS = [
    'decode_content',
    'headers',
    'reason',
    'request_method',
    'request_url',
    'status',
    'strict',
    'version',
]

ExpirationTime = Union[None, int, float, datetime, timedelta]
logger = getLogger(__name__)


class CachedResponse(Response):
    """A serializable wrapper for :py:class:`requests.Response`. CachedResponse objects will behave
    the same as the original response, but with some additional cache-related details. This class is
    responsible for converting and setting cache expiration times, and converting response info into
    a serializable format.

    Args:
        original_response: Response object
        expire_after: Time after which this cached response will expire
    """

    def __init__(self, original_response: Response, expire_after: ExpirationTime = None):
        """Create a CachedResponse based on an original Response"""
        super().__init__()
        # Set cache-specific attrs
        self.created_at = datetime.utcnow()
        self.expires = self._get_expiration_datetime(expire_after)
        self.from_cache = True

        # Copy basic response attrs and original request
        for k in RESPONSE_ATTRS:
            setattr(self, k, getattr(original_response, k, None))
        self.request = copy(original_response.request)
        self.request.hooks = []

        # Read content to support streaming requests, and reset file pointer on original request
        if hasattr(original_response.raw, '_fp') and not original_response.raw.isclosed():
            # Cache raw data
            raw_data = original_response.raw.read(decode_content=False)
            # Reset `_fp`
            original_response.raw._fp = BytesIO(raw_data)
            # Read and store (decoded) data
            self._content = original_response.content
            # Reset `_fp` again
            original_response.raw._fp = BytesIO(raw_data)
            original_response.raw._fp_bytes_read = 0
            original_response.raw.length_remaining = len(raw_data)
        else:
            self._content = original_response.content

        # Copy raw response
        self._raw_response = None
        self._raw_response_attrs: Dict[str, Any] = {}
        for k in RAW_RESPONSE_ATTRS:
            self._raw_response_attrs[k] = getattr(original_response.raw, k, None)

        # Copy redirect history, if any; avoid recursion by not copying redirects of redirects
        self.history = []
        if not self.is_redirect:
            for redirect in original_response.history:
                self.history.append(CachedResponse(redirect))

    def __getstate__(self):
        """Override pickling behavior in ``requests.Response.__getstate__``"""
        return self.__dict__

    def _get_expiration_datetime(self, expire_after: ExpirationTime) -> Optional[datetime]:
        """Convert a time value or delta to an absolute datetime, if it's not already"""
        logger.debug(f'Determining expiration time based on: {expire_after}')
        if expire_after is None or expire_after == -1:
            return None
        elif isinstance(expire_after, datetime):
            return expire_after

        if not isinstance(expire_after, timedelta):
            expire_after = timedelta(seconds=expire_after)
        return self.created_at + expire_after

    def reset(self):
        """Reset raw response file handler, if previously initialized"""
        self._raw_response = None

    @property
    def is_expired(self) -> bool:
        """Determine if this cached response is expired"""
        return self.expires is not None and datetime.utcnow() > self.expires

    @property
    def raw(self) -> HTTPResponse:
        """Reconstruct a raw urllib response object from stored attrs"""
        if not self._raw_response:
            logger.debug('Rebuilding raw response object')
            self._raw_response = CachedHTTPResponse(body=self._content, **self._raw_response_attrs)
        return self._raw_response

    @raw.setter
    def raw(self, value):
        """No-op to handle requests.Response attempting to set self.raw"""

    def revalidate(self, expire_after: ExpirationTime) -> bool:
        """Set a new expiration for this response, and determine if it is now expired"""
        self.expires = self._get_expiration_datetime(expire_after)
        return self.is_expired


class CachedHTTPResponse(HTTPResponse):
    """A wrapper for raw urllib response objects, which wraps cached content with support for
    streaming requests
    """

    def __init__(self, body: bytes = None, **kwargs):
        kwargs.setdefault('preload_content', False)
        super().__init__(body=BytesIO(body or b''), **kwargs)
        self._body = body

    def release_conn(self):
        """No-op for compatibility"""

    def read(self, amt=None, decode_content=None, **kwargs):
        """Simplified reader for cached content that emulates
        :py:meth:`urllib3.response.HTTPResponse.read()`
        """
        if 'content-encoding' in self.headers and (
            decode_content is False or (decode_content is None and not self.decode_content)
        ):
            # Warn if content was encoded and decode_content is set to False
            logger.warning('read() returns decoded data for cached responses, even with decode_content=False set')

        data = self._fp.read(amt)
        # "close" the file to inform consumers to stop reading from it
        if not data:
            self._fp.close()
        return data

    def stream(self, amt=None, **kwargs):
        """Simplified generator over cached content that emulates
        :py:meth:`urllib3.response.HTTPResponse.stream()`
        """
        while not self._fp.closed:
            yield self.read(amt=amt, **kwargs)


AnyResponse = Union[Response, CachedResponse]


def set_response_defaults(response: AnyResponse) -> AnyResponse:
    """Set some default CachedResponse values on a requests.Response object, so they can be
    expected to always be present
    """
    if not isinstance(response, CachedResponse):
        response.created_at = None
        response.expires = None
        response.from_cache = False
        response.is_expired = False
    return response