diff options
author | Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> | 2022-02-22 05:29:55 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-22 12:29:55 +0200 |
commit | d56baeb683fc1935cfa343fa2eeb0fa9bd955283 (patch) | |
tree | 47357a74bf1d1428cfbcf0d8b2c781f1f971cf77 /redis/asyncio/client.py | |
parent | e3c989d93e914e6502bd5a72f15ded49a135c5be (diff) | |
download | redis-py-d56baeb683fc1935cfa343fa2eeb0fa9bd955283.tar.gz |
Add Async Support (#1899)
Co-authored-by: Chayim I. Kirshen <c@kirshen.com>
Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'redis/asyncio/client.py')
-rw-r--r-- | redis/asyncio/client.py | 1344 |
1 files changed, 1344 insertions, 0 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py new file mode 100644 index 0000000..619592e --- /dev/null +++ b/redis/asyncio/client.py @@ -0,0 +1,1344 @@ +import asyncio +import copy +import inspect +import re +import warnings +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.client import ( + EMPTY_RESPONSE, + NEVER_DECODE, + AbstractRedis, + CaseInsensitiveDict, + bool_ok, +) +from redis.commands import ( + AsyncCoreCommands, + AsyncSentinelCommands, + RedisModuleCommands, + list_or_args, +) +from redis.compat import Protocol, TypedDict +from redis.exceptions import ( + ConnectionError, + ExecAbortError, + PubSubError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from redis.lock import Lock +from redis.retry import Retry +from redis.typing import ChannelT, EncodableT, KeyT +from redis.utils import safe_str, str_if_bytes + +PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] +_KeyT = TypeVar("_KeyT", bound=KeyT) +_ArgT = TypeVar("_ArgT", KeyT, EncodableT) +_RedisT = TypeVar("_RedisT", bound="Redis") +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) +if TYPE_CHECKING: + from redis.commands.core import Script + + +class ResponseCallbackProtocol(Protocol): + def __call__(self, response: Any, **kwargs): + ... + + +class AsyncResponseCallbackProtocol(Protocol): + async def __call__(self, response: Any, **kwargs): + ... + + +ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] + + +class Redis( + AbstractRedis, RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands +): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + """ + + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + + @classmethod + def from_url(cls, url: str, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/redis> + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/rediss> + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + connection_pool = ConnectionPool.from_url(url, **kwargs) + return cls(connection_pool=connection_pool) + + def __init__( + self, + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_check_hostname: bool = False, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + username: Optional[str] = None, + retry: Optional[Retry] = None, + auto_close_connection_pool: bool = True, + ): + """ + Initialize a new Redis client. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ + kwargs: Dict[str, Any] + # auto_close_connection_pool only has an effect if connection_pool is + # None. This is a similar feature to the missing __del__ to resolve #1103, + # but it accounts for whether a user wants to manually close the connection + # pool, as a similar feature to ConnectionPool's __del__. + self.auto_close_connection_pool = ( + auto_close_connection_pool if connection_pool is None else False + ) + if not connection_pool: + kwargs = { + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "retry": copy.deepcopy(retry), + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + } + # based on input, setup appropriate connection args + if unix_socket_path is not None: + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) + else: + # TCP specific options + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) + + if ssl: + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_check_hostname": ssl_check_hostname, + } + ) + connection_pool = ConnectionPool(**kwargs) + self.connection_pool = connection_pool + self.single_connection_client = single_connection_client + self.connection: Optional[Connection] = None + + self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + + def __repr__(self): + return f"{self.__class__.__name__}<{self.connection_pool!r}>" + + def __await__(self): + return self.initialize().__await__() + + async def initialize(self: _RedisT) -> _RedisT: + if self.single_connection_client and self.connection is None: + self.connection = await self.connection_pool.get_connection("_") + return self + + def set_response_callback(self, command: str, callback: ResponseCallbackT): + """Set a custom Response Callback""" + self.response_callbacks[command] = callback + + def load_external_module( + self, + funcname, + func, + ): + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + funcname - A string containing the name of the function to create + func - The function, being added to this class. + + ex: Assume that one has a custom redis module named foomod that + creates command named 'foo.dothing' and 'foo.anotherthing' in redis. + To load function functions into this namespace: + + from redis import Redis + from foomodule import F + r = Redis() + r.load_external_module("foo", F) + r.foo().dothing('your', 'arguments') + + For a concrete example see the reimport of the redisjson module in + tests/test_connection.py::test_loading_external_modules + """ + setattr(self, funcname, func) + + def pipeline( + self, transaction: bool = True, shard_hint: Optional[str] = None + ) -> "Pipeline": + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + pipe: Pipeline + async with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + await pipe.watch(*watches) + func_value = func(pipe) + if inspect.isawaitable(func_value): + func_value = await func_value + exec_value = await pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + await asyncio.sleep(watch_delay) + continue + + def lock( + self, + name: KeyT, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking_timeout: Optional[float] = None, + lock_class: Optional[Type[Lock]] = None, + thread_local: bool = True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def pubsub(self, **kwargs) -> "PubSub": + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, **kwargs) + + def monitor(self) -> "Monitor": + return Monitor(self.connection_pool) + + def client(self) -> "Redis": + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) + + async def __aenter__(self: _RedisT) -> _RedisT: + return await self.initialize() + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + _DEL_MESSAGE = "Unclosed Redis client" + + def __del__(self, _warnings: Any = warnings) -> None: + if self.connection is not None: + _warnings.warn( + f"Unclosed client session {self!r}", + ResourceWarning, + source=self, + ) + context = {"client": self, "message": self._DEL_MESSAGE} + asyncio.get_event_loop().call_exception_handler(context) + + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Closes Redis client connection + + :param close_connection_pool: decides whether to close the connection pool used + by this Redis client, overriding Redis.auto_close_connection_pool. By default, + let Redis.auto_close_connection_pool decide whether to close the connection + pool. + """ + conn = self.connection + if conn: + self.connection = None + await self.connection_pool.release(conn) + if close_connection_pool or ( + close_connection_pool is None and self.auto_close_connection_pool + ): + await self.connection_pool.disconnect() + + async def _send_command_parse_response(self, conn, command_name, *args, **options): + """ + Send a command and parse the response + """ + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + + async def _disconnect_raise(self, conn: Connection, error: Exception): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + finally: + if not self.connection: + await pool.release(conn) + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + """Parses a response from the Redis server""" + try: + if NEVER_DECODE in options: + response = await connection.read_response(disable_encoding=True) + else: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise + if command_name in self.response_callbacks: + # Mypy bug: https://github.com/python/mypy/issues/10977 + command_name = cast(str, command_name) + retval = self.response_callbacks[command_name](response, **options) + return await retval if inspect.isawaitable(retval) else retval + return response + + +StrictRedis = Redis + + +class MonitorCommandInfo(TypedDict): + time: float + db: int + client_address: str + client_port: str + client_type: str + command: str + + +class Monitor: + """ + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + + monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") + command_re = re.compile(r'"(.*?)(?<!\\)"') + + def __init__(self, connection_pool: ConnectionPool): + self.connection_pool = connection_pool + self.connection: Optional[Connection] = None + + async def connect(self): + if self.connection is None: + self.connection = await self.connection_pool.get_connection("MONITOR") + + async def __aenter__(self): + await self.connect() + await self.connection.send_command("MONITOR") + # check that monitor returns 'OK', but don't return it to user + response = await self.connection.read_response() + if not bool_ok(response): + raise RedisError(f"MONITOR failed: {response}") + return self + + async def __aexit__(self, *args): + await self.connection.disconnect() + await self.connection_pool.release(self.connection) + + async def next_command(self) -> MonitorCommandInfo: + """Parse the response from a monitor command""" + await self.connect() + response = await self.connection.read_response() + if isinstance(response, bytes): + response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) + m = self.monitor_re.match(command_data) + db_id, client_info, command = m.groups() + command = " ".join(self.command_re.findall(command)) + # Redis escapes double quotes because each piece of the command + # string is surrounded by double quotes. We don't have that + # requirement so remove the escaping and leave the quote. + command = command.replace('\\"', '"') + + if client_info == "lua": + client_address = "lua" + client_port = "" + client_type = "lua" + elif client_info.startswith("unix"): + client_address = "unix" + client_port = client_info[5:] + client_type = "unix" + else: + # use rsplit as ipv6 addresses contain colons + client_address, client_port = client_info.rsplit(":", 1) + client_type = "tcp" + return { + "time": float(command_time), + "db": int(db_id), + "client_address": client_address, + "client_port": client_port, + "client_type": client_type, + "command": command, + } + + async def listen(self) -> AsyncIterator[MonitorCommandInfo]: + """Listen for commands coming to the server.""" + while True: + yield await self.next_command() + + +class PubSub: + """ + PubSub provides publish, subscribe and listen support to Redis channels. + + After subscribing to one or more channels, the listen() method will block + until a message arrives on one of the subscribed channels. That message + will be returned and it's safe to start listening again. + """ + + PUBLISH_MESSAGE_TYPES = ("message", "pmessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + HEALTH_CHECK_MESSAGE = "redis-py-health-check" + + def __init__( + self, + connection_pool: ConnectionPool, + shard_hint: Optional[str] = None, + ignore_subscribe_messages: bool = False, + encoder=None, + ): + self.connection_pool = connection_pool + self.shard_hint = shard_hint + self.ignore_subscribe_messages = ignore_subscribe_messages + self.connection = None + # we need to know the encoding options for this connection in order + # to lookup channel and pattern names for callback handlers. + self.encoder = encoder + if self.encoder is None: + self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response: Iterable[Union[str, bytes]] = [ + "pong", + self.HEALTH_CHECK_MESSAGE, + ] + else: + self.health_check_response = [ + b"pong", + self.encoder.encode(self.HEALTH_CHECK_MESSAGE), + ] + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + self._lock = asyncio.Lock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __del__(self): + if self.connection: + self.connection.clear_connect_callbacks() + + async def reset(self): + async with self._lock: + if self.connection: + await self.connection.disconnect() + self.connection.clear_connect_callbacks() + await self.connection_pool.release(self.connection) + self.connection = None + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + + def close(self) -> Awaitable[NoReturn]: + return self.reset() + + async def on_connect(self, connection: Connection): + """Re-subscribe to any channels and patterns previously subscribed to""" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() + if self.channels: + channels = {} + for k, v in self.channels.items(): + channels[self.encoder.decode(k, force=True)] = v + await self.subscribe(**channels) + if self.patterns: + patterns = {} + for k, v in self.patterns.items(): + patterns[self.encoder.decode(k, force=True)] = v + await self.psubscribe(**patterns) + + @property + def subscribed(self): + """Indicates if there are subscriptions to any channels or patterns""" + return bool(self.channels or self.patterns) + + async def execute_command(self, *args: EncodableT): + """Execute a publish/subscribe command""" + + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + self.connection = await self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + kwargs = {"check_health": not self.subscribed} + await self._execute(connection, connection.send_command, *args, **kwargs) + + async def _disconnect_raise_connect(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + await conn.connect() + + async def _execute(self, conn, command, *args, **kwargs): + """ + Connect manually upon disconnection. If the Redis server is down, + this will fail and raise a ConnectionError as desired. + After reconnection, the ``on_connect`` callback should have been + called by the # connection to resubscribe us to any channels and + patterns we were previously listening to + """ + return await conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) + + async def parse_response(self, block: bool = True, timeout: float = 0): + """Parse the response from a publish/subscribe command""" + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + await self.check_health() + + if not block and not await self._execute(conn, conn.can_read, timeout=timeout): + return None + response = await self._execute(conn, conn.read_response) + + if conn.health_check_interval and response == self.health_check_response: + # ignore the health check message as user might not expect it + return None + return response + + async def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + if ( + conn.health_check_interval + and asyncio.get_event_loop().time() > conn.next_health_check + ): + await conn.send_command( + "PING", self.HEALTH_CHECK_MESSAGE, check_health=False + ) + + def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501 + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else args + new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) + return ret_val + + def punsubscribe(self, *args: ChannelT) -> Awaitable: + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + patterns: Iterable[ChannelT] + if args: + parsed_args = list_or_args((args[0],), args[1:]) + patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() + else: + parsed_args = [] + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return self.execute_command("PUNSUBSCRIBE", *parsed_args) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else () + new_channels = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) + return ret_val + + def unsubscribe(self, *args) -> Awaitable: + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + parsed_args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(parsed_args)) + else: + parsed_args = [] + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return self.execute_command("UNSUBSCRIBE", *parsed_args) + + async def listen(self) -> AsyncIterator: + """Listen for messages on channels this client has been subscribed to""" + while self.subscribed: + response = await self.handle_message(await self.parse_response(block=True)) + if response is not None: + yield response + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number. + """ + response = await self.parse_response(block=False, timeout=timeout) + if response: + return await self.handle_message(response, ignore_subscribe_messages) + return None + + def ping(self, message=None) -> Awaitable: + """ + Ping the Redis server + """ + message = "" if message is None else message + return self.execute_command("PING", message) + + async def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + message_type = str_if_bytes(response[0]) + if message_type == "pmessage": + message = { + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], + } + elif message_type == "pong": + message = { + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], + } + else: + message = { + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == "punsubscribe": + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) + else: + handler = self.channels.get(message["channel"], None) + if handler: + if inspect.iscoroutinefunction(handler): + await handler(message) + else: + handler(message) + return None + elif message_type != "pong": + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + for channel, handler in self.channels.items(): + if handler is None: + raise PubSubError(f"Channel: '{channel}' has no handler registered") + for pattern, handler in self.patterns.items(): + if handler is None: + raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + + while True: + try: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + except asyncio.CancelledError: + raise + except BaseException as e: + if exception_handler is None: + raise + res = exception_handler(e, self) + if inspect.isawaitable(res): + await res + # Ensure that other tasks on the event loop get a chance to run + # if we didn't have to block for I/O anywhere. + await asyncio.sleep(0) + + +class PubsubWorkerExceptionHandler(Protocol): + def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +class AsyncPubsubWorkerExceptionHandler(Protocol): + async def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +PSWorkerThreadExcHandlerT = Union[ + PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler +] + + +CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] +CommandStackT = List[CommandT] + + +class Pipeline(Redis): # lgtm [py/init-calls-subclass] + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], + transaction: bool, + shard_hint: Optional[str], + ): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.is_transaction = transaction + self.shard_hint = shard_hint + self.watching = False + self.command_stack: CommandStackT = [] + self.scripts: Set["Script"] = set() + self.explicit_transaction = False + + async def __aenter__(self: _RedisT) -> _RedisT: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __await__(self): + return self._async_self().__await__() + + _DEL_MESSAGE = "Unclosed Pipeline client" + + def __len__(self): + return len(self.command_stack) + + def __bool__(self): + """Pipeline instances should always evaluate to True""" + return True + + async def _async_self(self): + return self + + async def reset(self): + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self.connection.send_command("UNWATCH") + await self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + if self.connection: + await self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + await self.connection_pool.release(self.connection) + self.connection = None + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already " "been issued" + ) + self.explicit_transaction = True + + def execute_command( + self, *args, **kwargs + ) -> Union["Pipeline", Awaitable["Pipeline"]]: + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + async def _disconnect_reset_raise(self, conn, error): + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + await self.reset() + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = await self.connection_pool.get_connection( + command_name, self.shard_hint + ) + self.connection = conn + + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + async def _execute_transaction( # noqa: C901 + self, connection: Connection, commands: CommandStackT, raise_on_error + ): + pre: CommandT = (("MULTI",), {}) + post: CommandT = (("EXEC",), {}) + cmds = (pre, *commands, post) + all_cmds = connection.pack_commands( + args for args, options in cmds if EMPTY_RESPONSE not in options + ) + await connection.send_packed_command(all_cmds) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + await self.parse_response(connection, "_") + except ResponseError as err: + errors.append((0, err)) + + # and all the other commands + for i, command in enumerate(commands): + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + await self.parse_response(connection, "_") + except ResponseError as err: + self.annotate_exception(err, i + 1, command[0]) + errors.append((i, err)) + + # parse the EXEC. + try: + response = await self.parse_response(connection, "_") + except ExecAbortError as err: + if errors: + raise errors[0][1] from err + raise + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") from None + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + if self.connection: + await self.connection.disconnect() + raise ResponseError( + "Wrong number of response items from pipeline execution" + ) from None + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(commands, response) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + if inspect.isawaitable(r): + r = await r + data.append(r) + return data + + async def _execute_pipeline( + self, connection: Connection, commands: CommandStackT, raise_on_error: bool + ): + # build up all commands into a single request to increase network perf + all_cmds = connection.pack_commands([args for args, _ in commands]) + await connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append( + await self.parse_response(connection, args[0], **options) + ) + except ResponseError as e: + response.append(e) + + if raise_on_error: + self.raise_first_error(commands, response) + return response + + def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): + for i, r in enumerate(response): + if isinstance(r, ResponseError): + self.annotate_exception(r, i + 1, commands[i][0]) + raise r + + def annotate_exception( + self, exception: Exception, number: int, command: Iterable[object] + ) -> None: + cmd = " ".join(map(safe_str, command)) + msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" + exception.args = (msg,) + exception.args[1:] + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + result = await super().parse_response(connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == "WATCH": + self.watching = True + return result + + async def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + # we can't use the normal script_* methods because they would just + # get buffered in the pipeline. + exists = await immediate("SCRIPT EXISTS", *shas) + if not all(exists): + for s, exist in zip(scripts, exists): + if not exist: + s.sha = await immediate("SCRIPT LOAD", s.script) + + async def _disconnect_raise_reset(self, conn: Connection, error: Exception): + """ + Close the connection, raise an exception if we were watching, + and raise an exception if retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def execute(self, raise_on_error: bool = True): + """Execute all the commands in the current pipeline""" + stack = self.command_stack + if not stack and not self.watching: + return [] + if self.scripts: + await self.load_scripts() + if self.is_transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + conn = cast(Connection, conn) + + try: + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + finally: + await self.reset() + + async def discard(self): + """Flushes all previously queued commands + See: https://redis.io/commands/DISCARD + """ + await self.execute_command("DISCARD") + + async def watch(self, *names: KeyT): + """Watches the values at keys ``names``""" + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return await self.execute_command("WATCH", *names) + + async def unwatch(self): + """Unwatches all previously specified keys""" + return self.watching and await self.execute_command("UNWATCH") or True |