diff options
Diffstat (limited to 'kombu')
75 files changed, 1500 insertions, 438 deletions
diff --git a/kombu/__init__.py b/kombu/__init__.py index da4ed466..999cf9da 100644 --- a/kombu/__init__.py +++ b/kombu/__init__.py @@ -1,11 +1,14 @@ """Messaging library for Python.""" +from __future__ import annotations + import os import re import sys from collections import namedtuple +from typing import Any, cast -__version__ = '5.2.0' +__version__ = '5.3.0b3' __author__ = 'Ask Solem' __contact__ = 'auvipy@gmail.com, ask@celeryproject.org' __homepage__ = 'https://kombu.readthedocs.io' @@ -19,12 +22,12 @@ version_info_t = namedtuple('version_info_t', ( # bumpversion can only search for {current_version} # so we have to parse the version here. -_temp = re.match( - r'(\d+)\.(\d+).(\d+)(.+)?', __version__).groups() +_temp = cast(re.Match, re.match( + r'(\d+)\.(\d+).(\d+)(.+)?', __version__)).groups() VERSION = version_info = version_info_t( int(_temp[0]), int(_temp[1]), int(_temp[2]), _temp[3] or '', '') -del(_temp) -del(re) +del _temp +del re STATICA_HACK = True globals()['kcah_acitats'[::-1].upper()] = False @@ -61,15 +64,15 @@ all_by_module = { } object_origins = {} -for module, items in all_by_module.items(): +for _module, items in all_by_module.items(): for item in items: - object_origins[item] = module + object_origins[item] = _module class module(ModuleType): """Customized Python module.""" - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in object_origins: module = __import__(object_origins[name], None, None, [name]) for extra_name in all_by_module[module.__name__]: @@ -77,7 +80,7 @@ class module(ModuleType): return getattr(module, name) return ModuleType.__getattribute__(self, name) - def __dir__(self): + def __dir__(self) -> list[str]: result = list(new_module.__all__) result.extend(('__file__', '__path__', '__doc__', '__all__', '__docformat__', '__name__', '__path__', 'VERSION', @@ -86,12 +89,6 @@ class module(ModuleType): return result -# 2.5 does not define __package__ -try: - package = __package__ -except NameError: # pragma: no cover - package = 'kombu' - # keep a reference to this module so that it's not garbage collected old_module = sys.modules[__name__] @@ -106,7 +103,7 @@ new_module.__dict__.update({ '__contact__': __contact__, '__homepage__': __homepage__, '__docformat__': __docformat__, - '__package__': package, + '__package__': __package__, 'version_info_t': version_info_t, 'version_info': version_info, 'VERSION': VERSION diff --git a/kombu/abstract.py b/kombu/abstract.py index 38cff010..48a917c9 100644 --- a/kombu/abstract.py +++ b/kombu/abstract.py @@ -1,19 +1,35 @@ """Object utilities.""" +from __future__ import annotations + from copy import copy +from typing import TYPE_CHECKING, Any, Callable, TypeVar from .connection import maybe_channel from .exceptions import NotBoundError from .utils.functional import ChannelPromise +if TYPE_CHECKING: + from kombu.connection import Connection + from kombu.transport.virtual import Channel + + __all__ = ('Object', 'MaybeChannelBound') +_T = TypeVar("_T") +_ObjectType = TypeVar("_ObjectType", bound="Object") +_MaybeChannelBoundType = TypeVar( + "_MaybeChannelBoundType", bound="MaybeChannelBound" +) + -def unpickle_dict(cls, kwargs): +def unpickle_dict( + cls: type[_ObjectType], kwargs: dict[str, Any] +) -> _ObjectType: return cls(**kwargs) -def _any(v): +def _any(v: _T) -> _T: return v @@ -23,9 +39,9 @@ class Object: Supports automatic kwargs->attributes handling, and cloning. """ - attrs = () + attrs: tuple[tuple[str, Any], ...] = () - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: for name, type_ in self.attrs: value = kwargs.get(name) if value is not None: @@ -36,8 +52,8 @@ class Object: except AttributeError: setattr(self, name, None) - def as_dict(self, recurse=False): - def f(obj, type): + def as_dict(self, recurse: bool = False) -> dict[str, Any]: + def f(obj: Any, type: Callable[[Any], Any]) -> Any: if recurse and isinstance(obj, Object): return obj.as_dict(recurse=True) return type(obj) if type and obj is not None else obj @@ -45,31 +61,40 @@ class Object: attr: f(getattr(self, attr), type) for attr, type in self.attrs } - def __reduce__(self): + def __reduce__(self: _ObjectType) -> tuple[ + Callable[[type[_ObjectType], dict[str, Any]], _ObjectType], + tuple[type[_ObjectType], dict[str, Any]] + ]: return unpickle_dict, (self.__class__, self.as_dict()) - def __copy__(self): + def __copy__(self: _ObjectType) -> _ObjectType: return self.__class__(**self.as_dict()) class MaybeChannelBound(Object): """Mixin for classes that can be bound to an AMQP channel.""" - _channel = None + _channel: Channel | None = None _is_bound = False #: Defines whether maybe_declare can skip declaring this entity twice. can_cache_declaration = False - def __call__(self, channel): + def __call__( + self: _MaybeChannelBoundType, channel: (Channel | Connection) + ) -> _MaybeChannelBoundType: """`self(channel) -> self.bind(channel)`.""" return self.bind(channel) - def bind(self, channel): + def bind( + self: _MaybeChannelBoundType, channel: (Channel | Connection) + ) -> _MaybeChannelBoundType: """Create copy of the instance that is bound to a channel.""" return copy(self).maybe_bind(channel) - def maybe_bind(self, channel): + def maybe_bind( + self: _MaybeChannelBoundType, channel: (Channel | Connection) + ) -> _MaybeChannelBoundType: """Bind instance to channel if not already bound.""" if not self.is_bound and channel: self._channel = maybe_channel(channel) @@ -77,7 +102,7 @@ class MaybeChannelBound(Object): self._is_bound = True return self - def revive(self, channel): + def revive(self, channel: Channel) -> None: """Revive channel after the connection has been re-established. Used by :meth:`~kombu.Connection.ensure`. @@ -87,13 +112,13 @@ class MaybeChannelBound(Object): self._channel = channel self.when_bound() - def when_bound(self): + def when_bound(self) -> None: """Callback called when the class is bound.""" - def __repr__(self): + def __repr__(self) -> str: return self._repr_entity(type(self).__name__) - def _repr_entity(self, item=''): + def _repr_entity(self, item: str = '') -> str: item = item or type(self).__name__ if self.is_bound: return '<{} bound to chan:{}>'.format( @@ -101,12 +126,12 @@ class MaybeChannelBound(Object): return f'<unbound {item}>' @property - def is_bound(self): + def is_bound(self) -> bool: """Flag set if the channel is bound.""" return self._is_bound and self._channel is not None @property - def channel(self): + def channel(self) -> Channel: """Current channel if the object is bound.""" channel = self._channel if channel is None: diff --git a/kombu/asynchronous/__init__.py b/kombu/asynchronous/__init__.py index fb264aa5..53060753 100644 --- a/kombu/asynchronous/__init__.py +++ b/kombu/asynchronous/__init__.py @@ -1,5 +1,7 @@ """Event loop.""" +from __future__ import annotations + from kombu.utils.eventio import ERR, READ, WRITE from .hub import Hub, get_event_loop, set_event_loop diff --git a/kombu/asynchronous/aws/__init__.py b/kombu/asynchronous/aws/__init__.py index d8423c23..cbeb050f 100644 --- a/kombu/asynchronous/aws/__init__.py +++ b/kombu/asynchronous/aws/__init__.py @@ -1,4 +1,15 @@ -def connect_sqs(aws_access_key_id=None, aws_secret_access_key=None, **kwargs): +from __future__ import annotations + +from typing import Any + +from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection + + +def connect_sqs( + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + **kwargs: Any +) -> AsyncSQSConnection: """Return async connection to Amazon SQS.""" from .sqs.connection import AsyncSQSConnection return AsyncSQSConnection( diff --git a/kombu/asynchronous/aws/connection.py b/kombu/asynchronous/aws/connection.py index f3926388..887ab40c 100644 --- a/kombu/asynchronous/aws/connection.py +++ b/kombu/asynchronous/aws/connection.py @@ -1,5 +1,7 @@ """Amazon AWS Connection.""" +from __future__ import annotations + from email import message_from_bytes from email.mime.message import MIMEMessage diff --git a/kombu/asynchronous/aws/ext.py b/kombu/asynchronous/aws/ext.py index 2dedc812..1fa4a57e 100644 --- a/kombu/asynchronous/aws/ext.py +++ b/kombu/asynchronous/aws/ext.py @@ -1,5 +1,7 @@ """Amazon boto3 interface.""" +from __future__ import annotations + try: import boto3 from botocore import exceptions diff --git a/kombu/asynchronous/aws/sqs/connection.py b/kombu/asynchronous/aws/sqs/connection.py index 9db2523b..20b56344 100644 --- a/kombu/asynchronous/aws/sqs/connection.py +++ b/kombu/asynchronous/aws/sqs/connection.py @@ -1,5 +1,7 @@ """Amazon SQS Connection.""" +from __future__ import annotations + from vine import transform from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection diff --git a/kombu/asynchronous/aws/sqs/ext.py b/kombu/asynchronous/aws/sqs/ext.py index f6630936..72268b5d 100644 --- a/kombu/asynchronous/aws/sqs/ext.py +++ b/kombu/asynchronous/aws/sqs/ext.py @@ -1,6 +1,8 @@ """Amazon SQS boto3 interface.""" +from __future__ import annotations + try: import boto3 except ImportError: diff --git a/kombu/asynchronous/aws/sqs/message.py b/kombu/asynchronous/aws/sqs/message.py index 9425ff2d..52727bb7 100644 --- a/kombu/asynchronous/aws/sqs/message.py +++ b/kombu/asynchronous/aws/sqs/message.py @@ -1,5 +1,7 @@ """Amazon SQS message implementation.""" +from __future__ import annotations + import base64 from kombu.message import Message diff --git a/kombu/asynchronous/aws/sqs/queue.py b/kombu/asynchronous/aws/sqs/queue.py index 50b0be55..7ca78f75 100644 --- a/kombu/asynchronous/aws/sqs/queue.py +++ b/kombu/asynchronous/aws/sqs/queue.py @@ -1,5 +1,7 @@ """Amazon SQS queue implementation.""" +from __future__ import annotations + from vine import transform from .message import AsyncMessage @@ -12,7 +14,7 @@ def list_first(rs): return rs[0] if len(rs) == 1 else None -class AsyncQueue(): +class AsyncQueue: """Async SQS Queue.""" def __init__(self, connection=None, url=None, message_class=AsyncMessage): diff --git a/kombu/asynchronous/debug.py b/kombu/asynchronous/debug.py index 4fabb452..7c1e45c7 100644 --- a/kombu/asynchronous/debug.py +++ b/kombu/asynchronous/debug.py @@ -1,5 +1,7 @@ """Event-loop debugging tools.""" +from __future__ import annotations + from kombu.utils.eventio import ERR, READ, WRITE from kombu.utils.functional import reprcall diff --git a/kombu/asynchronous/http/__init__.py b/kombu/asynchronous/http/__init__.py index 1c45ebca..67d8b219 100644 --- a/kombu/asynchronous/http/__init__.py +++ b/kombu/asynchronous/http/__init__.py @@ -1,17 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from kombu.asynchronous import get_event_loop +from kombu.asynchronous.http.base import Headers, Request, Response +from kombu.asynchronous.hub import Hub -from .base import Headers, Request, Response +if TYPE_CHECKING: + from kombu.asynchronous.http.curl import CurlClient __all__ = ('Client', 'Headers', 'Response', 'Request') -def Client(hub=None, **kwargs): +def Client(hub: Hub | None = None, **kwargs: int) -> CurlClient: """Create new HTTP client.""" from .curl import CurlClient return CurlClient(hub, **kwargs) -def get_client(hub=None, **kwargs): +def get_client(hub: Hub | None = None, **kwargs: int) -> CurlClient: """Get or create HTTP client bound to the current event loop.""" hub = hub or get_event_loop() try: diff --git a/kombu/asynchronous/http/base.py b/kombu/asynchronous/http/base.py index e8d5043b..89be531f 100644 --- a/kombu/asynchronous/http/base.py +++ b/kombu/asynchronous/http/base.py @@ -1,7 +1,10 @@ """Base async HTTP client implementation.""" +from __future__ import annotations + import sys from http.client import responses +from typing import TYPE_CHECKING from vine import Thenable, maybe_promise, promise @@ -10,6 +13,9 @@ from kombu.utils.compat import coro from kombu.utils.encoding import bytes_to_str from kombu.utils.functional import maybe_list, memoize +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Headers', 'Response', 'Request') PYPY = hasattr(sys, 'pypy_version_info') @@ -61,7 +67,7 @@ class Request: auth_password (str): Password for HTTP authentication. auth_mode (str): Type of HTTP authentication (``basic`` or ``digest``). user_agent (str): Custom user agent for this request. - network_interace (str): Network interface to use for this request. + network_interface (str): Network interface to use for this request. on_ready (Callable): Callback to be called when the response has been received. Must accept single ``response`` argument. on_stream (Callable): Optional callback to be called every time body @@ -253,5 +259,10 @@ class BaseClient: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() diff --git a/kombu/asynchronous/http/curl.py b/kombu/asynchronous/http/curl.py index ee70f3c6..6f879fa9 100644 --- a/kombu/asynchronous/http/curl.py +++ b/kombu/asynchronous/http/curl.py @@ -1,11 +1,13 @@ """HTTP Client using pyCurl.""" +from __future__ import annotations + from collections import deque from functools import partial from io import BytesIO from time import time -from kombu.asynchronous.hub import READ, WRITE, get_event_loop +from kombu.asynchronous.hub import READ, WRITE, Hub, get_event_loop from kombu.exceptions import HttpError from kombu.utils.encoding import bytes_to_str @@ -36,7 +38,7 @@ class CurlClient(BaseClient): Curl = Curl - def __init__(self, hub=None, max_clients=10): + def __init__(self, hub: Hub | None = None, max_clients: int = 10): if pycurl is None: raise ImportError('The curl client requires the pycurl library.') hub = hub or get_event_loop() @@ -231,9 +233,6 @@ class CurlClient(BaseClient): if request.proxy_username: setopt(_pycurl.PROXYUSERPWD, '{}:{}'.format( request.proxy_username, request.proxy_password or '')) - else: - setopt(_pycurl.PROXY, '') - curl.unsetopt(_pycurl.PROXYUSERPWD) setopt(_pycurl.SSL_VERIFYPEER, 1 if request.validate_cert else 0) setopt(_pycurl.SSL_VERIFYHOST, 2 if request.validate_cert else 0) @@ -253,7 +252,7 @@ class CurlClient(BaseClient): setopt(meth, True) if request.method in ('POST', 'PUT'): - body = request.body.encode('utf-8') if request.body else bytes() + body = request.body.encode('utf-8') if request.body else b'' reqbuffer = BytesIO(body) setopt(_pycurl.READFUNCTION, reqbuffer.read) if request.method == 'POST': diff --git a/kombu/asynchronous/hub.py b/kombu/asynchronous/hub.py index b1f7e241..e5b1163c 100644 --- a/kombu/asynchronous/hub.py +++ b/kombu/asynchronous/hub.py @@ -1,6 +1,9 @@ """Event loop implementation.""" +from __future__ import annotations + import errno +import threading from contextlib import contextmanager from queue import Empty from time import sleep @@ -18,7 +21,7 @@ from .timer import Timer __all__ = ('Hub', 'get_event_loop', 'set_event_loop') logger = get_logger(__name__) -_current_loop = None +_current_loop: Hub | None = None W_UNKNOWN_EVENT = """\ Received unknown event %r for fd %r, please contact support!\ @@ -38,12 +41,12 @@ def _dummy_context(*args, **kwargs): yield -def get_event_loop(): +def get_event_loop() -> Hub | None: """Get current event loop object.""" return _current_loop -def set_event_loop(loop): +def set_event_loop(loop: Hub | None) -> Hub | None: """Set the current event loop object.""" global _current_loop _current_loop = loop @@ -78,6 +81,7 @@ class Hub: self.on_tick = set() self.on_close = set() self._ready = set() + self._ready_lock = threading.Lock() self._running = False self._loop = None @@ -198,7 +202,8 @@ class Hub: def call_soon(self, callback, *args): if not isinstance(callback, Thenable): callback = promise(callback, args) - self._ready.add(callback) + with self._ready_lock: + self._ready.add(callback) return callback def call_later(self, delay, callback, *args): @@ -242,6 +247,12 @@ class Hub: except (AttributeError, KeyError, OSError): pass + def _pop_ready(self): + with self._ready_lock: + ready = self._ready + self._ready = set() + return ready + def close(self, *args): [self._unregister(fd) for fd in self.readers] self.readers.clear() @@ -257,8 +268,7 @@ class Hub: # To avoid infinite loop where one of the callables adds items # to self._ready (via call_soon or otherwise). # we create new list with current self._ready - todos = list(self._ready) - self._ready = set() + todos = self._pop_ready() for item in todos: item() @@ -288,17 +298,17 @@ class Hub: propagate = self.propagate_errors while 1: - todo = self._ready - self._ready = set() - - for tick_callback in on_tick: - tick_callback() + todo = self._pop_ready() for item in todo: if item: item() poll_timeout = fire_timers(propagate=propagate) if scheduled else 1 + + for tick_callback in on_tick: + tick_callback() + # print('[[[HUB]]]: %s' % (self.repr_active(),)) if readers or writers: to_consolidate = [] diff --git a/kombu/asynchronous/semaphore.py b/kombu/asynchronous/semaphore.py index 9fe34a04..07fb8a09 100644 --- a/kombu/asynchronous/semaphore.py +++ b/kombu/asynchronous/semaphore.py @@ -1,9 +1,23 @@ """Semaphores and concurrency primitives.""" +from __future__ import annotations +import sys from collections import deque +from typing import TYPE_CHECKING, Callable, Deque + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('DummyLock', 'LaxBoundedSemaphore') +P = ParamSpec("P") + class LaxBoundedSemaphore: """Asynchronous Bounded Semaphore. @@ -12,18 +26,15 @@ class LaxBoundedSemaphore: range even if released more times than it was acquired. Example: - >>> from future import print_statement as printf - # ^ ignore: just fooling stupid pyflakes - >>> x = LaxBoundedSemaphore(2) - >>> x.acquire(printf, 'HELLO 1') + >>> x.acquire(print, 'HELLO 1') HELLO 1 - >>> x.acquire(printf, 'HELLO 2') + >>> x.acquire(print, 'HELLO 2') HELLO 2 - >>> x.acquire(printf, 'HELLO 3') + >>> x.acquire(print, 'HELLO 3') >>> x._waiters # private, do not access directly [print, ('HELLO 3',)] @@ -31,13 +42,18 @@ class LaxBoundedSemaphore: HELLO 3 """ - def __init__(self, value): + def __init__(self, value: int) -> None: self.initial_value = self.value = value - self._waiting = deque() + self._waiting: Deque[tuple] = deque() self._add_waiter = self._waiting.append self._pop_waiter = self._waiting.popleft - def acquire(self, callback, *partial_args, **partial_kwargs): + def acquire( + self, + callback: Callable[P, None], + *partial_args: P.args, + **partial_kwargs: P.kwargs + ) -> bool: """Acquire semaphore. This will immediately apply ``callback`` if @@ -57,7 +73,7 @@ class LaxBoundedSemaphore: callback(*partial_args, **partial_kwargs) return True - def release(self): + def release(self) -> None: """Release semaphore. Note: @@ -71,23 +87,24 @@ class LaxBoundedSemaphore: else: waiter(*args, **kwargs) - def grow(self, n=1): + def grow(self, n: int = 1) -> None: """Change the size of the semaphore to accept more users.""" self.initial_value += n self.value += n - [self.release() for _ in range(n)] + for _ in range(n): + self.release() - def shrink(self, n=1): + def shrink(self, n: int = 1) -> None: """Change the size of the semaphore to accept less users.""" self.initial_value = max(self.initial_value - n, 0) self.value = max(self.value - n, 0) - def clear(self): + def clear(self) -> None: """Reset the semaphore, which also wipes out any waiting callbacks.""" self._waiting.clear() self.value = self.initial_value - def __repr__(self): + def __repr__(self) -> str: return '<{} at {:#x} value:{} waiting:{}>'.format( self.__class__.__name__, id(self), self.value, len(self._waiting), ) @@ -96,8 +113,13 @@ class LaxBoundedSemaphore: class DummyLock: """Pretending to be a lock.""" - def __enter__(self): + def __enter__(self) -> DummyLock: return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: pass diff --git a/kombu/asynchronous/timer.py b/kombu/asynchronous/timer.py index 21ad37c1..f6be1346 100644 --- a/kombu/asynchronous/timer.py +++ b/kombu/asynchronous/timer.py @@ -1,5 +1,7 @@ """Timer scheduling Python callbacks.""" +from __future__ import annotations + import heapq import sys from collections import namedtuple @@ -7,6 +9,7 @@ from datetime import datetime from functools import total_ordering from time import monotonic from time import time as _time +from typing import TYPE_CHECKING from weakref import proxy as weakrefproxy from vine.utils import wraps @@ -18,6 +21,9 @@ try: except ImportError: # pragma: no cover utc = None +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Entry', 'Timer', 'to_timestamp') logger = get_logger(__name__) @@ -101,7 +107,12 @@ class Timer: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.stop() def call_at(self, eta, fun, args=(), kwargs=None, priority=0): diff --git a/kombu/clocks.py b/kombu/clocks.py index 3c152720..d02e8b32 100644 --- a/kombu/clocks.py +++ b/kombu/clocks.py @@ -1,8 +1,11 @@ """Logical Clocks and Synchronization.""" +from __future__ import annotations + from itertools import islice from operator import itemgetter from threading import Lock +from typing import Any __all__ = ('LamportClock', 'timetuple') @@ -15,7 +18,7 @@ class timetuple(tuple): Can be used as part of a heap to keep events ordered. Arguments: - clock (int): Event clock value. + clock (Optional[int]): Event clock value. timestamp (float): Event UNIX timestamp value. id (str): Event host id (e.g. ``hostname:pid``). obj (Any): Optional obj to associate with this event. @@ -23,16 +26,18 @@ class timetuple(tuple): __slots__ = () - def __new__(cls, clock, timestamp, id, obj=None): + def __new__( + cls, clock: int | None, timestamp: float, id: str, obj: Any = None + ) -> timetuple: return tuple.__new__(cls, (clock, timestamp, id, obj)) - def __repr__(self): + def __repr__(self) -> str: return R_CLOCK.format(*self) - def __getnewargs__(self): + def __getnewargs__(self) -> tuple: return tuple(self) - def __lt__(self, other): + def __lt__(self, other: tuple) -> bool: # 0: clock 1: timestamp 3: process id try: A, B = self[0], other[0] @@ -45,13 +50,13 @@ class timetuple(tuple): except IndexError: return NotImplemented - def __gt__(self, other): + def __gt__(self, other: tuple) -> bool: return other < self - def __le__(self, other): + def __le__(self, other: tuple) -> bool: return not other < self - def __ge__(self, other): + def __ge__(self, other: tuple) -> bool: return not self < other clock = property(itemgetter(0)) @@ -99,21 +104,23 @@ class LamportClock: #: The clocks current value. value = 0 - def __init__(self, initial_value=0, Lock=Lock): + def __init__( + self, initial_value: int = 0, Lock: type[Lock] = Lock + ) -> None: self.value = initial_value self.mutex = Lock() - def adjust(self, other): + def adjust(self, other: int) -> int: with self.mutex: value = self.value = max(self.value, other) + 1 return value - def forward(self): + def forward(self) -> int: with self.mutex: self.value += 1 return self.value - def sort_heap(self, h): + def sort_heap(self, h: list[tuple[int, str]]) -> tuple[int, str]: """Sort heap of events. List of tuples containing at least two elements, representing @@ -140,8 +147,8 @@ class LamportClock: # clock values unique, return first item return h[0] - def __str__(self): + def __str__(self) -> str: return str(self.value) - def __repr__(self): + def __repr__(self) -> str: return f'<LamportClock: {self.value}>' diff --git a/kombu/common.py b/kombu/common.py index 08bc1aff..c7b2d50a 100644 --- a/kombu/common.py +++ b/kombu/common.py @@ -1,5 +1,7 @@ """Common Utilities.""" +from __future__ import annotations + import os import socket import threading diff --git a/kombu/compat.py b/kombu/compat.py index 1fa3f631..d90aec75 100644 --- a/kombu/compat.py +++ b/kombu/compat.py @@ -3,11 +3,17 @@ See https://pypi.org/project/carrot/ for documentation. """ +from __future__ import annotations + from itertools import count +from typing import TYPE_CHECKING from . import messaging from .entity import Exchange, Queue +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Publisher', 'Consumer') # XXX compat attribute @@ -65,7 +71,12 @@ class Publisher(messaging.Producer): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() @property @@ -127,7 +138,12 @@ class Consumer(messaging.Consumer): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() def __iter__(self): diff --git a/kombu/compression.py b/kombu/compression.py index d9438539..f98c971b 100644 --- a/kombu/compression.py +++ b/kombu/compression.py @@ -1,5 +1,7 @@ """Compression utilities.""" +from __future__ import annotations + import zlib from kombu.utils.encoding import ensure_bytes diff --git a/kombu/connection.py b/kombu/connection.py index a63154f5..0c9779b5 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -1,11 +1,14 @@ """Client (Connection).""" +from __future__ import annotations + import os import socket -from collections import OrderedDict +import sys from contextlib import contextmanager from itertools import count, cycle from operator import itemgetter +from typing import TYPE_CHECKING, Any try: from ssl import CERT_NONE @@ -14,6 +17,7 @@ except ImportError: # pragma: no cover CERT_NONE = None ssl_available = False + # jython breaks on relative import for .exceptions for some reason # (Issue #112) from kombu import exceptions @@ -26,6 +30,16 @@ from .utils.functional import dictfilter, lazy, retry_over_time, shufflecycle from .utils.objects import cached_property from .utils.url import as_url, maybe_sanitize_url, parse_url, quote, urlparse +if TYPE_CHECKING: + from kombu.transport.virtual import Channel + + if sys.version_info < (3, 10): + from typing_extensions import TypeGuard + else: + from typing import TypeGuard + + from types import TracebackType + __all__ = ('Connection', 'ConnectionPool', 'ChannelPool') logger = get_logger(__name__) @@ -412,7 +426,7 @@ class Connection: callback (Callable): Optional callback that is called for every internal iteration (1 s). timeout (int): Maximum amount of time in seconds to spend - waiting for connection + attempting to connect, total over all retries. """ if self.connected: return self._connection @@ -468,7 +482,7 @@ class Connection: def ensure(self, obj, fun, errback=None, max_retries=None, interval_start=1, interval_step=1, interval_max=1, - on_revive=None): + on_revive=None, retry_errors=None): """Ensure operation completes. Regardless of any channel/connection errors occurring. @@ -497,6 +511,9 @@ class Connection: each retry. on_revive (Callable): Optional callback called whenever revival completes successfully + retry_errors (tuple): Optional list of errors to retry on + regardless of the connection state. Must provide max_retries + if this is specified. Examples: >>> from kombu import Connection, Producer @@ -511,6 +528,15 @@ class Connection: ... errback=errback, max_retries=3) >>> publish({'hello': 'world'}, routing_key='dest') """ + if retry_errors is None: + retry_errors = tuple() + elif max_retries is None: + # If the retry_errors is specified, but max_retries is not, + # this could lead into an infinite loop potentially. + raise ValueError( + "max_retries must be specified if retry_errors is specified" + ) + def _ensured(*args, **kwargs): got_connection = 0 conn_errors = self.recoverable_connection_errors @@ -522,6 +548,11 @@ class Connection: for retries in count(0): # for infinity try: return fun(*args, **kwargs) + except retry_errors as exc: + if max_retries is not None and retries >= max_retries: + raise + self._debug('ensure retry policy error: %r', + exc, exc_info=1) except conn_errors as exc: if got_connection and not has_modern_errors: # transport can not distinguish between @@ -529,7 +560,7 @@ class Connection: # the error if it persists after a new connection # was successfully established. raise - if max_retries is not None and retries > max_retries: + if max_retries is not None and retries >= max_retries: raise self._debug('ensure connection error: %r', exc, exc_info=1) @@ -626,7 +657,7 @@ class Connection: transport_cls, transport_cls) D = self.transport.default_connection_params - if not self.hostname: + if not self.hostname and D.get('hostname'): logger.warning( "No hostname was supplied. " f"Reverting to default '{D.get('hostname')}'") @@ -658,7 +689,7 @@ class Connection: def info(self): """Get connection info.""" - return OrderedDict(self._info()) + return dict(self._info()) def __eqhash__(self): return HashedSeq(self.transport_cls, self.hostname, self.userid, @@ -829,7 +860,12 @@ class Connection: def __enter__(self): return self - def __exit__(self, *args): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.release() @property @@ -837,7 +873,7 @@ class Connection: return self.transport.qos_semantics_matches_spec(self.connection) def _extract_failover_opts(self): - conn_opts = {} + conn_opts = {'timeout': self.connect_timeout} transport_opts = self.transport_options if transport_opts: if 'max_retries' in transport_opts: @@ -848,6 +884,9 @@ class Connection: conn_opts['interval_step'] = transport_opts['interval_step'] if 'interval_max' in transport_opts: conn_opts['interval_max'] = transport_opts['interval_max'] + if 'connect_retries_timeout' in transport_opts: + conn_opts['timeout'] = \ + transport_opts['connect_retries_timeout'] return conn_opts @property @@ -880,7 +919,7 @@ class Connection: return self._connection @property - def default_channel(self): + def default_channel(self) -> Channel: """Default channel. Created upon access and closed when the connection is closed. @@ -932,7 +971,7 @@ class Connection: but where the connection must be closed and re-established first. """ try: - return self.transport.recoverable_connection_errors + return self.get_transport_cls().recoverable_connection_errors except AttributeError: # There were no such classification before, # and all errors were assumed to be recoverable, @@ -948,19 +987,19 @@ class Connection: recovered from without re-establishing the connection. """ try: - return self.transport.recoverable_channel_errors + return self.get_transport_cls().recoverable_channel_errors except AttributeError: return () @cached_property def connection_errors(self): """List of exceptions that may be raised by the connection.""" - return self.transport.connection_errors + return self.get_transport_cls().connection_errors @cached_property def channel_errors(self): """List of exceptions that may be raised by the channel.""" - return self.transport.channel_errors + return self.get_transport_cls().channel_errors @property def supports_heartbeats(self): @@ -1043,7 +1082,7 @@ class ChannelPool(Resource): return channel -def maybe_channel(channel): +def maybe_channel(channel: Channel | Connection) -> Channel: """Get channel from object. Return the default channel if argument is a connection instance, @@ -1054,5 +1093,5 @@ def maybe_channel(channel): return channel -def is_connection(obj): +def is_connection(obj: Any) -> TypeGuard[Connection]: return isinstance(obj, Connection) diff --git a/kombu/entity.py b/kombu/entity.py index a89fabb9..2329e748 100644 --- a/kombu/entity.py +++ b/kombu/entity.py @@ -1,5 +1,7 @@ """Exchange and Queue declarations.""" +from __future__ import annotations + import numbers from .abstract import MaybeChannelBound, Object diff --git a/kombu/exceptions.py b/kombu/exceptions.py index f2501437..825baa12 100644 --- a/kombu/exceptions.py +++ b/kombu/exceptions.py @@ -1,9 +1,16 @@ """Exceptions.""" +from __future__ import annotations + from socket import timeout as TimeoutError +from types import TracebackType +from typing import TYPE_CHECKING, TypeVar from amqp import ChannelError, ConnectionError, ResourceError +if TYPE_CHECKING: + from kombu.asynchronous.http import Response + __all__ = ( 'reraise', 'KombuError', 'OperationalError', 'NotBoundError', 'MessageStateError', 'TimeoutError', @@ -14,8 +21,14 @@ __all__ = ( 'InconsistencyError', ) +BaseExceptionType = TypeVar('BaseExceptionType', bound=BaseException) + -def reraise(tp, value, tb=None): +def reraise( + tp: type[BaseExceptionType], + value: BaseExceptionType, + tb: TracebackType | None = None +) -> BaseExceptionType: """Reraise exception.""" if value.__traceback__ is not tb: raise value.with_traceback(tb) @@ -84,11 +97,16 @@ class InconsistencyError(ConnectionError): class HttpError(Exception): """HTTP Client Error.""" - def __init__(self, code, message=None, response=None): + def __init__( + self, + code: int, + message: str | None = None, + response: Response | None = None + ) -> None: self.code = code self.message = message self.response = response super().__init__(code, message, response) - def __str__(self): + def __str__(self) -> str: return 'HTTP {0.code}: {0.message}'.format(self) diff --git a/kombu/log.py b/kombu/log.py index de77e7f3..ed8d0a50 100644 --- a/kombu/log.py +++ b/kombu/log.py @@ -1,5 +1,7 @@ """Logging Utilities.""" +from __future__ import annotations + import logging import numbers import os diff --git a/kombu/matcher.py b/kombu/matcher.py index 7dcab8cd..a4d71bb1 100644 --- a/kombu/matcher.py +++ b/kombu/matcher.py @@ -1,11 +1,16 @@ """Pattern matching registry.""" +from __future__ import annotations + from fnmatch import fnmatch from re import match as rematch +from typing import Callable, cast from .utils.compat import entrypoints from .utils.encoding import bytes_to_str +MatcherFunction = Callable[[str, str], bool] + class MatcherNotInstalled(Exception): """Matcher not installed/found.""" @@ -17,15 +22,15 @@ class MatcherRegistry: MatcherNotInstalled = MatcherNotInstalled matcher_pattern_first = ["pcre", ] - def __init__(self): - self._matchers = {} - self._default_matcher = None + def __init__(self) -> None: + self._matchers: dict[str, MatcherFunction] = {} + self._default_matcher: MatcherFunction | None = None - def register(self, name, matcher): + def register(self, name: str, matcher: MatcherFunction) -> None: """Add matcher by name to the registry.""" self._matchers[name] = matcher - def unregister(self, name): + def unregister(self, name: str) -> None: """Remove matcher by name from the registry.""" try: self._matchers.pop(name) @@ -34,7 +39,7 @@ class MatcherRegistry: f'No matcher installed for {name}' ) - def _set_default_matcher(self, name): + def _set_default_matcher(self, name: str) -> None: """Set the default matching method. :param name: The name of the registered matching method. @@ -51,7 +56,13 @@ class MatcherRegistry: f'No matcher installed for {name}' ) - def match(self, data, pattern, matcher=None, matcher_kwargs=None): + def match( + self, + data: bytes, + pattern: bytes, + matcher: str | None = None, + matcher_kwargs: dict[str, str] | None = None + ) -> bool: """Call the matcher.""" if matcher and not self._matchers.get(matcher): raise self.MatcherNotInstalled( @@ -97,7 +108,7 @@ match = registry.match .. function:: register(name, matcher): Register a new matching method. - :param name: A convience name for the mathing method. + :param name: A convenient name for the mathing method. :param matcher: A method that will be passed data and pattern. """ register = registry.register @@ -111,14 +122,14 @@ register = registry.register unregister = registry.unregister -def register_glob(): +def register_glob() -> None: """Register glob into default registry.""" registry.register('glob', fnmatch) -def register_pcre(): +def register_pcre() -> None: """Register pcre into default registry.""" - registry.register('pcre', rematch) + registry.register('pcre', cast(MatcherFunction, rematch)) # Register the base matching methods. diff --git a/kombu/message.py b/kombu/message.py index bcc90d1a..f2af1686 100644 --- a/kombu/message.py +++ b/kombu/message.py @@ -1,5 +1,7 @@ """Message class.""" +from __future__ import annotations + import sys from .compression import decompress diff --git a/kombu/messaging.py b/kombu/messaging.py index 0bed52c5..2b600224 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -1,6 +1,9 @@ """Sending and receiving messages.""" +from __future__ import annotations + from itertools import count +from typing import TYPE_CHECKING from .common import maybe_declare from .compression import compress @@ -10,6 +13,9 @@ from .exceptions import ContentDisallowed from .serialization import dumps, prepare_accept_content from .utils.functional import ChannelPromise, maybe_list +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Exchange', 'Queue', 'Producer', 'Consumer') @@ -236,7 +242,12 @@ class Producer: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.release() def release(self): @@ -435,7 +446,12 @@ class Consumer: self.consume() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: if self.channel and self.channel.connection: conn_errors = self.channel.connection.client.connection_errors if not isinstance(exc_val, conn_errors): diff --git a/kombu/mixins.py b/kombu/mixins.py index b87e4b92..f1b3c1c9 100644 --- a/kombu/mixins.py +++ b/kombu/mixins.py @@ -1,5 +1,7 @@ """Mixins.""" +from __future__ import annotations + import socket from contextlib import contextmanager from functools import partial diff --git a/kombu/pidbox.py b/kombu/pidbox.py index 7649736a..ee639b3c 100644 --- a/kombu/pidbox.py +++ b/kombu/pidbox.py @@ -1,5 +1,7 @@ """Generic process mailbox.""" +from __future__ import annotations + import socket import warnings from collections import defaultdict, deque diff --git a/kombu/pools.py b/kombu/pools.py index 373bc06c..106be183 100644 --- a/kombu/pools.py +++ b/kombu/pools.py @@ -1,5 +1,7 @@ """Public resource pools.""" +from __future__ import annotations + import os from itertools import chain diff --git a/kombu/resource.py b/kombu/resource.py index e3617dc4..53ba1145 100644 --- a/kombu/resource.py +++ b/kombu/resource.py @@ -1,14 +1,20 @@ """Generic resource pool implementation.""" +from __future__ import annotations + import os from collections import deque from queue import Empty from queue import LifoQueue as _LifoQueue +from typing import TYPE_CHECKING from . import exceptions from .utils.compat import register_after_fork from .utils.functional import lazy +if TYPE_CHECKING: + from types import TracebackType + def _after_fork_cleanup_resource(resource): try: @@ -191,7 +197,12 @@ class Resource: def __enter__(self): pass - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: type, + exc_val: Exception, + exc_tb: TracebackType + ) -> None: pass resource = self._resource diff --git a/kombu/serialization.py b/kombu/serialization.py index 58c28717..5cddeb0b 100644 --- a/kombu/serialization.py +++ b/kombu/serialization.py @@ -1,5 +1,7 @@ """Serialization utilities.""" +from __future__ import annotations + import codecs import os import pickle @@ -382,18 +384,6 @@ register_msgpack() # Default serializer is 'json' registry._set_default_serializer('json') - -_setupfuns = { - 'json': register_json, - 'pickle': register_pickle, - 'yaml': register_yaml, - 'msgpack': register_msgpack, - 'application/json': register_json, - 'application/x-yaml': register_yaml, - 'application/x-python-serialize': register_pickle, - 'application/x-msgpack': register_msgpack, -} - NOTSET = object() diff --git a/kombu/simple.py b/kombu/simple.py index eee037be..a33e5f9e 100644 --- a/kombu/simple.py +++ b/kombu/simple.py @@ -1,13 +1,19 @@ """Simple messaging interface.""" +from __future__ import annotations + import socket from collections import deque from queue import Empty from time import monotonic +from typing import TYPE_CHECKING from . import entity, messaging from .connection import maybe_channel +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('SimpleQueue', 'SimpleBuffer') @@ -18,7 +24,12 @@ class SimpleBase: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() def __init__(self, channel, producer, consumer, no_ack=False): diff --git a/kombu/transport/SLMQ.py b/kombu/transport/SLMQ.py index 750f67bd..50efca72 100644 --- a/kombu/transport/SLMQ.py +++ b/kombu/transport/SLMQ.py @@ -18,6 +18,8 @@ Transport Options *Unreviewed* """ +from __future__ import annotations + import os import socket import string diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index fb6d3780..ac199aa1 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -59,8 +59,8 @@ exist in AWS) you can tell this transport about them as follows: 'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional 'backoff_tasks': ['svc.tasks.tasks.task1'] # optional }, - 'queue-2': { - 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb', + 'queue-2.fifo': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo', 'access_key_id': 'c', 'secret_access_key': 'd', 'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional @@ -71,6 +71,9 @@ exist in AWS) you can tell this transport about them as follows: 'sts_token_timeout': 900 # optional } +Note that FIFO and standard queues must be named accordingly (the name of +a FIFO queue must end with the .fifo suffix). + backoff_policy & backoff_tasks are optional arguments. These arguments automatically change the message visibility timeout, in order to have different times between specific task retries. This would apply after @@ -119,6 +122,8 @@ Features """ # noqa: E501 +from __future__ import annotations + import base64 import socket import string @@ -167,6 +172,10 @@ class UndefinedQueueException(Exception): """Predefined queues are being used and an undefined queue was used.""" +class InvalidQueueException(Exception): + """Predefined queues are being used and configuration is not valid.""" + + class QoS(virtual.QoS): """Quality of Service guarantees implementation for SQS.""" @@ -208,8 +217,8 @@ class QoS(virtual.QoS): VisibilityTimeout=policy_value ) - @staticmethod - def extract_task_name_and_number_of_retries(message): + def extract_task_name_and_number_of_retries(self, delivery_tag): + message = self._delivered[delivery_tag] message_headers = message.headers task_name = message_headers['task'] number_of_retries = int( @@ -237,6 +246,7 @@ class Channel(virtual.Channel): if boto3 is None: raise ImportError('boto3 is not installed') super().__init__(*args, **kwargs) + self._validate_predifined_queues() # SQS blows up if you try to create a new queue when one already # exists but with a different visibility_timeout. This prepopulates @@ -246,6 +256,26 @@ class Channel(virtual.Channel): self.hub = kwargs.get('hub') or get_event_loop() + def _validate_predifined_queues(self): + """Check that standard and FIFO queues are named properly. + + AWS requires FIFO queues to have a name + that ends with the .fifo suffix. + """ + for queue_name, q in self.predefined_queues.items(): + fifo_url = q['url'].endswith('.fifo') + fifo_name = queue_name.endswith('.fifo') + if fifo_url and not fifo_name: + raise InvalidQueueException( + "Queue with url '{}' must have a name " + "ending with .fifo".format(q['url']) + ) + elif not fifo_url and fifo_name: + raise InvalidQueueException( + "Queue with name '{}' is not a FIFO queue: " + "'{}'".format(queue_name, q['url']) + ) + def _update_queue_cache(self, queue_name_prefix): if self.predefined_queues: for queue_name, q in self.predefined_queues.items(): @@ -367,20 +397,28 @@ class Channel(virtual.Channel): def _put(self, queue, message, **kwargs): """Put message onto queue.""" q_url = self._new_queue(queue) - kwargs = {'QueueUrl': q_url, - 'MessageBody': AsyncMessage().encode(dumps(message))} - if queue.endswith('.fifo'): - if 'MessageGroupId' in message['properties']: - kwargs['MessageGroupId'] = \ - message['properties']['MessageGroupId'] - else: - kwargs['MessageGroupId'] = 'default' - if 'MessageDeduplicationId' in message['properties']: - kwargs['MessageDeduplicationId'] = \ - message['properties']['MessageDeduplicationId'] - else: - kwargs['MessageDeduplicationId'] = str(uuid.uuid4()) + if self.sqs_base64_encoding: + body = AsyncMessage().encode(dumps(message)) + else: + body = dumps(message) + kwargs = {'QueueUrl': q_url, 'MessageBody': body} + if 'properties' in message: + if queue.endswith('.fifo'): + if 'MessageGroupId' in message['properties']: + kwargs['MessageGroupId'] = \ + message['properties']['MessageGroupId'] + else: + kwargs['MessageGroupId'] = 'default' + if 'MessageDeduplicationId' in message['properties']: + kwargs['MessageDeduplicationId'] = \ + message['properties']['MessageDeduplicationId'] + else: + kwargs['MessageDeduplicationId'] = str(uuid.uuid4()) + else: + if "DelaySeconds" in message['properties']: + kwargs['DelaySeconds'] = \ + message['properties']['DelaySeconds'] c = self.sqs(queue=self.canonical_queue_name(queue)) if message.get('redelivered'): c.change_message_visibility( @@ -392,22 +430,19 @@ class Channel(virtual.Channel): c.send_message(**kwargs) @staticmethod - def __b64_encoded(byte_string): + def _optional_b64_decode(byte_string): try: - return base64.b64encode( - base64.b64decode(byte_string) - ) == byte_string + data = base64.b64decode(byte_string) + if base64.b64encode(data) == byte_string: + return data + # else the base64 module found some embedded base64 content + # that should be ignored. except Exception: # pylint: disable=broad-except - return False - - def _message_to_python(self, message, queue_name, queue): - body = message['Body'].encode() - try: - if self.__b64_encoded(body): - body = base64.b64decode(body) - except TypeError: pass + return byte_string + def _message_to_python(self, message, queue_name, queue): + body = self._optional_b64_decode(message['Body'].encode()) payload = loads(bytes_to_str(body)) if queue_name in self._noack_queues: queue = self._new_queue(queue_name) @@ -809,6 +844,10 @@ class Channel(virtual.Channel): return self.transport_options.get('wait_time_seconds', self.default_wait_time_seconds) + @cached_property + def sqs_base64_encoding(self): + return self.transport_options.get('sqs_base64_encoding', True) + class Transport(virtual.Transport): """SQS Transport. diff --git a/kombu/transport/__init__.py b/kombu/transport/__init__.py index 5fb5047b..8a217691 100644 --- a/kombu/transport/__init__.py +++ b/kombu/transport/__init__.py @@ -1,10 +1,12 @@ """Built-in transports.""" +from __future__ import annotations + from kombu.utils.compat import _detect_environment from kombu.utils.imports import symbol_by_name -def supports_librabbitmq(): +def supports_librabbitmq() -> bool | None: """Return true if :pypi:`librabbitmq` can be used.""" if _detect_environment() == 'default': try: @@ -13,6 +15,7 @@ def supports_librabbitmq(): pass else: # pragma: no cover return True + return None TRANSPORT_ALIASES = { @@ -20,6 +23,7 @@ TRANSPORT_ALIASES = { 'amqps': 'kombu.transport.pyamqp:SSLTransport', 'pyamqp': 'kombu.transport.pyamqp:Transport', 'librabbitmq': 'kombu.transport.librabbitmq:Transport', + 'confluentkafka': 'kombu.transport.confluentkafka:Transport', 'memory': 'kombu.transport.memory:Transport', 'redis': 'kombu.transport.redis:Transport', 'rediss': 'kombu.transport.redis:Transport', @@ -44,7 +48,7 @@ TRANSPORT_ALIASES = { _transport_cache = {} -def resolve_transport(transport=None): +def resolve_transport(transport: str | None = None) -> str | None: """Get transport by name. Arguments: @@ -71,7 +75,7 @@ def resolve_transport(transport=None): return transport -def get_transport_cls(transport=None): +def get_transport_cls(transport: str | None = None) -> str | None: """Get transport class by name. The transport string is the full path to a transport class, e.g.:: diff --git a/kombu/transport/azureservicebus.py b/kombu/transport/azureservicebus.py index 83237424..e7e2c0cc 100644 --- a/kombu/transport/azureservicebus.py +++ b/kombu/transport/azureservicebus.py @@ -53,9 +53,11 @@ Transport Options * ``retry_backoff_max`` - Azure SDK retry total time. Default ``120`` """ +from __future__ import annotations + import string from queue import Empty -from typing import Any, Dict, Optional, Set, Tuple, Union +from typing import Any, Dict, Set import azure.core.exceptions import azure.servicebus.exceptions @@ -83,10 +85,10 @@ class SendReceive: """Container for Sender and Receiver.""" def __init__(self, - receiver: Optional[ServiceBusReceiver] = None, - sender: Optional[ServiceBusSender] = None): - self.receiver = receiver # type: ServiceBusReceiver - self.sender = sender # type: ServiceBusSender + receiver: ServiceBusReceiver | None = None, + sender: ServiceBusSender | None = None): + self.receiver: ServiceBusReceiver = receiver + self.sender: ServiceBusSender = sender def close(self) -> None: if self.receiver: @@ -100,21 +102,19 @@ class SendReceive: class Channel(virtual.Channel): """Azure Service Bus channel.""" - default_wait_time_seconds = 5 # in seconds - default_peek_lock_seconds = 60 # in seconds (default 60, max 300) + default_wait_time_seconds: int = 5 # in seconds + default_peek_lock_seconds: int = 60 # in seconds (default 60, max 300) # in seconds (is the default from service bus repo) - default_uamqp_keep_alive_interval = 30 + default_uamqp_keep_alive_interval: int = 30 # number of retries (is the default from service bus repo) - default_retry_total = 3 + default_retry_total: int = 3 # exponential backoff factor (is the default from service bus repo) - default_retry_backoff_factor = 0.8 + default_retry_backoff_factor: float = 0.8 # Max time to backoff (is the default from service bus repo) - default_retry_backoff_max = 120 - domain_format = 'kombu%(vhost)s' - _queue_service = None # type: ServiceBusClient - _queue_mgmt_service = None # type: ServiceBusAdministrationClient - _queue_cache = {} # type: Dict[str, SendReceive] - _noack_queues = set() # type: Set[str] + default_retry_backoff_max: int = 120 + domain_format: str = 'kombu%(vhost)s' + _queue_cache: Dict[str, SendReceive] = {} + _noack_queues: Set[str] = set() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -160,8 +160,8 @@ class Channel(virtual.Channel): def _add_queue_to_cache( self, name: str, - receiver: Optional[ServiceBusReceiver] = None, - sender: Optional[ServiceBusSender] = None + receiver: ServiceBusReceiver | None = None, + sender: ServiceBusSender | None = None ) -> SendReceive: if name in self._queue_cache: obj = self._queue_cache[name] @@ -183,7 +183,7 @@ class Channel(virtual.Channel): def _get_asb_receiver( self, queue: str, recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK, - queue_cache_key: Optional[str] = None) -> SendReceive: + queue_cache_key: str | None = None) -> SendReceive: cache_key = queue_cache_key or queue queue_obj = self._queue_cache.get(cache_key, None) if queue_obj is None or queue_obj.receiver is None: @@ -194,7 +194,7 @@ class Channel(virtual.Channel): return queue_obj def entity_name( - self, name: str, table: Optional[Dict[int, int]] = None) -> str: + self, name: str, table: dict[int, int] | None = None) -> str: """Format AMQP queue name into a valid ServiceBus queue name.""" return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE) @@ -227,7 +227,7 @@ class Channel(virtual.Channel): """Delete queue by name.""" queue = self.entity_name(self.queue_name_prefix + queue) - self._queue_mgmt_service.delete_queue(queue) + self.queue_mgmt_service.delete_queue(queue) send_receive_obj = self._queue_cache.pop(queue, None) if send_receive_obj: send_receive_obj.close() @@ -242,8 +242,8 @@ class Channel(virtual.Channel): def _get( self, queue: str, - timeout: Optional[Union[float, int]] = None - ) -> Dict[str, Any]: + timeout: float | int | None = None + ) -> dict[str, Any]: """Try to retrieve a single message off ``queue``.""" # If we're not ack'ing for this queue, just change receive_mode recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE \ @@ -298,7 +298,7 @@ class Channel(virtual.Channel): return props.total_message_count - def _purge(self, queue): + def _purge(self, queue) -> int: """Delete all current messages in a queue.""" # Azure doesn't provide a purge api yet n = 0 @@ -337,24 +337,19 @@ class Channel(virtual.Channel): if self.connection is not None: self.connection.close_channel(self) - @property + @cached_property def queue_service(self) -> ServiceBusClient: - if self._queue_service is None: - self._queue_service = ServiceBusClient.from_connection_string( - self._connection_string, - retry_total=self.retry_total, - retry_backoff_factor=self.retry_backoff_factor, - retry_backoff_max=self.retry_backoff_max - ) - return self._queue_service + return ServiceBusClient.from_connection_string( + self._connection_string, + retry_total=self.retry_total, + retry_backoff_factor=self.retry_backoff_factor, + retry_backoff_max=self.retry_backoff_max + ) - @property + @cached_property def queue_mgmt_service(self) -> ServiceBusAdministrationClient: - if self._queue_mgmt_service is None: - self._queue_mgmt_service = \ - ServiceBusAdministrationClient.from_connection_string( + return ServiceBusAdministrationClient.from_connection_string( self._connection_string) - return self._queue_mgmt_service @property def conninfo(self): @@ -412,7 +407,7 @@ class Transport(virtual.Transport): can_parse_url = True @staticmethod - def parse_uri(uri: str) -> Tuple[str, str, str]: + def parse_uri(uri: str) -> tuple[str, str, str]: # URL like: # azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} # urllib parse does not work as the sas key could contain a slash diff --git a/kombu/transport/azurestoragequeues.py b/kombu/transport/azurestoragequeues.py index e83a20d3..16d22f0b 100644 --- a/kombu/transport/azurestoragequeues.py +++ b/kombu/transport/azurestoragequeues.py @@ -15,14 +15,34 @@ Features Connection String ================= -Connection string has the following format: +Connection string has the following formats: .. code-block:: - azurestoragequeues://:STORAGE_ACCOUNT_ACCESS kEY@STORAGE_ACCOUNT_NAME + azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL> + azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL> + azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL> + azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL> -Note that if the access key for the storage account contains a slash, it will -have to be regenerated before it can be used in the connection URL. +Note that if the access key for the storage account contains a forward slash +(``/``), it will have to be regenerated before it can be used in the connection +URL. + +.. code-block:: + + azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL> + azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL> + +If you wish to use an `Azure Managed Identity` you may use the +``DefaultAzureCredential`` format of the connection string which will use +``DefaultAzureCredential`` class in the azure-identity package. You may want to +read the `azure-identity documentation` for more information on how the +``DefaultAzureCredential`` works. + +.. _azure-identity documentation: +https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python +.. _Azure Managed Identity: +https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview Transport Options ================= @@ -30,8 +50,13 @@ Transport Options * ``queue_name_prefix`` """ +from __future__ import annotations + import string from queue import Empty +from typing import Any, Optional + +from azure.core.exceptions import ResourceExistsError from kombu.utils.encoding import safe_str from kombu.utils.json import dumps, loads @@ -40,9 +65,16 @@ from kombu.utils.objects import cached_property from . import virtual try: - from azure.storage.queue import QueueService + from azure.storage.queue import QueueServiceClient except ImportError: # pragma: no cover - QueueService = None + QueueServiceClient = None + +try: + from azure.identity import (DefaultAzureCredential, + ManagedIdentityCredential) +except ImportError: + DefaultAzureCredential = None + ManagedIdentityCredential = None # Azure storage queues allow only alphanumeric and dashes # so, replace everything with a dash @@ -54,21 +86,25 @@ CHARS_REPLACE_TABLE = { class Channel(virtual.Channel): """Azure Storage Queues channel.""" - domain_format = 'kombu%(vhost)s' - _queue_service = None - _queue_name_cache = {} - no_ack = True - _noack_queues = set() + domain_format: str = 'kombu%(vhost)s' + _queue_service: Optional[QueueServiceClient] = None + _queue_name_cache: dict[Any, Any] = {} + no_ack: bool = True + _noack_queues: set[Any] = set() def __init__(self, *args, **kwargs): - if QueueService is None: + if QueueServiceClient is None: raise ImportError('Azure Storage Queues transport requires the ' 'azure-storage-queue library') super().__init__(*args, **kwargs) - for queue_name in self.queue_service.list_queues(): - self._queue_name_cache[queue_name] = queue_name + self._credential, self._url = Transport.parse_uri( + self.conninfo.hostname + ) + + for queue in self.queue_service.list_queues(): + self._queue_name_cache[queue['name']] = queue def basic_consume(self, queue, no_ack, *args, **kwargs): if no_ack: @@ -77,7 +113,7 @@ class Channel(virtual.Channel): return super().basic_consume(queue, no_ack, *args, **kwargs) - def entity_name(self, name, table=CHARS_REPLACE_TABLE): + def entity_name(self, name, table=CHARS_REPLACE_TABLE) -> str: """Format AMQP queue name into a valid Azure Storage Queue name.""" return str(safe_str(name)).translate(table) @@ -85,61 +121,64 @@ class Channel(virtual.Channel): """Ensure a queue exists.""" queue = self.entity_name(self.queue_name_prefix + queue) try: - return self._queue_name_cache[queue] + q = self._queue_service.get_queue_client( + queue=self._queue_name_cache[queue] + ) except KeyError: - self.queue_service.create_queue(queue, fail_on_exist=False) - q = self._queue_name_cache[queue] = queue - return q + try: + q = self.queue_service.create_queue(queue) + except ResourceExistsError: + q = self._queue_service.get_queue_client(queue=queue) + + self._queue_name_cache[queue] = q.get_queue_properties() + return q def _delete(self, queue, *args, **kwargs): """Delete queue by name.""" queue_name = self.entity_name(queue) self._queue_name_cache.pop(queue_name, None) self.queue_service.delete_queue(queue_name) - super()._delete(queue_name) def _put(self, queue, message, **kwargs): """Put message onto queue.""" q = self._ensure_queue(queue) encoded_message = dumps(message) - self.queue_service.put_message(q, encoded_message) + q.send_message(encoded_message) def _get(self, queue, timeout=None): """Try to retrieve a single message off ``queue``.""" q = self._ensure_queue(queue) - messages = self.queue_service.get_messages(q, num_messages=1, - timeout=timeout) - if not messages: + messages = q.receive_messages(messages_per_page=1, timeout=timeout) + try: + message = next(messages) + except StopIteration: raise Empty() - message = messages[0] - raw_content = self.queue_service.decode_function(message.content) - content = loads(raw_content) + content = loads(message.content) - self.queue_service.delete_message(q, message.id, message.pop_receipt) + q.delete_message(message=message) return content def _size(self, queue): """Return the number of messages in a queue.""" q = self._ensure_queue(queue) - metadata = self.queue_service.get_queue_metadata(q) - return metadata.approximate_message_count + return q.get_queue_properties().approximate_message_count def _purge(self, queue): """Delete all current messages in a queue.""" q = self._ensure_queue(queue) - n = self._size(q) - self.queue_service.clear_messages(q) + n = self._size(q.queue_name) + q.clear_messages() return n @property - def queue_service(self): + def queue_service(self) -> QueueServiceClient: if self._queue_service is None: - self._queue_service = QueueService( - account_name=self.conninfo.hostname, - account_key=self.conninfo.password) + self._queue_service = QueueServiceClient( + account_url=self._url, credential=self._credential + ) return self._queue_service @@ -152,7 +191,7 @@ class Channel(virtual.Channel): return self.connection.client.transport_options @cached_property - def queue_name_prefix(self): + def queue_name_prefix(self) -> str: return self.transport_options.get('queue_name_prefix', '') @@ -161,5 +200,64 @@ class Transport(virtual.Transport): Channel = Channel - polling_interval = 1 - default_port = None + polling_interval: int = 1 + default_port: Optional[int] = None + can_parse_url: bool = True + + @staticmethod + def parse_uri(uri: str) -> tuple[str | dict, str]: + # URL like: + # azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL> + # azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL> + # azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL> + # azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL> + + # urllib parse does not work as the sas key could contain a slash + # e.g.: azurestoragequeues://some/key@someurl + + try: + # > 'some/key@url' + uri = uri.replace('azurestoragequeues://', '') + # > 'some/key', 'url' + credential, url = uri.rsplit('@', 1) + + if "DefaultAzureCredential".lower() == credential.lower(): + if DefaultAzureCredential is None: + raise ImportError('Azure Storage Queues transport with a ' + 'DefaultAzureCredential requires the ' + 'azure-identity library') + credential = DefaultAzureCredential() + elif "ManagedIdentityCredential".lower() == credential.lower(): + if ManagedIdentityCredential is None: + raise ImportError('Azure Storage Queues transport with a ' + 'ManagedIdentityCredential requires the ' + 'azure-identity library') + credential = ManagedIdentityCredential() + elif "devstoreaccount1" in url and ".core.windows.net" not in url: + # parse credential as a dict if Azurite is being used + credential = { + "account_name": "devstoreaccount1", + "account_key": credential, + } + + # Validate parameters + assert all([credential, url]) + except Exception: + raise ValueError( + 'Need a URI like ' + 'azurestoragequeues://{SAS or access key}@{URL}, ' + 'azurestoragequeues://DefaultAzureCredential@{URL}, ' + ', or ' + 'azurestoragequeues://ManagedIdentityCredential@{URL}' + ) + + return credential, url + + @classmethod + def as_uri( + cls, uri: str, include_password: bool = False, mask: str = "**" + ) -> str: + credential, url = cls.parse_uri(uri) + return "azurestoragequeues://{}@{}".format( + credential if include_password else mask, url + ) diff --git a/kombu/transport/base.py b/kombu/transport/base.py index 3083acf4..ec4c0aca 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -2,8 +2,11 @@ # flake8: noqa +from __future__ import annotations + import errno import socket +from typing import TYPE_CHECKING from amqp.exceptions import RecoverableConnectionError @@ -13,6 +16,9 @@ from kombu.utils.functional import dictfilter from kombu.utils.objects import cached_property from kombu.utils.time import maybe_s_to_ms +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Message', 'StdChannel', 'Management', 'Transport') RABBITMQ_QUEUE_ARGUMENTS = { @@ -100,7 +106,12 @@ class StdChannel: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() diff --git a/kombu/transport/confluentkafka.py b/kombu/transport/confluentkafka.py new file mode 100644 index 00000000..5332a310 --- /dev/null +++ b/kombu/transport/confluentkafka.py @@ -0,0 +1,379 @@ +"""confluent-kafka transport module for Kombu. + +Kafka transport using confluent-kafka library. + +**References** + +- http://docs.confluent.io/current/clients/confluent-kafka-python + +**Limitations** + +The confluent-kafka transport does not support PyPy environment. + +Features +======== +* Type: Virtual +* Supports Direct: Yes +* Supports Topic: Yes +* Supports Fanout: No +* Supports Priority: No +* Supports TTL: No + +Connection String +================= +Connection string has the following format: + +.. code-block:: + + confluentkafka://[USER:PASSWORD@]KAFKA_ADDRESS[:PORT] + +Transport Options +================= +* ``connection_wait_time_seconds`` - Time in seconds to wait for connection + to succeed. Default ``5`` +* ``wait_time_seconds`` - Time in seconds to wait to receive messages. + Default ``5`` +* ``security_protocol`` - Protocol used to communicate with broker. + Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for + an explanation of valid values. Default ``plaintext`` +* ``sasl_mechanism`` - SASL mechanism to use for authentication. + Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for + an explanation of valid values. +* ``num_partitions`` - Number of partitions to create. Default ``1`` +* ``replication_factor`` - Replication factor of partitions. Default ``1`` +* ``topic_config`` - Topic configuration. Must be a dict whose key-value pairs + correspond with attributes in the + http://kafka.apache.org/documentation.html#topicconfigs. +* ``kafka_common_config`` - Configuration applied to producer, consumer and + admin client. Must be a dict whose key-value pairs correspond with attributes + in the https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +* ``kafka_producer_config`` - Producer configuration. Must be a dict whose + key-value pairs correspond with attributes in the + https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +* ``kafka_consumer_config`` - Consumer configuration. Must be a dict whose + key-value pairs correspond with attributes in the + https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +* ``kafka_admin_config`` - Admin client configuration. Must be a dict whose + key-value pairs correspond with attributes in the + https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +""" + +from __future__ import annotations + +from queue import Empty + +from kombu.transport import virtual +from kombu.utils import cached_property +from kombu.utils.encoding import str_to_bytes +from kombu.utils.json import dumps, loads + +try: + import confluent_kafka + from confluent_kafka import Consumer, Producer, TopicPartition + from confluent_kafka.admin import AdminClient, NewTopic + + KAFKA_CONNECTION_ERRORS = () + KAFKA_CHANNEL_ERRORS = () + +except ImportError: + confluent_kafka = None + KAFKA_CONNECTION_ERRORS = KAFKA_CHANNEL_ERRORS = () + +from kombu.log import get_logger + +logger = get_logger(__name__) + +DEFAULT_PORT = 9092 + + +class NoBrokersAvailable(confluent_kafka.KafkaException): + """Kafka broker is not available exception.""" + + retriable = True + + +class Message(virtual.Message): + """Message object.""" + + def __init__(self, payload, channel=None, **kwargs): + self.topic = payload.get('topic') + super().__init__(payload, channel=channel, **kwargs) + + +class QoS(virtual.QoS): + """Quality of Service guarantees.""" + + _not_yet_acked = {} + + def can_consume(self): + """Return true if the channel can be consumed from. + + :returns: True, if this QoS object can accept a message. + :rtype: bool + """ + return not self.prefetch_count or len(self._not_yet_acked) < self \ + .prefetch_count + + def can_consume_max_estimate(self): + if self.prefetch_count: + return self.prefetch_count - len(self._not_yet_acked) + else: + return 1 + + def append(self, message, delivery_tag): + self._not_yet_acked[delivery_tag] = message + + def get(self, delivery_tag): + return self._not_yet_acked[delivery_tag] + + def ack(self, delivery_tag): + if delivery_tag not in self._not_yet_acked: + return + message = self._not_yet_acked.pop(delivery_tag) + consumer = self.channel._get_consumer(message.topic) + consumer.commit() + + def reject(self, delivery_tag, requeue=False): + """Reject a message by delivery tag. + + If requeue is True, then the last consumed message is reverted so + it'll be refetched on the next attempt. + If False, that message is consumed and ignored. + """ + if requeue: + message = self._not_yet_acked.pop(delivery_tag) + consumer = self.channel._get_consumer(message.topic) + for assignment in consumer.assignment(): + topic_partition = TopicPartition(message.topic, + assignment.partition) + [committed_offset] = consumer.committed([topic_partition]) + consumer.seek(committed_offset) + else: + self.ack(delivery_tag) + + def restore_unacked_once(self, stderr=None): + pass + + +class Channel(virtual.Channel): + """Kafka Channel.""" + + QoS = QoS + Message = Message + + default_wait_time_seconds = 5 + default_connection_wait_time_seconds = 5 + _client = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._kafka_consumers = {} + self._kafka_producers = {} + + self._client = self._open() + + def sanitize_queue_name(self, queue): + """Need to sanitize the name, celery sometimes pushes in @ signs.""" + return str(queue).replace('@', '') + + def _get_producer(self, queue): + """Create/get a producer instance for the given topic/queue.""" + queue = self.sanitize_queue_name(queue) + producer = self._kafka_producers.get(queue, None) + if producer is None: + producer = Producer({ + **self.common_config, + **(self.options.get('kafka_producer_config') or {}), + }) + self._kafka_producers[queue] = producer + + return producer + + def _get_consumer(self, queue): + """Create/get a consumer instance for the given topic/queue.""" + queue = self.sanitize_queue_name(queue) + consumer = self._kafka_consumers.get(queue, None) + if consumer is None: + consumer = Consumer({ + 'group.id': f'{queue}-consumer-group', + 'auto.offset.reset': 'earliest', + 'enable.auto.commit': False, + **self.common_config, + **(self.options.get('kafka_consumer_config') or {}), + }) + consumer.subscribe([queue]) + self._kafka_consumers[queue] = consumer + + return consumer + + def _put(self, queue, message, **kwargs): + """Put a message on the topic/queue.""" + queue = self.sanitize_queue_name(queue) + producer = self._get_producer(queue) + producer.produce(queue, str_to_bytes(dumps(message))) + producer.flush() + + def _get(self, queue, **kwargs): + """Get a message from the topic/queue.""" + queue = self.sanitize_queue_name(queue) + consumer = self._get_consumer(queue) + message = None + + try: + message = consumer.poll(self.wait_time_seconds) + except StopIteration: + pass + + if not message: + raise Empty() + + error = message.error() + if error: + logger.error(error) + raise Empty() + + return {**loads(message.value()), 'topic': message.topic()} + + def _delete(self, queue, *args, **kwargs): + """Delete a queue/topic.""" + queue = self.sanitize_queue_name(queue) + self._kafka_consumers[queue].close() + self._kafka_consumers.pop(queue) + self.client.delete_topics([queue]) + + def _size(self, queue): + """Get the number of pending messages in the topic/queue.""" + queue = self.sanitize_queue_name(queue) + + consumer = self._kafka_consumers.get(queue, None) + if consumer is None: + return 0 + + size = 0 + for assignment in consumer.assignment(): + topic_partition = TopicPartition(queue, assignment.partition) + (_, end_offset) = consumer.get_watermark_offsets(topic_partition) + [committed_offset] = consumer.committed([topic_partition]) + size += end_offset - committed_offset.offset + return size + + def _new_queue(self, queue, **kwargs): + """Create a new topic if it does not exist.""" + queue = self.sanitize_queue_name(queue) + if queue in self.client.list_topics().topics: + return + + topic = NewTopic( + queue, + num_partitions=self.options.get('num_partitions', 1), + replication_factor=self.options.get('replication_factor', 1), + config=self.options.get('topic_config', {}) + ) + self.client.create_topics(new_topics=[topic]) + + def _has_queue(self, queue, **kwargs): + """Check if a topic already exists.""" + queue = self.sanitize_queue_name(queue) + return queue in self.client.list_topics().topics + + def _open(self): + client = AdminClient({ + **self.common_config, + **(self.options.get('kafka_admin_config') or {}), + }) + + try: + # seems to be the only way to check connection + client.list_topics(timeout=self.wait_time_seconds) + except confluent_kafka.KafkaException as e: + raise NoBrokersAvailable(e) + + return client + + @property + def client(self): + if self._client is None: + self._client = self._open() + return self._client + + @property + def options(self): + return self.connection.client.transport_options + + @property + def conninfo(self): + return self.connection.client + + @cached_property + def wait_time_seconds(self): + return self.options.get( + 'wait_time_seconds', self.default_wait_time_seconds + ) + + @cached_property + def connection_wait_time_seconds(self): + return self.options.get( + 'connection_wait_time_seconds', + self.default_connection_wait_time_seconds, + ) + + @cached_property + def common_config(self): + conninfo = self.connection.client + config = { + 'bootstrap.servers': + f'{conninfo.hostname}:{int(conninfo.port) or DEFAULT_PORT}', + } + security_protocol = self.options.get('security_protocol', 'plaintext') + if security_protocol.lower() != 'plaintext': + config.update({ + 'security.protocol': security_protocol, + 'sasl.username': conninfo.userid, + 'sasl.password': conninfo.password, + 'sasl.mechanism': self.options.get('sasl_mechanism'), + }) + + config.update(self.options.get('kafka_common_config') or {}) + return config + + def close(self): + super().close() + self._kafka_producers = {} + + for consumer in self._kafka_consumers.values(): + consumer.close() + + self._kafka_consumers = {} + + +class Transport(virtual.Transport): + """Kafka Transport.""" + + def as_uri(self, uri: str, include_password=False, mask='**') -> str: + pass + + Channel = Channel + + default_port = DEFAULT_PORT + + driver_type = 'kafka' + driver_name = 'confluentkafka' + + recoverable_connection_errors = ( + NoBrokersAvailable, + ) + + def __init__(self, client, **kwargs): + if confluent_kafka is None: + raise ImportError('The confluent-kafka library is not installed') + super().__init__(client, **kwargs) + + def driver_version(self): + return confluent_kafka.__version__ + + def establish_connection(self): + return super().establish_connection() + + def close_connection(self, connection): + return super().close_connection(connection) diff --git a/kombu/transport/consul.py b/kombu/transport/consul.py index ea275c95..7ace52f6 100644 --- a/kombu/transport/consul.py +++ b/kombu/transport/consul.py @@ -27,6 +27,8 @@ Connection string has the following format: """ +from __future__ import annotations + import socket import uuid from collections import defaultdict @@ -276,24 +278,25 @@ class Transport(virtual.Transport): driver_type = 'consul' driver_name = 'consul' - def __init__(self, *args, **kwargs): - if consul is None: - raise ImportError('Missing python-consul library') - - super().__init__(*args, **kwargs) - - self.connection_errors = ( + if consul: + connection_errors = ( virtual.Transport.connection_errors + ( consul.ConsulException, consul.base.ConsulException ) ) - self.channel_errors = ( + channel_errors = ( virtual.Transport.channel_errors + ( consul.ConsulException, consul.base.ConsulException ) ) + def __init__(self, *args, **kwargs): + if consul is None: + raise ImportError('Missing python-consul library') + + super().__init__(*args, **kwargs) + def verify_connection(self, connection): port = connection.client.port or self.default_port host = connection.client.hostname or DEFAULT_HOST diff --git a/kombu/transport/etcd.py b/kombu/transport/etcd.py index 4d0b0364..2ab85841 100644 --- a/kombu/transport/etcd.py +++ b/kombu/transport/etcd.py @@ -24,6 +24,8 @@ Connection string has the following format: """ +from __future__ import annotations + import os import socket from collections import defaultdict @@ -242,6 +244,15 @@ class Transport(virtual.Transport): implements = virtual.Transport.implements.extend( exchange_type=frozenset(['direct'])) + if etcd: + connection_errors = ( + virtual.Transport.connection_errors + (etcd.EtcdException, ) + ) + + channel_errors = ( + virtual.Transport.channel_errors + (etcd.EtcdException, ) + ) + def __init__(self, *args, **kwargs): """Create a new instance of etcd.Transport.""" if etcd is None: @@ -249,14 +260,6 @@ class Transport(virtual.Transport): super().__init__(*args, **kwargs) - self.connection_errors = ( - virtual.Transport.connection_errors + (etcd.EtcdException, ) - ) - - self.channel_errors = ( - virtual.Transport.channel_errors + (etcd.EtcdException, ) - ) - def verify_connection(self, connection): """Verify the connection works.""" port = connection.client.port or self.default_port diff --git a/kombu/transport/filesystem.py b/kombu/transport/filesystem.py index d66c42d6..9d2b3581 100644 --- a/kombu/transport/filesystem.py +++ b/kombu/transport/filesystem.py @@ -65,7 +65,7 @@ Features * Type: Virtual * Supports Direct: Yes * Supports Topic: Yes -* Supports Fanout: No +* Supports Fanout: Yes * Supports Priority: No * Supports TTL: No @@ -86,22 +86,26 @@ Transport Options * ``store_processed`` - if set to True, all processed messages are backed up to ``processed_folder``. * ``processed_folder`` - directory where are backed up processed files. +* ``control_folder`` - directory where are exchange-queue table stored. """ +from __future__ import annotations + import os import shutil import tempfile import uuid +from collections import namedtuple +from pathlib import Path from queue import Empty from time import monotonic from kombu.exceptions import ChannelError +from kombu.transport import virtual from kombu.utils.encoding import bytes_to_str, str_to_bytes from kombu.utils.json import dumps, loads from kombu.utils.objects import cached_property -from . import virtual - VERSION = (1, 0, 0) __version__ = '.'.join(map(str, VERSION)) @@ -128,10 +132,11 @@ if os.name == 'nt': hfile = win32file._get_osfhandle(file.fileno()) win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped) + elif os.name == 'posix': import fcntl - from fcntl import LOCK_EX, LOCK_NB, LOCK_SH # noqa + from fcntl import LOCK_EX, LOCK_SH def lock(file, flags): """Create file lock.""" @@ -140,14 +145,66 @@ elif os.name == 'posix': def unlock(file): """Remove file lock.""" fcntl.flock(file.fileno(), fcntl.LOCK_UN) + + else: raise RuntimeError( 'Filesystem plugin only defined for NT and POSIX platforms') +exchange_queue_t = namedtuple("exchange_queue_t", + ["routing_key", "pattern", "queue"]) + + class Channel(virtual.Channel): """Filesystem Channel.""" + supports_fanout = True + + def get_table(self, exchange): + file = self.control_folder / f"{exchange}.exchange" + try: + f_obj = file.open("r") + try: + lock(f_obj, LOCK_SH) + exchange_table = loads(bytes_to_str(f_obj.read())) + return [exchange_queue_t(*q) for q in exchange_table] + finally: + unlock(f_obj) + f_obj.close() + except FileNotFoundError: + return [] + except OSError: + raise ChannelError(f"Cannot open {file}") + + def _queue_bind(self, exchange, routing_key, pattern, queue): + file = self.control_folder / f"{exchange}.exchange" + self.control_folder.mkdir(exist_ok=True) + queue_val = exchange_queue_t(routing_key or "", pattern or "", + queue or "") + try: + if file.exists(): + f_obj = file.open("rb+", buffering=0) + lock(f_obj, LOCK_EX) + exchange_table = loads(bytes_to_str(f_obj.read())) + queues = [exchange_queue_t(*q) for q in exchange_table] + if queue_val not in queues: + queues.insert(0, queue_val) + f_obj.seek(0) + f_obj.write(str_to_bytes(dumps(queues))) + else: + f_obj = file.open("wb", buffering=0) + lock(f_obj, LOCK_EX) + queues = [queue_val] + f_obj.write(str_to_bytes(dumps(queues))) + finally: + unlock(f_obj) + f_obj.close() + + def _put_fanout(self, exchange, payload, routing_key, **kwargs): + for q in self.get_table(exchange): + self._put(q.queue, payload, **kwargs) + def _put(self, queue, payload, **kwargs): """Put `message` onto `queue`.""" filename = '{}_{}.{}.msg'.format(int(round(monotonic() * 1000)), @@ -155,7 +212,7 @@ class Channel(virtual.Channel): filename = os.path.join(self.data_folder_out, filename) try: - f = open(filename, 'wb') + f = open(filename, 'wb', buffering=0) lock(f, LOCK_EX) f.write(str_to_bytes(dumps(payload))) except OSError: @@ -187,7 +244,8 @@ class Channel(virtual.Channel): shutil.move(os.path.join(self.data_folder_in, filename), processed_folder) except OSError: - pass # file could be locked, or removed in meantime so ignore + # file could be locked, or removed in meantime so ignore + continue filename = os.path.join(processed_folder, filename) try: @@ -266,10 +324,19 @@ class Channel(virtual.Channel): def processed_folder(self): return self.transport_options.get('processed_folder', 'processed') + @property + def control_folder(self): + return Path(self.transport_options.get('control_folder', 'control')) + class Transport(virtual.Transport): """Filesystem Transport.""" + implements = virtual.Transport.implements.extend( + asynchronous=False, + exchange_type=frozenset(['direct', 'topic', 'fanout']) + ) + Channel = Channel # filesystem backend state is global. global_state = virtual.BrokerState() diff --git a/kombu/transport/librabbitmq.py b/kombu/transport/librabbitmq.py index dec50ccf..37015b18 100644 --- a/kombu/transport/librabbitmq.py +++ b/kombu/transport/librabbitmq.py @@ -3,6 +3,8 @@ .. _`librabbitmq`: https://pypi.org/project/librabbitmq/ """ +from __future__ import annotations + import os import socket import warnings diff --git a/kombu/transport/memory.py b/kombu/transport/memory.py index 3073d1cf..9bfaff8d 100644 --- a/kombu/transport/memory.py +++ b/kombu/transport/memory.py @@ -22,6 +22,8 @@ Connection string is in the following format: """ +from __future__ import annotations + from collections import defaultdict from queue import Queue diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py index db758c18..b923f5f4 100644 --- a/kombu/transport/mongodb.py +++ b/kombu/transport/mongodb.py @@ -33,6 +33,8 @@ Transport Options * ``calc_queue_size``, """ +from __future__ import annotations + import datetime from queue import Empty @@ -63,11 +65,10 @@ class BroadcastCursor: def __init__(self, cursor): self._cursor = cursor - self.purge(rewind=False) def get_size(self): - return self._cursor.count() - self._offset + return self._cursor.collection.count_documents({}) - self._offset def close(self): self._cursor.close() @@ -77,7 +78,7 @@ class BroadcastCursor: self._cursor.rewind() # Fast forward the cursor past old events - self._offset = self._cursor.count() + self._offset = self._cursor.collection.count_documents({}) self._cursor = self._cursor.skip(self._offset) def __iter__(self): @@ -149,11 +150,17 @@ class Channel(virtual.Channel): def _new_queue(self, queue, **kwargs): if self.ttl: - self.queues.update( + self.queues.update_one( {'_id': queue}, - {'_id': queue, - 'options': kwargs, - 'expire_at': self._get_expire(kwargs, 'x-expires')}, + { + '$set': { + '_id': queue, + 'options': kwargs, + 'expire_at': self._get_queue_expire( + kwargs, 'x-expires' + ), + }, + }, upsert=True) def _get(self, queue): @@ -163,10 +170,9 @@ class Channel(virtual.Channel): except StopIteration: msg = None else: - msg = self.messages.find_and_modify( - query={'queue': queue}, + msg = self.messages.find_one_and_delete( + {'queue': queue}, sort=[('priority', pymongo.ASCENDING)], - remove=True, ) if self.ttl: @@ -186,7 +192,7 @@ class Channel(virtual.Channel): if queue in self._fanout_queues: return self._get_broadcast_cursor(queue).get_size() - return self.messages.find({'queue': queue}).count() + return self.messages.count_documents({'queue': queue}) def _put(self, queue, message, **kwargs): data = { @@ -196,13 +202,18 @@ class Channel(virtual.Channel): } if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-message-ttl') + data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl') + msg_expire = self._get_message_expire(message) + if msg_expire is not None and ( + data['expire_at'] is None or msg_expire < data['expire_at'] + ): + data['expire_at'] = msg_expire - self.messages.insert(data) + self.messages.insert_one(data) def _put_fanout(self, exchange, message, routing_key, **kwargs): - self.broadcast.insert({'payload': dumps(message), - 'queue': exchange}) + self.broadcast.insert_one({'payload': dumps(message), + 'queue': exchange}) def _purge(self, queue): size = self._size(queue) @@ -241,9 +252,9 @@ class Channel(virtual.Channel): data = lookup.copy() if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-expires') + data['expire_at'] = self._get_queue_expire(queue, 'x-expires') - self.routing.update(lookup, data, upsert=True) + self.routing.update_one(lookup, {'$set': data}, upsert=True) def queue_delete(self, queue, **kwargs): self.routing.remove({'queue': queue}) @@ -346,7 +357,7 @@ class Channel(virtual.Channel): def _create_broadcast(self, database): """Create capped collection for broadcast messages.""" - if self.broadcast_collection in database.collection_names(): + if self.broadcast_collection in database.list_collection_names(): return database.create_collection(self.broadcast_collection, @@ -356,20 +367,20 @@ class Channel(virtual.Channel): def _ensure_indexes(self, database): """Ensure indexes on collections.""" messages = database[self.messages_collection] - messages.ensure_index( + messages.create_index( [('queue', 1), ('priority', 1), ('_id', 1)], background=True, ) - database[self.broadcast_collection].ensure_index([('queue', 1)]) + database[self.broadcast_collection].create_index([('queue', 1)]) routing = database[self.routing_collection] - routing.ensure_index([('queue', 1), ('exchange', 1)]) + routing.create_index([('queue', 1), ('exchange', 1)]) if self.ttl: - messages.ensure_index([('expire_at', 1)], expireAfterSeconds=0) - routing.ensure_index([('expire_at', 1)], expireAfterSeconds=0) + messages.create_index([('expire_at', 1)], expireAfterSeconds=0) + routing.create_index([('expire_at', 1)], expireAfterSeconds=0) - database[self.queues_collection].ensure_index( + database[self.queues_collection].create_index( [('expire_at', 1)], expireAfterSeconds=0) def _create_client(self): @@ -427,7 +438,12 @@ class Channel(virtual.Channel): ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor) return ret - def _get_expire(self, queue, argument): + def _get_message_expire(self, message): + value = message.get('properties', {}).get('expiration') + if value is not None: + return self.get_now() + datetime.timedelta(milliseconds=int(value)) + + def _get_queue_expire(self, queue, argument): """Get expiration header named `argument` of queue definition. Note: @@ -452,15 +468,15 @@ class Channel(virtual.Channel): def _update_queues_expire(self, queue): """Update expiration field on queues documents.""" - expire_at = self._get_expire(queue, 'x-expires') + expire_at = self._get_queue_expire(queue, 'x-expires') if not expire_at: return - self.routing.update( - {'queue': queue}, {'$set': {'expire_at': expire_at}}, multi=True) - self.queues.update( - {'_id': queue}, {'$set': {'expire_at': expire_at}}, multi=True) + self.routing.update_many( + {'queue': queue}, {'$set': {'expire_at': expire_at}}) + self.queues.update_many( + {'_id': queue}, {'$set': {'expire_at': expire_at}}) def get_now(self): """Return current time in UTC.""" diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py index f230f911..c8fd3c86 100644 --- a/kombu/transport/pyamqp.py +++ b/kombu/transport/pyamqp.py @@ -68,6 +68,8 @@ hostname from broker URL. This is usefull when failover is used to fill """ +from __future__ import annotations + import amqp from kombu.utils.amq_manager import get_manager diff --git a/kombu/transport/pyro.py b/kombu/transport/pyro.py index 833d9792..7b27cb61 100644 --- a/kombu/transport/pyro.py +++ b/kombu/transport/pyro.py @@ -32,6 +32,8 @@ Transport Options """ +from __future__ import annotations + import sys from queue import Empty, Queue diff --git a/kombu/transport/qpid.py b/kombu/transport/qpid.py index b0f8df13..cfd864d8 100644 --- a/kombu/transport/qpid.py +++ b/kombu/transport/qpid.py @@ -86,13 +86,14 @@ Celery, this can be accomplished by setting the *BROKER_TRANSPORT_OPTIONS* Celery option. """ +from __future__ import annotations + import os import select import socket import ssl import sys import uuid -from collections import OrderedDict from gettext import gettext as _ from queue import Empty from time import monotonic @@ -189,7 +190,7 @@ class QoS: def __init__(self, session, prefetch_count=1): self.session = session self.prefetch_count = 1 - self._not_yet_acked = OrderedDict() + self._not_yet_acked = {} def can_consume(self): """Return True if the :class:`Channel` can consume more messages. @@ -229,8 +230,8 @@ class QoS: """Append message to the list of un-ACKed messages. Add a message, referenced by the delivery_tag, for ACKing, - rejecting, or getting later. Messages are saved into an - :class:`collections.OrderedDict` by delivery_tag. + rejecting, or getting later. Messages are saved into a + dict by delivery_tag. :param message: A received message that has not yet been ACKed. :type message: qpid.messaging.Message diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index 103a8466..6cbfbdcf 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -51,6 +51,8 @@ Transport Options * ``priority_steps`` """ +from __future__ import annotations + import functools import numbers import socket @@ -189,6 +191,7 @@ class GlobalKeyPrefixMixin: PREFIXED_SIMPLE_COMMANDS = [ "HDEL", "HGET", + "HLEN", "HSET", "LLEN", "LPUSH", @@ -208,6 +211,7 @@ class GlobalKeyPrefixMixin: "DEL": {"args_start": 0, "args_end": None}, "BRPOP": {"args_start": 0, "args_end": -1}, "EVALSHA": {"args_start": 2, "args_end": 3}, + "WATCH": {"args_start": 0, "args_end": None}, } def _prefix_args(self, args): @@ -216,8 +220,7 @@ class GlobalKeyPrefixMixin: if command in self.PREFIXED_SIMPLE_COMMANDS: args[0] = self.global_keyprefix + str(args[0]) - - if command in self.PREFIXED_COMPLEX_COMMANDS.keys(): + elif command in self.PREFIXED_COMPLEX_COMMANDS: args_start = self.PREFIXED_COMPLEX_COMMANDS[command]["args_start"] args_end = self.PREFIXED_COMPLEX_COMMANDS[command]["args_end"] @@ -267,6 +270,13 @@ class PrefixedStrictRedis(GlobalKeyPrefixMixin, redis.Redis): self.global_keyprefix = kwargs.pop('global_keyprefix', '') redis.Redis.__init__(self, *args, **kwargs) + def pubsub(self, **kwargs): + return PrefixedRedisPubSub( + self.connection_pool, + global_keyprefix=self.global_keyprefix, + **kwargs, + ) + class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline): """Custom Redis pipeline that takes global_keyprefix into consideration. @@ -281,6 +291,58 @@ class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline): redis.client.Pipeline.__init__(self, *args, **kwargs) +class PrefixedRedisPubSub(redis.client.PubSub): + """Redis pubsub client that takes global_keyprefix into consideration.""" + + PUBSUB_COMMANDS = ( + "SUBSCRIBE", + "UNSUBSCRIBE", + "PSUBSCRIBE", + "PUNSUBSCRIBE", + ) + + def __init__(self, *args, **kwargs): + self.global_keyprefix = kwargs.pop('global_keyprefix', '') + super().__init__(*args, **kwargs) + + def _prefix_args(self, args): + args = list(args) + command = args.pop(0) + + if command in self.PUBSUB_COMMANDS: + args = [ + self.global_keyprefix + str(arg) + for arg in args + ] + + return [command, *args] + + def parse_response(self, *args, **kwargs): + """Parse a response from the Redis server. + + Method wraps ``PubSub.parse_response()`` to remove prefixes of keys + returned by redis command. + """ + ret = super().parse_response(*args, **kwargs) + if ret is None: + return ret + + # response formats + # SUBSCRIBE and UNSUBSCRIBE + # -> [message type, channel, message] + # PSUBSCRIBE and PUNSUBSCRIBE + # -> [message type, pattern, channel, message] + message_type, *channels, message = ret + return [ + message_type, + *[channel[len(self.global_keyprefix):] for channel in channels], + message, + ] + + def execute_command(self, *args, **kwargs): + return super().execute_command(*self._prefix_args(args), **kwargs) + + class QoS(virtual.QoS): """Redis Ack Emulation.""" @@ -353,13 +415,17 @@ class QoS(virtual.QoS): pass def restore_by_tag(self, tag, client=None, leftmost=False): - with self.channel.conn_or_acquire(client) as client: - with client.pipeline() as pipe: - p, _, _ = self._remove_from_indices( - tag, pipe.hget(self.unacked_key, tag)).execute() + + def restore_transaction(pipe): + p = pipe.hget(self.unacked_key, tag) + pipe.multi() + self._remove_from_indices(tag, pipe) if p: M, EX, RK = loads(bytes_to_str(p)) # json is unicode - self.channel._do_restore_message(M, EX, RK, client, leftmost) + self.channel._do_restore_message(M, EX, RK, pipe, leftmost) + + with self.channel.conn_or_acquire(client) as client: + client.transaction(restore_transaction, self.unacked_key) @cached_property def unacked_key(self): @@ -709,32 +775,35 @@ class Channel(virtual.Channel): self.connection.cycle._on_connection_disconnect(connection) def _do_restore_message(self, payload, exchange, routing_key, - client=None, leftmost=False): - with self.conn_or_acquire(client) as client: + pipe, leftmost=False): + try: try: - try: - payload['headers']['redelivered'] = True - except KeyError: - pass - for queue in self._lookup(exchange, routing_key): - (client.lpush if leftmost else client.rpush)( - queue, dumps(payload), - ) - except Exception: - crit('Could not restore message: %r', payload, exc_info=True) + payload['headers']['redelivered'] = True + payload['properties']['delivery_info']['redelivered'] = True + except KeyError: + pass + for queue in self._lookup(exchange, routing_key): + (pipe.lpush if leftmost else pipe.rpush)( + queue, dumps(payload), + ) + except Exception: + crit('Could not restore message: %r', payload, exc_info=True) def _restore(self, message, leftmost=False): if not self.ack_emulation: return super()._restore(message) tag = message.delivery_tag - with self.conn_or_acquire() as client: - with client.pipeline() as pipe: - P, _ = pipe.hget(self.unacked_key, tag) \ - .hdel(self.unacked_key, tag) \ - .execute() + + def restore_transaction(pipe): + P = pipe.hget(self.unacked_key, tag) + pipe.multi() + pipe.hdel(self.unacked_key, tag) if P: M, EX, RK = loads(bytes_to_str(P)) # json is unicode - self._do_restore_message(M, EX, RK, client, leftmost) + self._do_restore_message(M, EX, RK, pipe, leftmost) + + with self.conn_or_acquire() as client: + client.transaction(restore_transaction, self.unacked_key) def _restore_at_beginning(self, message): return self._restore(message, leftmost=True) @@ -1116,8 +1185,8 @@ class Channel(virtual.Channel): if asynchronous: class Connection(connection_cls): - def disconnect(self): - super().disconnect() + def disconnect(self, *args): + super().disconnect(*args) channel._on_connection_disconnect(self) connection_cls = Connection @@ -1208,13 +1277,14 @@ class Transport(virtual.Transport): exchange_type=frozenset(['direct', 'topic', 'fanout']) ) + if redis: + connection_errors, channel_errors = get_redis_error_classes() + def __init__(self, *args, **kwargs): if redis is None: raise ImportError('Missing redis library (pip install redis)') super().__init__(*args, **kwargs) - # Get redis-py exceptions. - self.connection_errors, self.channel_errors = self._get_errors() # All channels share the same poller. self.cycle = MultiChannelPoller() @@ -1231,6 +1301,14 @@ class Transport(virtual.Transport): def _on_disconnect(connection): if connection._sock: loop.remove(connection._sock) + + # must have started polling or this will break reconnection + if cycle.fds: + # stop polling in the event loop + try: + loop.on_tick.remove(on_poll_start) + except KeyError: + pass cycle._on_connection_disconnect = _on_disconnect def on_poll_start(): @@ -1251,10 +1329,6 @@ class Transport(virtual.Transport): """Handle AIO event for one of our file descriptors.""" self.cycle.on_readable(fileno) - def _get_errors(self): - """Utility to import redis-py's exceptions at runtime.""" - return get_redis_error_classes() - if sentinel: class SentinelManagedSSLConnection( diff --git a/kombu/transport/sqlalchemy/__init__.py b/kombu/transport/sqlalchemy/__init__.py index 91f87a86..a61c8ea8 100644 --- a/kombu/transport/sqlalchemy/__init__.py +++ b/kombu/transport/sqlalchemy/__init__.py @@ -50,15 +50,13 @@ Transport Options Moreover parameters of :func:`sqlalchemy.create_engine()` function can be passed as transport options. """ -# SQLAlchemy overrides != False to have special meaning and pep8 complains -# flake8: noqa - +from __future__ import annotations import threading from json import dumps, loads from queue import Empty -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text from sqlalchemy.exc import OperationalError from sqlalchemy.orm import sessionmaker @@ -71,6 +69,13 @@ from .models import ModelBase from .models import Queue as QueueBase from .models import class_registry, metadata +# SQLAlchemy overrides != False to have special meaning and pep8 complains +# flake8: noqa + + + + + VERSION = (1, 4, 1) __version__ = '.'.join(map(str, VERSION)) @@ -164,7 +169,7 @@ class Channel(virtual.Channel): def _get(self, queue): obj = self._get_or_create(queue) if self.session.bind.name == 'sqlite': - self.session.execute('BEGIN IMMEDIATE TRANSACTION') + self.session.execute(text('BEGIN IMMEDIATE TRANSACTION')) try: msg = self.session.query(self.message_cls) \ .with_for_update() \ diff --git a/kombu/transport/sqlalchemy/models.py b/kombu/transport/sqlalchemy/models.py index 45863852..edff572a 100644 --- a/kombu/transport/sqlalchemy/models.py +++ b/kombu/transport/sqlalchemy/models.py @@ -1,10 +1,12 @@ """Kombu transport using SQLAlchemy as the message store.""" +from __future__ import annotations + import datetime from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer, Sequence, SmallInteger, String, Text) -from sqlalchemy.orm import relation +from sqlalchemy.orm import relationship from sqlalchemy.schema import MetaData try: @@ -35,7 +37,7 @@ class Queue: @declared_attr def messages(cls): - return relation('Message', backref='queue', lazy='noload') + return relationship('Message', backref='queue', lazy='noload') class Message: diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py index 7ab11772..54e84665 100644 --- a/kombu/transport/virtual/__init__.py +++ b/kombu/transport/virtual/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .base import (AbstractChannel, Base64, BrokerState, Channel, Empty, Management, Message, NotEquivalentError, QoS, Transport, UndeliverableWarning, binding_key_t, queue_binding_t) diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py index 95e539ac..552ebec7 100644 --- a/kombu/transport/virtual/base.py +++ b/kombu/transport/virtual/base.py @@ -3,6 +3,8 @@ Emulates the AMQ API for non-AMQ transports. """ +from __future__ import annotations + import base64 import socket import sys @@ -13,6 +15,7 @@ from itertools import count from multiprocessing.util import Finalize from queue import Empty from time import monotonic, sleep +from typing import TYPE_CHECKING from amqp.protocol import queue_declare_ok_t @@ -26,6 +29,9 @@ from kombu.utils.uuid import uuid from .exchange import STANDARD_EXCHANGE_TYPES +if TYPE_CHECKING: + from types import TracebackType + ARRAY_TYPE_H = 'H' UNDELIVERABLE_FMT = """\ @@ -177,6 +183,8 @@ class QoS: self.channel = channel self.prefetch_count = prefetch_count or 0 + # Standard Python dictionaries do not support setting attributes + # on the object, hence the use of OrderedDict self._delivered = OrderedDict() self._delivered.restored = False self._dirty = set() @@ -462,14 +470,7 @@ class Channel(AbstractChannel, base.StdChannel): typ: cls(self) for typ, cls in self.exchange_types.items() } - try: - self.channel_id = self.connection._avail_channel_ids.pop() - except IndexError: - raise ResourceError( - 'No free channel ids, current={}, channel_max={}'.format( - len(self.connection.channels), - self.connection.channel_max), (20, 10), - ) + self.channel_id = self._get_free_channel_id() topts = self.connection.client.transport_options for opt_name in self.from_transport_options: @@ -727,7 +728,8 @@ class Channel(AbstractChannel, base.StdChannel): message = message.serializable() message['redelivered'] = True for queue in self._lookup( - delivery_info['exchange'], delivery_info['routing_key']): + delivery_info['exchange'], + delivery_info['routing_key']): self._put(queue, message) def _restore_at_beginning(self, message): @@ -804,7 +806,12 @@ class Channel(AbstractChannel, base.StdChannel): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() @property @@ -844,6 +851,22 @@ class Channel(AbstractChannel, base.StdChannel): return (self.max_priority - priority) if reverse else priority + def _get_free_channel_id(self): + # Cast to a set for fast lookups, and keep stored as an array + # for lower memory usage. + used_channel_ids = set(self.connection._used_channel_ids) + + for channel_id in range(1, self.connection.channel_max + 1): + if channel_id not in used_channel_ids: + self.connection._used_channel_ids.append(channel_id) + return channel_id + + raise ResourceError( + 'No free channel ids, current={}, channel_max={}'.format( + len(self.connection.channels), + self.connection.channel_max), (20, 10), + ) + class Management(base.Management): """Base class for the AMQP management API.""" @@ -907,9 +930,7 @@ class Transport(base.Transport): polling_interval = client.transport_options.get('polling_interval') if polling_interval is not None: self.polling_interval = polling_interval - self._avail_channel_ids = array( - ARRAY_TYPE_H, range(self.channel_max, 0, -1), - ) + self._used_channel_ids = array(ARRAY_TYPE_H) def create_channel(self, connection): try: @@ -921,7 +942,11 @@ class Transport(base.Transport): def close_channel(self, channel): try: - self._avail_channel_ids.append(channel.channel_id) + try: + self._used_channel_ids.remove(channel.channel_id) + except ValueError: + # channel id already removed + pass try: self.channels.remove(channel) except ValueError: @@ -934,7 +959,7 @@ class Transport(base.Transport): # this channel is then used as the next requested channel. # (returned by ``create_channel``). self._avail_channels.append(self.create_channel(self)) - return self # for drain events + return self # for drain events def close_connection(self, connection): self.cycle.close() diff --git a/kombu/transport/virtual/exchange.py b/kombu/transport/virtual/exchange.py index c6b6161c..b70544cd 100644 --- a/kombu/transport/virtual/exchange.py +++ b/kombu/transport/virtual/exchange.py @@ -4,6 +4,8 @@ Implementations of the standard exchanges defined by the AMQ protocol (excluding the `headers` exchange). """ +from __future__ import annotations + import re from kombu.utils.text import escape_regex diff --git a/kombu/transport/zookeeper.py b/kombu/transport/zookeeper.py index 1a2ab63c..c72ce2f5 100644 --- a/kombu/transport/zookeeper.py +++ b/kombu/transport/zookeeper.py @@ -42,6 +42,8 @@ Transport Options """ +from __future__ import annotations + import os import socket from queue import Empty diff --git a/kombu/utils/__init__.py b/kombu/utils/__init__.py index 304e2dfa..94bb3cdf 100644 --- a/kombu/utils/__init__.py +++ b/kombu/utils/__init__.py @@ -1,5 +1,7 @@ """DEPRECATED - Import from modules below.""" +from __future__ import annotations + from .collections import EqualityDict from .compat import fileno, maybe_fileno, nested, register_after_fork from .div import emergency_dump_state diff --git a/kombu/utils/amq_manager.py b/kombu/utils/amq_manager.py index 7491bb25..f3e429fd 100644 --- a/kombu/utils/amq_manager.py +++ b/kombu/utils/amq_manager.py @@ -1,6 +1,9 @@ """AMQP Management API utilities.""" +from __future__ import annotations + + def get_manager(client, hostname=None, port=None, userid=None, password=None): """Get pyrabbit manager.""" diff --git a/kombu/utils/collections.py b/kombu/utils/collections.py index 77781047..1a0a6d0d 100644 --- a/kombu/utils/collections.py +++ b/kombu/utils/collections.py @@ -1,6 +1,9 @@ """Custom maps, sequences, etc.""" +from __future__ import annotations + + class HashedSeq(list): """Hashed Sequence. diff --git a/kombu/utils/compat.py b/kombu/utils/compat.py index ffc224c1..e1b22f66 100644 --- a/kombu/utils/compat.py +++ b/kombu/utils/compat.py @@ -1,5 +1,7 @@ """Python Compatibility Utilities.""" +from __future__ import annotations + import numbers import sys from contextlib import contextmanager @@ -77,9 +79,18 @@ def detect_environment(): def entrypoints(namespace): """Return setuptools entrypoints for namespace.""" + if sys.version_info >= (3,10): + entry_points = importlib_metadata.entry_points(group=namespace) + else: + entry_points = importlib_metadata.entry_points() + try: + entry_points = entry_points.get(namespace, []) + except AttributeError: + entry_points = entry_points.select(group=namespace) + return ( (ep, ep.load()) - for ep in importlib_metadata.entry_points().get(namespace, []) + for ep in entry_points ) diff --git a/kombu/utils/debug.py b/kombu/utils/debug.py index acc2d60b..bd20948f 100644 --- a/kombu/utils/debug.py +++ b/kombu/utils/debug.py @@ -1,5 +1,7 @@ """Debugging support.""" +from __future__ import annotations + import logging from vine.utils import wraps diff --git a/kombu/utils/div.py b/kombu/utils/div.py index 45be7f94..439b6639 100644 --- a/kombu/utils/div.py +++ b/kombu/utils/div.py @@ -1,5 +1,7 @@ """Div. Utilities.""" +from __future__ import annotations + import sys from .encoding import default_encode diff --git a/kombu/utils/encoding.py b/kombu/utils/encoding.py index 5f58f0fa..42bf2ce9 100644 --- a/kombu/utils/encoding.py +++ b/kombu/utils/encoding.py @@ -5,6 +5,8 @@ applications without crashing from the infamous :exc:`UnicodeDecodeError` exception. """ +from __future__ import annotations + import sys import traceback diff --git a/kombu/utils/eventio.py b/kombu/utils/eventio.py index 48260a48..f8d89d45 100644 --- a/kombu/utils/eventio.py +++ b/kombu/utils/eventio.py @@ -1,5 +1,7 @@ """Selector Utilities.""" +from __future__ import annotations + import errno import math import select as __select__ diff --git a/kombu/utils/functional.py b/kombu/utils/functional.py index 366a0b99..6beb17d7 100644 --- a/kombu/utils/functional.py +++ b/kombu/utils/functional.py @@ -1,5 +1,7 @@ """Functional Utilities.""" +from __future__ import annotations + import inspect import random import threading diff --git a/kombu/utils/imports.py b/kombu/utils/imports.py index fd4482a8..8752fa1a 100644 --- a/kombu/utils/imports.py +++ b/kombu/utils/imports.py @@ -1,5 +1,7 @@ """Import related utilities.""" +from __future__ import annotations + import importlib import sys diff --git a/kombu/utils/json.py b/kombu/utils/json.py index cedaa793..ec6269e2 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -1,75 +1,75 @@ """JSON Serialization Utilities.""" -import datetime -import decimal -import json as stdjson +from __future__ import annotations + +import base64 +import json import uuid +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any, Callable, TypeVar -try: - from django.utils.functional import Promise as DjangoPromise -except ImportError: # pragma: no cover - class DjangoPromise: - """Dummy object.""" +textual_types = () try: - import json - _json_extra_kwargs = {} - - class _DecodeError(Exception): - pass -except ImportError: # pragma: no cover - import simplejson as json - from simplejson.decoder import JSONDecodeError as _DecodeError - _json_extra_kwargs = { - 'use_decimal': False, - 'namedtuple_as_object': False, - } - + from django.utils.functional import Promise -_encoder_cls = type(json._default_encoder) -_default_encoder = None # ... set to JSONEncoder below. + textual_types += (Promise,) +except ImportError: + pass -class JSONEncoder(_encoder_cls): +class JSONEncoder(json.JSONEncoder): """Kombu custom json encoder.""" - def default(self, o, - dates=(datetime.datetime, datetime.date), - times=(datetime.time,), - textual=(decimal.Decimal, uuid.UUID, DjangoPromise), - isinstance=isinstance, - datetime=datetime.datetime, - text_t=str): - reducer = getattr(o, '__json__', None) + def default(self, o): + reducer = getattr(o, "__json__", None) if reducer is not None: return reducer() - else: - if isinstance(o, dates): - if not isinstance(o, datetime): - o = datetime(o.year, o.month, o.day, 0, 0, 0, 0) - r = o.isoformat() - if r.endswith("+00:00"): - r = r[:-6] + "Z" - return r - elif isinstance(o, times): - return o.isoformat() - elif isinstance(o, textual): - return text_t(o) - return super().default(o) + if isinstance(o, textual_types): + return str(o) + + for t, (marker, encoder) in _encoders.items(): + if isinstance(o, t): + return _as(marker, encoder(o)) + + # Bytes is slightly trickier, so we cannot put them directly + # into _encoders, because we use two formats: bytes, and base64. + if isinstance(o, bytes): + try: + return _as("bytes", o.decode("utf-8")) + except UnicodeDecodeError: + return _as("base64", base64.b64encode(o).decode("utf-8")) -_default_encoder = JSONEncoder + return super().default(o) -def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs): +def _as(t: str, v: Any): + return {"__type__": t, "__value__": v} + + +def dumps( + s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs +): """Serialize object to json string.""" - if not default_kwargs: - default_kwargs = _json_extra_kwargs - return _dumps(s, cls=cls or _default_encoder, - **dict(default_kwargs, **kwargs)) + default_kwargs = default_kwargs or {} + return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs)) + + +def object_hook(o: dict): + """Hook function to perform custom deserialization.""" + if o.keys() == {"__type__", "__value__"}: + decoder = _decoders.get(o["__type__"]) + if decoder: + return decoder(o["__value__"]) + else: + raise ValueError("Unsupported type", type, o) + else: + return o -def loads(s, _loads=json.loads, decode_bytes=True): +def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): """Deserialize json from string.""" # None of the json implementations supports decoding from # a buffer/memoryview, or even reading from a stream @@ -78,14 +78,51 @@ def loads(s, _loads=json.loads, decode_bytes=True): # over. Note that pickle does support buffer/memoryview # </rant> if isinstance(s, memoryview): - s = s.tobytes().decode('utf-8') + s = s.tobytes().decode("utf-8") elif isinstance(s, bytearray): - s = s.decode('utf-8') + s = s.decode("utf-8") elif decode_bytes and isinstance(s, bytes): - s = s.decode('utf-8') - - try: - return _loads(s) - except _DecodeError: - # catch "Unpaired high surrogate" error - return stdjson.loads(s) + s = s.decode("utf-8") + + return _loads(s, object_hook=object_hook) + + +DecoderT = EncoderT = Callable[[Any], Any] +T = TypeVar("T") +EncodedT = TypeVar("EncodedT") + + +def register_type( + t: type[T], + marker: str, + encoder: Callable[[T], EncodedT], + decoder: Callable[[EncodedT], T], +): + """Add support for serializing/deserializing native python type.""" + _encoders[t] = (marker, encoder) + _decoders[marker] = decoder + + +_encoders: dict[type, tuple[str, EncoderT]] = {} +_decoders: dict[str, DecoderT] = { + "bytes": lambda o: o.encode("utf-8"), + "base64": lambda o: base64.b64decode(o.encode("utf-8")), +} + +# NOTE: datetime should be registered before date, +# because datetime is also instance of date. +register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat) +register_type( + date, + "date", + lambda o: o.isoformat(), + lambda o: datetime.fromisoformat(o).date(), +) +register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) +register_type(Decimal, "decimal", str, Decimal) +register_type( + uuid.UUID, + "uuid", + lambda o: {"hex": o.hex, "version": o.version}, + lambda o: uuid.UUID(**o), +) diff --git a/kombu/utils/limits.py b/kombu/utils/limits.py index d82884f5..36d11f1f 100644 --- a/kombu/utils/limits.py +++ b/kombu/utils/limits.py @@ -1,5 +1,7 @@ """Token bucket implementation for rate limiting.""" +from __future__ import annotations + from collections import deque from time import monotonic diff --git a/kombu/utils/objects.py b/kombu/utils/objects.py index 7fef4a2f..eb4dfc2a 100644 --- a/kombu/utils/objects.py +++ b/kombu/utils/objects.py @@ -1,5 +1,7 @@ """Object Utilities.""" +from __future__ import annotations + __all__ = ('cached_property',) try: diff --git a/kombu/utils/scheduling.py b/kombu/utils/scheduling.py index 1875fce4..94286be8 100644 --- a/kombu/utils/scheduling.py +++ b/kombu/utils/scheduling.py @@ -1,5 +1,7 @@ """Scheduling Utilities.""" +from __future__ import annotations + from itertools import count from .imports import symbol_by_name diff --git a/kombu/utils/text.py b/kombu/utils/text.py index 1d5fb9de..fea53347 100644 --- a/kombu/utils/text.py +++ b/kombu/utils/text.py @@ -2,7 +2,10 @@ # flake8: noqa +from __future__ import annotations + from difflib import SequenceMatcher +from typing import Iterable, Iterator from kombu import version_info_t @@ -16,8 +19,7 @@ def escape_regex(p, white=''): for c in p) -def fmatch_iter(needle, haystack, min_ratio=0.6): - # type: (str, Sequence[str], float) -> Iterator[Tuple[float, str]] +def fmatch_iter(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> Iterator[tuple[float, str]]: """Fuzzy match: iteratively. Yields: @@ -29,19 +31,17 @@ def fmatch_iter(needle, haystack, min_ratio=0.6): yield ratio, key -def fmatch_best(needle, haystack, min_ratio=0.6): - # type: (str, Sequence[str], float) -> str +def fmatch_best(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> str | None: """Fuzzy match - Find best match (scalar).""" try: return sorted( fmatch_iter(needle, haystack, min_ratio), reverse=True, )[0][1] except IndexError: - pass + return None -def version_string_as_tuple(s): - # type: (str) -> version_info_t +def version_string_as_tuple(s: str) -> version_info_t: """Convert version string to version info tuple.""" v = _unpack_version(*s.split('.')) # X.Y.3a1 -> (X, Y, 3, 'a1') @@ -53,13 +53,17 @@ def version_string_as_tuple(s): return v -def _unpack_version(major, minor=0, micro=0, releaselevel='', serial=''): - # type: (int, int, int, str, str) -> version_info_t +def _unpack_version( + major: str, + minor: str | int = 0, + micro: str | int = 0, + releaselevel: str = '', + serial: str = '' +) -> version_info_t: return version_info_t(int(major), int(minor), micro, releaselevel, serial) -def _splitmicro(micro, releaselevel='', serial=''): - # type: (int, str, str) -> Tuple[int, str, str] +def _splitmicro(micro: str, releaselevel: str = '', serial: str = '') -> tuple[int, str, str]: for index, char in enumerate(micro): if not char.isdigit(): break diff --git a/kombu/utils/time.py b/kombu/utils/time.py index 863f4017..8228d2be 100644 --- a/kombu/utils/time.py +++ b/kombu/utils/time.py @@ -1,11 +1,9 @@ """Time Utilities.""" -# flake8: noqa - +from __future__ import annotations __all__ = ('maybe_s_to_ms',) -def maybe_s_to_ms(v): - # type: (Optional[Union[int, float]]) -> int +def maybe_s_to_ms(v: int | float | None) -> int | None: """Convert seconds to milliseconds, but return None for None.""" return int(float(v) * 1000.0) if v is not None else v diff --git a/kombu/utils/url.py b/kombu/utils/url.py index de3a9139..f5f47701 100644 --- a/kombu/utils/url.py +++ b/kombu/utils/url.py @@ -2,6 +2,8 @@ # flake8: noqa +from __future__ import annotations + from collections.abc import Mapping from functools import partial from typing import NamedTuple diff --git a/kombu/utils/uuid.py b/kombu/utils/uuid.py index 010b3440..9f77dad9 100644 --- a/kombu/utils/uuid.py +++ b/kombu/utils/uuid.py @@ -1,9 +1,11 @@ """UUID utilities.""" +from __future__ import annotations -from uuid import uuid4 +from typing import Callable +from uuid import UUID, uuid4 -def uuid(_uuid=uuid4): +def uuid(_uuid: Callable[[], UUID] = uuid4) -> str: """Generate unique id in UUID4 format. See Also: |