summaryrefslogtreecommitdiff
path: root/kombu
diff options
context:
space:
mode:
Diffstat (limited to 'kombu')
-rw-r--r--kombu/__init__.py29
-rw-r--r--kombu/abstract.py61
-rw-r--r--kombu/asynchronous/__init__.py2
-rw-r--r--kombu/asynchronous/aws/__init__.py13
-rw-r--r--kombu/asynchronous/aws/connection.py2
-rw-r--r--kombu/asynchronous/aws/ext.py2
-rw-r--r--kombu/asynchronous/aws/sqs/connection.py2
-rw-r--r--kombu/asynchronous/aws/sqs/ext.py2
-rw-r--r--kombu/asynchronous/aws/sqs/message.py2
-rw-r--r--kombu/asynchronous/aws/sqs/queue.py4
-rw-r--r--kombu/asynchronous/debug.py2
-rw-r--r--kombu/asynchronous/http/__init__.py13
-rw-r--r--kombu/asynchronous/http/base.py15
-rw-r--r--kombu/asynchronous/http/curl.py11
-rw-r--r--kombu/asynchronous/hub.py32
-rw-r--r--kombu/asynchronous/semaphore.py56
-rw-r--r--kombu/asynchronous/timer.py13
-rw-r--r--kombu/clocks.py35
-rw-r--r--kombu/common.py2
-rw-r--r--kombu/compat.py20
-rw-r--r--kombu/compression.py2
-rw-r--r--kombu/connection.py69
-rw-r--r--kombu/entity.py2
-rw-r--r--kombu/exceptions.py24
-rw-r--r--kombu/log.py2
-rw-r--r--kombu/matcher.py33
-rw-r--r--kombu/message.py2
-rw-r--r--kombu/messaging.py20
-rw-r--r--kombu/mixins.py2
-rw-r--r--kombu/pidbox.py2
-rw-r--r--kombu/pools.py2
-rw-r--r--kombu/resource.py13
-rw-r--r--kombu/serialization.py14
-rw-r--r--kombu/simple.py13
-rw-r--r--kombu/transport/SLMQ.py2
-rw-r--r--kombu/transport/SQS.py97
-rw-r--r--kombu/transport/__init__.py10
-rw-r--r--kombu/transport/azureservicebus.py73
-rw-r--r--kombu/transport/azurestoragequeues.py176
-rw-r--r--kombu/transport/base.py13
-rw-r--r--kombu/transport/confluentkafka.py379
-rw-r--r--kombu/transport/consul.py19
-rw-r--r--kombu/transport/etcd.py19
-rw-r--r--kombu/transport/filesystem.py79
-rw-r--r--kombu/transport/librabbitmq.py2
-rw-r--r--kombu/transport/memory.py2
-rw-r--r--kombu/transport/mongodb.py76
-rw-r--r--kombu/transport/pyamqp.py2
-rw-r--r--kombu/transport/pyro.py2
-rw-r--r--kombu/transport/qpid.py9
-rw-r--r--kombu/transport/redis.py140
-rw-r--r--kombu/transport/sqlalchemy/__init__.py15
-rw-r--r--kombu/transport/sqlalchemy/models.py6
-rw-r--r--kombu/transport/virtual/__init__.py2
-rw-r--r--kombu/transport/virtual/base.py55
-rw-r--r--kombu/transport/virtual/exchange.py2
-rw-r--r--kombu/transport/zookeeper.py2
-rw-r--r--kombu/utils/__init__.py2
-rw-r--r--kombu/utils/amq_manager.py3
-rw-r--r--kombu/utils/collections.py3
-rw-r--r--kombu/utils/compat.py13
-rw-r--r--kombu/utils/debug.py2
-rw-r--r--kombu/utils/div.py2
-rw-r--r--kombu/utils/encoding.py2
-rw-r--r--kombu/utils/eventio.py2
-rw-r--r--kombu/utils/functional.py2
-rw-r--r--kombu/utils/imports.py2
-rw-r--r--kombu/utils/json.py159
-rw-r--r--kombu/utils/limits.py2
-rw-r--r--kombu/utils/objects.py2
-rw-r--r--kombu/utils/scheduling.py2
-rw-r--r--kombu/utils/text.py26
-rw-r--r--kombu/utils/time.py6
-rw-r--r--kombu/utils/url.py2
-rw-r--r--kombu/utils/uuid.py6
75 files changed, 1500 insertions, 438 deletions
diff --git a/kombu/__init__.py b/kombu/__init__.py
index da4ed466..999cf9da 100644
--- a/kombu/__init__.py
+++ b/kombu/__init__.py
@@ -1,11 +1,14 @@
"""Messaging library for Python."""
+from __future__ import annotations
+
import os
import re
import sys
from collections import namedtuple
+from typing import Any, cast
-__version__ = '5.2.0'
+__version__ = '5.3.0b3'
__author__ = 'Ask Solem'
__contact__ = 'auvipy@gmail.com, ask@celeryproject.org'
__homepage__ = 'https://kombu.readthedocs.io'
@@ -19,12 +22,12 @@ version_info_t = namedtuple('version_info_t', (
# bumpversion can only search for {current_version}
# so we have to parse the version here.
-_temp = re.match(
- r'(\d+)\.(\d+).(\d+)(.+)?', __version__).groups()
+_temp = cast(re.Match, re.match(
+ r'(\d+)\.(\d+).(\d+)(.+)?', __version__)).groups()
VERSION = version_info = version_info_t(
int(_temp[0]), int(_temp[1]), int(_temp[2]), _temp[3] or '', '')
-del(_temp)
-del(re)
+del _temp
+del re
STATICA_HACK = True
globals()['kcah_acitats'[::-1].upper()] = False
@@ -61,15 +64,15 @@ all_by_module = {
}
object_origins = {}
-for module, items in all_by_module.items():
+for _module, items in all_by_module.items():
for item in items:
- object_origins[item] = module
+ object_origins[item] = _module
class module(ModuleType):
"""Customized Python module."""
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
if name in object_origins:
module = __import__(object_origins[name], None, None, [name])
for extra_name in all_by_module[module.__name__]:
@@ -77,7 +80,7 @@ class module(ModuleType):
return getattr(module, name)
return ModuleType.__getattribute__(self, name)
- def __dir__(self):
+ def __dir__(self) -> list[str]:
result = list(new_module.__all__)
result.extend(('__file__', '__path__', '__doc__', '__all__',
'__docformat__', '__name__', '__path__', 'VERSION',
@@ -86,12 +89,6 @@ class module(ModuleType):
return result
-# 2.5 does not define __package__
-try:
- package = __package__
-except NameError: # pragma: no cover
- package = 'kombu'
-
# keep a reference to this module so that it's not garbage collected
old_module = sys.modules[__name__]
@@ -106,7 +103,7 @@ new_module.__dict__.update({
'__contact__': __contact__,
'__homepage__': __homepage__,
'__docformat__': __docformat__,
- '__package__': package,
+ '__package__': __package__,
'version_info_t': version_info_t,
'version_info': version_info,
'VERSION': VERSION
diff --git a/kombu/abstract.py b/kombu/abstract.py
index 38cff010..48a917c9 100644
--- a/kombu/abstract.py
+++ b/kombu/abstract.py
@@ -1,19 +1,35 @@
"""Object utilities."""
+from __future__ import annotations
+
from copy import copy
+from typing import TYPE_CHECKING, Any, Callable, TypeVar
from .connection import maybe_channel
from .exceptions import NotBoundError
from .utils.functional import ChannelPromise
+if TYPE_CHECKING:
+ from kombu.connection import Connection
+ from kombu.transport.virtual import Channel
+
+
__all__ = ('Object', 'MaybeChannelBound')
+_T = TypeVar("_T")
+_ObjectType = TypeVar("_ObjectType", bound="Object")
+_MaybeChannelBoundType = TypeVar(
+ "_MaybeChannelBoundType", bound="MaybeChannelBound"
+)
+
-def unpickle_dict(cls, kwargs):
+def unpickle_dict(
+ cls: type[_ObjectType], kwargs: dict[str, Any]
+) -> _ObjectType:
return cls(**kwargs)
-def _any(v):
+def _any(v: _T) -> _T:
return v
@@ -23,9 +39,9 @@ class Object:
Supports automatic kwargs->attributes handling, and cloning.
"""
- attrs = ()
+ attrs: tuple[tuple[str, Any], ...] = ()
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
for name, type_ in self.attrs:
value = kwargs.get(name)
if value is not None:
@@ -36,8 +52,8 @@ class Object:
except AttributeError:
setattr(self, name, None)
- def as_dict(self, recurse=False):
- def f(obj, type):
+ def as_dict(self, recurse: bool = False) -> dict[str, Any]:
+ def f(obj: Any, type: Callable[[Any], Any]) -> Any:
if recurse and isinstance(obj, Object):
return obj.as_dict(recurse=True)
return type(obj) if type and obj is not None else obj
@@ -45,31 +61,40 @@ class Object:
attr: f(getattr(self, attr), type) for attr, type in self.attrs
}
- def __reduce__(self):
+ def __reduce__(self: _ObjectType) -> tuple[
+ Callable[[type[_ObjectType], dict[str, Any]], _ObjectType],
+ tuple[type[_ObjectType], dict[str, Any]]
+ ]:
return unpickle_dict, (self.__class__, self.as_dict())
- def __copy__(self):
+ def __copy__(self: _ObjectType) -> _ObjectType:
return self.__class__(**self.as_dict())
class MaybeChannelBound(Object):
"""Mixin for classes that can be bound to an AMQP channel."""
- _channel = None
+ _channel: Channel | None = None
_is_bound = False
#: Defines whether maybe_declare can skip declaring this entity twice.
can_cache_declaration = False
- def __call__(self, channel):
+ def __call__(
+ self: _MaybeChannelBoundType, channel: (Channel | Connection)
+ ) -> _MaybeChannelBoundType:
"""`self(channel) -> self.bind(channel)`."""
return self.bind(channel)
- def bind(self, channel):
+ def bind(
+ self: _MaybeChannelBoundType, channel: (Channel | Connection)
+ ) -> _MaybeChannelBoundType:
"""Create copy of the instance that is bound to a channel."""
return copy(self).maybe_bind(channel)
- def maybe_bind(self, channel):
+ def maybe_bind(
+ self: _MaybeChannelBoundType, channel: (Channel | Connection)
+ ) -> _MaybeChannelBoundType:
"""Bind instance to channel if not already bound."""
if not self.is_bound and channel:
self._channel = maybe_channel(channel)
@@ -77,7 +102,7 @@ class MaybeChannelBound(Object):
self._is_bound = True
return self
- def revive(self, channel):
+ def revive(self, channel: Channel) -> None:
"""Revive channel after the connection has been re-established.
Used by :meth:`~kombu.Connection.ensure`.
@@ -87,13 +112,13 @@ class MaybeChannelBound(Object):
self._channel = channel
self.when_bound()
- def when_bound(self):
+ def when_bound(self) -> None:
"""Callback called when the class is bound."""
- def __repr__(self):
+ def __repr__(self) -> str:
return self._repr_entity(type(self).__name__)
- def _repr_entity(self, item=''):
+ def _repr_entity(self, item: str = '') -> str:
item = item or type(self).__name__
if self.is_bound:
return '<{} bound to chan:{}>'.format(
@@ -101,12 +126,12 @@ class MaybeChannelBound(Object):
return f'<unbound {item}>'
@property
- def is_bound(self):
+ def is_bound(self) -> bool:
"""Flag set if the channel is bound."""
return self._is_bound and self._channel is not None
@property
- def channel(self):
+ def channel(self) -> Channel:
"""Current channel if the object is bound."""
channel = self._channel
if channel is None:
diff --git a/kombu/asynchronous/__init__.py b/kombu/asynchronous/__init__.py
index fb264aa5..53060753 100644
--- a/kombu/asynchronous/__init__.py
+++ b/kombu/asynchronous/__init__.py
@@ -1,5 +1,7 @@
"""Event loop."""
+from __future__ import annotations
+
from kombu.utils.eventio import ERR, READ, WRITE
from .hub import Hub, get_event_loop, set_event_loop
diff --git a/kombu/asynchronous/aws/__init__.py b/kombu/asynchronous/aws/__init__.py
index d8423c23..cbeb050f 100644
--- a/kombu/asynchronous/aws/__init__.py
+++ b/kombu/asynchronous/aws/__init__.py
@@ -1,4 +1,15 @@
-def connect_sqs(aws_access_key_id=None, aws_secret_access_key=None, **kwargs):
+from __future__ import annotations
+
+from typing import Any
+
+from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection
+
+
+def connect_sqs(
+ aws_access_key_id: str | None = None,
+ aws_secret_access_key: str | None = None,
+ **kwargs: Any
+) -> AsyncSQSConnection:
"""Return async connection to Amazon SQS."""
from .sqs.connection import AsyncSQSConnection
return AsyncSQSConnection(
diff --git a/kombu/asynchronous/aws/connection.py b/kombu/asynchronous/aws/connection.py
index f3926388..887ab40c 100644
--- a/kombu/asynchronous/aws/connection.py
+++ b/kombu/asynchronous/aws/connection.py
@@ -1,5 +1,7 @@
"""Amazon AWS Connection."""
+from __future__ import annotations
+
from email import message_from_bytes
from email.mime.message import MIMEMessage
diff --git a/kombu/asynchronous/aws/ext.py b/kombu/asynchronous/aws/ext.py
index 2dedc812..1fa4a57e 100644
--- a/kombu/asynchronous/aws/ext.py
+++ b/kombu/asynchronous/aws/ext.py
@@ -1,5 +1,7 @@
"""Amazon boto3 interface."""
+from __future__ import annotations
+
try:
import boto3
from botocore import exceptions
diff --git a/kombu/asynchronous/aws/sqs/connection.py b/kombu/asynchronous/aws/sqs/connection.py
index 9db2523b..20b56344 100644
--- a/kombu/asynchronous/aws/sqs/connection.py
+++ b/kombu/asynchronous/aws/sqs/connection.py
@@ -1,5 +1,7 @@
"""Amazon SQS Connection."""
+from __future__ import annotations
+
from vine import transform
from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection
diff --git a/kombu/asynchronous/aws/sqs/ext.py b/kombu/asynchronous/aws/sqs/ext.py
index f6630936..72268b5d 100644
--- a/kombu/asynchronous/aws/sqs/ext.py
+++ b/kombu/asynchronous/aws/sqs/ext.py
@@ -1,6 +1,8 @@
"""Amazon SQS boto3 interface."""
+from __future__ import annotations
+
try:
import boto3
except ImportError:
diff --git a/kombu/asynchronous/aws/sqs/message.py b/kombu/asynchronous/aws/sqs/message.py
index 9425ff2d..52727bb7 100644
--- a/kombu/asynchronous/aws/sqs/message.py
+++ b/kombu/asynchronous/aws/sqs/message.py
@@ -1,5 +1,7 @@
"""Amazon SQS message implementation."""
+from __future__ import annotations
+
import base64
from kombu.message import Message
diff --git a/kombu/asynchronous/aws/sqs/queue.py b/kombu/asynchronous/aws/sqs/queue.py
index 50b0be55..7ca78f75 100644
--- a/kombu/asynchronous/aws/sqs/queue.py
+++ b/kombu/asynchronous/aws/sqs/queue.py
@@ -1,5 +1,7 @@
"""Amazon SQS queue implementation."""
+from __future__ import annotations
+
from vine import transform
from .message import AsyncMessage
@@ -12,7 +14,7 @@ def list_first(rs):
return rs[0] if len(rs) == 1 else None
-class AsyncQueue():
+class AsyncQueue:
"""Async SQS Queue."""
def __init__(self, connection=None, url=None, message_class=AsyncMessage):
diff --git a/kombu/asynchronous/debug.py b/kombu/asynchronous/debug.py
index 4fabb452..7c1e45c7 100644
--- a/kombu/asynchronous/debug.py
+++ b/kombu/asynchronous/debug.py
@@ -1,5 +1,7 @@
"""Event-loop debugging tools."""
+from __future__ import annotations
+
from kombu.utils.eventio import ERR, READ, WRITE
from kombu.utils.functional import reprcall
diff --git a/kombu/asynchronous/http/__init__.py b/kombu/asynchronous/http/__init__.py
index 1c45ebca..67d8b219 100644
--- a/kombu/asynchronous/http/__init__.py
+++ b/kombu/asynchronous/http/__init__.py
@@ -1,17 +1,24 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
from kombu.asynchronous import get_event_loop
+from kombu.asynchronous.http.base import Headers, Request, Response
+from kombu.asynchronous.hub import Hub
-from .base import Headers, Request, Response
+if TYPE_CHECKING:
+ from kombu.asynchronous.http.curl import CurlClient
__all__ = ('Client', 'Headers', 'Response', 'Request')
-def Client(hub=None, **kwargs):
+def Client(hub: Hub | None = None, **kwargs: int) -> CurlClient:
"""Create new HTTP client."""
from .curl import CurlClient
return CurlClient(hub, **kwargs)
-def get_client(hub=None, **kwargs):
+def get_client(hub: Hub | None = None, **kwargs: int) -> CurlClient:
"""Get or create HTTP client bound to the current event loop."""
hub = hub or get_event_loop()
try:
diff --git a/kombu/asynchronous/http/base.py b/kombu/asynchronous/http/base.py
index e8d5043b..89be531f 100644
--- a/kombu/asynchronous/http/base.py
+++ b/kombu/asynchronous/http/base.py
@@ -1,7 +1,10 @@
"""Base async HTTP client implementation."""
+from __future__ import annotations
+
import sys
from http.client import responses
+from typing import TYPE_CHECKING
from vine import Thenable, maybe_promise, promise
@@ -10,6 +13,9 @@ from kombu.utils.compat import coro
from kombu.utils.encoding import bytes_to_str
from kombu.utils.functional import maybe_list, memoize
+if TYPE_CHECKING:
+ from types import TracebackType
+
__all__ = ('Headers', 'Response', 'Request')
PYPY = hasattr(sys, 'pypy_version_info')
@@ -61,7 +67,7 @@ class Request:
auth_password (str): Password for HTTP authentication.
auth_mode (str): Type of HTTP authentication (``basic`` or ``digest``).
user_agent (str): Custom user agent for this request.
- network_interace (str): Network interface to use for this request.
+ network_interface (str): Network interface to use for this request.
on_ready (Callable): Callback to be called when the response has been
received. Must accept single ``response`` argument.
on_stream (Callable): Optional callback to be called every time body
@@ -253,5 +259,10 @@ class BaseClient:
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.close()
diff --git a/kombu/asynchronous/http/curl.py b/kombu/asynchronous/http/curl.py
index ee70f3c6..6f879fa9 100644
--- a/kombu/asynchronous/http/curl.py
+++ b/kombu/asynchronous/http/curl.py
@@ -1,11 +1,13 @@
"""HTTP Client using pyCurl."""
+from __future__ import annotations
+
from collections import deque
from functools import partial
from io import BytesIO
from time import time
-from kombu.asynchronous.hub import READ, WRITE, get_event_loop
+from kombu.asynchronous.hub import READ, WRITE, Hub, get_event_loop
from kombu.exceptions import HttpError
from kombu.utils.encoding import bytes_to_str
@@ -36,7 +38,7 @@ class CurlClient(BaseClient):
Curl = Curl
- def __init__(self, hub=None, max_clients=10):
+ def __init__(self, hub: Hub | None = None, max_clients: int = 10):
if pycurl is None:
raise ImportError('The curl client requires the pycurl library.')
hub = hub or get_event_loop()
@@ -231,9 +233,6 @@ class CurlClient(BaseClient):
if request.proxy_username:
setopt(_pycurl.PROXYUSERPWD, '{}:{}'.format(
request.proxy_username, request.proxy_password or ''))
- else:
- setopt(_pycurl.PROXY, '')
- curl.unsetopt(_pycurl.PROXYUSERPWD)
setopt(_pycurl.SSL_VERIFYPEER, 1 if request.validate_cert else 0)
setopt(_pycurl.SSL_VERIFYHOST, 2 if request.validate_cert else 0)
@@ -253,7 +252,7 @@ class CurlClient(BaseClient):
setopt(meth, True)
if request.method in ('POST', 'PUT'):
- body = request.body.encode('utf-8') if request.body else bytes()
+ body = request.body.encode('utf-8') if request.body else b''
reqbuffer = BytesIO(body)
setopt(_pycurl.READFUNCTION, reqbuffer.read)
if request.method == 'POST':
diff --git a/kombu/asynchronous/hub.py b/kombu/asynchronous/hub.py
index b1f7e241..e5b1163c 100644
--- a/kombu/asynchronous/hub.py
+++ b/kombu/asynchronous/hub.py
@@ -1,6 +1,9 @@
"""Event loop implementation."""
+from __future__ import annotations
+
import errno
+import threading
from contextlib import contextmanager
from queue import Empty
from time import sleep
@@ -18,7 +21,7 @@ from .timer import Timer
__all__ = ('Hub', 'get_event_loop', 'set_event_loop')
logger = get_logger(__name__)
-_current_loop = None
+_current_loop: Hub | None = None
W_UNKNOWN_EVENT = """\
Received unknown event %r for fd %r, please contact support!\
@@ -38,12 +41,12 @@ def _dummy_context(*args, **kwargs):
yield
-def get_event_loop():
+def get_event_loop() -> Hub | None:
"""Get current event loop object."""
return _current_loop
-def set_event_loop(loop):
+def set_event_loop(loop: Hub | None) -> Hub | None:
"""Set the current event loop object."""
global _current_loop
_current_loop = loop
@@ -78,6 +81,7 @@ class Hub:
self.on_tick = set()
self.on_close = set()
self._ready = set()
+ self._ready_lock = threading.Lock()
self._running = False
self._loop = None
@@ -198,7 +202,8 @@ class Hub:
def call_soon(self, callback, *args):
if not isinstance(callback, Thenable):
callback = promise(callback, args)
- self._ready.add(callback)
+ with self._ready_lock:
+ self._ready.add(callback)
return callback
def call_later(self, delay, callback, *args):
@@ -242,6 +247,12 @@ class Hub:
except (AttributeError, KeyError, OSError):
pass
+ def _pop_ready(self):
+ with self._ready_lock:
+ ready = self._ready
+ self._ready = set()
+ return ready
+
def close(self, *args):
[self._unregister(fd) for fd in self.readers]
self.readers.clear()
@@ -257,8 +268,7 @@ class Hub:
# To avoid infinite loop where one of the callables adds items
# to self._ready (via call_soon or otherwise).
# we create new list with current self._ready
- todos = list(self._ready)
- self._ready = set()
+ todos = self._pop_ready()
for item in todos:
item()
@@ -288,17 +298,17 @@ class Hub:
propagate = self.propagate_errors
while 1:
- todo = self._ready
- self._ready = set()
-
- for tick_callback in on_tick:
- tick_callback()
+ todo = self._pop_ready()
for item in todo:
if item:
item()
poll_timeout = fire_timers(propagate=propagate) if scheduled else 1
+
+ for tick_callback in on_tick:
+ tick_callback()
+
# print('[[[HUB]]]: %s' % (self.repr_active(),))
if readers or writers:
to_consolidate = []
diff --git a/kombu/asynchronous/semaphore.py b/kombu/asynchronous/semaphore.py
index 9fe34a04..07fb8a09 100644
--- a/kombu/asynchronous/semaphore.py
+++ b/kombu/asynchronous/semaphore.py
@@ -1,9 +1,23 @@
"""Semaphores and concurrency primitives."""
+from __future__ import annotations
+import sys
from collections import deque
+from typing import TYPE_CHECKING, Callable, Deque
+
+if sys.version_info < (3, 10):
+ from typing_extensions import ParamSpec
+else:
+ from typing import ParamSpec
+
+if TYPE_CHECKING:
+ from types import TracebackType
+
__all__ = ('DummyLock', 'LaxBoundedSemaphore')
+P = ParamSpec("P")
+
class LaxBoundedSemaphore:
"""Asynchronous Bounded Semaphore.
@@ -12,18 +26,15 @@ class LaxBoundedSemaphore:
range even if released more times than it was acquired.
Example:
- >>> from future import print_statement as printf
- # ^ ignore: just fooling stupid pyflakes
-
>>> x = LaxBoundedSemaphore(2)
- >>> x.acquire(printf, 'HELLO 1')
+ >>> x.acquire(print, 'HELLO 1')
HELLO 1
- >>> x.acquire(printf, 'HELLO 2')
+ >>> x.acquire(print, 'HELLO 2')
HELLO 2
- >>> x.acquire(printf, 'HELLO 3')
+ >>> x.acquire(print, 'HELLO 3')
>>> x._waiters # private, do not access directly
[print, ('HELLO 3',)]
@@ -31,13 +42,18 @@ class LaxBoundedSemaphore:
HELLO 3
"""
- def __init__(self, value):
+ def __init__(self, value: int) -> None:
self.initial_value = self.value = value
- self._waiting = deque()
+ self._waiting: Deque[tuple] = deque()
self._add_waiter = self._waiting.append
self._pop_waiter = self._waiting.popleft
- def acquire(self, callback, *partial_args, **partial_kwargs):
+ def acquire(
+ self,
+ callback: Callable[P, None],
+ *partial_args: P.args,
+ **partial_kwargs: P.kwargs
+ ) -> bool:
"""Acquire semaphore.
This will immediately apply ``callback`` if
@@ -57,7 +73,7 @@ class LaxBoundedSemaphore:
callback(*partial_args, **partial_kwargs)
return True
- def release(self):
+ def release(self) -> None:
"""Release semaphore.
Note:
@@ -71,23 +87,24 @@ class LaxBoundedSemaphore:
else:
waiter(*args, **kwargs)
- def grow(self, n=1):
+ def grow(self, n: int = 1) -> None:
"""Change the size of the semaphore to accept more users."""
self.initial_value += n
self.value += n
- [self.release() for _ in range(n)]
+ for _ in range(n):
+ self.release()
- def shrink(self, n=1):
+ def shrink(self, n: int = 1) -> None:
"""Change the size of the semaphore to accept less users."""
self.initial_value = max(self.initial_value - n, 0)
self.value = max(self.value - n, 0)
- def clear(self):
+ def clear(self) -> None:
"""Reset the semaphore, which also wipes out any waiting callbacks."""
self._waiting.clear()
self.value = self.initial_value
- def __repr__(self):
+ def __repr__(self) -> str:
return '<{} at {:#x} value:{} waiting:{}>'.format(
self.__class__.__name__, id(self), self.value, len(self._waiting),
)
@@ -96,8 +113,13 @@ class LaxBoundedSemaphore:
class DummyLock:
"""Pretending to be a lock."""
- def __enter__(self):
+ def __enter__(self) -> DummyLock:
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
pass
diff --git a/kombu/asynchronous/timer.py b/kombu/asynchronous/timer.py
index 21ad37c1..f6be1346 100644
--- a/kombu/asynchronous/timer.py
+++ b/kombu/asynchronous/timer.py
@@ -1,5 +1,7 @@
"""Timer scheduling Python callbacks."""
+from __future__ import annotations
+
import heapq
import sys
from collections import namedtuple
@@ -7,6 +9,7 @@ from datetime import datetime
from functools import total_ordering
from time import monotonic
from time import time as _time
+from typing import TYPE_CHECKING
from weakref import proxy as weakrefproxy
from vine.utils import wraps
@@ -18,6 +21,9 @@ try:
except ImportError: # pragma: no cover
utc = None
+if TYPE_CHECKING:
+ from types import TracebackType
+
__all__ = ('Entry', 'Timer', 'to_timestamp')
logger = get_logger(__name__)
@@ -101,7 +107,12 @@ class Timer:
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.stop()
def call_at(self, eta, fun, args=(), kwargs=None, priority=0):
diff --git a/kombu/clocks.py b/kombu/clocks.py
index 3c152720..d02e8b32 100644
--- a/kombu/clocks.py
+++ b/kombu/clocks.py
@@ -1,8 +1,11 @@
"""Logical Clocks and Synchronization."""
+from __future__ import annotations
+
from itertools import islice
from operator import itemgetter
from threading import Lock
+from typing import Any
__all__ = ('LamportClock', 'timetuple')
@@ -15,7 +18,7 @@ class timetuple(tuple):
Can be used as part of a heap to keep events ordered.
Arguments:
- clock (int): Event clock value.
+ clock (Optional[int]): Event clock value.
timestamp (float): Event UNIX timestamp value.
id (str): Event host id (e.g. ``hostname:pid``).
obj (Any): Optional obj to associate with this event.
@@ -23,16 +26,18 @@ class timetuple(tuple):
__slots__ = ()
- def __new__(cls, clock, timestamp, id, obj=None):
+ def __new__(
+ cls, clock: int | None, timestamp: float, id: str, obj: Any = None
+ ) -> timetuple:
return tuple.__new__(cls, (clock, timestamp, id, obj))
- def __repr__(self):
+ def __repr__(self) -> str:
return R_CLOCK.format(*self)
- def __getnewargs__(self):
+ def __getnewargs__(self) -> tuple:
return tuple(self)
- def __lt__(self, other):
+ def __lt__(self, other: tuple) -> bool:
# 0: clock 1: timestamp 3: process id
try:
A, B = self[0], other[0]
@@ -45,13 +50,13 @@ class timetuple(tuple):
except IndexError:
return NotImplemented
- def __gt__(self, other):
+ def __gt__(self, other: tuple) -> bool:
return other < self
- def __le__(self, other):
+ def __le__(self, other: tuple) -> bool:
return not other < self
- def __ge__(self, other):
+ def __ge__(self, other: tuple) -> bool:
return not self < other
clock = property(itemgetter(0))
@@ -99,21 +104,23 @@ class LamportClock:
#: The clocks current value.
value = 0
- def __init__(self, initial_value=0, Lock=Lock):
+ def __init__(
+ self, initial_value: int = 0, Lock: type[Lock] = Lock
+ ) -> None:
self.value = initial_value
self.mutex = Lock()
- def adjust(self, other):
+ def adjust(self, other: int) -> int:
with self.mutex:
value = self.value = max(self.value, other) + 1
return value
- def forward(self):
+ def forward(self) -> int:
with self.mutex:
self.value += 1
return self.value
- def sort_heap(self, h):
+ def sort_heap(self, h: list[tuple[int, str]]) -> tuple[int, str]:
"""Sort heap of events.
List of tuples containing at least two elements, representing
@@ -140,8 +147,8 @@ class LamportClock:
# clock values unique, return first item
return h[0]
- def __str__(self):
+ def __str__(self) -> str:
return str(self.value)
- def __repr__(self):
+ def __repr__(self) -> str:
return f'<LamportClock: {self.value}>'
diff --git a/kombu/common.py b/kombu/common.py
index 08bc1aff..c7b2d50a 100644
--- a/kombu/common.py
+++ b/kombu/common.py
@@ -1,5 +1,7 @@
"""Common Utilities."""
+from __future__ import annotations
+
import os
import socket
import threading
diff --git a/kombu/compat.py b/kombu/compat.py
index 1fa3f631..d90aec75 100644
--- a/kombu/compat.py
+++ b/kombu/compat.py
@@ -3,11 +3,17 @@
See https://pypi.org/project/carrot/ for documentation.
"""
+from __future__ import annotations
+
from itertools import count
+from typing import TYPE_CHECKING
from . import messaging
from .entity import Exchange, Queue
+if TYPE_CHECKING:
+ from types import TracebackType
+
__all__ = ('Publisher', 'Consumer')
# XXX compat attribute
@@ -65,7 +71,12 @@ class Publisher(messaging.Producer):
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.close()
@property
@@ -127,7 +138,12 @@ class Consumer(messaging.Consumer):
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.close()
def __iter__(self):
diff --git a/kombu/compression.py b/kombu/compression.py
index d9438539..f98c971b 100644
--- a/kombu/compression.py
+++ b/kombu/compression.py
@@ -1,5 +1,7 @@
"""Compression utilities."""
+from __future__ import annotations
+
import zlib
from kombu.utils.encoding import ensure_bytes
diff --git a/kombu/connection.py b/kombu/connection.py
index a63154f5..0c9779b5 100644
--- a/kombu/connection.py
+++ b/kombu/connection.py
@@ -1,11 +1,14 @@
"""Client (Connection)."""
+from __future__ import annotations
+
import os
import socket
-from collections import OrderedDict
+import sys
from contextlib import contextmanager
from itertools import count, cycle
from operator import itemgetter
+from typing import TYPE_CHECKING, Any
try:
from ssl import CERT_NONE
@@ -14,6 +17,7 @@ except ImportError: # pragma: no cover
CERT_NONE = None
ssl_available = False
+
# jython breaks on relative import for .exceptions for some reason
# (Issue #112)
from kombu import exceptions
@@ -26,6 +30,16 @@ from .utils.functional import dictfilter, lazy, retry_over_time, shufflecycle
from .utils.objects import cached_property
from .utils.url import as_url, maybe_sanitize_url, parse_url, quote, urlparse
+if TYPE_CHECKING:
+ from kombu.transport.virtual import Channel
+
+ if sys.version_info < (3, 10):
+ from typing_extensions import TypeGuard
+ else:
+ from typing import TypeGuard
+
+ from types import TracebackType
+
__all__ = ('Connection', 'ConnectionPool', 'ChannelPool')
logger = get_logger(__name__)
@@ -412,7 +426,7 @@ class Connection:
callback (Callable): Optional callback that is called for every
internal iteration (1 s).
timeout (int): Maximum amount of time in seconds to spend
- waiting for connection
+ attempting to connect, total over all retries.
"""
if self.connected:
return self._connection
@@ -468,7 +482,7 @@ class Connection:
def ensure(self, obj, fun, errback=None, max_retries=None,
interval_start=1, interval_step=1, interval_max=1,
- on_revive=None):
+ on_revive=None, retry_errors=None):
"""Ensure operation completes.
Regardless of any channel/connection errors occurring.
@@ -497,6 +511,9 @@ class Connection:
each retry.
on_revive (Callable): Optional callback called whenever
revival completes successfully
+ retry_errors (tuple): Optional list of errors to retry on
+ regardless of the connection state. Must provide max_retries
+ if this is specified.
Examples:
>>> from kombu import Connection, Producer
@@ -511,6 +528,15 @@ class Connection:
... errback=errback, max_retries=3)
>>> publish({'hello': 'world'}, routing_key='dest')
"""
+ if retry_errors is None:
+ retry_errors = tuple()
+ elif max_retries is None:
+ # If the retry_errors is specified, but max_retries is not,
+ # this could lead into an infinite loop potentially.
+ raise ValueError(
+ "max_retries must be specified if retry_errors is specified"
+ )
+
def _ensured(*args, **kwargs):
got_connection = 0
conn_errors = self.recoverable_connection_errors
@@ -522,6 +548,11 @@ class Connection:
for retries in count(0): # for infinity
try:
return fun(*args, **kwargs)
+ except retry_errors as exc:
+ if max_retries is not None and retries >= max_retries:
+ raise
+ self._debug('ensure retry policy error: %r',
+ exc, exc_info=1)
except conn_errors as exc:
if got_connection and not has_modern_errors:
# transport can not distinguish between
@@ -529,7 +560,7 @@ class Connection:
# the error if it persists after a new connection
# was successfully established.
raise
- if max_retries is not None and retries > max_retries:
+ if max_retries is not None and retries >= max_retries:
raise
self._debug('ensure connection error: %r',
exc, exc_info=1)
@@ -626,7 +657,7 @@ class Connection:
transport_cls, transport_cls)
D = self.transport.default_connection_params
- if not self.hostname:
+ if not self.hostname and D.get('hostname'):
logger.warning(
"No hostname was supplied. "
f"Reverting to default '{D.get('hostname')}'")
@@ -658,7 +689,7 @@ class Connection:
def info(self):
"""Get connection info."""
- return OrderedDict(self._info())
+ return dict(self._info())
def __eqhash__(self):
return HashedSeq(self.transport_cls, self.hostname, self.userid,
@@ -829,7 +860,12 @@ class Connection:
def __enter__(self):
return self
- def __exit__(self, *args):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.release()
@property
@@ -837,7 +873,7 @@ class Connection:
return self.transport.qos_semantics_matches_spec(self.connection)
def _extract_failover_opts(self):
- conn_opts = {}
+ conn_opts = {'timeout': self.connect_timeout}
transport_opts = self.transport_options
if transport_opts:
if 'max_retries' in transport_opts:
@@ -848,6 +884,9 @@ class Connection:
conn_opts['interval_step'] = transport_opts['interval_step']
if 'interval_max' in transport_opts:
conn_opts['interval_max'] = transport_opts['interval_max']
+ if 'connect_retries_timeout' in transport_opts:
+ conn_opts['timeout'] = \
+ transport_opts['connect_retries_timeout']
return conn_opts
@property
@@ -880,7 +919,7 @@ class Connection:
return self._connection
@property
- def default_channel(self):
+ def default_channel(self) -> Channel:
"""Default channel.
Created upon access and closed when the connection is closed.
@@ -932,7 +971,7 @@ class Connection:
but where the connection must be closed and re-established first.
"""
try:
- return self.transport.recoverable_connection_errors
+ return self.get_transport_cls().recoverable_connection_errors
except AttributeError:
# There were no such classification before,
# and all errors were assumed to be recoverable,
@@ -948,19 +987,19 @@ class Connection:
recovered from without re-establishing the connection.
"""
try:
- return self.transport.recoverable_channel_errors
+ return self.get_transport_cls().recoverable_channel_errors
except AttributeError:
return ()
@cached_property
def connection_errors(self):
"""List of exceptions that may be raised by the connection."""
- return self.transport.connection_errors
+ return self.get_transport_cls().connection_errors
@cached_property
def channel_errors(self):
"""List of exceptions that may be raised by the channel."""
- return self.transport.channel_errors
+ return self.get_transport_cls().channel_errors
@property
def supports_heartbeats(self):
@@ -1043,7 +1082,7 @@ class ChannelPool(Resource):
return channel
-def maybe_channel(channel):
+def maybe_channel(channel: Channel | Connection) -> Channel:
"""Get channel from object.
Return the default channel if argument is a connection instance,
@@ -1054,5 +1093,5 @@ def maybe_channel(channel):
return channel
-def is_connection(obj):
+def is_connection(obj: Any) -> TypeGuard[Connection]:
return isinstance(obj, Connection)
diff --git a/kombu/entity.py b/kombu/entity.py
index a89fabb9..2329e748 100644
--- a/kombu/entity.py
+++ b/kombu/entity.py
@@ -1,5 +1,7 @@
"""Exchange and Queue declarations."""
+from __future__ import annotations
+
import numbers
from .abstract import MaybeChannelBound, Object
diff --git a/kombu/exceptions.py b/kombu/exceptions.py
index f2501437..825baa12 100644
--- a/kombu/exceptions.py
+++ b/kombu/exceptions.py
@@ -1,9 +1,16 @@
"""Exceptions."""
+from __future__ import annotations
+
from socket import timeout as TimeoutError
+from types import TracebackType
+from typing import TYPE_CHECKING, TypeVar
from amqp import ChannelError, ConnectionError, ResourceError
+if TYPE_CHECKING:
+ from kombu.asynchronous.http import Response
+
__all__ = (
'reraise', 'KombuError', 'OperationalError',
'NotBoundError', 'MessageStateError', 'TimeoutError',
@@ -14,8 +21,14 @@ __all__ = (
'InconsistencyError',
)
+BaseExceptionType = TypeVar('BaseExceptionType', bound=BaseException)
+
-def reraise(tp, value, tb=None):
+def reraise(
+ tp: type[BaseExceptionType],
+ value: BaseExceptionType,
+ tb: TracebackType | None = None
+) -> BaseExceptionType:
"""Reraise exception."""
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
@@ -84,11 +97,16 @@ class InconsistencyError(ConnectionError):
class HttpError(Exception):
"""HTTP Client Error."""
- def __init__(self, code, message=None, response=None):
+ def __init__(
+ self,
+ code: int,
+ message: str | None = None,
+ response: Response | None = None
+ ) -> None:
self.code = code
self.message = message
self.response = response
super().__init__(code, message, response)
- def __str__(self):
+ def __str__(self) -> str:
return 'HTTP {0.code}: {0.message}'.format(self)
diff --git a/kombu/log.py b/kombu/log.py
index de77e7f3..ed8d0a50 100644
--- a/kombu/log.py
+++ b/kombu/log.py
@@ -1,5 +1,7 @@
"""Logging Utilities."""
+from __future__ import annotations
+
import logging
import numbers
import os
diff --git a/kombu/matcher.py b/kombu/matcher.py
index 7dcab8cd..a4d71bb1 100644
--- a/kombu/matcher.py
+++ b/kombu/matcher.py
@@ -1,11 +1,16 @@
"""Pattern matching registry."""
+from __future__ import annotations
+
from fnmatch import fnmatch
from re import match as rematch
+from typing import Callable, cast
from .utils.compat import entrypoints
from .utils.encoding import bytes_to_str
+MatcherFunction = Callable[[str, str], bool]
+
class MatcherNotInstalled(Exception):
"""Matcher not installed/found."""
@@ -17,15 +22,15 @@ class MatcherRegistry:
MatcherNotInstalled = MatcherNotInstalled
matcher_pattern_first = ["pcre", ]
- def __init__(self):
- self._matchers = {}
- self._default_matcher = None
+ def __init__(self) -> None:
+ self._matchers: dict[str, MatcherFunction] = {}
+ self._default_matcher: MatcherFunction | None = None
- def register(self, name, matcher):
+ def register(self, name: str, matcher: MatcherFunction) -> None:
"""Add matcher by name to the registry."""
self._matchers[name] = matcher
- def unregister(self, name):
+ def unregister(self, name: str) -> None:
"""Remove matcher by name from the registry."""
try:
self._matchers.pop(name)
@@ -34,7 +39,7 @@ class MatcherRegistry:
f'No matcher installed for {name}'
)
- def _set_default_matcher(self, name):
+ def _set_default_matcher(self, name: str) -> None:
"""Set the default matching method.
:param name: The name of the registered matching method.
@@ -51,7 +56,13 @@ class MatcherRegistry:
f'No matcher installed for {name}'
)
- def match(self, data, pattern, matcher=None, matcher_kwargs=None):
+ def match(
+ self,
+ data: bytes,
+ pattern: bytes,
+ matcher: str | None = None,
+ matcher_kwargs: dict[str, str] | None = None
+ ) -> bool:
"""Call the matcher."""
if matcher and not self._matchers.get(matcher):
raise self.MatcherNotInstalled(
@@ -97,7 +108,7 @@ match = registry.match
.. function:: register(name, matcher):
Register a new matching method.
- :param name: A convience name for the mathing method.
+ :param name: A convenient name for the mathing method.
:param matcher: A method that will be passed data and pattern.
"""
register = registry.register
@@ -111,14 +122,14 @@ register = registry.register
unregister = registry.unregister
-def register_glob():
+def register_glob() -> None:
"""Register glob into default registry."""
registry.register('glob', fnmatch)
-def register_pcre():
+def register_pcre() -> None:
"""Register pcre into default registry."""
- registry.register('pcre', rematch)
+ registry.register('pcre', cast(MatcherFunction, rematch))
# Register the base matching methods.
diff --git a/kombu/message.py b/kombu/message.py
index bcc90d1a..f2af1686 100644
--- a/kombu/message.py
+++ b/kombu/message.py
@@ -1,5 +1,7 @@
"""Message class."""
+from __future__ import annotations
+
import sys
from .compression import decompress
diff --git a/kombu/messaging.py b/kombu/messaging.py
index 0bed52c5..2b600224 100644
--- a/kombu/messaging.py
+++ b/kombu/messaging.py
@@ -1,6 +1,9 @@
"""Sending and receiving messages."""
+from __future__ import annotations
+
from itertools import count
+from typing import TYPE_CHECKING
from .common import maybe_declare
from .compression import compress
@@ -10,6 +13,9 @@ from .exceptions import ContentDisallowed
from .serialization import dumps, prepare_accept_content
from .utils.functional import ChannelPromise, maybe_list
+if TYPE_CHECKING:
+ from types import TracebackType
+
__all__ = ('Exchange', 'Queue', 'Producer', 'Consumer')
@@ -236,7 +242,12 @@ class Producer:
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.release()
def release(self):
@@ -435,7 +446,12 @@ class Consumer:
self.consume()
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
if self.channel and self.channel.connection:
conn_errors = self.channel.connection.client.connection_errors
if not isinstance(exc_val, conn_errors):
diff --git a/kombu/mixins.py b/kombu/mixins.py
index b87e4b92..f1b3c1c9 100644
--- a/kombu/mixins.py
+++ b/kombu/mixins.py
@@ -1,5 +1,7 @@
"""Mixins."""
+from __future__ import annotations
+
import socket
from contextlib import contextmanager
from functools import partial
diff --git a/kombu/pidbox.py b/kombu/pidbox.py
index 7649736a..ee639b3c 100644
--- a/kombu/pidbox.py
+++ b/kombu/pidbox.py
@@ -1,5 +1,7 @@
"""Generic process mailbox."""
+from __future__ import annotations
+
import socket
import warnings
from collections import defaultdict, deque
diff --git a/kombu/pools.py b/kombu/pools.py
index 373bc06c..106be183 100644
--- a/kombu/pools.py
+++ b/kombu/pools.py
@@ -1,5 +1,7 @@
"""Public resource pools."""
+from __future__ import annotations
+
import os
from itertools import chain
diff --git a/kombu/resource.py b/kombu/resource.py
index e3617dc4..53ba1145 100644
--- a/kombu/resource.py
+++ b/kombu/resource.py
@@ -1,14 +1,20 @@
"""Generic resource pool implementation."""
+from __future__ import annotations
+
import os
from collections import deque
from queue import Empty
from queue import LifoQueue as _LifoQueue
+from typing import TYPE_CHECKING
from . import exceptions
from .utils.compat import register_after_fork
from .utils.functional import lazy
+if TYPE_CHECKING:
+ from types import TracebackType
+
def _after_fork_cleanup_resource(resource):
try:
@@ -191,7 +197,12 @@ class Resource:
def __enter__(self):
pass
- def __exit__(self, type, value, traceback):
+ def __exit__(
+ self,
+ exc_type: type,
+ exc_val: Exception,
+ exc_tb: TracebackType
+ ) -> None:
pass
resource = self._resource
diff --git a/kombu/serialization.py b/kombu/serialization.py
index 58c28717..5cddeb0b 100644
--- a/kombu/serialization.py
+++ b/kombu/serialization.py
@@ -1,5 +1,7 @@
"""Serialization utilities."""
+from __future__ import annotations
+
import codecs
import os
import pickle
@@ -382,18 +384,6 @@ register_msgpack()
# Default serializer is 'json'
registry._set_default_serializer('json')
-
-_setupfuns = {
- 'json': register_json,
- 'pickle': register_pickle,
- 'yaml': register_yaml,
- 'msgpack': register_msgpack,
- 'application/json': register_json,
- 'application/x-yaml': register_yaml,
- 'application/x-python-serialize': register_pickle,
- 'application/x-msgpack': register_msgpack,
-}
-
NOTSET = object()
diff --git a/kombu/simple.py b/kombu/simple.py
index eee037be..a33e5f9e 100644
--- a/kombu/simple.py
+++ b/kombu/simple.py
@@ -1,13 +1,19 @@
"""Simple messaging interface."""
+from __future__ import annotations
+
import socket
from collections import deque
from queue import Empty
from time import monotonic
+from typing import TYPE_CHECKING
from . import entity, messaging
from .connection import maybe_channel
+if TYPE_CHECKING:
+ from types import TracebackType
+
__all__ = ('SimpleQueue', 'SimpleBuffer')
@@ -18,7 +24,12 @@ class SimpleBase:
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.close()
def __init__(self, channel, producer, consumer, no_ack=False):
diff --git a/kombu/transport/SLMQ.py b/kombu/transport/SLMQ.py
index 750f67bd..50efca72 100644
--- a/kombu/transport/SLMQ.py
+++ b/kombu/transport/SLMQ.py
@@ -18,6 +18,8 @@ Transport Options
*Unreviewed*
"""
+from __future__ import annotations
+
import os
import socket
import string
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py
index fb6d3780..ac199aa1 100644
--- a/kombu/transport/SQS.py
+++ b/kombu/transport/SQS.py
@@ -59,8 +59,8 @@ exist in AWS) you can tell this transport about them as follows:
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
'backoff_tasks': ['svc.tasks.tasks.task1'] # optional
},
- 'queue-2': {
- 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb',
+ 'queue-2.fifo': {
+ 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo',
'access_key_id': 'c',
'secret_access_key': 'd',
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional
@@ -71,6 +71,9 @@ exist in AWS) you can tell this transport about them as follows:
'sts_token_timeout': 900 # optional
}
+Note that FIFO and standard queues must be named accordingly (the name of
+a FIFO queue must end with the .fifo suffix).
+
backoff_policy & backoff_tasks are optional arguments. These arguments
automatically change the message visibility timeout, in order to have
different times between specific task retries. This would apply after
@@ -119,6 +122,8 @@ Features
""" # noqa: E501
+from __future__ import annotations
+
import base64
import socket
import string
@@ -167,6 +172,10 @@ class UndefinedQueueException(Exception):
"""Predefined queues are being used and an undefined queue was used."""
+class InvalidQueueException(Exception):
+ """Predefined queues are being used and configuration is not valid."""
+
+
class QoS(virtual.QoS):
"""Quality of Service guarantees implementation for SQS."""
@@ -208,8 +217,8 @@ class QoS(virtual.QoS):
VisibilityTimeout=policy_value
)
- @staticmethod
- def extract_task_name_and_number_of_retries(message):
+ def extract_task_name_and_number_of_retries(self, delivery_tag):
+ message = self._delivered[delivery_tag]
message_headers = message.headers
task_name = message_headers['task']
number_of_retries = int(
@@ -237,6 +246,7 @@ class Channel(virtual.Channel):
if boto3 is None:
raise ImportError('boto3 is not installed')
super().__init__(*args, **kwargs)
+ self._validate_predifined_queues()
# SQS blows up if you try to create a new queue when one already
# exists but with a different visibility_timeout. This prepopulates
@@ -246,6 +256,26 @@ class Channel(virtual.Channel):
self.hub = kwargs.get('hub') or get_event_loop()
+ def _validate_predifined_queues(self):
+ """Check that standard and FIFO queues are named properly.
+
+ AWS requires FIFO queues to have a name
+ that ends with the .fifo suffix.
+ """
+ for queue_name, q in self.predefined_queues.items():
+ fifo_url = q['url'].endswith('.fifo')
+ fifo_name = queue_name.endswith('.fifo')
+ if fifo_url and not fifo_name:
+ raise InvalidQueueException(
+ "Queue with url '{}' must have a name "
+ "ending with .fifo".format(q['url'])
+ )
+ elif not fifo_url and fifo_name:
+ raise InvalidQueueException(
+ "Queue with name '{}' is not a FIFO queue: "
+ "'{}'".format(queue_name, q['url'])
+ )
+
def _update_queue_cache(self, queue_name_prefix):
if self.predefined_queues:
for queue_name, q in self.predefined_queues.items():
@@ -367,20 +397,28 @@ class Channel(virtual.Channel):
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q_url = self._new_queue(queue)
- kwargs = {'QueueUrl': q_url,
- 'MessageBody': AsyncMessage().encode(dumps(message))}
- if queue.endswith('.fifo'):
- if 'MessageGroupId' in message['properties']:
- kwargs['MessageGroupId'] = \
- message['properties']['MessageGroupId']
- else:
- kwargs['MessageGroupId'] = 'default'
- if 'MessageDeduplicationId' in message['properties']:
- kwargs['MessageDeduplicationId'] = \
- message['properties']['MessageDeduplicationId']
- else:
- kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
+ if self.sqs_base64_encoding:
+ body = AsyncMessage().encode(dumps(message))
+ else:
+ body = dumps(message)
+ kwargs = {'QueueUrl': q_url, 'MessageBody': body}
+ if 'properties' in message:
+ if queue.endswith('.fifo'):
+ if 'MessageGroupId' in message['properties']:
+ kwargs['MessageGroupId'] = \
+ message['properties']['MessageGroupId']
+ else:
+ kwargs['MessageGroupId'] = 'default'
+ if 'MessageDeduplicationId' in message['properties']:
+ kwargs['MessageDeduplicationId'] = \
+ message['properties']['MessageDeduplicationId']
+ else:
+ kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
+ else:
+ if "DelaySeconds" in message['properties']:
+ kwargs['DelaySeconds'] = \
+ message['properties']['DelaySeconds']
c = self.sqs(queue=self.canonical_queue_name(queue))
if message.get('redelivered'):
c.change_message_visibility(
@@ -392,22 +430,19 @@ class Channel(virtual.Channel):
c.send_message(**kwargs)
@staticmethod
- def __b64_encoded(byte_string):
+ def _optional_b64_decode(byte_string):
try:
- return base64.b64encode(
- base64.b64decode(byte_string)
- ) == byte_string
+ data = base64.b64decode(byte_string)
+ if base64.b64encode(data) == byte_string:
+ return data
+ # else the base64 module found some embedded base64 content
+ # that should be ignored.
except Exception: # pylint: disable=broad-except
- return False
-
- def _message_to_python(self, message, queue_name, queue):
- body = message['Body'].encode()
- try:
- if self.__b64_encoded(body):
- body = base64.b64decode(body)
- except TypeError:
pass
+ return byte_string
+ def _message_to_python(self, message, queue_name, queue):
+ body = self._optional_b64_decode(message['Body'].encode())
payload = loads(bytes_to_str(body))
if queue_name in self._noack_queues:
queue = self._new_queue(queue_name)
@@ -809,6 +844,10 @@ class Channel(virtual.Channel):
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
+ @cached_property
+ def sqs_base64_encoding(self):
+ return self.transport_options.get('sqs_base64_encoding', True)
+
class Transport(virtual.Transport):
"""SQS Transport.
diff --git a/kombu/transport/__init__.py b/kombu/transport/__init__.py
index 5fb5047b..8a217691 100644
--- a/kombu/transport/__init__.py
+++ b/kombu/transport/__init__.py
@@ -1,10 +1,12 @@
"""Built-in transports."""
+from __future__ import annotations
+
from kombu.utils.compat import _detect_environment
from kombu.utils.imports import symbol_by_name
-def supports_librabbitmq():
+def supports_librabbitmq() -> bool | None:
"""Return true if :pypi:`librabbitmq` can be used."""
if _detect_environment() == 'default':
try:
@@ -13,6 +15,7 @@ def supports_librabbitmq():
pass
else: # pragma: no cover
return True
+ return None
TRANSPORT_ALIASES = {
@@ -20,6 +23,7 @@ TRANSPORT_ALIASES = {
'amqps': 'kombu.transport.pyamqp:SSLTransport',
'pyamqp': 'kombu.transport.pyamqp:Transport',
'librabbitmq': 'kombu.transport.librabbitmq:Transport',
+ 'confluentkafka': 'kombu.transport.confluentkafka:Transport',
'memory': 'kombu.transport.memory:Transport',
'redis': 'kombu.transport.redis:Transport',
'rediss': 'kombu.transport.redis:Transport',
@@ -44,7 +48,7 @@ TRANSPORT_ALIASES = {
_transport_cache = {}
-def resolve_transport(transport=None):
+def resolve_transport(transport: str | None = None) -> str | None:
"""Get transport by name.
Arguments:
@@ -71,7 +75,7 @@ def resolve_transport(transport=None):
return transport
-def get_transport_cls(transport=None):
+def get_transport_cls(transport: str | None = None) -> str | None:
"""Get transport class by name.
The transport string is the full path to a transport class, e.g.::
diff --git a/kombu/transport/azureservicebus.py b/kombu/transport/azureservicebus.py
index 83237424..e7e2c0cc 100644
--- a/kombu/transport/azureservicebus.py
+++ b/kombu/transport/azureservicebus.py
@@ -53,9 +53,11 @@ Transport Options
* ``retry_backoff_max`` - Azure SDK retry total time. Default ``120``
"""
+from __future__ import annotations
+
import string
from queue import Empty
-from typing import Any, Dict, Optional, Set, Tuple, Union
+from typing import Any, Dict, Set
import azure.core.exceptions
import azure.servicebus.exceptions
@@ -83,10 +85,10 @@ class SendReceive:
"""Container for Sender and Receiver."""
def __init__(self,
- receiver: Optional[ServiceBusReceiver] = None,
- sender: Optional[ServiceBusSender] = None):
- self.receiver = receiver # type: ServiceBusReceiver
- self.sender = sender # type: ServiceBusSender
+ receiver: ServiceBusReceiver | None = None,
+ sender: ServiceBusSender | None = None):
+ self.receiver: ServiceBusReceiver = receiver
+ self.sender: ServiceBusSender = sender
def close(self) -> None:
if self.receiver:
@@ -100,21 +102,19 @@ class SendReceive:
class Channel(virtual.Channel):
"""Azure Service Bus channel."""
- default_wait_time_seconds = 5 # in seconds
- default_peek_lock_seconds = 60 # in seconds (default 60, max 300)
+ default_wait_time_seconds: int = 5 # in seconds
+ default_peek_lock_seconds: int = 60 # in seconds (default 60, max 300)
# in seconds (is the default from service bus repo)
- default_uamqp_keep_alive_interval = 30
+ default_uamqp_keep_alive_interval: int = 30
# number of retries (is the default from service bus repo)
- default_retry_total = 3
+ default_retry_total: int = 3
# exponential backoff factor (is the default from service bus repo)
- default_retry_backoff_factor = 0.8
+ default_retry_backoff_factor: float = 0.8
# Max time to backoff (is the default from service bus repo)
- default_retry_backoff_max = 120
- domain_format = 'kombu%(vhost)s'
- _queue_service = None # type: ServiceBusClient
- _queue_mgmt_service = None # type: ServiceBusAdministrationClient
- _queue_cache = {} # type: Dict[str, SendReceive]
- _noack_queues = set() # type: Set[str]
+ default_retry_backoff_max: int = 120
+ domain_format: str = 'kombu%(vhost)s'
+ _queue_cache: Dict[str, SendReceive] = {}
+ _noack_queues: Set[str] = set()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -160,8 +160,8 @@ class Channel(virtual.Channel):
def _add_queue_to_cache(
self, name: str,
- receiver: Optional[ServiceBusReceiver] = None,
- sender: Optional[ServiceBusSender] = None
+ receiver: ServiceBusReceiver | None = None,
+ sender: ServiceBusSender | None = None
) -> SendReceive:
if name in self._queue_cache:
obj = self._queue_cache[name]
@@ -183,7 +183,7 @@ class Channel(virtual.Channel):
def _get_asb_receiver(
self, queue: str,
recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
- queue_cache_key: Optional[str] = None) -> SendReceive:
+ queue_cache_key: str | None = None) -> SendReceive:
cache_key = queue_cache_key or queue
queue_obj = self._queue_cache.get(cache_key, None)
if queue_obj is None or queue_obj.receiver is None:
@@ -194,7 +194,7 @@ class Channel(virtual.Channel):
return queue_obj
def entity_name(
- self, name: str, table: Optional[Dict[int, int]] = None) -> str:
+ self, name: str, table: dict[int, int] | None = None) -> str:
"""Format AMQP queue name into a valid ServiceBus queue name."""
return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
@@ -227,7 +227,7 @@ class Channel(virtual.Channel):
"""Delete queue by name."""
queue = self.entity_name(self.queue_name_prefix + queue)
- self._queue_mgmt_service.delete_queue(queue)
+ self.queue_mgmt_service.delete_queue(queue)
send_receive_obj = self._queue_cache.pop(queue, None)
if send_receive_obj:
send_receive_obj.close()
@@ -242,8 +242,8 @@ class Channel(virtual.Channel):
def _get(
self, queue: str,
- timeout: Optional[Union[float, int]] = None
- ) -> Dict[str, Any]:
+ timeout: float | int | None = None
+ ) -> dict[str, Any]:
"""Try to retrieve a single message off ``queue``."""
# If we're not ack'ing for this queue, just change receive_mode
recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE \
@@ -298,7 +298,7 @@ class Channel(virtual.Channel):
return props.total_message_count
- def _purge(self, queue):
+ def _purge(self, queue) -> int:
"""Delete all current messages in a queue."""
# Azure doesn't provide a purge api yet
n = 0
@@ -337,24 +337,19 @@ class Channel(virtual.Channel):
if self.connection is not None:
self.connection.close_channel(self)
- @property
+ @cached_property
def queue_service(self) -> ServiceBusClient:
- if self._queue_service is None:
- self._queue_service = ServiceBusClient.from_connection_string(
- self._connection_string,
- retry_total=self.retry_total,
- retry_backoff_factor=self.retry_backoff_factor,
- retry_backoff_max=self.retry_backoff_max
- )
- return self._queue_service
+ return ServiceBusClient.from_connection_string(
+ self._connection_string,
+ retry_total=self.retry_total,
+ retry_backoff_factor=self.retry_backoff_factor,
+ retry_backoff_max=self.retry_backoff_max
+ )
- @property
+ @cached_property
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
- if self._queue_mgmt_service is None:
- self._queue_mgmt_service = \
- ServiceBusAdministrationClient.from_connection_string(
+ return ServiceBusAdministrationClient.from_connection_string(
self._connection_string)
- return self._queue_mgmt_service
@property
def conninfo(self):
@@ -412,7 +407,7 @@ class Transport(virtual.Transport):
can_parse_url = True
@staticmethod
- def parse_uri(uri: str) -> Tuple[str, str, str]:
+ def parse_uri(uri: str) -> tuple[str, str, str]:
# URL like:
# azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
# urllib parse does not work as the sas key could contain a slash
diff --git a/kombu/transport/azurestoragequeues.py b/kombu/transport/azurestoragequeues.py
index e83a20d3..16d22f0b 100644
--- a/kombu/transport/azurestoragequeues.py
+++ b/kombu/transport/azurestoragequeues.py
@@ -15,14 +15,34 @@ Features
Connection String
=================
-Connection string has the following format:
+Connection string has the following formats:
.. code-block::
- azurestoragequeues://:STORAGE_ACCOUNT_ACCESS kEY@STORAGE_ACCOUNT_NAME
+ azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
+ azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
+ azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
+ azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
-Note that if the access key for the storage account contains a slash, it will
-have to be regenerated before it can be used in the connection URL.
+Note that if the access key for the storage account contains a forward slash
+(``/``), it will have to be regenerated before it can be used in the connection
+URL.
+
+.. code-block::
+
+ azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
+ azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
+
+If you wish to use an `Azure Managed Identity` you may use the
+``DefaultAzureCredential`` format of the connection string which will use
+``DefaultAzureCredential`` class in the azure-identity package. You may want to
+read the `azure-identity documentation` for more information on how the
+``DefaultAzureCredential`` works.
+
+.. _azure-identity documentation:
+https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python
+.. _Azure Managed Identity:
+https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
Transport Options
=================
@@ -30,8 +50,13 @@ Transport Options
* ``queue_name_prefix``
"""
+from __future__ import annotations
+
import string
from queue import Empty
+from typing import Any, Optional
+
+from azure.core.exceptions import ResourceExistsError
from kombu.utils.encoding import safe_str
from kombu.utils.json import dumps, loads
@@ -40,9 +65,16 @@ from kombu.utils.objects import cached_property
from . import virtual
try:
- from azure.storage.queue import QueueService
+ from azure.storage.queue import QueueServiceClient
except ImportError: # pragma: no cover
- QueueService = None
+ QueueServiceClient = None
+
+try:
+ from azure.identity import (DefaultAzureCredential,
+ ManagedIdentityCredential)
+except ImportError:
+ DefaultAzureCredential = None
+ ManagedIdentityCredential = None
# Azure storage queues allow only alphanumeric and dashes
# so, replace everything with a dash
@@ -54,21 +86,25 @@ CHARS_REPLACE_TABLE = {
class Channel(virtual.Channel):
"""Azure Storage Queues channel."""
- domain_format = 'kombu%(vhost)s'
- _queue_service = None
- _queue_name_cache = {}
- no_ack = True
- _noack_queues = set()
+ domain_format: str = 'kombu%(vhost)s'
+ _queue_service: Optional[QueueServiceClient] = None
+ _queue_name_cache: dict[Any, Any] = {}
+ no_ack: bool = True
+ _noack_queues: set[Any] = set()
def __init__(self, *args, **kwargs):
- if QueueService is None:
+ if QueueServiceClient is None:
raise ImportError('Azure Storage Queues transport requires the '
'azure-storage-queue library')
super().__init__(*args, **kwargs)
- for queue_name in self.queue_service.list_queues():
- self._queue_name_cache[queue_name] = queue_name
+ self._credential, self._url = Transport.parse_uri(
+ self.conninfo.hostname
+ )
+
+ for queue in self.queue_service.list_queues():
+ self._queue_name_cache[queue['name']] = queue
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
@@ -77,7 +113,7 @@ class Channel(virtual.Channel):
return super().basic_consume(queue, no_ack,
*args, **kwargs)
- def entity_name(self, name, table=CHARS_REPLACE_TABLE):
+ def entity_name(self, name, table=CHARS_REPLACE_TABLE) -> str:
"""Format AMQP queue name into a valid Azure Storage Queue name."""
return str(safe_str(name)).translate(table)
@@ -85,61 +121,64 @@ class Channel(virtual.Channel):
"""Ensure a queue exists."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
- return self._queue_name_cache[queue]
+ q = self._queue_service.get_queue_client(
+ queue=self._queue_name_cache[queue]
+ )
except KeyError:
- self.queue_service.create_queue(queue, fail_on_exist=False)
- q = self._queue_name_cache[queue] = queue
- return q
+ try:
+ q = self.queue_service.create_queue(queue)
+ except ResourceExistsError:
+ q = self._queue_service.get_queue_client(queue=queue)
+
+ self._queue_name_cache[queue] = q.get_queue_properties()
+ return q
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
queue_name = self.entity_name(queue)
self._queue_name_cache.pop(queue_name, None)
self.queue_service.delete_queue(queue_name)
- super()._delete(queue_name)
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q = self._ensure_queue(queue)
encoded_message = dumps(message)
- self.queue_service.put_message(q, encoded_message)
+ q.send_message(encoded_message)
def _get(self, queue, timeout=None):
"""Try to retrieve a single message off ``queue``."""
q = self._ensure_queue(queue)
- messages = self.queue_service.get_messages(q, num_messages=1,
- timeout=timeout)
- if not messages:
+ messages = q.receive_messages(messages_per_page=1, timeout=timeout)
+ try:
+ message = next(messages)
+ except StopIteration:
raise Empty()
- message = messages[0]
- raw_content = self.queue_service.decode_function(message.content)
- content = loads(raw_content)
+ content = loads(message.content)
- self.queue_service.delete_message(q, message.id, message.pop_receipt)
+ q.delete_message(message=message)
return content
def _size(self, queue):
"""Return the number of messages in a queue."""
q = self._ensure_queue(queue)
- metadata = self.queue_service.get_queue_metadata(q)
- return metadata.approximate_message_count
+ return q.get_queue_properties().approximate_message_count
def _purge(self, queue):
"""Delete all current messages in a queue."""
q = self._ensure_queue(queue)
- n = self._size(q)
- self.queue_service.clear_messages(q)
+ n = self._size(q.queue_name)
+ q.clear_messages()
return n
@property
- def queue_service(self):
+ def queue_service(self) -> QueueServiceClient:
if self._queue_service is None:
- self._queue_service = QueueService(
- account_name=self.conninfo.hostname,
- account_key=self.conninfo.password)
+ self._queue_service = QueueServiceClient(
+ account_url=self._url, credential=self._credential
+ )
return self._queue_service
@@ -152,7 +191,7 @@ class Channel(virtual.Channel):
return self.connection.client.transport_options
@cached_property
- def queue_name_prefix(self):
+ def queue_name_prefix(self) -> str:
return self.transport_options.get('queue_name_prefix', '')
@@ -161,5 +200,64 @@ class Transport(virtual.Transport):
Channel = Channel
- polling_interval = 1
- default_port = None
+ polling_interval: int = 1
+ default_port: Optional[int] = None
+ can_parse_url: bool = True
+
+ @staticmethod
+ def parse_uri(uri: str) -> tuple[str | dict, str]:
+ # URL like:
+ # azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL>
+ # azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL>
+ # azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL>
+ # azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL>
+
+ # urllib parse does not work as the sas key could contain a slash
+ # e.g.: azurestoragequeues://some/key@someurl
+
+ try:
+ # > 'some/key@url'
+ uri = uri.replace('azurestoragequeues://', '')
+ # > 'some/key', 'url'
+ credential, url = uri.rsplit('@', 1)
+
+ if "DefaultAzureCredential".lower() == credential.lower():
+ if DefaultAzureCredential is None:
+ raise ImportError('Azure Storage Queues transport with a '
+ 'DefaultAzureCredential requires the '
+ 'azure-identity library')
+ credential = DefaultAzureCredential()
+ elif "ManagedIdentityCredential".lower() == credential.lower():
+ if ManagedIdentityCredential is None:
+ raise ImportError('Azure Storage Queues transport with a '
+ 'ManagedIdentityCredential requires the '
+ 'azure-identity library')
+ credential = ManagedIdentityCredential()
+ elif "devstoreaccount1" in url and ".core.windows.net" not in url:
+ # parse credential as a dict if Azurite is being used
+ credential = {
+ "account_name": "devstoreaccount1",
+ "account_key": credential,
+ }
+
+ # Validate parameters
+ assert all([credential, url])
+ except Exception:
+ raise ValueError(
+ 'Need a URI like '
+ 'azurestoragequeues://{SAS or access key}@{URL}, '
+ 'azurestoragequeues://DefaultAzureCredential@{URL}, '
+ ', or '
+ 'azurestoragequeues://ManagedIdentityCredential@{URL}'
+ )
+
+ return credential, url
+
+ @classmethod
+ def as_uri(
+ cls, uri: str, include_password: bool = False, mask: str = "**"
+ ) -> str:
+ credential, url = cls.parse_uri(uri)
+ return "azurestoragequeues://{}@{}".format(
+ credential if include_password else mask, url
+ )
diff --git a/kombu/transport/base.py b/kombu/transport/base.py
index 3083acf4..ec4c0aca 100644
--- a/kombu/transport/base.py
+++ b/kombu/transport/base.py
@@ -2,8 +2,11 @@
# flake8: noqa
+from __future__ import annotations
+
import errno
import socket
+from typing import TYPE_CHECKING
from amqp.exceptions import RecoverableConnectionError
@@ -13,6 +16,9 @@ from kombu.utils.functional import dictfilter
from kombu.utils.objects import cached_property
from kombu.utils.time import maybe_s_to_ms
+if TYPE_CHECKING:
+ from types import TracebackType
+
__all__ = ('Message', 'StdChannel', 'Management', 'Transport')
RABBITMQ_QUEUE_ARGUMENTS = {
@@ -100,7 +106,12 @@ class StdChannel:
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.close()
diff --git a/kombu/transport/confluentkafka.py b/kombu/transport/confluentkafka.py
new file mode 100644
index 00000000..5332a310
--- /dev/null
+++ b/kombu/transport/confluentkafka.py
@@ -0,0 +1,379 @@
+"""confluent-kafka transport module for Kombu.
+
+Kafka transport using confluent-kafka library.
+
+**References**
+
+- http://docs.confluent.io/current/clients/confluent-kafka-python
+
+**Limitations**
+
+The confluent-kafka transport does not support PyPy environment.
+
+Features
+========
+* Type: Virtual
+* Supports Direct: Yes
+* Supports Topic: Yes
+* Supports Fanout: No
+* Supports Priority: No
+* Supports TTL: No
+
+Connection String
+=================
+Connection string has the following format:
+
+.. code-block::
+
+ confluentkafka://[USER:PASSWORD@]KAFKA_ADDRESS[:PORT]
+
+Transport Options
+=================
+* ``connection_wait_time_seconds`` - Time in seconds to wait for connection
+ to succeed. Default ``5``
+* ``wait_time_seconds`` - Time in seconds to wait to receive messages.
+ Default ``5``
+* ``security_protocol`` - Protocol used to communicate with broker.
+ Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
+ an explanation of valid values. Default ``plaintext``
+* ``sasl_mechanism`` - SASL mechanism to use for authentication.
+ Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for
+ an explanation of valid values.
+* ``num_partitions`` - Number of partitions to create. Default ``1``
+* ``replication_factor`` - Replication factor of partitions. Default ``1``
+* ``topic_config`` - Topic configuration. Must be a dict whose key-value pairs
+ correspond with attributes in the
+ http://kafka.apache.org/documentation.html#topicconfigs.
+* ``kafka_common_config`` - Configuration applied to producer, consumer and
+ admin client. Must be a dict whose key-value pairs correspond with attributes
+ in the https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
+* ``kafka_producer_config`` - Producer configuration. Must be a dict whose
+ key-value pairs correspond with attributes in the
+ https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
+* ``kafka_consumer_config`` - Consumer configuration. Must be a dict whose
+ key-value pairs correspond with attributes in the
+ https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
+* ``kafka_admin_config`` - Admin client configuration. Must be a dict whose
+ key-value pairs correspond with attributes in the
+ https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md.
+"""
+
+from __future__ import annotations
+
+from queue import Empty
+
+from kombu.transport import virtual
+from kombu.utils import cached_property
+from kombu.utils.encoding import str_to_bytes
+from kombu.utils.json import dumps, loads
+
+try:
+ import confluent_kafka
+ from confluent_kafka import Consumer, Producer, TopicPartition
+ from confluent_kafka.admin import AdminClient, NewTopic
+
+ KAFKA_CONNECTION_ERRORS = ()
+ KAFKA_CHANNEL_ERRORS = ()
+
+except ImportError:
+ confluent_kafka = None
+ KAFKA_CONNECTION_ERRORS = KAFKA_CHANNEL_ERRORS = ()
+
+from kombu.log import get_logger
+
+logger = get_logger(__name__)
+
+DEFAULT_PORT = 9092
+
+
+class NoBrokersAvailable(confluent_kafka.KafkaException):
+ """Kafka broker is not available exception."""
+
+ retriable = True
+
+
+class Message(virtual.Message):
+ """Message object."""
+
+ def __init__(self, payload, channel=None, **kwargs):
+ self.topic = payload.get('topic')
+ super().__init__(payload, channel=channel, **kwargs)
+
+
+class QoS(virtual.QoS):
+ """Quality of Service guarantees."""
+
+ _not_yet_acked = {}
+
+ def can_consume(self):
+ """Return true if the channel can be consumed from.
+
+ :returns: True, if this QoS object can accept a message.
+ :rtype: bool
+ """
+ return not self.prefetch_count or len(self._not_yet_acked) < self \
+ .prefetch_count
+
+ def can_consume_max_estimate(self):
+ if self.prefetch_count:
+ return self.prefetch_count - len(self._not_yet_acked)
+ else:
+ return 1
+
+ def append(self, message, delivery_tag):
+ self._not_yet_acked[delivery_tag] = message
+
+ def get(self, delivery_tag):
+ return self._not_yet_acked[delivery_tag]
+
+ def ack(self, delivery_tag):
+ if delivery_tag not in self._not_yet_acked:
+ return
+ message = self._not_yet_acked.pop(delivery_tag)
+ consumer = self.channel._get_consumer(message.topic)
+ consumer.commit()
+
+ def reject(self, delivery_tag, requeue=False):
+ """Reject a message by delivery tag.
+
+ If requeue is True, then the last consumed message is reverted so
+ it'll be refetched on the next attempt.
+ If False, that message is consumed and ignored.
+ """
+ if requeue:
+ message = self._not_yet_acked.pop(delivery_tag)
+ consumer = self.channel._get_consumer(message.topic)
+ for assignment in consumer.assignment():
+ topic_partition = TopicPartition(message.topic,
+ assignment.partition)
+ [committed_offset] = consumer.committed([topic_partition])
+ consumer.seek(committed_offset)
+ else:
+ self.ack(delivery_tag)
+
+ def restore_unacked_once(self, stderr=None):
+ pass
+
+
+class Channel(virtual.Channel):
+ """Kafka Channel."""
+
+ QoS = QoS
+ Message = Message
+
+ default_wait_time_seconds = 5
+ default_connection_wait_time_seconds = 5
+ _client = None
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self._kafka_consumers = {}
+ self._kafka_producers = {}
+
+ self._client = self._open()
+
+ def sanitize_queue_name(self, queue):
+ """Need to sanitize the name, celery sometimes pushes in @ signs."""
+ return str(queue).replace('@', '')
+
+ def _get_producer(self, queue):
+ """Create/get a producer instance for the given topic/queue."""
+ queue = self.sanitize_queue_name(queue)
+ producer = self._kafka_producers.get(queue, None)
+ if producer is None:
+ producer = Producer({
+ **self.common_config,
+ **(self.options.get('kafka_producer_config') or {}),
+ })
+ self._kafka_producers[queue] = producer
+
+ return producer
+
+ def _get_consumer(self, queue):
+ """Create/get a consumer instance for the given topic/queue."""
+ queue = self.sanitize_queue_name(queue)
+ consumer = self._kafka_consumers.get(queue, None)
+ if consumer is None:
+ consumer = Consumer({
+ 'group.id': f'{queue}-consumer-group',
+ 'auto.offset.reset': 'earliest',
+ 'enable.auto.commit': False,
+ **self.common_config,
+ **(self.options.get('kafka_consumer_config') or {}),
+ })
+ consumer.subscribe([queue])
+ self._kafka_consumers[queue] = consumer
+
+ return consumer
+
+ def _put(self, queue, message, **kwargs):
+ """Put a message on the topic/queue."""
+ queue = self.sanitize_queue_name(queue)
+ producer = self._get_producer(queue)
+ producer.produce(queue, str_to_bytes(dumps(message)))
+ producer.flush()
+
+ def _get(self, queue, **kwargs):
+ """Get a message from the topic/queue."""
+ queue = self.sanitize_queue_name(queue)
+ consumer = self._get_consumer(queue)
+ message = None
+
+ try:
+ message = consumer.poll(self.wait_time_seconds)
+ except StopIteration:
+ pass
+
+ if not message:
+ raise Empty()
+
+ error = message.error()
+ if error:
+ logger.error(error)
+ raise Empty()
+
+ return {**loads(message.value()), 'topic': message.topic()}
+
+ def _delete(self, queue, *args, **kwargs):
+ """Delete a queue/topic."""
+ queue = self.sanitize_queue_name(queue)
+ self._kafka_consumers[queue].close()
+ self._kafka_consumers.pop(queue)
+ self.client.delete_topics([queue])
+
+ def _size(self, queue):
+ """Get the number of pending messages in the topic/queue."""
+ queue = self.sanitize_queue_name(queue)
+
+ consumer = self._kafka_consumers.get(queue, None)
+ if consumer is None:
+ return 0
+
+ size = 0
+ for assignment in consumer.assignment():
+ topic_partition = TopicPartition(queue, assignment.partition)
+ (_, end_offset) = consumer.get_watermark_offsets(topic_partition)
+ [committed_offset] = consumer.committed([topic_partition])
+ size += end_offset - committed_offset.offset
+ return size
+
+ def _new_queue(self, queue, **kwargs):
+ """Create a new topic if it does not exist."""
+ queue = self.sanitize_queue_name(queue)
+ if queue in self.client.list_topics().topics:
+ return
+
+ topic = NewTopic(
+ queue,
+ num_partitions=self.options.get('num_partitions', 1),
+ replication_factor=self.options.get('replication_factor', 1),
+ config=self.options.get('topic_config', {})
+ )
+ self.client.create_topics(new_topics=[topic])
+
+ def _has_queue(self, queue, **kwargs):
+ """Check if a topic already exists."""
+ queue = self.sanitize_queue_name(queue)
+ return queue in self.client.list_topics().topics
+
+ def _open(self):
+ client = AdminClient({
+ **self.common_config,
+ **(self.options.get('kafka_admin_config') or {}),
+ })
+
+ try:
+ # seems to be the only way to check connection
+ client.list_topics(timeout=self.wait_time_seconds)
+ except confluent_kafka.KafkaException as e:
+ raise NoBrokersAvailable(e)
+
+ return client
+
+ @property
+ def client(self):
+ if self._client is None:
+ self._client = self._open()
+ return self._client
+
+ @property
+ def options(self):
+ return self.connection.client.transport_options
+
+ @property
+ def conninfo(self):
+ return self.connection.client
+
+ @cached_property
+ def wait_time_seconds(self):
+ return self.options.get(
+ 'wait_time_seconds', self.default_wait_time_seconds
+ )
+
+ @cached_property
+ def connection_wait_time_seconds(self):
+ return self.options.get(
+ 'connection_wait_time_seconds',
+ self.default_connection_wait_time_seconds,
+ )
+
+ @cached_property
+ def common_config(self):
+ conninfo = self.connection.client
+ config = {
+ 'bootstrap.servers':
+ f'{conninfo.hostname}:{int(conninfo.port) or DEFAULT_PORT}',
+ }
+ security_protocol = self.options.get('security_protocol', 'plaintext')
+ if security_protocol.lower() != 'plaintext':
+ config.update({
+ 'security.protocol': security_protocol,
+ 'sasl.username': conninfo.userid,
+ 'sasl.password': conninfo.password,
+ 'sasl.mechanism': self.options.get('sasl_mechanism'),
+ })
+
+ config.update(self.options.get('kafka_common_config') or {})
+ return config
+
+ def close(self):
+ super().close()
+ self._kafka_producers = {}
+
+ for consumer in self._kafka_consumers.values():
+ consumer.close()
+
+ self._kafka_consumers = {}
+
+
+class Transport(virtual.Transport):
+ """Kafka Transport."""
+
+ def as_uri(self, uri: str, include_password=False, mask='**') -> str:
+ pass
+
+ Channel = Channel
+
+ default_port = DEFAULT_PORT
+
+ driver_type = 'kafka'
+ driver_name = 'confluentkafka'
+
+ recoverable_connection_errors = (
+ NoBrokersAvailable,
+ )
+
+ def __init__(self, client, **kwargs):
+ if confluent_kafka is None:
+ raise ImportError('The confluent-kafka library is not installed')
+ super().__init__(client, **kwargs)
+
+ def driver_version(self):
+ return confluent_kafka.__version__
+
+ def establish_connection(self):
+ return super().establish_connection()
+
+ def close_connection(self, connection):
+ return super().close_connection(connection)
diff --git a/kombu/transport/consul.py b/kombu/transport/consul.py
index ea275c95..7ace52f6 100644
--- a/kombu/transport/consul.py
+++ b/kombu/transport/consul.py
@@ -27,6 +27,8 @@ Connection string has the following format:
"""
+from __future__ import annotations
+
import socket
import uuid
from collections import defaultdict
@@ -276,24 +278,25 @@ class Transport(virtual.Transport):
driver_type = 'consul'
driver_name = 'consul'
- def __init__(self, *args, **kwargs):
- if consul is None:
- raise ImportError('Missing python-consul library')
-
- super().__init__(*args, **kwargs)
-
- self.connection_errors = (
+ if consul:
+ connection_errors = (
virtual.Transport.connection_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
- self.channel_errors = (
+ channel_errors = (
virtual.Transport.channel_errors + (
consul.ConsulException, consul.base.ConsulException
)
)
+ def __init__(self, *args, **kwargs):
+ if consul is None:
+ raise ImportError('Missing python-consul library')
+
+ super().__init__(*args, **kwargs)
+
def verify_connection(self, connection):
port = connection.client.port or self.default_port
host = connection.client.hostname or DEFAULT_HOST
diff --git a/kombu/transport/etcd.py b/kombu/transport/etcd.py
index 4d0b0364..2ab85841 100644
--- a/kombu/transport/etcd.py
+++ b/kombu/transport/etcd.py
@@ -24,6 +24,8 @@ Connection string has the following format:
"""
+from __future__ import annotations
+
import os
import socket
from collections import defaultdict
@@ -242,6 +244,15 @@ class Transport(virtual.Transport):
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct']))
+ if etcd:
+ connection_errors = (
+ virtual.Transport.connection_errors + (etcd.EtcdException, )
+ )
+
+ channel_errors = (
+ virtual.Transport.channel_errors + (etcd.EtcdException, )
+ )
+
def __init__(self, *args, **kwargs):
"""Create a new instance of etcd.Transport."""
if etcd is None:
@@ -249,14 +260,6 @@ class Transport(virtual.Transport):
super().__init__(*args, **kwargs)
- self.connection_errors = (
- virtual.Transport.connection_errors + (etcd.EtcdException, )
- )
-
- self.channel_errors = (
- virtual.Transport.channel_errors + (etcd.EtcdException, )
- )
-
def verify_connection(self, connection):
"""Verify the connection works."""
port = connection.client.port or self.default_port
diff --git a/kombu/transport/filesystem.py b/kombu/transport/filesystem.py
index d66c42d6..9d2b3581 100644
--- a/kombu/transport/filesystem.py
+++ b/kombu/transport/filesystem.py
@@ -65,7 +65,7 @@ Features
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: Yes
-* Supports Fanout: No
+* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
@@ -86,22 +86,26 @@ Transport Options
* ``store_processed`` - if set to True, all processed messages are backed up to
``processed_folder``.
* ``processed_folder`` - directory where are backed up processed files.
+* ``control_folder`` - directory where are exchange-queue table stored.
"""
+from __future__ import annotations
+
import os
import shutil
import tempfile
import uuid
+from collections import namedtuple
+from pathlib import Path
from queue import Empty
from time import monotonic
from kombu.exceptions import ChannelError
+from kombu.transport import virtual
from kombu.utils.encoding import bytes_to_str, str_to_bytes
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
-from . import virtual
-
VERSION = (1, 0, 0)
__version__ = '.'.join(map(str, VERSION))
@@ -128,10 +132,11 @@ if os.name == 'nt':
hfile = win32file._get_osfhandle(file.fileno())
win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped)
+
elif os.name == 'posix':
import fcntl
- from fcntl import LOCK_EX, LOCK_NB, LOCK_SH # noqa
+ from fcntl import LOCK_EX, LOCK_SH
def lock(file, flags):
"""Create file lock."""
@@ -140,14 +145,66 @@ elif os.name == 'posix':
def unlock(file):
"""Remove file lock."""
fcntl.flock(file.fileno(), fcntl.LOCK_UN)
+
+
else:
raise RuntimeError(
'Filesystem plugin only defined for NT and POSIX platforms')
+exchange_queue_t = namedtuple("exchange_queue_t",
+ ["routing_key", "pattern", "queue"])
+
+
class Channel(virtual.Channel):
"""Filesystem Channel."""
+ supports_fanout = True
+
+ def get_table(self, exchange):
+ file = self.control_folder / f"{exchange}.exchange"
+ try:
+ f_obj = file.open("r")
+ try:
+ lock(f_obj, LOCK_SH)
+ exchange_table = loads(bytes_to_str(f_obj.read()))
+ return [exchange_queue_t(*q) for q in exchange_table]
+ finally:
+ unlock(f_obj)
+ f_obj.close()
+ except FileNotFoundError:
+ return []
+ except OSError:
+ raise ChannelError(f"Cannot open {file}")
+
+ def _queue_bind(self, exchange, routing_key, pattern, queue):
+ file = self.control_folder / f"{exchange}.exchange"
+ self.control_folder.mkdir(exist_ok=True)
+ queue_val = exchange_queue_t(routing_key or "", pattern or "",
+ queue or "")
+ try:
+ if file.exists():
+ f_obj = file.open("rb+", buffering=0)
+ lock(f_obj, LOCK_EX)
+ exchange_table = loads(bytes_to_str(f_obj.read()))
+ queues = [exchange_queue_t(*q) for q in exchange_table]
+ if queue_val not in queues:
+ queues.insert(0, queue_val)
+ f_obj.seek(0)
+ f_obj.write(str_to_bytes(dumps(queues)))
+ else:
+ f_obj = file.open("wb", buffering=0)
+ lock(f_obj, LOCK_EX)
+ queues = [queue_val]
+ f_obj.write(str_to_bytes(dumps(queues)))
+ finally:
+ unlock(f_obj)
+ f_obj.close()
+
+ def _put_fanout(self, exchange, payload, routing_key, **kwargs):
+ for q in self.get_table(exchange):
+ self._put(q.queue, payload, **kwargs)
+
def _put(self, queue, payload, **kwargs):
"""Put `message` onto `queue`."""
filename = '{}_{}.{}.msg'.format(int(round(monotonic() * 1000)),
@@ -155,7 +212,7 @@ class Channel(virtual.Channel):
filename = os.path.join(self.data_folder_out, filename)
try:
- f = open(filename, 'wb')
+ f = open(filename, 'wb', buffering=0)
lock(f, LOCK_EX)
f.write(str_to_bytes(dumps(payload)))
except OSError:
@@ -187,7 +244,8 @@ class Channel(virtual.Channel):
shutil.move(os.path.join(self.data_folder_in, filename),
processed_folder)
except OSError:
- pass # file could be locked, or removed in meantime so ignore
+ # file could be locked, or removed in meantime so ignore
+ continue
filename = os.path.join(processed_folder, filename)
try:
@@ -266,10 +324,19 @@ class Channel(virtual.Channel):
def processed_folder(self):
return self.transport_options.get('processed_folder', 'processed')
+ @property
+ def control_folder(self):
+ return Path(self.transport_options.get('control_folder', 'control'))
+
class Transport(virtual.Transport):
"""Filesystem Transport."""
+ implements = virtual.Transport.implements.extend(
+ asynchronous=False,
+ exchange_type=frozenset(['direct', 'topic', 'fanout'])
+ )
+
Channel = Channel
# filesystem backend state is global.
global_state = virtual.BrokerState()
diff --git a/kombu/transport/librabbitmq.py b/kombu/transport/librabbitmq.py
index dec50ccf..37015b18 100644
--- a/kombu/transport/librabbitmq.py
+++ b/kombu/transport/librabbitmq.py
@@ -3,6 +3,8 @@
.. _`librabbitmq`: https://pypi.org/project/librabbitmq/
"""
+from __future__ import annotations
+
import os
import socket
import warnings
diff --git a/kombu/transport/memory.py b/kombu/transport/memory.py
index 3073d1cf..9bfaff8d 100644
--- a/kombu/transport/memory.py
+++ b/kombu/transport/memory.py
@@ -22,6 +22,8 @@ Connection string is in the following format:
"""
+from __future__ import annotations
+
from collections import defaultdict
from queue import Queue
diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py
index db758c18..b923f5f4 100644
--- a/kombu/transport/mongodb.py
+++ b/kombu/transport/mongodb.py
@@ -33,6 +33,8 @@ Transport Options
* ``calc_queue_size``,
"""
+from __future__ import annotations
+
import datetime
from queue import Empty
@@ -63,11 +65,10 @@ class BroadcastCursor:
def __init__(self, cursor):
self._cursor = cursor
-
self.purge(rewind=False)
def get_size(self):
- return self._cursor.count() - self._offset
+ return self._cursor.collection.count_documents({}) - self._offset
def close(self):
self._cursor.close()
@@ -77,7 +78,7 @@ class BroadcastCursor:
self._cursor.rewind()
# Fast forward the cursor past old events
- self._offset = self._cursor.count()
+ self._offset = self._cursor.collection.count_documents({})
self._cursor = self._cursor.skip(self._offset)
def __iter__(self):
@@ -149,11 +150,17 @@ class Channel(virtual.Channel):
def _new_queue(self, queue, **kwargs):
if self.ttl:
- self.queues.update(
+ self.queues.update_one(
{'_id': queue},
- {'_id': queue,
- 'options': kwargs,
- 'expire_at': self._get_expire(kwargs, 'x-expires')},
+ {
+ '$set': {
+ '_id': queue,
+ 'options': kwargs,
+ 'expire_at': self._get_queue_expire(
+ kwargs, 'x-expires'
+ ),
+ },
+ },
upsert=True)
def _get(self, queue):
@@ -163,10 +170,9 @@ class Channel(virtual.Channel):
except StopIteration:
msg = None
else:
- msg = self.messages.find_and_modify(
- query={'queue': queue},
+ msg = self.messages.find_one_and_delete(
+ {'queue': queue},
sort=[('priority', pymongo.ASCENDING)],
- remove=True,
)
if self.ttl:
@@ -186,7 +192,7 @@ class Channel(virtual.Channel):
if queue in self._fanout_queues:
return self._get_broadcast_cursor(queue).get_size()
- return self.messages.find({'queue': queue}).count()
+ return self.messages.count_documents({'queue': queue})
def _put(self, queue, message, **kwargs):
data = {
@@ -196,13 +202,18 @@ class Channel(virtual.Channel):
}
if self.ttl:
- data['expire_at'] = self._get_expire(queue, 'x-message-ttl')
+ data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl')
+ msg_expire = self._get_message_expire(message)
+ if msg_expire is not None and (
+ data['expire_at'] is None or msg_expire < data['expire_at']
+ ):
+ data['expire_at'] = msg_expire
- self.messages.insert(data)
+ self.messages.insert_one(data)
def _put_fanout(self, exchange, message, routing_key, **kwargs):
- self.broadcast.insert({'payload': dumps(message),
- 'queue': exchange})
+ self.broadcast.insert_one({'payload': dumps(message),
+ 'queue': exchange})
def _purge(self, queue):
size = self._size(queue)
@@ -241,9 +252,9 @@ class Channel(virtual.Channel):
data = lookup.copy()
if self.ttl:
- data['expire_at'] = self._get_expire(queue, 'x-expires')
+ data['expire_at'] = self._get_queue_expire(queue, 'x-expires')
- self.routing.update(lookup, data, upsert=True)
+ self.routing.update_one(lookup, {'$set': data}, upsert=True)
def queue_delete(self, queue, **kwargs):
self.routing.remove({'queue': queue})
@@ -346,7 +357,7 @@ class Channel(virtual.Channel):
def _create_broadcast(self, database):
"""Create capped collection for broadcast messages."""
- if self.broadcast_collection in database.collection_names():
+ if self.broadcast_collection in database.list_collection_names():
return
database.create_collection(self.broadcast_collection,
@@ -356,20 +367,20 @@ class Channel(virtual.Channel):
def _ensure_indexes(self, database):
"""Ensure indexes on collections."""
messages = database[self.messages_collection]
- messages.ensure_index(
+ messages.create_index(
[('queue', 1), ('priority', 1), ('_id', 1)], background=True,
)
- database[self.broadcast_collection].ensure_index([('queue', 1)])
+ database[self.broadcast_collection].create_index([('queue', 1)])
routing = database[self.routing_collection]
- routing.ensure_index([('queue', 1), ('exchange', 1)])
+ routing.create_index([('queue', 1), ('exchange', 1)])
if self.ttl:
- messages.ensure_index([('expire_at', 1)], expireAfterSeconds=0)
- routing.ensure_index([('expire_at', 1)], expireAfterSeconds=0)
+ messages.create_index([('expire_at', 1)], expireAfterSeconds=0)
+ routing.create_index([('expire_at', 1)], expireAfterSeconds=0)
- database[self.queues_collection].ensure_index(
+ database[self.queues_collection].create_index(
[('expire_at', 1)], expireAfterSeconds=0)
def _create_client(self):
@@ -427,7 +438,12 @@ class Channel(virtual.Channel):
ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor)
return ret
- def _get_expire(self, queue, argument):
+ def _get_message_expire(self, message):
+ value = message.get('properties', {}).get('expiration')
+ if value is not None:
+ return self.get_now() + datetime.timedelta(milliseconds=int(value))
+
+ def _get_queue_expire(self, queue, argument):
"""Get expiration header named `argument` of queue definition.
Note:
@@ -452,15 +468,15 @@ class Channel(virtual.Channel):
def _update_queues_expire(self, queue):
"""Update expiration field on queues documents."""
- expire_at = self._get_expire(queue, 'x-expires')
+ expire_at = self._get_queue_expire(queue, 'x-expires')
if not expire_at:
return
- self.routing.update(
- {'queue': queue}, {'$set': {'expire_at': expire_at}}, multi=True)
- self.queues.update(
- {'_id': queue}, {'$set': {'expire_at': expire_at}}, multi=True)
+ self.routing.update_many(
+ {'queue': queue}, {'$set': {'expire_at': expire_at}})
+ self.queues.update_many(
+ {'_id': queue}, {'$set': {'expire_at': expire_at}})
def get_now(self):
"""Return current time in UTC."""
diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py
index f230f911..c8fd3c86 100644
--- a/kombu/transport/pyamqp.py
+++ b/kombu/transport/pyamqp.py
@@ -68,6 +68,8 @@ hostname from broker URL. This is usefull when failover is used to fill
"""
+from __future__ import annotations
+
import amqp
from kombu.utils.amq_manager import get_manager
diff --git a/kombu/transport/pyro.py b/kombu/transport/pyro.py
index 833d9792..7b27cb61 100644
--- a/kombu/transport/pyro.py
+++ b/kombu/transport/pyro.py
@@ -32,6 +32,8 @@ Transport Options
"""
+from __future__ import annotations
+
import sys
from queue import Empty, Queue
diff --git a/kombu/transport/qpid.py b/kombu/transport/qpid.py
index b0f8df13..cfd864d8 100644
--- a/kombu/transport/qpid.py
+++ b/kombu/transport/qpid.py
@@ -86,13 +86,14 @@ Celery, this can be accomplished by setting the
*BROKER_TRANSPORT_OPTIONS* Celery option.
"""
+from __future__ import annotations
+
import os
import select
import socket
import ssl
import sys
import uuid
-from collections import OrderedDict
from gettext import gettext as _
from queue import Empty
from time import monotonic
@@ -189,7 +190,7 @@ class QoS:
def __init__(self, session, prefetch_count=1):
self.session = session
self.prefetch_count = 1
- self._not_yet_acked = OrderedDict()
+ self._not_yet_acked = {}
def can_consume(self):
"""Return True if the :class:`Channel` can consume more messages.
@@ -229,8 +230,8 @@ class QoS:
"""Append message to the list of un-ACKed messages.
Add a message, referenced by the delivery_tag, for ACKing,
- rejecting, or getting later. Messages are saved into an
- :class:`collections.OrderedDict` by delivery_tag.
+ rejecting, or getting later. Messages are saved into a
+ dict by delivery_tag.
:param message: A received message that has not yet been ACKed.
:type message: qpid.messaging.Message
diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py
index 103a8466..6cbfbdcf 100644
--- a/kombu/transport/redis.py
+++ b/kombu/transport/redis.py
@@ -51,6 +51,8 @@ Transport Options
* ``priority_steps``
"""
+from __future__ import annotations
+
import functools
import numbers
import socket
@@ -189,6 +191,7 @@ class GlobalKeyPrefixMixin:
PREFIXED_SIMPLE_COMMANDS = [
"HDEL",
"HGET",
+ "HLEN",
"HSET",
"LLEN",
"LPUSH",
@@ -208,6 +211,7 @@ class GlobalKeyPrefixMixin:
"DEL": {"args_start": 0, "args_end": None},
"BRPOP": {"args_start": 0, "args_end": -1},
"EVALSHA": {"args_start": 2, "args_end": 3},
+ "WATCH": {"args_start": 0, "args_end": None},
}
def _prefix_args(self, args):
@@ -216,8 +220,7 @@ class GlobalKeyPrefixMixin:
if command in self.PREFIXED_SIMPLE_COMMANDS:
args[0] = self.global_keyprefix + str(args[0])
-
- if command in self.PREFIXED_COMPLEX_COMMANDS.keys():
+ elif command in self.PREFIXED_COMPLEX_COMMANDS:
args_start = self.PREFIXED_COMPLEX_COMMANDS[command]["args_start"]
args_end = self.PREFIXED_COMPLEX_COMMANDS[command]["args_end"]
@@ -267,6 +270,13 @@ class PrefixedStrictRedis(GlobalKeyPrefixMixin, redis.Redis):
self.global_keyprefix = kwargs.pop('global_keyprefix', '')
redis.Redis.__init__(self, *args, **kwargs)
+ def pubsub(self, **kwargs):
+ return PrefixedRedisPubSub(
+ self.connection_pool,
+ global_keyprefix=self.global_keyprefix,
+ **kwargs,
+ )
+
class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline):
"""Custom Redis pipeline that takes global_keyprefix into consideration.
@@ -281,6 +291,58 @@ class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline):
redis.client.Pipeline.__init__(self, *args, **kwargs)
+class PrefixedRedisPubSub(redis.client.PubSub):
+ """Redis pubsub client that takes global_keyprefix into consideration."""
+
+ PUBSUB_COMMANDS = (
+ "SUBSCRIBE",
+ "UNSUBSCRIBE",
+ "PSUBSCRIBE",
+ "PUNSUBSCRIBE",
+ )
+
+ def __init__(self, *args, **kwargs):
+ self.global_keyprefix = kwargs.pop('global_keyprefix', '')
+ super().__init__(*args, **kwargs)
+
+ def _prefix_args(self, args):
+ args = list(args)
+ command = args.pop(0)
+
+ if command in self.PUBSUB_COMMANDS:
+ args = [
+ self.global_keyprefix + str(arg)
+ for arg in args
+ ]
+
+ return [command, *args]
+
+ def parse_response(self, *args, **kwargs):
+ """Parse a response from the Redis server.
+
+ Method wraps ``PubSub.parse_response()`` to remove prefixes of keys
+ returned by redis command.
+ """
+ ret = super().parse_response(*args, **kwargs)
+ if ret is None:
+ return ret
+
+ # response formats
+ # SUBSCRIBE and UNSUBSCRIBE
+ # -> [message type, channel, message]
+ # PSUBSCRIBE and PUNSUBSCRIBE
+ # -> [message type, pattern, channel, message]
+ message_type, *channels, message = ret
+ return [
+ message_type,
+ *[channel[len(self.global_keyprefix):] for channel in channels],
+ message,
+ ]
+
+ def execute_command(self, *args, **kwargs):
+ return super().execute_command(*self._prefix_args(args), **kwargs)
+
+
class QoS(virtual.QoS):
"""Redis Ack Emulation."""
@@ -353,13 +415,17 @@ class QoS(virtual.QoS):
pass
def restore_by_tag(self, tag, client=None, leftmost=False):
- with self.channel.conn_or_acquire(client) as client:
- with client.pipeline() as pipe:
- p, _, _ = self._remove_from_indices(
- tag, pipe.hget(self.unacked_key, tag)).execute()
+
+ def restore_transaction(pipe):
+ p = pipe.hget(self.unacked_key, tag)
+ pipe.multi()
+ self._remove_from_indices(tag, pipe)
if p:
M, EX, RK = loads(bytes_to_str(p)) # json is unicode
- self.channel._do_restore_message(M, EX, RK, client, leftmost)
+ self.channel._do_restore_message(M, EX, RK, pipe, leftmost)
+
+ with self.channel.conn_or_acquire(client) as client:
+ client.transaction(restore_transaction, self.unacked_key)
@cached_property
def unacked_key(self):
@@ -709,32 +775,35 @@ class Channel(virtual.Channel):
self.connection.cycle._on_connection_disconnect(connection)
def _do_restore_message(self, payload, exchange, routing_key,
- client=None, leftmost=False):
- with self.conn_or_acquire(client) as client:
+ pipe, leftmost=False):
+ try:
try:
- try:
- payload['headers']['redelivered'] = True
- except KeyError:
- pass
- for queue in self._lookup(exchange, routing_key):
- (client.lpush if leftmost else client.rpush)(
- queue, dumps(payload),
- )
- except Exception:
- crit('Could not restore message: %r', payload, exc_info=True)
+ payload['headers']['redelivered'] = True
+ payload['properties']['delivery_info']['redelivered'] = True
+ except KeyError:
+ pass
+ for queue in self._lookup(exchange, routing_key):
+ (pipe.lpush if leftmost else pipe.rpush)(
+ queue, dumps(payload),
+ )
+ except Exception:
+ crit('Could not restore message: %r', payload, exc_info=True)
def _restore(self, message, leftmost=False):
if not self.ack_emulation:
return super()._restore(message)
tag = message.delivery_tag
- with self.conn_or_acquire() as client:
- with client.pipeline() as pipe:
- P, _ = pipe.hget(self.unacked_key, tag) \
- .hdel(self.unacked_key, tag) \
- .execute()
+
+ def restore_transaction(pipe):
+ P = pipe.hget(self.unacked_key, tag)
+ pipe.multi()
+ pipe.hdel(self.unacked_key, tag)
if P:
M, EX, RK = loads(bytes_to_str(P)) # json is unicode
- self._do_restore_message(M, EX, RK, client, leftmost)
+ self._do_restore_message(M, EX, RK, pipe, leftmost)
+
+ with self.conn_or_acquire() as client:
+ client.transaction(restore_transaction, self.unacked_key)
def _restore_at_beginning(self, message):
return self._restore(message, leftmost=True)
@@ -1116,8 +1185,8 @@ class Channel(virtual.Channel):
if asynchronous:
class Connection(connection_cls):
- def disconnect(self):
- super().disconnect()
+ def disconnect(self, *args):
+ super().disconnect(*args)
channel._on_connection_disconnect(self)
connection_cls = Connection
@@ -1208,13 +1277,14 @@ class Transport(virtual.Transport):
exchange_type=frozenset(['direct', 'topic', 'fanout'])
)
+ if redis:
+ connection_errors, channel_errors = get_redis_error_classes()
+
def __init__(self, *args, **kwargs):
if redis is None:
raise ImportError('Missing redis library (pip install redis)')
super().__init__(*args, **kwargs)
- # Get redis-py exceptions.
- self.connection_errors, self.channel_errors = self._get_errors()
# All channels share the same poller.
self.cycle = MultiChannelPoller()
@@ -1231,6 +1301,14 @@ class Transport(virtual.Transport):
def _on_disconnect(connection):
if connection._sock:
loop.remove(connection._sock)
+
+ # must have started polling or this will break reconnection
+ if cycle.fds:
+ # stop polling in the event loop
+ try:
+ loop.on_tick.remove(on_poll_start)
+ except KeyError:
+ pass
cycle._on_connection_disconnect = _on_disconnect
def on_poll_start():
@@ -1251,10 +1329,6 @@ class Transport(virtual.Transport):
"""Handle AIO event for one of our file descriptors."""
self.cycle.on_readable(fileno)
- def _get_errors(self):
- """Utility to import redis-py's exceptions at runtime."""
- return get_redis_error_classes()
-
if sentinel:
class SentinelManagedSSLConnection(
diff --git a/kombu/transport/sqlalchemy/__init__.py b/kombu/transport/sqlalchemy/__init__.py
index 91f87a86..a61c8ea8 100644
--- a/kombu/transport/sqlalchemy/__init__.py
+++ b/kombu/transport/sqlalchemy/__init__.py
@@ -50,15 +50,13 @@ Transport Options
Moreover parameters of :func:`sqlalchemy.create_engine()` function can be passed as transport options.
"""
-# SQLAlchemy overrides != False to have special meaning and pep8 complains
-# flake8: noqa
-
+from __future__ import annotations
import threading
from json import dumps, loads
from queue import Empty
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
@@ -71,6 +69,13 @@ from .models import ModelBase
from .models import Queue as QueueBase
from .models import class_registry, metadata
+# SQLAlchemy overrides != False to have special meaning and pep8 complains
+# flake8: noqa
+
+
+
+
+
VERSION = (1, 4, 1)
__version__ = '.'.join(map(str, VERSION))
@@ -164,7 +169,7 @@ class Channel(virtual.Channel):
def _get(self, queue):
obj = self._get_or_create(queue)
if self.session.bind.name == 'sqlite':
- self.session.execute('BEGIN IMMEDIATE TRANSACTION')
+ self.session.execute(text('BEGIN IMMEDIATE TRANSACTION'))
try:
msg = self.session.query(self.message_cls) \
.with_for_update() \
diff --git a/kombu/transport/sqlalchemy/models.py b/kombu/transport/sqlalchemy/models.py
index 45863852..edff572a 100644
--- a/kombu/transport/sqlalchemy/models.py
+++ b/kombu/transport/sqlalchemy/models.py
@@ -1,10 +1,12 @@
"""Kombu transport using SQLAlchemy as the message store."""
+from __future__ import annotations
+
import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer,
Sequence, SmallInteger, String, Text)
-from sqlalchemy.orm import relation
+from sqlalchemy.orm import relationship
from sqlalchemy.schema import MetaData
try:
@@ -35,7 +37,7 @@ class Queue:
@declared_attr
def messages(cls):
- return relation('Message', backref='queue', lazy='noload')
+ return relationship('Message', backref='queue', lazy='noload')
class Message:
diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py
index 7ab11772..54e84665 100644
--- a/kombu/transport/virtual/__init__.py
+++ b/kombu/transport/virtual/__init__.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from .base import (AbstractChannel, Base64, BrokerState, Channel, Empty,
Management, Message, NotEquivalentError, QoS, Transport,
UndeliverableWarning, binding_key_t, queue_binding_t)
diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py
index 95e539ac..552ebec7 100644
--- a/kombu/transport/virtual/base.py
+++ b/kombu/transport/virtual/base.py
@@ -3,6 +3,8 @@
Emulates the AMQ API for non-AMQ transports.
"""
+from __future__ import annotations
+
import base64
import socket
import sys
@@ -13,6 +15,7 @@ from itertools import count
from multiprocessing.util import Finalize
from queue import Empty
from time import monotonic, sleep
+from typing import TYPE_CHECKING
from amqp.protocol import queue_declare_ok_t
@@ -26,6 +29,9 @@ from kombu.utils.uuid import uuid
from .exchange import STANDARD_EXCHANGE_TYPES
+if TYPE_CHECKING:
+ from types import TracebackType
+
ARRAY_TYPE_H = 'H'
UNDELIVERABLE_FMT = """\
@@ -177,6 +183,8 @@ class QoS:
self.channel = channel
self.prefetch_count = prefetch_count or 0
+ # Standard Python dictionaries do not support setting attributes
+ # on the object, hence the use of OrderedDict
self._delivered = OrderedDict()
self._delivered.restored = False
self._dirty = set()
@@ -462,14 +470,7 @@ class Channel(AbstractChannel, base.StdChannel):
typ: cls(self) for typ, cls in self.exchange_types.items()
}
- try:
- self.channel_id = self.connection._avail_channel_ids.pop()
- except IndexError:
- raise ResourceError(
- 'No free channel ids, current={}, channel_max={}'.format(
- len(self.connection.channels),
- self.connection.channel_max), (20, 10),
- )
+ self.channel_id = self._get_free_channel_id()
topts = self.connection.client.transport_options
for opt_name in self.from_transport_options:
@@ -727,7 +728,8 @@ class Channel(AbstractChannel, base.StdChannel):
message = message.serializable()
message['redelivered'] = True
for queue in self._lookup(
- delivery_info['exchange'], delivery_info['routing_key']):
+ delivery_info['exchange'],
+ delivery_info['routing_key']):
self._put(queue, message)
def _restore_at_beginning(self, message):
@@ -804,7 +806,12 @@ class Channel(AbstractChannel, base.StdChannel):
def __enter__(self):
return self
- def __exit__(self, *exc_info):
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None
+ ) -> None:
self.close()
@property
@@ -844,6 +851,22 @@ class Channel(AbstractChannel, base.StdChannel):
return (self.max_priority - priority) if reverse else priority
+ def _get_free_channel_id(self):
+ # Cast to a set for fast lookups, and keep stored as an array
+ # for lower memory usage.
+ used_channel_ids = set(self.connection._used_channel_ids)
+
+ for channel_id in range(1, self.connection.channel_max + 1):
+ if channel_id not in used_channel_ids:
+ self.connection._used_channel_ids.append(channel_id)
+ return channel_id
+
+ raise ResourceError(
+ 'No free channel ids, current={}, channel_max={}'.format(
+ len(self.connection.channels),
+ self.connection.channel_max), (20, 10),
+ )
+
class Management(base.Management):
"""Base class for the AMQP management API."""
@@ -907,9 +930,7 @@ class Transport(base.Transport):
polling_interval = client.transport_options.get('polling_interval')
if polling_interval is not None:
self.polling_interval = polling_interval
- self._avail_channel_ids = array(
- ARRAY_TYPE_H, range(self.channel_max, 0, -1),
- )
+ self._used_channel_ids = array(ARRAY_TYPE_H)
def create_channel(self, connection):
try:
@@ -921,7 +942,11 @@ class Transport(base.Transport):
def close_channel(self, channel):
try:
- self._avail_channel_ids.append(channel.channel_id)
+ try:
+ self._used_channel_ids.remove(channel.channel_id)
+ except ValueError:
+ # channel id already removed
+ pass
try:
self.channels.remove(channel)
except ValueError:
@@ -934,7 +959,7 @@ class Transport(base.Transport):
# this channel is then used as the next requested channel.
# (returned by ``create_channel``).
self._avail_channels.append(self.create_channel(self))
- return self # for drain events
+ return self # for drain events
def close_connection(self, connection):
self.cycle.close()
diff --git a/kombu/transport/virtual/exchange.py b/kombu/transport/virtual/exchange.py
index c6b6161c..b70544cd 100644
--- a/kombu/transport/virtual/exchange.py
+++ b/kombu/transport/virtual/exchange.py
@@ -4,6 +4,8 @@ Implementations of the standard exchanges defined
by the AMQ protocol (excluding the `headers` exchange).
"""
+from __future__ import annotations
+
import re
from kombu.utils.text import escape_regex
diff --git a/kombu/transport/zookeeper.py b/kombu/transport/zookeeper.py
index 1a2ab63c..c72ce2f5 100644
--- a/kombu/transport/zookeeper.py
+++ b/kombu/transport/zookeeper.py
@@ -42,6 +42,8 @@ Transport Options
"""
+from __future__ import annotations
+
import os
import socket
from queue import Empty
diff --git a/kombu/utils/__init__.py b/kombu/utils/__init__.py
index 304e2dfa..94bb3cdf 100644
--- a/kombu/utils/__init__.py
+++ b/kombu/utils/__init__.py
@@ -1,5 +1,7 @@
"""DEPRECATED - Import from modules below."""
+from __future__ import annotations
+
from .collections import EqualityDict
from .compat import fileno, maybe_fileno, nested, register_after_fork
from .div import emergency_dump_state
diff --git a/kombu/utils/amq_manager.py b/kombu/utils/amq_manager.py
index 7491bb25..f3e429fd 100644
--- a/kombu/utils/amq_manager.py
+++ b/kombu/utils/amq_manager.py
@@ -1,6 +1,9 @@
"""AMQP Management API utilities."""
+from __future__ import annotations
+
+
def get_manager(client, hostname=None, port=None, userid=None,
password=None):
"""Get pyrabbit manager."""
diff --git a/kombu/utils/collections.py b/kombu/utils/collections.py
index 77781047..1a0a6d0d 100644
--- a/kombu/utils/collections.py
+++ b/kombu/utils/collections.py
@@ -1,6 +1,9 @@
"""Custom maps, sequences, etc."""
+from __future__ import annotations
+
+
class HashedSeq(list):
"""Hashed Sequence.
diff --git a/kombu/utils/compat.py b/kombu/utils/compat.py
index ffc224c1..e1b22f66 100644
--- a/kombu/utils/compat.py
+++ b/kombu/utils/compat.py
@@ -1,5 +1,7 @@
"""Python Compatibility Utilities."""
+from __future__ import annotations
+
import numbers
import sys
from contextlib import contextmanager
@@ -77,9 +79,18 @@ def detect_environment():
def entrypoints(namespace):
"""Return setuptools entrypoints for namespace."""
+ if sys.version_info >= (3,10):
+ entry_points = importlib_metadata.entry_points(group=namespace)
+ else:
+ entry_points = importlib_metadata.entry_points()
+ try:
+ entry_points = entry_points.get(namespace, [])
+ except AttributeError:
+ entry_points = entry_points.select(group=namespace)
+
return (
(ep, ep.load())
- for ep in importlib_metadata.entry_points().get(namespace, [])
+ for ep in entry_points
)
diff --git a/kombu/utils/debug.py b/kombu/utils/debug.py
index acc2d60b..bd20948f 100644
--- a/kombu/utils/debug.py
+++ b/kombu/utils/debug.py
@@ -1,5 +1,7 @@
"""Debugging support."""
+from __future__ import annotations
+
import logging
from vine.utils import wraps
diff --git a/kombu/utils/div.py b/kombu/utils/div.py
index 45be7f94..439b6639 100644
--- a/kombu/utils/div.py
+++ b/kombu/utils/div.py
@@ -1,5 +1,7 @@
"""Div. Utilities."""
+from __future__ import annotations
+
import sys
from .encoding import default_encode
diff --git a/kombu/utils/encoding.py b/kombu/utils/encoding.py
index 5f58f0fa..42bf2ce9 100644
--- a/kombu/utils/encoding.py
+++ b/kombu/utils/encoding.py
@@ -5,6 +5,8 @@ applications without crashing from the infamous
:exc:`UnicodeDecodeError` exception.
"""
+from __future__ import annotations
+
import sys
import traceback
diff --git a/kombu/utils/eventio.py b/kombu/utils/eventio.py
index 48260a48..f8d89d45 100644
--- a/kombu/utils/eventio.py
+++ b/kombu/utils/eventio.py
@@ -1,5 +1,7 @@
"""Selector Utilities."""
+from __future__ import annotations
+
import errno
import math
import select as __select__
diff --git a/kombu/utils/functional.py b/kombu/utils/functional.py
index 366a0b99..6beb17d7 100644
--- a/kombu/utils/functional.py
+++ b/kombu/utils/functional.py
@@ -1,5 +1,7 @@
"""Functional Utilities."""
+from __future__ import annotations
+
import inspect
import random
import threading
diff --git a/kombu/utils/imports.py b/kombu/utils/imports.py
index fd4482a8..8752fa1a 100644
--- a/kombu/utils/imports.py
+++ b/kombu/utils/imports.py
@@ -1,5 +1,7 @@
"""Import related utilities."""
+from __future__ import annotations
+
import importlib
import sys
diff --git a/kombu/utils/json.py b/kombu/utils/json.py
index cedaa793..ec6269e2 100644
--- a/kombu/utils/json.py
+++ b/kombu/utils/json.py
@@ -1,75 +1,75 @@
"""JSON Serialization Utilities."""
-import datetime
-import decimal
-import json as stdjson
+from __future__ import annotations
+
+import base64
+import json
import uuid
+from datetime import date, datetime, time
+from decimal import Decimal
+from typing import Any, Callable, TypeVar
-try:
- from django.utils.functional import Promise as DjangoPromise
-except ImportError: # pragma: no cover
- class DjangoPromise:
- """Dummy object."""
+textual_types = ()
try:
- import json
- _json_extra_kwargs = {}
-
- class _DecodeError(Exception):
- pass
-except ImportError: # pragma: no cover
- import simplejson as json
- from simplejson.decoder import JSONDecodeError as _DecodeError
- _json_extra_kwargs = {
- 'use_decimal': False,
- 'namedtuple_as_object': False,
- }
-
+ from django.utils.functional import Promise
-_encoder_cls = type(json._default_encoder)
-_default_encoder = None # ... set to JSONEncoder below.
+ textual_types += (Promise,)
+except ImportError:
+ pass
-class JSONEncoder(_encoder_cls):
+class JSONEncoder(json.JSONEncoder):
"""Kombu custom json encoder."""
- def default(self, o,
- dates=(datetime.datetime, datetime.date),
- times=(datetime.time,),
- textual=(decimal.Decimal, uuid.UUID, DjangoPromise),
- isinstance=isinstance,
- datetime=datetime.datetime,
- text_t=str):
- reducer = getattr(o, '__json__', None)
+ def default(self, o):
+ reducer = getattr(o, "__json__", None)
if reducer is not None:
return reducer()
- else:
- if isinstance(o, dates):
- if not isinstance(o, datetime):
- o = datetime(o.year, o.month, o.day, 0, 0, 0, 0)
- r = o.isoformat()
- if r.endswith("+00:00"):
- r = r[:-6] + "Z"
- return r
- elif isinstance(o, times):
- return o.isoformat()
- elif isinstance(o, textual):
- return text_t(o)
- return super().default(o)
+ if isinstance(o, textual_types):
+ return str(o)
+
+ for t, (marker, encoder) in _encoders.items():
+ if isinstance(o, t):
+ return _as(marker, encoder(o))
+
+ # Bytes is slightly trickier, so we cannot put them directly
+ # into _encoders, because we use two formats: bytes, and base64.
+ if isinstance(o, bytes):
+ try:
+ return _as("bytes", o.decode("utf-8"))
+ except UnicodeDecodeError:
+ return _as("base64", base64.b64encode(o).decode("utf-8"))
-_default_encoder = JSONEncoder
+ return super().default(o)
-def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs):
+def _as(t: str, v: Any):
+ return {"__type__": t, "__value__": v}
+
+
+def dumps(
+ s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs
+):
"""Serialize object to json string."""
- if not default_kwargs:
- default_kwargs = _json_extra_kwargs
- return _dumps(s, cls=cls or _default_encoder,
- **dict(default_kwargs, **kwargs))
+ default_kwargs = default_kwargs or {}
+ return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs))
+
+
+def object_hook(o: dict):
+ """Hook function to perform custom deserialization."""
+ if o.keys() == {"__type__", "__value__"}:
+ decoder = _decoders.get(o["__type__"])
+ if decoder:
+ return decoder(o["__value__"])
+ else:
+ raise ValueError("Unsupported type", type, o)
+ else:
+ return o
-def loads(s, _loads=json.loads, decode_bytes=True):
+def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook):
"""Deserialize json from string."""
# None of the json implementations supports decoding from
# a buffer/memoryview, or even reading from a stream
@@ -78,14 +78,51 @@ def loads(s, _loads=json.loads, decode_bytes=True):
# over. Note that pickle does support buffer/memoryview
# </rant>
if isinstance(s, memoryview):
- s = s.tobytes().decode('utf-8')
+ s = s.tobytes().decode("utf-8")
elif isinstance(s, bytearray):
- s = s.decode('utf-8')
+ s = s.decode("utf-8")
elif decode_bytes and isinstance(s, bytes):
- s = s.decode('utf-8')
-
- try:
- return _loads(s)
- except _DecodeError:
- # catch "Unpaired high surrogate" error
- return stdjson.loads(s)
+ s = s.decode("utf-8")
+
+ return _loads(s, object_hook=object_hook)
+
+
+DecoderT = EncoderT = Callable[[Any], Any]
+T = TypeVar("T")
+EncodedT = TypeVar("EncodedT")
+
+
+def register_type(
+ t: type[T],
+ marker: str,
+ encoder: Callable[[T], EncodedT],
+ decoder: Callable[[EncodedT], T],
+):
+ """Add support for serializing/deserializing native python type."""
+ _encoders[t] = (marker, encoder)
+ _decoders[marker] = decoder
+
+
+_encoders: dict[type, tuple[str, EncoderT]] = {}
+_decoders: dict[str, DecoderT] = {
+ "bytes": lambda o: o.encode("utf-8"),
+ "base64": lambda o: base64.b64decode(o.encode("utf-8")),
+}
+
+# NOTE: datetime should be registered before date,
+# because datetime is also instance of date.
+register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat)
+register_type(
+ date,
+ "date",
+ lambda o: o.isoformat(),
+ lambda o: datetime.fromisoformat(o).date(),
+)
+register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat)
+register_type(Decimal, "decimal", str, Decimal)
+register_type(
+ uuid.UUID,
+ "uuid",
+ lambda o: {"hex": o.hex, "version": o.version},
+ lambda o: uuid.UUID(**o),
+)
diff --git a/kombu/utils/limits.py b/kombu/utils/limits.py
index d82884f5..36d11f1f 100644
--- a/kombu/utils/limits.py
+++ b/kombu/utils/limits.py
@@ -1,5 +1,7 @@
"""Token bucket implementation for rate limiting."""
+from __future__ import annotations
+
from collections import deque
from time import monotonic
diff --git a/kombu/utils/objects.py b/kombu/utils/objects.py
index 7fef4a2f..eb4dfc2a 100644
--- a/kombu/utils/objects.py
+++ b/kombu/utils/objects.py
@@ -1,5 +1,7 @@
"""Object Utilities."""
+from __future__ import annotations
+
__all__ = ('cached_property',)
try:
diff --git a/kombu/utils/scheduling.py b/kombu/utils/scheduling.py
index 1875fce4..94286be8 100644
--- a/kombu/utils/scheduling.py
+++ b/kombu/utils/scheduling.py
@@ -1,5 +1,7 @@
"""Scheduling Utilities."""
+from __future__ import annotations
+
from itertools import count
from .imports import symbol_by_name
diff --git a/kombu/utils/text.py b/kombu/utils/text.py
index 1d5fb9de..fea53347 100644
--- a/kombu/utils/text.py
+++ b/kombu/utils/text.py
@@ -2,7 +2,10 @@
# flake8: noqa
+from __future__ import annotations
+
from difflib import SequenceMatcher
+from typing import Iterable, Iterator
from kombu import version_info_t
@@ -16,8 +19,7 @@ def escape_regex(p, white=''):
for c in p)
-def fmatch_iter(needle, haystack, min_ratio=0.6):
- # type: (str, Sequence[str], float) -> Iterator[Tuple[float, str]]
+def fmatch_iter(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> Iterator[tuple[float, str]]:
"""Fuzzy match: iteratively.
Yields:
@@ -29,19 +31,17 @@ def fmatch_iter(needle, haystack, min_ratio=0.6):
yield ratio, key
-def fmatch_best(needle, haystack, min_ratio=0.6):
- # type: (str, Sequence[str], float) -> str
+def fmatch_best(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> str | None:
"""Fuzzy match - Find best match (scalar)."""
try:
return sorted(
fmatch_iter(needle, haystack, min_ratio), reverse=True,
)[0][1]
except IndexError:
- pass
+ return None
-def version_string_as_tuple(s):
- # type: (str) -> version_info_t
+def version_string_as_tuple(s: str) -> version_info_t:
"""Convert version string to version info tuple."""
v = _unpack_version(*s.split('.'))
# X.Y.3a1 -> (X, Y, 3, 'a1')
@@ -53,13 +53,17 @@ def version_string_as_tuple(s):
return v
-def _unpack_version(major, minor=0, micro=0, releaselevel='', serial=''):
- # type: (int, int, int, str, str) -> version_info_t
+def _unpack_version(
+ major: str,
+ minor: str | int = 0,
+ micro: str | int = 0,
+ releaselevel: str = '',
+ serial: str = ''
+) -> version_info_t:
return version_info_t(int(major), int(minor), micro, releaselevel, serial)
-def _splitmicro(micro, releaselevel='', serial=''):
- # type: (int, str, str) -> Tuple[int, str, str]
+def _splitmicro(micro: str, releaselevel: str = '', serial: str = '') -> tuple[int, str, str]:
for index, char in enumerate(micro):
if not char.isdigit():
break
diff --git a/kombu/utils/time.py b/kombu/utils/time.py
index 863f4017..8228d2be 100644
--- a/kombu/utils/time.py
+++ b/kombu/utils/time.py
@@ -1,11 +1,9 @@
"""Time Utilities."""
-# flake8: noqa
-
+from __future__ import annotations
__all__ = ('maybe_s_to_ms',)
-def maybe_s_to_ms(v):
- # type: (Optional[Union[int, float]]) -> int
+def maybe_s_to_ms(v: int | float | None) -> int | None:
"""Convert seconds to milliseconds, but return None for None."""
return int(float(v) * 1000.0) if v is not None else v
diff --git a/kombu/utils/url.py b/kombu/utils/url.py
index de3a9139..f5f47701 100644
--- a/kombu/utils/url.py
+++ b/kombu/utils/url.py
@@ -2,6 +2,8 @@
# flake8: noqa
+from __future__ import annotations
+
from collections.abc import Mapping
from functools import partial
from typing import NamedTuple
diff --git a/kombu/utils/uuid.py b/kombu/utils/uuid.py
index 010b3440..9f77dad9 100644
--- a/kombu/utils/uuid.py
+++ b/kombu/utils/uuid.py
@@ -1,9 +1,11 @@
"""UUID utilities."""
+from __future__ import annotations
-from uuid import uuid4
+from typing import Callable
+from uuid import UUID, uuid4
-def uuid(_uuid=uuid4):
+def uuid(_uuid: Callable[[], UUID] = uuid4) -> str:
"""Generate unique id in UUID4 format.
See Also: