summaryrefslogtreecommitdiff
path: root/requests_cache/backends/redis.py
blob: 5209aa6590c99bcc696286a93148984cfec456b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from redis import Redis, StrictRedis

from . import BaseCache, BaseStorage, get_valid_kwargs


class RedisCache(BaseCache):
    """Redis cache backend.

    Args:
        namespace: redis namespace (default: ``'requests-cache'``)
        connection: Redis connection instance to use instead of creating a new one
        kwargs: Additional keyword arguments for :py:class:`redis.client.Redis`
    """

    def __init__(self, namespace='http_cache', connection: Redis = None, **kwargs):
        super().__init__(**kwargs)
        self.responses = RedisDict(namespace, 'responses', connection=connection, **kwargs)
        self.redirects = RedisDict(namespace, 'redirects', connection=self.responses.connection, **kwargs)


class RedisDict(BaseStorage):
    """A dictionary-like interface for redis key-value store.

    Notes:
        * In order to deal with how redis stores data/keys, all keys and data are pickled.
        * The actual key name on the redis server will be ``namespace:collection_name``.

    Args:
        namespace: Redis namespace
        collection_name: Name of the Redis hash map
        connection: (optional) Redis connection instance to use instead of creating a new one
        kwargs: Additional keyword arguments for :py:class:`redis.client.Redis`
    """

    def __init__(self, namespace, collection_name='http_cache', connection=None, **kwargs):
        super().__init__(**kwargs)
        connection_kwargs = get_valid_kwargs(Redis, kwargs)
        self.connection = connection or StrictRedis(**connection_kwargs)
        self._self_key = ':'.join([namespace, collection_name])

    def __getitem__(self, key):
        result = self.connection.hget(self._self_key, self.serialize(key))
        if result is None:
            raise KeyError
        return self.deserialize(result)

    def __setitem__(self, key, item):
        self.connection.hset(self._self_key, self.serialize(key), self.serialize(item))

    def __delitem__(self, key):
        if not self.connection.hdel(self._self_key, self.serialize(key)):
            raise KeyError

    def __len__(self):
        return self.connection.hlen(self._self_key)

    def __iter__(self):
        for v in self.connection.hkeys(self._self_key):
            yield self.deserialize(v)

    def clear(self):
        self.connection.delete(self._self_key)