diff options
author | Asif Saif Uddin <auvipy@gmail.com> | 2023-04-08 22:45:08 +0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-08 22:45:08 +0600 |
commit | 973dc3790ac25b9da7b6d2641ac72d95470f6ed8 (patch) | |
tree | 9e7ba02d8520994a06efc37dde05fba722138189 | |
parent | 7ceb675bb69917fae182ebdaf9a2298a308c3fa4 (diff) | |
parent | 2de7f9f038dd62e097e490cb3fa609067c1c3c36 (diff) | |
download | kombu-py310.tar.gz |
Merge branch 'main' into py310py310
201 files changed, 3493 insertions, 767 deletions
diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 2d98ba68..600e72c0 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.0 +current_version = 5.3.0b3 commit = True tag = True parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?P<releaselevel>[a-z]+)? diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index cfa66c6d..9ba62dba 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,8 +1,7 @@ # These are supported funding model platforms -github: auvipy patreon: auvipy open_collective: celery ko_fi: # Replace with a single Ko-fi username -tidelift: pypi/kombu +tidelift: "pypi/kombu" custom: # Replace with a single custom sponsorship URL diff --git a/.github/tidelift.yml b/.github/tidelift.yml new file mode 100644 index 00000000..3df65f56 --- /dev/null +++ b/.github/tidelift.yml @@ -0,0 +1,18 @@ +name: Tidelift Alignment +on: + push: + + +jobs: + build: + name: Run Tidelift to ensure approved open source packages are in use + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Alignment + uses: tidelift/alignment-action@main + env: + TIDELIFT_API_KEY: ${{ secrets.TIDELIFT_API_KEY }} + TIDELIFT_ORGANIZATION: ${{ secrets.TIDELIFT_ORGANIZATION }} + TIDELIFT_PROJECT: ${{ secrets.TIDELIFT_PROJECT }} diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dadcb5a8..e0e2a15b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,12 +6,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7,3.8,3.9,"3.11"] + python-version: [3.7,3.8,3.9,"3.10","3.11"] + steps: - name: Install system packages run: sudo apt update && sudo apt-get install libcurl4-openssl-dev libssl-dev - name: Check out code from GitHub - uses: actions/checkout@v2.3.5 + uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} id: python uses: actions/setup-python@main @@ -29,16 +30,16 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7,3.8,3.9,"3.11"] + python-version: [3.7,3.8,3.9,"3.10","3.11"] experimental: [false] include: - - python-version: pypy3 + - python-version: pypy-3.9 experimental: true steps: - name: Install system packages run: sudo apt update && sudo apt-get install libcurl4-openssl-dev libssl-dev - name: Check out code from GitHub - uses: actions/checkout@v2.3.5 + uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} id: python uses: actions/setup-python@main @@ -46,10 +47,18 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: pip install --upgrade pip setuptools wheel tox tox-docker + # Tox fails if a Python versions contains a hyphen, this changes "pypy-3.9" to "pypy3.9". + - name: Determine Python version + run: echo PYTHON_VERSION=$(echo ${{ matrix.python-version }} | sed s/-//) >> $GITHUB_ENV - name: Run AMQP integration tests - run: tox -v -e ${{ matrix.python-version }}-linux-integration-py-amqp -- -v + run: tox -v -e ${{ env.PYTHON_VERSION }}-linux-integration-py-amqp -- -v - name: Run redis integration tests - run: tox -v -e ${{ matrix.python-version }}-linux-integration-py-redis -- -v + run: tox -v -e ${{ env.PYTHON_VERSION }}-linux-integration-py-redis -- -v + - name: Run MongoDB integration tests + run: tox -v -e ${{ env.PYTHON_VERSION }}-linux-integration-py-mongodb -- -v + - name: Run kafka integration tests + if: ${{ env.PYTHON_VERSION != 'pypy3.9'}} + run: tox -v -e ${{ env.PYTHON_VERSION }}-linux-integration-py-kafka -- -v #################### Linters and checkers #################### lint: @@ -57,12 +66,12 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: ["3.10"] steps: - name: Install system packages run: sudo apt update && sudo apt-get install libcurl4-openssl-dev libssl-dev - name: Check out code from GitHub - uses: actions/checkout@v2.3.5 + uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} id: python uses: actions/setup-python@main @@ -76,3 +85,5 @@ jobs: run: tox -v -e pydocstyle -- -v - name: Run apicheck run: tox -v -e apicheck -- -v + - name: Run mypy + run: tox -v -e mypy -- -v diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..b9a40588 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,68 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ main ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ main ] + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 @@ -26,6 +26,7 @@ kombu/tests/coverage.xml .coverage dump.rdb .idea/ +.vscode/ .cache/ .pytest_cache/ htmlcov/ @@ -34,3 +35,4 @@ coverage.xml venv/ env .eggs +.python-version diff --git a/.landscape.yml b/.landscape.yml deleted file mode 100644 index f90444af..00000000 --- a/.landscape.yml +++ /dev/null @@ -1,3 +0,0 @@ -pylint: - disable: - - cyclic-import diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f1dfb37c..e2b8fac1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,28 @@ repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.29.0 + rev: v3.3.1 hooks: - id: pyupgrade - args: ["--py36-plus"] + args: ["--py37-plus", "--keep-runtime-typing"] + + - repo: https://github.com/PyCQA/autoflake + rev: v2.0.2 + hooks: + - id: autoflake + args: ["--in-place", "--ignore-pass-after-docstring", "--imports"] - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 - repo: https://github.com/asottile/yesqa - rev: v1.3.0 + rev: v1.4.0 hooks: - id: yesqa - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.4.0 hooks: - id: check-merge-conflict - id: check-toml @@ -24,6 +30,6 @@ repos: - id: mixed-line-ending - repo: https://github.com/pycqa/isort - rev: 5.9.3 + rev: 5.12.0 hooks: - id: isort diff --git a/.pyup.yml b/.pyup.yml new file mode 100644 index 00000000..bdd9a62e --- /dev/null +++ b/.pyup.yml @@ -0,0 +1,5 @@ +# autogenerated pyup.io config file +# see https://pyup.io/docs/configuration/ for all available options + +schedule: '' +update: false @@ -18,6 +18,7 @@ Anthony Lukach <anthonylukach@gmail.com> Antoine Legrand <antoine.legrand@smartjog.com> Anton Gyllenberg <anton@iki.fi> Ask Solem <ask@celeryproject.org> +Asif Saif Uddin <auvipy@gmail.com> Basil Mironenko <bmironenko@ddn.com> Bobby Beever <bobby.beever@yahoo.com> Brian Bernstein @@ -88,6 +89,7 @@ Lorenzo Mancini <lmancini@develer.com> Luyun Xie <2304310@qq.com> Mads Jensen <https://github.com/atombrella> Mahendra M <Mahendra_M@infosys.com> +Manuel Vazquez Acosta <https://github.com/mvaled> Marcin Lulek (ergo) <info@webreactor.eu> Marcin Puhacz <marcin.puhacz@gmail.com> Mark Lavin <mlavin@caktusgroup.com> diff --git a/Changelog.rst b/Changelog.rst index 14ff80ce..6a9890e4 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -4,11 +4,148 @@ Change history ================ +.. _version-5.3.0b3: + +5.3.0b3 +======= +:release-date: 20 Mar, 2023 +:release-by: Asif Saif Uddin + +- Use SPDX license expression in project metadata. +- Allowing Connection.ensure() to retry on specific exceptions given by policy (#1629). +- Redis==4.3.4 temporarilly in an attempt to avoid BC (#1634). +- Add managed identity support to azure storage queue (#1631). +- Support sqla v2.0 (#1651). +- Switch to Pyro5 (#1655). +- Remove unused _setupfuns from serialization.py. +- Refactor: Refactor utils/json (#1659). +- Adapt the mock to correctly mock the behaviors as implemented on Python 3.10. (Ref #1663). + + +.. _version-5.3.0b2: + +5.3.0b2 +======= +:release-date: 19 Oct, 2022 +:release-by: Asif Saif Uddin + +- fix: save QueueProperties to _queue_name_cache instead of QueueClient. +- hub: tick delay fix (#1587). +- Fix incompatibility with redis in disconnect() (#1589). +- Solve Kombu filesystem transport not thread safe. +- importlib_metadata remove deprecated entry point interfaces (#1601). +- Allow azurestoragequeues transport to be used with Azurite emulator in docker-compose (#1611). + + +.. _version-5.3.0b1: + +5.3.0b1 +======= +:release-date: 1 Aug, 2022 +:release-by: Asif Saif Uddin + +- Add ext.py files to setup.cfg. +- Add support to SQS DelaySeconds (#1567). +- Add WATCH to prefixed complex commands. +- Avoid losing type of UUID when serializing/deserializing (#1575). +- chore: add confluentkafka to extras. + +.. _version-5.3.0a1: + +5.3.0a1 +======= +:release-date: 29 Jun, 2022 +:release-by: Asif Saif Uddin + +- Add fanout to filesystem (#1499). +- Protect set of ready tasks by lock to avoid concurrent updates. (#1489). +- Correct documentation stating kombu uses pickle protocol version 2. +- Use new entry_points interface. +- Add mypy to the pipeline (#1512). +- Added possibility to serialize and deserialize binary messages in json (#1516). +- Bump pyupgrade version and add __future__.annotations import. +- json.py cleaning from outdated libs (#1533). +- bump new py-amqp to 5.1.1 (#1534). +- add GitHub URL for PyPi. +- Upgrade pytest to ~=7.1.1. +- Support pymongo 4.x (#1536). +- Initial Kafka support (#1506). +- Upgrade Azure Storage Queues transport to version 12 (#1539). +- move to consul2 (#1544). +- Datetime serialization and deserialization fixed (#1515). +- Bump redis>=4.2.2 (#1546). +- Update sqs dependencies (#1547). +- Added HLEN to the list of prefixed redis commands (#1540). +- Added some type annotations. + + +.. _version-5.2.4: + +5.2.4 +===== +:release-date: 06 Mar, 2022 +:release-by: Asif Saif Uddin + +- Allow getting recoverable_connection_errors without an active transport. +- Prevent KeyError: 'purelib' by removing INSTALLED_SCHEME hack from setup.py. +- Revert "try pining setuptools (#1466)" (#1481). +- Fix issue #789: Async http code not allowing for proxy config (#790). +- Fix The incorrect times of retrying. +- Set redelivered property for Celery with Redis (#1484). +- Remove use of OrderedDict in various places (#1483). +- Warn about missing hostname only when default one is available (#1488). +- All supported versions of Python define __package__. +- Added global_keyprefix support for pubsub clients (#1495). +- try pytest 7 (#1497). +- Add an option to not base64-encode SQS messages. +- Fix SQS extract_task_name message reference. + + +.. _version-5.2.3: + +5.2.3 +===== +:release-date: 29 Dec, 2021 +:release-by: Asif Saif Uddin + +- Allow redis >= 4.0.2. +- Fix PyPy CI jobs. +- SQS transport: detect FIFO queue properly by checking queue URL (#1450). +- Ensure that restore is atomic in redis transport (#1444). +- Restrict setuptools>=59.1.1,<59.7.0. +- Bump minimum py-amqp to v5.0.9 (#1462). +- Reduce memory usage of Transport (#1470). +- Prevent event loop polling on closed redis transports (and causing leak). +- Respect connection timeout (#1458) +- prevent redis event loop stopping on 'consumer: Cannot connect' (#1477). + + +.. _version-5.2.2: + +5.2.2 +===== +:release-date: 16 Nov, 2021 +:release-by: Asif Saif Uddin + +- Pin redis version to >= 3.4.1<4.0.0 as it is not fully compatible yet. + + +.. _version-5.2.1: + +5.2.1 +===== +:release-date: 8 Nov, 2021 +:release-by: Asif Saif Uddin + +- Bump redis version to >= 3.4.1. +- try latest sqs dependencies ti fix security warning. +- Tests & dependency updates + .. _version-5.2.0: 5.2.0 ===== -:release-date: soon +:release-date: 5 Nov, 2021 :release-by: Naomi Elstein - v 1.4.x (#1338). @@ -4,7 +4,7 @@ |build-status| |coverage| |license| |wheel| |pyversion| |pyimp| |downloads| -:Version: 5.2.0 +:Version: 5.3.0b3 :Documentation: https://kombu.readthedocs.io/ :Download: https://pypi.org/project/kombu/ :Source: https://github.com/celery/kombu/ @@ -127,7 +127,7 @@ Quick overview video_queue = Queue('video', exchange=media_exchange, routing_key='video') def process_media(body, message): - print body + print(body) message.ack() # connections @@ -305,7 +305,7 @@ Mailing list Join the `celery-users`_ mailing list. -.. _`celery-users`: https://groups.google.com/group/celery-users/ +.. _`kombu forum`: https://github.com/celery/kombu/discussions Bug tracker =========== @@ -328,12 +328,12 @@ This software is licensed under the `New BSD License`. See the `LICENSE` file in the top distribution directory for the full license text. -.. |build-status| image:: https://api.travis-ci.com/celery/kombu.png?branch=master +.. |build-status| image:: https://github.com/celery/kombu/actions/workflows/ci.yaml/badge.svg :alt: Build status - :target: https://travis-ci.com/celery/kombu + :target: https://github.com/celery/kombu/actions/workflows/ci.yml -.. |coverage| image:: https://codecov.io/github/celery/kombu/coverage.svg?branch=master - :target: https://codecov.io/github/celery/kombu?branch=master +.. |coverage| image:: https://codecov.io/github/celery/kombu/coverage.svg?branch=main + :target: https://codecov.io/github/celery/kombu?branch=main .. |license| image:: https://img.shields.io/pypi/l/kombu.svg :alt: BSD License diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..ec793b59 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,17 @@ +# Security Policy + +## Supported Versions + + +| Version | Supported | +| ------- | ------------------ | +| 5.2.x | :white_check_mark: | +| 5.0.x | :x: | +| 5.1.x | :white_check_mark: | +| < 5.0 | :x: | + +## Reporting a Vulnerability + +Please report vulnerability issues directly to auvipy@gmail.com + + diff --git a/conftest.py b/conftest.py index 3fc0a687..842d39a2 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest diff --git a/docs/conf.py b/docs/conf.py index 1f781486..0163fe39 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sphinx_celery import conf globals().update(conf.build_config( diff --git a/docs/includes/introduction.txt b/docs/includes/introduction.txt index f79ef193..36eca4bd 100644 --- a/docs/includes/introduction.txt +++ b/docs/includes/introduction.txt @@ -1,4 +1,4 @@ -:Version: 5.2.0 +:Version: 5.3.0b3 :Web: https://kombu.readthedocs.io/ :Download: https://pypi.org/project/kombu/ :Source: https://github.com/celery/kombu/ @@ -23,15 +23,7 @@ Features * Allows application authors to support several message server solutions by using pluggable transports. - * AMQP transport using the `py-amqp`_, `librabbitmq`_, or `qpid-python`_ libraries. - - * High performance AMQP transport written in C - when using `librabbitmq`_ - - This is automatically enabled if librabbitmq is installed: - - .. code-block:: console - - $ pip install librabbitmq + * AMQP transport using the `py-amqp`_, `redis`_, or `SQS`_ libraries. * Virtual transports makes it really easy to add support for non-AMQP transports. There is already built-in support for `Redis`_, diff --git a/docs/reference/kombu.serialization.rst b/docs/reference/kombu.serialization.rst index 04d7b84a..07ca9bfb 100644 --- a/docs/reference/kombu.serialization.rst +++ b/docs/reference/kombu.serialization.rst @@ -44,8 +44,6 @@ .. autodata:: registry -.. _`cjson`: https://pypi.org/project/python-cjson/ -.. _`simplejson`: https://github.com/simplejson/simplejson .. _`Python 2.7+`: https://docs.python.org/library/json.html .. _`PyYAML`: https://pyyaml.org/ .. _`msgpack`: https://msgpack.org/ diff --git a/docs/reference/kombu.transport.confluentkafka.rst b/docs/reference/kombu.transport.confluentkafka.rst new file mode 100644 index 00000000..3b171a28 --- /dev/null +++ b/docs/reference/kombu.transport.confluentkafka.rst @@ -0,0 +1,31 @@ +========================================================= + confluent-kafka Transport - ``kombu.transport.confluentkafka`` +========================================================= + +.. currentmodule:: kombu.transport.confluentkafka + +.. automodule:: kombu.transport.confluentkafka + + .. contents:: + :local: + + Transport + --------- + + .. autoclass:: Transport + :members: + :undoc-members: + + Channel + ------- + + .. autoclass:: Channel + :members: + :undoc-members: + + Message + ------- + + .. autoclass:: Message + :members: + :undoc-members: diff --git a/docs/templates/readme.txt b/docs/templates/readme.txt index 05b8edb0..55175ca2 100644 --- a/docs/templates/readme.txt +++ b/docs/templates/readme.txt @@ -10,12 +10,12 @@ .. include:: ../includes/resources.txt -.. |build-status| image:: https://secure.travis-ci.org/celery/kombu.png?branch=master +.. |build-status| image:: https://github.com/celery/kombu/actions/workflows/ci.yaml/badge.svg :alt: Build status - :target: https://travis-ci.org/celery/kombu + :target: https://github.com/celery/kombu/actions/workflows/ci.yml -.. |coverage| image:: https://codecov.io/github/celery/kombu/coverage.svg?branch=master - :target: https://codecov.io/github/celery/kombu?branch=master +.. |coverage| image:: https://codecov.io/github/celery/kombu/coverage.svg?branch=main + :target: https://codecov.io/github/celery/kombu?branch=main .. |license| image:: https://img.shields.io/pypi/l/kombu.svg :alt: BSD License diff --git a/docs/userguide/serialization.rst b/docs/userguide/serialization.rst index 711bedd9..f8523e69 100644 --- a/docs/userguide/serialization.rst +++ b/docs/userguide/serialization.rst @@ -32,23 +32,45 @@ The accept argument can also include MIME-types. Each option has its advantages and disadvantages. -`json` -- JSON is supported in many programming languages, is now - a standard part of Python (since 2.6), and is fairly fast to - decode using the modern Python libraries such as `cjson` or - `simplejson`. +`json` -- JSON is supported in many programming languages, is + a standard part of Python, and is fairly fast to + decode. The primary disadvantage to `JSON` is that it limits you to the following data types: strings, Unicode, floats, boolean, - dictionaries, and lists. Decimals and dates are notably missing. + dictionaries, lists, decimals, DjangoPromise, datetimes, dates, + time, bytes and UUIDs. + + For dates, datetimes, UUIDs and bytes the serializer will generate + a dict that will later instruct the deserializer how to produce + the right type. Also, binary data will be transferred using Base64 encoding, which will cause the transferred data to be around 34% larger than an - encoding which supports native binary types. + encoding which supports native binary types. This will only happen + if the bytes object can't be decoded into utf8. However, if your data fits inside the above constraints and you need cross-language support, the default setting of `JSON` is probably your best choice. + If you need support for custom types, you can write serialize/deserialize + functions and register them as follows: + + .. code-block:: python + + from kombu.utils.json import register_type + from django.db.models import Model + from django.apps import apps + + # Allow serialization of django models: + register_type( + Model, + "model", + lambda o: [o._meta.label, o.pk], + lambda o: apps.get_model(o[0]).objects.get(pk=o[1]), + ) + `pickle` -- If you have no desire to support any language other than Python, then using the `pickle` encoding will gain you the support of all built-in Python data types (except class instances), @@ -67,7 +89,7 @@ Each option has its advantages and disadvantages. to limit access to the broker so that untrusted parties do not have the ability to send messages! - By default Kombu uses pickle protocol 2, but this can be changed + By default Kombu uses pickle protocol 4, but this can be changed using the :envvar:`PICKLE_PROTOCOL` environment variable or by changing the global :data:`kombu.serialization.pickle_protocol` flag. diff --git a/docs/userguide/simple.rst b/docs/userguide/simple.rst index 8d86711c..41daae52 100644 --- a/docs/userguide/simple.rst +++ b/docs/userguide/simple.rst @@ -61,7 +61,7 @@ to produce and consume logging messages: from kombu import Connection - class Logger(object): + class Logger: def __init__(self, connection, queue_name='log_queue', serializer='json', compression=None): diff --git a/examples/complete_receive.py b/examples/complete_receive.py index 21903c59..1f1fb5e2 100644 --- a/examples/complete_receive.py +++ b/examples/complete_receive.py @@ -4,6 +4,8 @@ and exits. """ +from __future__ import annotations + from pprint import pformat from kombu import Connection, Consumer, Exchange, Queue, eventloop diff --git a/examples/complete_send.py b/examples/complete_send.py index d4643e59..ea7ab65c 100644 --- a/examples/complete_send.py +++ b/examples/complete_send.py @@ -6,6 +6,8 @@ You can use `complete_receive.py` to receive the message sent. """ +from __future__ import annotations + from kombu import Connection, Exchange, Producer, Queue #: By default messages sent to exchanges are persistent (delivery_mode=2), diff --git a/examples/experimental/async_consume.py b/examples/experimental/async_consume.py index 5f70b4ed..55126bda 100644 --- a/examples/experimental/async_consume.py +++ b/examples/experimental/async_consume.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +from __future__ import annotations + from kombu import Connection, Consumer, Exchange, Producer, Queue from kombu.asynchronous import Hub diff --git a/examples/hello_consumer.py b/examples/hello_consumer.py index 3a0d9f06..38b2b61d 100644 --- a/examples/hello_consumer.py +++ b/examples/hello_consumer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from kombu import Connection with Connection('amqp://guest:guest@localhost:5672//') as conn: diff --git a/examples/hello_publisher.py b/examples/hello_publisher.py index 48342800..dfb2df8e 100644 --- a/examples/hello_publisher.py +++ b/examples/hello_publisher.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime from kombu import Connection diff --git a/examples/memory_transport.py b/examples/memory_transport.py index 8d667b81..e25cc898 100644 --- a/examples/memory_transport.py +++ b/examples/memory_transport.py @@ -1,6 +1,8 @@ """ Example that use memory transport for message produce. """ +from __future__ import annotations + import time from kombu import Connection, Consumer, Exchange, Queue diff --git a/examples/rpc-tut6/rpc_client.py b/examples/rpc-tut6/rpc_client.py index 6b1d509b..a899bd5c 100644 --- a/examples/rpc-tut6/rpc_client.py +++ b/examples/rpc-tut6/rpc_client.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +from __future__ import annotations + from kombu import Connection, Consumer, Producer, Queue, uuid diff --git a/examples/rpc-tut6/rpc_server.py b/examples/rpc-tut6/rpc_server.py index 761630ca..ce13ea83 100644 --- a/examples/rpc-tut6/rpc_server.py +++ b/examples/rpc-tut6/rpc_server.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +from __future__ import annotations + from kombu import Connection, Queue from kombu.mixins import ConsumerProducerMixin diff --git a/examples/simple_eventlet_receive.py b/examples/simple_eventlet_receive.py index 33f09d80..703203ad 100644 --- a/examples/simple_eventlet_receive.py +++ b/examples/simple_eventlet_receive.py @@ -7,6 +7,8 @@ message sent. """ +from __future__ import annotations + import eventlet from kombu import Connection diff --git a/examples/simple_eventlet_send.py b/examples/simple_eventlet_send.py index 9a753e7a..ad0c1f2b 100644 --- a/examples/simple_eventlet_send.py +++ b/examples/simple_eventlet_send.py @@ -7,6 +7,8 @@ message sent. """ +from __future__ import annotations + import eventlet from kombu import Connection diff --git a/examples/simple_receive.py b/examples/simple_receive.py index 90c2d0ae..c2512cca 100644 --- a/examples/simple_receive.py +++ b/examples/simple_receive.py @@ -3,6 +3,8 @@ Example receiving a message using the SimpleQueue interface. """ +from __future__ import annotations + from kombu import Connection #: Create connection diff --git a/examples/simple_send.py b/examples/simple_send.py index d2eacea4..194fb810 100644 --- a/examples/simple_send.py +++ b/examples/simple_send.py @@ -7,6 +7,8 @@ message sent. """ +from __future__ import annotations + from kombu import Connection #: Create connection diff --git a/examples/simple_task_queue/client.py b/examples/simple_task_queue/client.py index bc2d180f..aaa18f5c 100644 --- a/examples/simple_task_queue/client.py +++ b/examples/simple_task_queue/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from kombu.pools import producers from .queues import task_exchange diff --git a/examples/simple_task_queue/queues.py b/examples/simple_task_queue/queues.py index 602c2b0e..a545b4c2 100644 --- a/examples/simple_task_queue/queues.py +++ b/examples/simple_task_queue/queues.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from kombu import Exchange, Queue task_exchange = Exchange('tasks', type='direct') diff --git a/examples/simple_task_queue/tasks.py b/examples/simple_task_queue/tasks.py index 2810f7a8..02edda07 100644 --- a/examples/simple_task_queue/tasks.py +++ b/examples/simple_task_queue/tasks.py @@ -1,2 +1,5 @@ +from __future__ import annotations + + def hello_task(who='world'): print(f'Hello {who}') diff --git a/examples/simple_task_queue/worker.py b/examples/simple_task_queue/worker.py index 66ad9c30..4d16e309 100644 --- a/examples/simple_task_queue/worker.py +++ b/examples/simple_task_queue/worker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from kombu.log import get_logger from kombu.mixins import ConsumerMixin from kombu.utils.functional import reprcall diff --git a/kombu/__init__.py b/kombu/__init__.py index da4ed466..999cf9da 100644 --- a/kombu/__init__.py +++ b/kombu/__init__.py @@ -1,11 +1,14 @@ """Messaging library for Python.""" +from __future__ import annotations + import os import re import sys from collections import namedtuple +from typing import Any, cast -__version__ = '5.2.0' +__version__ = '5.3.0b3' __author__ = 'Ask Solem' __contact__ = 'auvipy@gmail.com, ask@celeryproject.org' __homepage__ = 'https://kombu.readthedocs.io' @@ -19,12 +22,12 @@ version_info_t = namedtuple('version_info_t', ( # bumpversion can only search for {current_version} # so we have to parse the version here. -_temp = re.match( - r'(\d+)\.(\d+).(\d+)(.+)?', __version__).groups() +_temp = cast(re.Match, re.match( + r'(\d+)\.(\d+).(\d+)(.+)?', __version__)).groups() VERSION = version_info = version_info_t( int(_temp[0]), int(_temp[1]), int(_temp[2]), _temp[3] or '', '') -del(_temp) -del(re) +del _temp +del re STATICA_HACK = True globals()['kcah_acitats'[::-1].upper()] = False @@ -61,15 +64,15 @@ all_by_module = { } object_origins = {} -for module, items in all_by_module.items(): +for _module, items in all_by_module.items(): for item in items: - object_origins[item] = module + object_origins[item] = _module class module(ModuleType): """Customized Python module.""" - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in object_origins: module = __import__(object_origins[name], None, None, [name]) for extra_name in all_by_module[module.__name__]: @@ -77,7 +80,7 @@ class module(ModuleType): return getattr(module, name) return ModuleType.__getattribute__(self, name) - def __dir__(self): + def __dir__(self) -> list[str]: result = list(new_module.__all__) result.extend(('__file__', '__path__', '__doc__', '__all__', '__docformat__', '__name__', '__path__', 'VERSION', @@ -86,12 +89,6 @@ class module(ModuleType): return result -# 2.5 does not define __package__ -try: - package = __package__ -except NameError: # pragma: no cover - package = 'kombu' - # keep a reference to this module so that it's not garbage collected old_module = sys.modules[__name__] @@ -106,7 +103,7 @@ new_module.__dict__.update({ '__contact__': __contact__, '__homepage__': __homepage__, '__docformat__': __docformat__, - '__package__': package, + '__package__': __package__, 'version_info_t': version_info_t, 'version_info': version_info, 'VERSION': VERSION diff --git a/kombu/abstract.py b/kombu/abstract.py index 38cff010..48a917c9 100644 --- a/kombu/abstract.py +++ b/kombu/abstract.py @@ -1,19 +1,35 @@ """Object utilities.""" +from __future__ import annotations + from copy import copy +from typing import TYPE_CHECKING, Any, Callable, TypeVar from .connection import maybe_channel from .exceptions import NotBoundError from .utils.functional import ChannelPromise +if TYPE_CHECKING: + from kombu.connection import Connection + from kombu.transport.virtual import Channel + + __all__ = ('Object', 'MaybeChannelBound') +_T = TypeVar("_T") +_ObjectType = TypeVar("_ObjectType", bound="Object") +_MaybeChannelBoundType = TypeVar( + "_MaybeChannelBoundType", bound="MaybeChannelBound" +) + -def unpickle_dict(cls, kwargs): +def unpickle_dict( + cls: type[_ObjectType], kwargs: dict[str, Any] +) -> _ObjectType: return cls(**kwargs) -def _any(v): +def _any(v: _T) -> _T: return v @@ -23,9 +39,9 @@ class Object: Supports automatic kwargs->attributes handling, and cloning. """ - attrs = () + attrs: tuple[tuple[str, Any], ...] = () - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: for name, type_ in self.attrs: value = kwargs.get(name) if value is not None: @@ -36,8 +52,8 @@ class Object: except AttributeError: setattr(self, name, None) - def as_dict(self, recurse=False): - def f(obj, type): + def as_dict(self, recurse: bool = False) -> dict[str, Any]: + def f(obj: Any, type: Callable[[Any], Any]) -> Any: if recurse and isinstance(obj, Object): return obj.as_dict(recurse=True) return type(obj) if type and obj is not None else obj @@ -45,31 +61,40 @@ class Object: attr: f(getattr(self, attr), type) for attr, type in self.attrs } - def __reduce__(self): + def __reduce__(self: _ObjectType) -> tuple[ + Callable[[type[_ObjectType], dict[str, Any]], _ObjectType], + tuple[type[_ObjectType], dict[str, Any]] + ]: return unpickle_dict, (self.__class__, self.as_dict()) - def __copy__(self): + def __copy__(self: _ObjectType) -> _ObjectType: return self.__class__(**self.as_dict()) class MaybeChannelBound(Object): """Mixin for classes that can be bound to an AMQP channel.""" - _channel = None + _channel: Channel | None = None _is_bound = False #: Defines whether maybe_declare can skip declaring this entity twice. can_cache_declaration = False - def __call__(self, channel): + def __call__( + self: _MaybeChannelBoundType, channel: (Channel | Connection) + ) -> _MaybeChannelBoundType: """`self(channel) -> self.bind(channel)`.""" return self.bind(channel) - def bind(self, channel): + def bind( + self: _MaybeChannelBoundType, channel: (Channel | Connection) + ) -> _MaybeChannelBoundType: """Create copy of the instance that is bound to a channel.""" return copy(self).maybe_bind(channel) - def maybe_bind(self, channel): + def maybe_bind( + self: _MaybeChannelBoundType, channel: (Channel | Connection) + ) -> _MaybeChannelBoundType: """Bind instance to channel if not already bound.""" if not self.is_bound and channel: self._channel = maybe_channel(channel) @@ -77,7 +102,7 @@ class MaybeChannelBound(Object): self._is_bound = True return self - def revive(self, channel): + def revive(self, channel: Channel) -> None: """Revive channel after the connection has been re-established. Used by :meth:`~kombu.Connection.ensure`. @@ -87,13 +112,13 @@ class MaybeChannelBound(Object): self._channel = channel self.when_bound() - def when_bound(self): + def when_bound(self) -> None: """Callback called when the class is bound.""" - def __repr__(self): + def __repr__(self) -> str: return self._repr_entity(type(self).__name__) - def _repr_entity(self, item=''): + def _repr_entity(self, item: str = '') -> str: item = item or type(self).__name__ if self.is_bound: return '<{} bound to chan:{}>'.format( @@ -101,12 +126,12 @@ class MaybeChannelBound(Object): return f'<unbound {item}>' @property - def is_bound(self): + def is_bound(self) -> bool: """Flag set if the channel is bound.""" return self._is_bound and self._channel is not None @property - def channel(self): + def channel(self) -> Channel: """Current channel if the object is bound.""" channel = self._channel if channel is None: diff --git a/kombu/asynchronous/__init__.py b/kombu/asynchronous/__init__.py index fb264aa5..53060753 100644 --- a/kombu/asynchronous/__init__.py +++ b/kombu/asynchronous/__init__.py @@ -1,5 +1,7 @@ """Event loop.""" +from __future__ import annotations + from kombu.utils.eventio import ERR, READ, WRITE from .hub import Hub, get_event_loop, set_event_loop diff --git a/kombu/asynchronous/aws/__init__.py b/kombu/asynchronous/aws/__init__.py index d8423c23..cbeb050f 100644 --- a/kombu/asynchronous/aws/__init__.py +++ b/kombu/asynchronous/aws/__init__.py @@ -1,4 +1,15 @@ -def connect_sqs(aws_access_key_id=None, aws_secret_access_key=None, **kwargs): +from __future__ import annotations + +from typing import Any + +from kombu.asynchronous.aws.sqs.connection import AsyncSQSConnection + + +def connect_sqs( + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + **kwargs: Any +) -> AsyncSQSConnection: """Return async connection to Amazon SQS.""" from .sqs.connection import AsyncSQSConnection return AsyncSQSConnection( diff --git a/kombu/asynchronous/aws/connection.py b/kombu/asynchronous/aws/connection.py index f3926388..887ab40c 100644 --- a/kombu/asynchronous/aws/connection.py +++ b/kombu/asynchronous/aws/connection.py @@ -1,5 +1,7 @@ """Amazon AWS Connection.""" +from __future__ import annotations + from email import message_from_bytes from email.mime.message import MIMEMessage diff --git a/kombu/asynchronous/aws/ext.py b/kombu/asynchronous/aws/ext.py index 2dedc812..1fa4a57e 100644 --- a/kombu/asynchronous/aws/ext.py +++ b/kombu/asynchronous/aws/ext.py @@ -1,5 +1,7 @@ """Amazon boto3 interface.""" +from __future__ import annotations + try: import boto3 from botocore import exceptions diff --git a/kombu/asynchronous/aws/sqs/connection.py b/kombu/asynchronous/aws/sqs/connection.py index 9db2523b..20b56344 100644 --- a/kombu/asynchronous/aws/sqs/connection.py +++ b/kombu/asynchronous/aws/sqs/connection.py @@ -1,5 +1,7 @@ """Amazon SQS Connection.""" +from __future__ import annotations + from vine import transform from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection diff --git a/kombu/asynchronous/aws/sqs/ext.py b/kombu/asynchronous/aws/sqs/ext.py index f6630936..72268b5d 100644 --- a/kombu/asynchronous/aws/sqs/ext.py +++ b/kombu/asynchronous/aws/sqs/ext.py @@ -1,6 +1,8 @@ """Amazon SQS boto3 interface.""" +from __future__ import annotations + try: import boto3 except ImportError: diff --git a/kombu/asynchronous/aws/sqs/message.py b/kombu/asynchronous/aws/sqs/message.py index 9425ff2d..52727bb7 100644 --- a/kombu/asynchronous/aws/sqs/message.py +++ b/kombu/asynchronous/aws/sqs/message.py @@ -1,5 +1,7 @@ """Amazon SQS message implementation.""" +from __future__ import annotations + import base64 from kombu.message import Message diff --git a/kombu/asynchronous/aws/sqs/queue.py b/kombu/asynchronous/aws/sqs/queue.py index 50b0be55..7ca78f75 100644 --- a/kombu/asynchronous/aws/sqs/queue.py +++ b/kombu/asynchronous/aws/sqs/queue.py @@ -1,5 +1,7 @@ """Amazon SQS queue implementation.""" +from __future__ import annotations + from vine import transform from .message import AsyncMessage @@ -12,7 +14,7 @@ def list_first(rs): return rs[0] if len(rs) == 1 else None -class AsyncQueue(): +class AsyncQueue: """Async SQS Queue.""" def __init__(self, connection=None, url=None, message_class=AsyncMessage): diff --git a/kombu/asynchronous/debug.py b/kombu/asynchronous/debug.py index 4fabb452..7c1e45c7 100644 --- a/kombu/asynchronous/debug.py +++ b/kombu/asynchronous/debug.py @@ -1,5 +1,7 @@ """Event-loop debugging tools.""" +from __future__ import annotations + from kombu.utils.eventio import ERR, READ, WRITE from kombu.utils.functional import reprcall diff --git a/kombu/asynchronous/http/__init__.py b/kombu/asynchronous/http/__init__.py index 1c45ebca..67d8b219 100644 --- a/kombu/asynchronous/http/__init__.py +++ b/kombu/asynchronous/http/__init__.py @@ -1,17 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from kombu.asynchronous import get_event_loop +from kombu.asynchronous.http.base import Headers, Request, Response +from kombu.asynchronous.hub import Hub -from .base import Headers, Request, Response +if TYPE_CHECKING: + from kombu.asynchronous.http.curl import CurlClient __all__ = ('Client', 'Headers', 'Response', 'Request') -def Client(hub=None, **kwargs): +def Client(hub: Hub | None = None, **kwargs: int) -> CurlClient: """Create new HTTP client.""" from .curl import CurlClient return CurlClient(hub, **kwargs) -def get_client(hub=None, **kwargs): +def get_client(hub: Hub | None = None, **kwargs: int) -> CurlClient: """Get or create HTTP client bound to the current event loop.""" hub = hub or get_event_loop() try: diff --git a/kombu/asynchronous/http/base.py b/kombu/asynchronous/http/base.py index e8d5043b..89be531f 100644 --- a/kombu/asynchronous/http/base.py +++ b/kombu/asynchronous/http/base.py @@ -1,7 +1,10 @@ """Base async HTTP client implementation.""" +from __future__ import annotations + import sys from http.client import responses +from typing import TYPE_CHECKING from vine import Thenable, maybe_promise, promise @@ -10,6 +13,9 @@ from kombu.utils.compat import coro from kombu.utils.encoding import bytes_to_str from kombu.utils.functional import maybe_list, memoize +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Headers', 'Response', 'Request') PYPY = hasattr(sys, 'pypy_version_info') @@ -61,7 +67,7 @@ class Request: auth_password (str): Password for HTTP authentication. auth_mode (str): Type of HTTP authentication (``basic`` or ``digest``). user_agent (str): Custom user agent for this request. - network_interace (str): Network interface to use for this request. + network_interface (str): Network interface to use for this request. on_ready (Callable): Callback to be called when the response has been received. Must accept single ``response`` argument. on_stream (Callable): Optional callback to be called every time body @@ -253,5 +259,10 @@ class BaseClient: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() diff --git a/kombu/asynchronous/http/curl.py b/kombu/asynchronous/http/curl.py index ee70f3c6..6f879fa9 100644 --- a/kombu/asynchronous/http/curl.py +++ b/kombu/asynchronous/http/curl.py @@ -1,11 +1,13 @@ """HTTP Client using pyCurl.""" +from __future__ import annotations + from collections import deque from functools import partial from io import BytesIO from time import time -from kombu.asynchronous.hub import READ, WRITE, get_event_loop +from kombu.asynchronous.hub import READ, WRITE, Hub, get_event_loop from kombu.exceptions import HttpError from kombu.utils.encoding import bytes_to_str @@ -36,7 +38,7 @@ class CurlClient(BaseClient): Curl = Curl - def __init__(self, hub=None, max_clients=10): + def __init__(self, hub: Hub | None = None, max_clients: int = 10): if pycurl is None: raise ImportError('The curl client requires the pycurl library.') hub = hub or get_event_loop() @@ -231,9 +233,6 @@ class CurlClient(BaseClient): if request.proxy_username: setopt(_pycurl.PROXYUSERPWD, '{}:{}'.format( request.proxy_username, request.proxy_password or '')) - else: - setopt(_pycurl.PROXY, '') - curl.unsetopt(_pycurl.PROXYUSERPWD) setopt(_pycurl.SSL_VERIFYPEER, 1 if request.validate_cert else 0) setopt(_pycurl.SSL_VERIFYHOST, 2 if request.validate_cert else 0) @@ -253,7 +252,7 @@ class CurlClient(BaseClient): setopt(meth, True) if request.method in ('POST', 'PUT'): - body = request.body.encode('utf-8') if request.body else bytes() + body = request.body.encode('utf-8') if request.body else b'' reqbuffer = BytesIO(body) setopt(_pycurl.READFUNCTION, reqbuffer.read) if request.method == 'POST': diff --git a/kombu/asynchronous/hub.py b/kombu/asynchronous/hub.py index b1f7e241..e5b1163c 100644 --- a/kombu/asynchronous/hub.py +++ b/kombu/asynchronous/hub.py @@ -1,6 +1,9 @@ """Event loop implementation.""" +from __future__ import annotations + import errno +import threading from contextlib import contextmanager from queue import Empty from time import sleep @@ -18,7 +21,7 @@ from .timer import Timer __all__ = ('Hub', 'get_event_loop', 'set_event_loop') logger = get_logger(__name__) -_current_loop = None +_current_loop: Hub | None = None W_UNKNOWN_EVENT = """\ Received unknown event %r for fd %r, please contact support!\ @@ -38,12 +41,12 @@ def _dummy_context(*args, **kwargs): yield -def get_event_loop(): +def get_event_loop() -> Hub | None: """Get current event loop object.""" return _current_loop -def set_event_loop(loop): +def set_event_loop(loop: Hub | None) -> Hub | None: """Set the current event loop object.""" global _current_loop _current_loop = loop @@ -78,6 +81,7 @@ class Hub: self.on_tick = set() self.on_close = set() self._ready = set() + self._ready_lock = threading.Lock() self._running = False self._loop = None @@ -198,7 +202,8 @@ class Hub: def call_soon(self, callback, *args): if not isinstance(callback, Thenable): callback = promise(callback, args) - self._ready.add(callback) + with self._ready_lock: + self._ready.add(callback) return callback def call_later(self, delay, callback, *args): @@ -242,6 +247,12 @@ class Hub: except (AttributeError, KeyError, OSError): pass + def _pop_ready(self): + with self._ready_lock: + ready = self._ready + self._ready = set() + return ready + def close(self, *args): [self._unregister(fd) for fd in self.readers] self.readers.clear() @@ -257,8 +268,7 @@ class Hub: # To avoid infinite loop where one of the callables adds items # to self._ready (via call_soon or otherwise). # we create new list with current self._ready - todos = list(self._ready) - self._ready = set() + todos = self._pop_ready() for item in todos: item() @@ -288,17 +298,17 @@ class Hub: propagate = self.propagate_errors while 1: - todo = self._ready - self._ready = set() - - for tick_callback in on_tick: - tick_callback() + todo = self._pop_ready() for item in todo: if item: item() poll_timeout = fire_timers(propagate=propagate) if scheduled else 1 + + for tick_callback in on_tick: + tick_callback() + # print('[[[HUB]]]: %s' % (self.repr_active(),)) if readers or writers: to_consolidate = [] diff --git a/kombu/asynchronous/semaphore.py b/kombu/asynchronous/semaphore.py index 9fe34a04..07fb8a09 100644 --- a/kombu/asynchronous/semaphore.py +++ b/kombu/asynchronous/semaphore.py @@ -1,9 +1,23 @@ """Semaphores and concurrency primitives.""" +from __future__ import annotations +import sys from collections import deque +from typing import TYPE_CHECKING, Callable, Deque + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec + +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('DummyLock', 'LaxBoundedSemaphore') +P = ParamSpec("P") + class LaxBoundedSemaphore: """Asynchronous Bounded Semaphore. @@ -12,18 +26,15 @@ class LaxBoundedSemaphore: range even if released more times than it was acquired. Example: - >>> from future import print_statement as printf - # ^ ignore: just fooling stupid pyflakes - >>> x = LaxBoundedSemaphore(2) - >>> x.acquire(printf, 'HELLO 1') + >>> x.acquire(print, 'HELLO 1') HELLO 1 - >>> x.acquire(printf, 'HELLO 2') + >>> x.acquire(print, 'HELLO 2') HELLO 2 - >>> x.acquire(printf, 'HELLO 3') + >>> x.acquire(print, 'HELLO 3') >>> x._waiters # private, do not access directly [print, ('HELLO 3',)] @@ -31,13 +42,18 @@ class LaxBoundedSemaphore: HELLO 3 """ - def __init__(self, value): + def __init__(self, value: int) -> None: self.initial_value = self.value = value - self._waiting = deque() + self._waiting: Deque[tuple] = deque() self._add_waiter = self._waiting.append self._pop_waiter = self._waiting.popleft - def acquire(self, callback, *partial_args, **partial_kwargs): + def acquire( + self, + callback: Callable[P, None], + *partial_args: P.args, + **partial_kwargs: P.kwargs + ) -> bool: """Acquire semaphore. This will immediately apply ``callback`` if @@ -57,7 +73,7 @@ class LaxBoundedSemaphore: callback(*partial_args, **partial_kwargs) return True - def release(self): + def release(self) -> None: """Release semaphore. Note: @@ -71,23 +87,24 @@ class LaxBoundedSemaphore: else: waiter(*args, **kwargs) - def grow(self, n=1): + def grow(self, n: int = 1) -> None: """Change the size of the semaphore to accept more users.""" self.initial_value += n self.value += n - [self.release() for _ in range(n)] + for _ in range(n): + self.release() - def shrink(self, n=1): + def shrink(self, n: int = 1) -> None: """Change the size of the semaphore to accept less users.""" self.initial_value = max(self.initial_value - n, 0) self.value = max(self.value - n, 0) - def clear(self): + def clear(self) -> None: """Reset the semaphore, which also wipes out any waiting callbacks.""" self._waiting.clear() self.value = self.initial_value - def __repr__(self): + def __repr__(self) -> str: return '<{} at {:#x} value:{} waiting:{}>'.format( self.__class__.__name__, id(self), self.value, len(self._waiting), ) @@ -96,8 +113,13 @@ class LaxBoundedSemaphore: class DummyLock: """Pretending to be a lock.""" - def __enter__(self): + def __enter__(self) -> DummyLock: return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: pass diff --git a/kombu/asynchronous/timer.py b/kombu/asynchronous/timer.py index 21ad37c1..f6be1346 100644 --- a/kombu/asynchronous/timer.py +++ b/kombu/asynchronous/timer.py @@ -1,5 +1,7 @@ """Timer scheduling Python callbacks.""" +from __future__ import annotations + import heapq import sys from collections import namedtuple @@ -7,6 +9,7 @@ from datetime import datetime from functools import total_ordering from time import monotonic from time import time as _time +from typing import TYPE_CHECKING from weakref import proxy as weakrefproxy from vine.utils import wraps @@ -18,6 +21,9 @@ try: except ImportError: # pragma: no cover utc = None +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Entry', 'Timer', 'to_timestamp') logger = get_logger(__name__) @@ -101,7 +107,12 @@ class Timer: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.stop() def call_at(self, eta, fun, args=(), kwargs=None, priority=0): diff --git a/kombu/clocks.py b/kombu/clocks.py index 3c152720..d02e8b32 100644 --- a/kombu/clocks.py +++ b/kombu/clocks.py @@ -1,8 +1,11 @@ """Logical Clocks and Synchronization.""" +from __future__ import annotations + from itertools import islice from operator import itemgetter from threading import Lock +from typing import Any __all__ = ('LamportClock', 'timetuple') @@ -15,7 +18,7 @@ class timetuple(tuple): Can be used as part of a heap to keep events ordered. Arguments: - clock (int): Event clock value. + clock (Optional[int]): Event clock value. timestamp (float): Event UNIX timestamp value. id (str): Event host id (e.g. ``hostname:pid``). obj (Any): Optional obj to associate with this event. @@ -23,16 +26,18 @@ class timetuple(tuple): __slots__ = () - def __new__(cls, clock, timestamp, id, obj=None): + def __new__( + cls, clock: int | None, timestamp: float, id: str, obj: Any = None + ) -> timetuple: return tuple.__new__(cls, (clock, timestamp, id, obj)) - def __repr__(self): + def __repr__(self) -> str: return R_CLOCK.format(*self) - def __getnewargs__(self): + def __getnewargs__(self) -> tuple: return tuple(self) - def __lt__(self, other): + def __lt__(self, other: tuple) -> bool: # 0: clock 1: timestamp 3: process id try: A, B = self[0], other[0] @@ -45,13 +50,13 @@ class timetuple(tuple): except IndexError: return NotImplemented - def __gt__(self, other): + def __gt__(self, other: tuple) -> bool: return other < self - def __le__(self, other): + def __le__(self, other: tuple) -> bool: return not other < self - def __ge__(self, other): + def __ge__(self, other: tuple) -> bool: return not self < other clock = property(itemgetter(0)) @@ -99,21 +104,23 @@ class LamportClock: #: The clocks current value. value = 0 - def __init__(self, initial_value=0, Lock=Lock): + def __init__( + self, initial_value: int = 0, Lock: type[Lock] = Lock + ) -> None: self.value = initial_value self.mutex = Lock() - def adjust(self, other): + def adjust(self, other: int) -> int: with self.mutex: value = self.value = max(self.value, other) + 1 return value - def forward(self): + def forward(self) -> int: with self.mutex: self.value += 1 return self.value - def sort_heap(self, h): + def sort_heap(self, h: list[tuple[int, str]]) -> tuple[int, str]: """Sort heap of events. List of tuples containing at least two elements, representing @@ -140,8 +147,8 @@ class LamportClock: # clock values unique, return first item return h[0] - def __str__(self): + def __str__(self) -> str: return str(self.value) - def __repr__(self): + def __repr__(self) -> str: return f'<LamportClock: {self.value}>' diff --git a/kombu/common.py b/kombu/common.py index 08bc1aff..c7b2d50a 100644 --- a/kombu/common.py +++ b/kombu/common.py @@ -1,5 +1,7 @@ """Common Utilities.""" +from __future__ import annotations + import os import socket import threading diff --git a/kombu/compat.py b/kombu/compat.py index 1fa3f631..d90aec75 100644 --- a/kombu/compat.py +++ b/kombu/compat.py @@ -3,11 +3,17 @@ See https://pypi.org/project/carrot/ for documentation. """ +from __future__ import annotations + from itertools import count +from typing import TYPE_CHECKING from . import messaging from .entity import Exchange, Queue +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Publisher', 'Consumer') # XXX compat attribute @@ -65,7 +71,12 @@ class Publisher(messaging.Producer): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() @property @@ -127,7 +138,12 @@ class Consumer(messaging.Consumer): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() def __iter__(self): diff --git a/kombu/compression.py b/kombu/compression.py index d9438539..f98c971b 100644 --- a/kombu/compression.py +++ b/kombu/compression.py @@ -1,5 +1,7 @@ """Compression utilities.""" +from __future__ import annotations + import zlib from kombu.utils.encoding import ensure_bytes diff --git a/kombu/connection.py b/kombu/connection.py index a63154f5..0c9779b5 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -1,11 +1,14 @@ """Client (Connection).""" +from __future__ import annotations + import os import socket -from collections import OrderedDict +import sys from contextlib import contextmanager from itertools import count, cycle from operator import itemgetter +from typing import TYPE_CHECKING, Any try: from ssl import CERT_NONE @@ -14,6 +17,7 @@ except ImportError: # pragma: no cover CERT_NONE = None ssl_available = False + # jython breaks on relative import for .exceptions for some reason # (Issue #112) from kombu import exceptions @@ -26,6 +30,16 @@ from .utils.functional import dictfilter, lazy, retry_over_time, shufflecycle from .utils.objects import cached_property from .utils.url import as_url, maybe_sanitize_url, parse_url, quote, urlparse +if TYPE_CHECKING: + from kombu.transport.virtual import Channel + + if sys.version_info < (3, 10): + from typing_extensions import TypeGuard + else: + from typing import TypeGuard + + from types import TracebackType + __all__ = ('Connection', 'ConnectionPool', 'ChannelPool') logger = get_logger(__name__) @@ -412,7 +426,7 @@ class Connection: callback (Callable): Optional callback that is called for every internal iteration (1 s). timeout (int): Maximum amount of time in seconds to spend - waiting for connection + attempting to connect, total over all retries. """ if self.connected: return self._connection @@ -468,7 +482,7 @@ class Connection: def ensure(self, obj, fun, errback=None, max_retries=None, interval_start=1, interval_step=1, interval_max=1, - on_revive=None): + on_revive=None, retry_errors=None): """Ensure operation completes. Regardless of any channel/connection errors occurring. @@ -497,6 +511,9 @@ class Connection: each retry. on_revive (Callable): Optional callback called whenever revival completes successfully + retry_errors (tuple): Optional list of errors to retry on + regardless of the connection state. Must provide max_retries + if this is specified. Examples: >>> from kombu import Connection, Producer @@ -511,6 +528,15 @@ class Connection: ... errback=errback, max_retries=3) >>> publish({'hello': 'world'}, routing_key='dest') """ + if retry_errors is None: + retry_errors = tuple() + elif max_retries is None: + # If the retry_errors is specified, but max_retries is not, + # this could lead into an infinite loop potentially. + raise ValueError( + "max_retries must be specified if retry_errors is specified" + ) + def _ensured(*args, **kwargs): got_connection = 0 conn_errors = self.recoverable_connection_errors @@ -522,6 +548,11 @@ class Connection: for retries in count(0): # for infinity try: return fun(*args, **kwargs) + except retry_errors as exc: + if max_retries is not None and retries >= max_retries: + raise + self._debug('ensure retry policy error: %r', + exc, exc_info=1) except conn_errors as exc: if got_connection and not has_modern_errors: # transport can not distinguish between @@ -529,7 +560,7 @@ class Connection: # the error if it persists after a new connection # was successfully established. raise - if max_retries is not None and retries > max_retries: + if max_retries is not None and retries >= max_retries: raise self._debug('ensure connection error: %r', exc, exc_info=1) @@ -626,7 +657,7 @@ class Connection: transport_cls, transport_cls) D = self.transport.default_connection_params - if not self.hostname: + if not self.hostname and D.get('hostname'): logger.warning( "No hostname was supplied. " f"Reverting to default '{D.get('hostname')}'") @@ -658,7 +689,7 @@ class Connection: def info(self): """Get connection info.""" - return OrderedDict(self._info()) + return dict(self._info()) def __eqhash__(self): return HashedSeq(self.transport_cls, self.hostname, self.userid, @@ -829,7 +860,12 @@ class Connection: def __enter__(self): return self - def __exit__(self, *args): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.release() @property @@ -837,7 +873,7 @@ class Connection: return self.transport.qos_semantics_matches_spec(self.connection) def _extract_failover_opts(self): - conn_opts = {} + conn_opts = {'timeout': self.connect_timeout} transport_opts = self.transport_options if transport_opts: if 'max_retries' in transport_opts: @@ -848,6 +884,9 @@ class Connection: conn_opts['interval_step'] = transport_opts['interval_step'] if 'interval_max' in transport_opts: conn_opts['interval_max'] = transport_opts['interval_max'] + if 'connect_retries_timeout' in transport_opts: + conn_opts['timeout'] = \ + transport_opts['connect_retries_timeout'] return conn_opts @property @@ -880,7 +919,7 @@ class Connection: return self._connection @property - def default_channel(self): + def default_channel(self) -> Channel: """Default channel. Created upon access and closed when the connection is closed. @@ -932,7 +971,7 @@ class Connection: but where the connection must be closed and re-established first. """ try: - return self.transport.recoverable_connection_errors + return self.get_transport_cls().recoverable_connection_errors except AttributeError: # There were no such classification before, # and all errors were assumed to be recoverable, @@ -948,19 +987,19 @@ class Connection: recovered from without re-establishing the connection. """ try: - return self.transport.recoverable_channel_errors + return self.get_transport_cls().recoverable_channel_errors except AttributeError: return () @cached_property def connection_errors(self): """List of exceptions that may be raised by the connection.""" - return self.transport.connection_errors + return self.get_transport_cls().connection_errors @cached_property def channel_errors(self): """List of exceptions that may be raised by the channel.""" - return self.transport.channel_errors + return self.get_transport_cls().channel_errors @property def supports_heartbeats(self): @@ -1043,7 +1082,7 @@ class ChannelPool(Resource): return channel -def maybe_channel(channel): +def maybe_channel(channel: Channel | Connection) -> Channel: """Get channel from object. Return the default channel if argument is a connection instance, @@ -1054,5 +1093,5 @@ def maybe_channel(channel): return channel -def is_connection(obj): +def is_connection(obj: Any) -> TypeGuard[Connection]: return isinstance(obj, Connection) diff --git a/kombu/entity.py b/kombu/entity.py index a89fabb9..2329e748 100644 --- a/kombu/entity.py +++ b/kombu/entity.py @@ -1,5 +1,7 @@ """Exchange and Queue declarations.""" +from __future__ import annotations + import numbers from .abstract import MaybeChannelBound, Object diff --git a/kombu/exceptions.py b/kombu/exceptions.py index f2501437..825baa12 100644 --- a/kombu/exceptions.py +++ b/kombu/exceptions.py @@ -1,9 +1,16 @@ """Exceptions.""" +from __future__ import annotations + from socket import timeout as TimeoutError +from types import TracebackType +from typing import TYPE_CHECKING, TypeVar from amqp import ChannelError, ConnectionError, ResourceError +if TYPE_CHECKING: + from kombu.asynchronous.http import Response + __all__ = ( 'reraise', 'KombuError', 'OperationalError', 'NotBoundError', 'MessageStateError', 'TimeoutError', @@ -14,8 +21,14 @@ __all__ = ( 'InconsistencyError', ) +BaseExceptionType = TypeVar('BaseExceptionType', bound=BaseException) + -def reraise(tp, value, tb=None): +def reraise( + tp: type[BaseExceptionType], + value: BaseExceptionType, + tb: TracebackType | None = None +) -> BaseExceptionType: """Reraise exception.""" if value.__traceback__ is not tb: raise value.with_traceback(tb) @@ -84,11 +97,16 @@ class InconsistencyError(ConnectionError): class HttpError(Exception): """HTTP Client Error.""" - def __init__(self, code, message=None, response=None): + def __init__( + self, + code: int, + message: str | None = None, + response: Response | None = None + ) -> None: self.code = code self.message = message self.response = response super().__init__(code, message, response) - def __str__(self): + def __str__(self) -> str: return 'HTTP {0.code}: {0.message}'.format(self) diff --git a/kombu/log.py b/kombu/log.py index de77e7f3..ed8d0a50 100644 --- a/kombu/log.py +++ b/kombu/log.py @@ -1,5 +1,7 @@ """Logging Utilities.""" +from __future__ import annotations + import logging import numbers import os diff --git a/kombu/matcher.py b/kombu/matcher.py index 7dcab8cd..a4d71bb1 100644 --- a/kombu/matcher.py +++ b/kombu/matcher.py @@ -1,11 +1,16 @@ """Pattern matching registry.""" +from __future__ import annotations + from fnmatch import fnmatch from re import match as rematch +from typing import Callable, cast from .utils.compat import entrypoints from .utils.encoding import bytes_to_str +MatcherFunction = Callable[[str, str], bool] + class MatcherNotInstalled(Exception): """Matcher not installed/found.""" @@ -17,15 +22,15 @@ class MatcherRegistry: MatcherNotInstalled = MatcherNotInstalled matcher_pattern_first = ["pcre", ] - def __init__(self): - self._matchers = {} - self._default_matcher = None + def __init__(self) -> None: + self._matchers: dict[str, MatcherFunction] = {} + self._default_matcher: MatcherFunction | None = None - def register(self, name, matcher): + def register(self, name: str, matcher: MatcherFunction) -> None: """Add matcher by name to the registry.""" self._matchers[name] = matcher - def unregister(self, name): + def unregister(self, name: str) -> None: """Remove matcher by name from the registry.""" try: self._matchers.pop(name) @@ -34,7 +39,7 @@ class MatcherRegistry: f'No matcher installed for {name}' ) - def _set_default_matcher(self, name): + def _set_default_matcher(self, name: str) -> None: """Set the default matching method. :param name: The name of the registered matching method. @@ -51,7 +56,13 @@ class MatcherRegistry: f'No matcher installed for {name}' ) - def match(self, data, pattern, matcher=None, matcher_kwargs=None): + def match( + self, + data: bytes, + pattern: bytes, + matcher: str | None = None, + matcher_kwargs: dict[str, str] | None = None + ) -> bool: """Call the matcher.""" if matcher and not self._matchers.get(matcher): raise self.MatcherNotInstalled( @@ -97,7 +108,7 @@ match = registry.match .. function:: register(name, matcher): Register a new matching method. - :param name: A convience name for the mathing method. + :param name: A convenient name for the mathing method. :param matcher: A method that will be passed data and pattern. """ register = registry.register @@ -111,14 +122,14 @@ register = registry.register unregister = registry.unregister -def register_glob(): +def register_glob() -> None: """Register glob into default registry.""" registry.register('glob', fnmatch) -def register_pcre(): +def register_pcre() -> None: """Register pcre into default registry.""" - registry.register('pcre', rematch) + registry.register('pcre', cast(MatcherFunction, rematch)) # Register the base matching methods. diff --git a/kombu/message.py b/kombu/message.py index bcc90d1a..f2af1686 100644 --- a/kombu/message.py +++ b/kombu/message.py @@ -1,5 +1,7 @@ """Message class.""" +from __future__ import annotations + import sys from .compression import decompress diff --git a/kombu/messaging.py b/kombu/messaging.py index 0bed52c5..2b600224 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -1,6 +1,9 @@ """Sending and receiving messages.""" +from __future__ import annotations + from itertools import count +from typing import TYPE_CHECKING from .common import maybe_declare from .compression import compress @@ -10,6 +13,9 @@ from .exceptions import ContentDisallowed from .serialization import dumps, prepare_accept_content from .utils.functional import ChannelPromise, maybe_list +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Exchange', 'Queue', 'Producer', 'Consumer') @@ -236,7 +242,12 @@ class Producer: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.release() def release(self): @@ -435,7 +446,12 @@ class Consumer: self.consume() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: if self.channel and self.channel.connection: conn_errors = self.channel.connection.client.connection_errors if not isinstance(exc_val, conn_errors): diff --git a/kombu/mixins.py b/kombu/mixins.py index b87e4b92..f1b3c1c9 100644 --- a/kombu/mixins.py +++ b/kombu/mixins.py @@ -1,5 +1,7 @@ """Mixins.""" +from __future__ import annotations + import socket from contextlib import contextmanager from functools import partial diff --git a/kombu/pidbox.py b/kombu/pidbox.py index 7649736a..ee639b3c 100644 --- a/kombu/pidbox.py +++ b/kombu/pidbox.py @@ -1,5 +1,7 @@ """Generic process mailbox.""" +from __future__ import annotations + import socket import warnings from collections import defaultdict, deque diff --git a/kombu/pools.py b/kombu/pools.py index 373bc06c..106be183 100644 --- a/kombu/pools.py +++ b/kombu/pools.py @@ -1,5 +1,7 @@ """Public resource pools.""" +from __future__ import annotations + import os from itertools import chain diff --git a/kombu/resource.py b/kombu/resource.py index e3617dc4..53ba1145 100644 --- a/kombu/resource.py +++ b/kombu/resource.py @@ -1,14 +1,20 @@ """Generic resource pool implementation.""" +from __future__ import annotations + import os from collections import deque from queue import Empty from queue import LifoQueue as _LifoQueue +from typing import TYPE_CHECKING from . import exceptions from .utils.compat import register_after_fork from .utils.functional import lazy +if TYPE_CHECKING: + from types import TracebackType + def _after_fork_cleanup_resource(resource): try: @@ -191,7 +197,12 @@ class Resource: def __enter__(self): pass - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: type, + exc_val: Exception, + exc_tb: TracebackType + ) -> None: pass resource = self._resource diff --git a/kombu/serialization.py b/kombu/serialization.py index 58c28717..5cddeb0b 100644 --- a/kombu/serialization.py +++ b/kombu/serialization.py @@ -1,5 +1,7 @@ """Serialization utilities.""" +from __future__ import annotations + import codecs import os import pickle @@ -382,18 +384,6 @@ register_msgpack() # Default serializer is 'json' registry._set_default_serializer('json') - -_setupfuns = { - 'json': register_json, - 'pickle': register_pickle, - 'yaml': register_yaml, - 'msgpack': register_msgpack, - 'application/json': register_json, - 'application/x-yaml': register_yaml, - 'application/x-python-serialize': register_pickle, - 'application/x-msgpack': register_msgpack, -} - NOTSET = object() diff --git a/kombu/simple.py b/kombu/simple.py index eee037be..a33e5f9e 100644 --- a/kombu/simple.py +++ b/kombu/simple.py @@ -1,13 +1,19 @@ """Simple messaging interface.""" +from __future__ import annotations + import socket from collections import deque from queue import Empty from time import monotonic +from typing import TYPE_CHECKING from . import entity, messaging from .connection import maybe_channel +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('SimpleQueue', 'SimpleBuffer') @@ -18,7 +24,12 @@ class SimpleBase: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() def __init__(self, channel, producer, consumer, no_ack=False): diff --git a/kombu/transport/SLMQ.py b/kombu/transport/SLMQ.py index 750f67bd..50efca72 100644 --- a/kombu/transport/SLMQ.py +++ b/kombu/transport/SLMQ.py @@ -18,6 +18,8 @@ Transport Options *Unreviewed* """ +from __future__ import annotations + import os import socket import string diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index fb6d3780..ac199aa1 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -59,8 +59,8 @@ exist in AWS) you can tell this transport about them as follows: 'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional 'backoff_tasks': ['svc.tasks.tasks.task1'] # optional }, - 'queue-2': { - 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb', + 'queue-2.fifo': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/bbb.fifo', 'access_key_id': 'c', 'secret_access_key': 'd', 'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}, # optional @@ -71,6 +71,9 @@ exist in AWS) you can tell this transport about them as follows: 'sts_token_timeout': 900 # optional } +Note that FIFO and standard queues must be named accordingly (the name of +a FIFO queue must end with the .fifo suffix). + backoff_policy & backoff_tasks are optional arguments. These arguments automatically change the message visibility timeout, in order to have different times between specific task retries. This would apply after @@ -119,6 +122,8 @@ Features """ # noqa: E501 +from __future__ import annotations + import base64 import socket import string @@ -167,6 +172,10 @@ class UndefinedQueueException(Exception): """Predefined queues are being used and an undefined queue was used.""" +class InvalidQueueException(Exception): + """Predefined queues are being used and configuration is not valid.""" + + class QoS(virtual.QoS): """Quality of Service guarantees implementation for SQS.""" @@ -208,8 +217,8 @@ class QoS(virtual.QoS): VisibilityTimeout=policy_value ) - @staticmethod - def extract_task_name_and_number_of_retries(message): + def extract_task_name_and_number_of_retries(self, delivery_tag): + message = self._delivered[delivery_tag] message_headers = message.headers task_name = message_headers['task'] number_of_retries = int( @@ -237,6 +246,7 @@ class Channel(virtual.Channel): if boto3 is None: raise ImportError('boto3 is not installed') super().__init__(*args, **kwargs) + self._validate_predifined_queues() # SQS blows up if you try to create a new queue when one already # exists but with a different visibility_timeout. This prepopulates @@ -246,6 +256,26 @@ class Channel(virtual.Channel): self.hub = kwargs.get('hub') or get_event_loop() + def _validate_predifined_queues(self): + """Check that standard and FIFO queues are named properly. + + AWS requires FIFO queues to have a name + that ends with the .fifo suffix. + """ + for queue_name, q in self.predefined_queues.items(): + fifo_url = q['url'].endswith('.fifo') + fifo_name = queue_name.endswith('.fifo') + if fifo_url and not fifo_name: + raise InvalidQueueException( + "Queue with url '{}' must have a name " + "ending with .fifo".format(q['url']) + ) + elif not fifo_url and fifo_name: + raise InvalidQueueException( + "Queue with name '{}' is not a FIFO queue: " + "'{}'".format(queue_name, q['url']) + ) + def _update_queue_cache(self, queue_name_prefix): if self.predefined_queues: for queue_name, q in self.predefined_queues.items(): @@ -367,20 +397,28 @@ class Channel(virtual.Channel): def _put(self, queue, message, **kwargs): """Put message onto queue.""" q_url = self._new_queue(queue) - kwargs = {'QueueUrl': q_url, - 'MessageBody': AsyncMessage().encode(dumps(message))} - if queue.endswith('.fifo'): - if 'MessageGroupId' in message['properties']: - kwargs['MessageGroupId'] = \ - message['properties']['MessageGroupId'] - else: - kwargs['MessageGroupId'] = 'default' - if 'MessageDeduplicationId' in message['properties']: - kwargs['MessageDeduplicationId'] = \ - message['properties']['MessageDeduplicationId'] - else: - kwargs['MessageDeduplicationId'] = str(uuid.uuid4()) + if self.sqs_base64_encoding: + body = AsyncMessage().encode(dumps(message)) + else: + body = dumps(message) + kwargs = {'QueueUrl': q_url, 'MessageBody': body} + if 'properties' in message: + if queue.endswith('.fifo'): + if 'MessageGroupId' in message['properties']: + kwargs['MessageGroupId'] = \ + message['properties']['MessageGroupId'] + else: + kwargs['MessageGroupId'] = 'default' + if 'MessageDeduplicationId' in message['properties']: + kwargs['MessageDeduplicationId'] = \ + message['properties']['MessageDeduplicationId'] + else: + kwargs['MessageDeduplicationId'] = str(uuid.uuid4()) + else: + if "DelaySeconds" in message['properties']: + kwargs['DelaySeconds'] = \ + message['properties']['DelaySeconds'] c = self.sqs(queue=self.canonical_queue_name(queue)) if message.get('redelivered'): c.change_message_visibility( @@ -392,22 +430,19 @@ class Channel(virtual.Channel): c.send_message(**kwargs) @staticmethod - def __b64_encoded(byte_string): + def _optional_b64_decode(byte_string): try: - return base64.b64encode( - base64.b64decode(byte_string) - ) == byte_string + data = base64.b64decode(byte_string) + if base64.b64encode(data) == byte_string: + return data + # else the base64 module found some embedded base64 content + # that should be ignored. except Exception: # pylint: disable=broad-except - return False - - def _message_to_python(self, message, queue_name, queue): - body = message['Body'].encode() - try: - if self.__b64_encoded(body): - body = base64.b64decode(body) - except TypeError: pass + return byte_string + def _message_to_python(self, message, queue_name, queue): + body = self._optional_b64_decode(message['Body'].encode()) payload = loads(bytes_to_str(body)) if queue_name in self._noack_queues: queue = self._new_queue(queue_name) @@ -809,6 +844,10 @@ class Channel(virtual.Channel): return self.transport_options.get('wait_time_seconds', self.default_wait_time_seconds) + @cached_property + def sqs_base64_encoding(self): + return self.transport_options.get('sqs_base64_encoding', True) + class Transport(virtual.Transport): """SQS Transport. diff --git a/kombu/transport/__init__.py b/kombu/transport/__init__.py index 5fb5047b..8a217691 100644 --- a/kombu/transport/__init__.py +++ b/kombu/transport/__init__.py @@ -1,10 +1,12 @@ """Built-in transports.""" +from __future__ import annotations + from kombu.utils.compat import _detect_environment from kombu.utils.imports import symbol_by_name -def supports_librabbitmq(): +def supports_librabbitmq() -> bool | None: """Return true if :pypi:`librabbitmq` can be used.""" if _detect_environment() == 'default': try: @@ -13,6 +15,7 @@ def supports_librabbitmq(): pass else: # pragma: no cover return True + return None TRANSPORT_ALIASES = { @@ -20,6 +23,7 @@ TRANSPORT_ALIASES = { 'amqps': 'kombu.transport.pyamqp:SSLTransport', 'pyamqp': 'kombu.transport.pyamqp:Transport', 'librabbitmq': 'kombu.transport.librabbitmq:Transport', + 'confluentkafka': 'kombu.transport.confluentkafka:Transport', 'memory': 'kombu.transport.memory:Transport', 'redis': 'kombu.transport.redis:Transport', 'rediss': 'kombu.transport.redis:Transport', @@ -44,7 +48,7 @@ TRANSPORT_ALIASES = { _transport_cache = {} -def resolve_transport(transport=None): +def resolve_transport(transport: str | None = None) -> str | None: """Get transport by name. Arguments: @@ -71,7 +75,7 @@ def resolve_transport(transport=None): return transport -def get_transport_cls(transport=None): +def get_transport_cls(transport: str | None = None) -> str | None: """Get transport class by name. The transport string is the full path to a transport class, e.g.:: diff --git a/kombu/transport/azureservicebus.py b/kombu/transport/azureservicebus.py index 83237424..e7e2c0cc 100644 --- a/kombu/transport/azureservicebus.py +++ b/kombu/transport/azureservicebus.py @@ -53,9 +53,11 @@ Transport Options * ``retry_backoff_max`` - Azure SDK retry total time. Default ``120`` """ +from __future__ import annotations + import string from queue import Empty -from typing import Any, Dict, Optional, Set, Tuple, Union +from typing import Any, Dict, Set import azure.core.exceptions import azure.servicebus.exceptions @@ -83,10 +85,10 @@ class SendReceive: """Container for Sender and Receiver.""" def __init__(self, - receiver: Optional[ServiceBusReceiver] = None, - sender: Optional[ServiceBusSender] = None): - self.receiver = receiver # type: ServiceBusReceiver - self.sender = sender # type: ServiceBusSender + receiver: ServiceBusReceiver | None = None, + sender: ServiceBusSender | None = None): + self.receiver: ServiceBusReceiver = receiver + self.sender: ServiceBusSender = sender def close(self) -> None: if self.receiver: @@ -100,21 +102,19 @@ class SendReceive: class Channel(virtual.Channel): """Azure Service Bus channel.""" - default_wait_time_seconds = 5 # in seconds - default_peek_lock_seconds = 60 # in seconds (default 60, max 300) + default_wait_time_seconds: int = 5 # in seconds + default_peek_lock_seconds: int = 60 # in seconds (default 60, max 300) # in seconds (is the default from service bus repo) - default_uamqp_keep_alive_interval = 30 + default_uamqp_keep_alive_interval: int = 30 # number of retries (is the default from service bus repo) - default_retry_total = 3 + default_retry_total: int = 3 # exponential backoff factor (is the default from service bus repo) - default_retry_backoff_factor = 0.8 + default_retry_backoff_factor: float = 0.8 # Max time to backoff (is the default from service bus repo) - default_retry_backoff_max = 120 - domain_format = 'kombu%(vhost)s' - _queue_service = None # type: ServiceBusClient - _queue_mgmt_service = None # type: ServiceBusAdministrationClient - _queue_cache = {} # type: Dict[str, SendReceive] - _noack_queues = set() # type: Set[str] + default_retry_backoff_max: int = 120 + domain_format: str = 'kombu%(vhost)s' + _queue_cache: Dict[str, SendReceive] = {} + _noack_queues: Set[str] = set() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -160,8 +160,8 @@ class Channel(virtual.Channel): def _add_queue_to_cache( self, name: str, - receiver: Optional[ServiceBusReceiver] = None, - sender: Optional[ServiceBusSender] = None + receiver: ServiceBusReceiver | None = None, + sender: ServiceBusSender | None = None ) -> SendReceive: if name in self._queue_cache: obj = self._queue_cache[name] @@ -183,7 +183,7 @@ class Channel(virtual.Channel): def _get_asb_receiver( self, queue: str, recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK, - queue_cache_key: Optional[str] = None) -> SendReceive: + queue_cache_key: str | None = None) -> SendReceive: cache_key = queue_cache_key or queue queue_obj = self._queue_cache.get(cache_key, None) if queue_obj is None or queue_obj.receiver is None: @@ -194,7 +194,7 @@ class Channel(virtual.Channel): return queue_obj def entity_name( - self, name: str, table: Optional[Dict[int, int]] = None) -> str: + self, name: str, table: dict[int, int] | None = None) -> str: """Format AMQP queue name into a valid ServiceBus queue name.""" return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE) @@ -227,7 +227,7 @@ class Channel(virtual.Channel): """Delete queue by name.""" queue = self.entity_name(self.queue_name_prefix + queue) - self._queue_mgmt_service.delete_queue(queue) + self.queue_mgmt_service.delete_queue(queue) send_receive_obj = self._queue_cache.pop(queue, None) if send_receive_obj: send_receive_obj.close() @@ -242,8 +242,8 @@ class Channel(virtual.Channel): def _get( self, queue: str, - timeout: Optional[Union[float, int]] = None - ) -> Dict[str, Any]: + timeout: float | int | None = None + ) -> dict[str, Any]: """Try to retrieve a single message off ``queue``.""" # If we're not ack'ing for this queue, just change receive_mode recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE \ @@ -298,7 +298,7 @@ class Channel(virtual.Channel): return props.total_message_count - def _purge(self, queue): + def _purge(self, queue) -> int: """Delete all current messages in a queue.""" # Azure doesn't provide a purge api yet n = 0 @@ -337,24 +337,19 @@ class Channel(virtual.Channel): if self.connection is not None: self.connection.close_channel(self) - @property + @cached_property def queue_service(self) -> ServiceBusClient: - if self._queue_service is None: - self._queue_service = ServiceBusClient.from_connection_string( - self._connection_string, - retry_total=self.retry_total, - retry_backoff_factor=self.retry_backoff_factor, - retry_backoff_max=self.retry_backoff_max - ) - return self._queue_service + return ServiceBusClient.from_connection_string( + self._connection_string, + retry_total=self.retry_total, + retry_backoff_factor=self.retry_backoff_factor, + retry_backoff_max=self.retry_backoff_max + ) - @property + @cached_property def queue_mgmt_service(self) -> ServiceBusAdministrationClient: - if self._queue_mgmt_service is None: - self._queue_mgmt_service = \ - ServiceBusAdministrationClient.from_connection_string( + return ServiceBusAdministrationClient.from_connection_string( self._connection_string) - return self._queue_mgmt_service @property def conninfo(self): @@ -412,7 +407,7 @@ class Transport(virtual.Transport): can_parse_url = True @staticmethod - def parse_uri(uri: str) -> Tuple[str, str, str]: + def parse_uri(uri: str) -> tuple[str, str, str]: # URL like: # azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} # urllib parse does not work as the sas key could contain a slash diff --git a/kombu/transport/azurestoragequeues.py b/kombu/transport/azurestoragequeues.py index e83a20d3..16d22f0b 100644 --- a/kombu/transport/azurestoragequeues.py +++ b/kombu/transport/azurestoragequeues.py @@ -15,14 +15,34 @@ Features Connection String ================= -Connection string has the following format: +Connection string has the following formats: .. code-block:: - azurestoragequeues://:STORAGE_ACCOUNT_ACCESS kEY@STORAGE_ACCOUNT_NAME + azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL> + azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL> + azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL> + azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL> -Note that if the access key for the storage account contains a slash, it will -have to be regenerated before it can be used in the connection URL. +Note that if the access key for the storage account contains a forward slash +(``/``), it will have to be regenerated before it can be used in the connection +URL. + +.. code-block:: + + azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL> + azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL> + +If you wish to use an `Azure Managed Identity` you may use the +``DefaultAzureCredential`` format of the connection string which will use +``DefaultAzureCredential`` class in the azure-identity package. You may want to +read the `azure-identity documentation` for more information on how the +``DefaultAzureCredential`` works. + +.. _azure-identity documentation: +https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python +.. _Azure Managed Identity: +https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview Transport Options ================= @@ -30,8 +50,13 @@ Transport Options * ``queue_name_prefix`` """ +from __future__ import annotations + import string from queue import Empty +from typing import Any, Optional + +from azure.core.exceptions import ResourceExistsError from kombu.utils.encoding import safe_str from kombu.utils.json import dumps, loads @@ -40,9 +65,16 @@ from kombu.utils.objects import cached_property from . import virtual try: - from azure.storage.queue import QueueService + from azure.storage.queue import QueueServiceClient except ImportError: # pragma: no cover - QueueService = None + QueueServiceClient = None + +try: + from azure.identity import (DefaultAzureCredential, + ManagedIdentityCredential) +except ImportError: + DefaultAzureCredential = None + ManagedIdentityCredential = None # Azure storage queues allow only alphanumeric and dashes # so, replace everything with a dash @@ -54,21 +86,25 @@ CHARS_REPLACE_TABLE = { class Channel(virtual.Channel): """Azure Storage Queues channel.""" - domain_format = 'kombu%(vhost)s' - _queue_service = None - _queue_name_cache = {} - no_ack = True - _noack_queues = set() + domain_format: str = 'kombu%(vhost)s' + _queue_service: Optional[QueueServiceClient] = None + _queue_name_cache: dict[Any, Any] = {} + no_ack: bool = True + _noack_queues: set[Any] = set() def __init__(self, *args, **kwargs): - if QueueService is None: + if QueueServiceClient is None: raise ImportError('Azure Storage Queues transport requires the ' 'azure-storage-queue library') super().__init__(*args, **kwargs) - for queue_name in self.queue_service.list_queues(): - self._queue_name_cache[queue_name] = queue_name + self._credential, self._url = Transport.parse_uri( + self.conninfo.hostname + ) + + for queue in self.queue_service.list_queues(): + self._queue_name_cache[queue['name']] = queue def basic_consume(self, queue, no_ack, *args, **kwargs): if no_ack: @@ -77,7 +113,7 @@ class Channel(virtual.Channel): return super().basic_consume(queue, no_ack, *args, **kwargs) - def entity_name(self, name, table=CHARS_REPLACE_TABLE): + def entity_name(self, name, table=CHARS_REPLACE_TABLE) -> str: """Format AMQP queue name into a valid Azure Storage Queue name.""" return str(safe_str(name)).translate(table) @@ -85,61 +121,64 @@ class Channel(virtual.Channel): """Ensure a queue exists.""" queue = self.entity_name(self.queue_name_prefix + queue) try: - return self._queue_name_cache[queue] + q = self._queue_service.get_queue_client( + queue=self._queue_name_cache[queue] + ) except KeyError: - self.queue_service.create_queue(queue, fail_on_exist=False) - q = self._queue_name_cache[queue] = queue - return q + try: + q = self.queue_service.create_queue(queue) + except ResourceExistsError: + q = self._queue_service.get_queue_client(queue=queue) + + self._queue_name_cache[queue] = q.get_queue_properties() + return q def _delete(self, queue, *args, **kwargs): """Delete queue by name.""" queue_name = self.entity_name(queue) self._queue_name_cache.pop(queue_name, None) self.queue_service.delete_queue(queue_name) - super()._delete(queue_name) def _put(self, queue, message, **kwargs): """Put message onto queue.""" q = self._ensure_queue(queue) encoded_message = dumps(message) - self.queue_service.put_message(q, encoded_message) + q.send_message(encoded_message) def _get(self, queue, timeout=None): """Try to retrieve a single message off ``queue``.""" q = self._ensure_queue(queue) - messages = self.queue_service.get_messages(q, num_messages=1, - timeout=timeout) - if not messages: + messages = q.receive_messages(messages_per_page=1, timeout=timeout) + try: + message = next(messages) + except StopIteration: raise Empty() - message = messages[0] - raw_content = self.queue_service.decode_function(message.content) - content = loads(raw_content) + content = loads(message.content) - self.queue_service.delete_message(q, message.id, message.pop_receipt) + q.delete_message(message=message) return content def _size(self, queue): """Return the number of messages in a queue.""" q = self._ensure_queue(queue) - metadata = self.queue_service.get_queue_metadata(q) - return metadata.approximate_message_count + return q.get_queue_properties().approximate_message_count def _purge(self, queue): """Delete all current messages in a queue.""" q = self._ensure_queue(queue) - n = self._size(q) - self.queue_service.clear_messages(q) + n = self._size(q.queue_name) + q.clear_messages() return n @property - def queue_service(self): + def queue_service(self) -> QueueServiceClient: if self._queue_service is None: - self._queue_service = QueueService( - account_name=self.conninfo.hostname, - account_key=self.conninfo.password) + self._queue_service = QueueServiceClient( + account_url=self._url, credential=self._credential + ) return self._queue_service @@ -152,7 +191,7 @@ class Channel(virtual.Channel): return self.connection.client.transport_options @cached_property - def queue_name_prefix(self): + def queue_name_prefix(self) -> str: return self.transport_options.get('queue_name_prefix', '') @@ -161,5 +200,64 @@ class Transport(virtual.Transport): Channel = Channel - polling_interval = 1 - default_port = None + polling_interval: int = 1 + default_port: Optional[int] = None + can_parse_url: bool = True + + @staticmethod + def parse_uri(uri: str) -> tuple[str | dict, str]: + # URL like: + # azurestoragequeues://<STORAGE_ACCOUNT_ACCESS_KEY>@<STORAGE_ACCOUNT_URL> + # azurestoragequeues://<SAS_TOKEN>@<STORAGE_ACCOUNT_URL> + # azurestoragequeues://DefaultAzureCredential@<STORAGE_ACCOUNT_URL> + # azurestoragequeues://ManagedIdentityCredential@<STORAGE_ACCOUNT_URL> + + # urllib parse does not work as the sas key could contain a slash + # e.g.: azurestoragequeues://some/key@someurl + + try: + # > 'some/key@url' + uri = uri.replace('azurestoragequeues://', '') + # > 'some/key', 'url' + credential, url = uri.rsplit('@', 1) + + if "DefaultAzureCredential".lower() == credential.lower(): + if DefaultAzureCredential is None: + raise ImportError('Azure Storage Queues transport with a ' + 'DefaultAzureCredential requires the ' + 'azure-identity library') + credential = DefaultAzureCredential() + elif "ManagedIdentityCredential".lower() == credential.lower(): + if ManagedIdentityCredential is None: + raise ImportError('Azure Storage Queues transport with a ' + 'ManagedIdentityCredential requires the ' + 'azure-identity library') + credential = ManagedIdentityCredential() + elif "devstoreaccount1" in url and ".core.windows.net" not in url: + # parse credential as a dict if Azurite is being used + credential = { + "account_name": "devstoreaccount1", + "account_key": credential, + } + + # Validate parameters + assert all([credential, url]) + except Exception: + raise ValueError( + 'Need a URI like ' + 'azurestoragequeues://{SAS or access key}@{URL}, ' + 'azurestoragequeues://DefaultAzureCredential@{URL}, ' + ', or ' + 'azurestoragequeues://ManagedIdentityCredential@{URL}' + ) + + return credential, url + + @classmethod + def as_uri( + cls, uri: str, include_password: bool = False, mask: str = "**" + ) -> str: + credential, url = cls.parse_uri(uri) + return "azurestoragequeues://{}@{}".format( + credential if include_password else mask, url + ) diff --git a/kombu/transport/base.py b/kombu/transport/base.py index 3083acf4..ec4c0aca 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -2,8 +2,11 @@ # flake8: noqa +from __future__ import annotations + import errno import socket +from typing import TYPE_CHECKING from amqp.exceptions import RecoverableConnectionError @@ -13,6 +16,9 @@ from kombu.utils.functional import dictfilter from kombu.utils.objects import cached_property from kombu.utils.time import maybe_s_to_ms +if TYPE_CHECKING: + from types import TracebackType + __all__ = ('Message', 'StdChannel', 'Management', 'Transport') RABBITMQ_QUEUE_ARGUMENTS = { @@ -100,7 +106,12 @@ class StdChannel: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() diff --git a/kombu/transport/confluentkafka.py b/kombu/transport/confluentkafka.py new file mode 100644 index 00000000..5332a310 --- /dev/null +++ b/kombu/transport/confluentkafka.py @@ -0,0 +1,379 @@ +"""confluent-kafka transport module for Kombu. + +Kafka transport using confluent-kafka library. + +**References** + +- http://docs.confluent.io/current/clients/confluent-kafka-python + +**Limitations** + +The confluent-kafka transport does not support PyPy environment. + +Features +======== +* Type: Virtual +* Supports Direct: Yes +* Supports Topic: Yes +* Supports Fanout: No +* Supports Priority: No +* Supports TTL: No + +Connection String +================= +Connection string has the following format: + +.. code-block:: + + confluentkafka://[USER:PASSWORD@]KAFKA_ADDRESS[:PORT] + +Transport Options +================= +* ``connection_wait_time_seconds`` - Time in seconds to wait for connection + to succeed. Default ``5`` +* ``wait_time_seconds`` - Time in seconds to wait to receive messages. + Default ``5`` +* ``security_protocol`` - Protocol used to communicate with broker. + Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for + an explanation of valid values. Default ``plaintext`` +* ``sasl_mechanism`` - SASL mechanism to use for authentication. + Visit https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md for + an explanation of valid values. +* ``num_partitions`` - Number of partitions to create. Default ``1`` +* ``replication_factor`` - Replication factor of partitions. Default ``1`` +* ``topic_config`` - Topic configuration. Must be a dict whose key-value pairs + correspond with attributes in the + http://kafka.apache.org/documentation.html#topicconfigs. +* ``kafka_common_config`` - Configuration applied to producer, consumer and + admin client. Must be a dict whose key-value pairs correspond with attributes + in the https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +* ``kafka_producer_config`` - Producer configuration. Must be a dict whose + key-value pairs correspond with attributes in the + https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +* ``kafka_consumer_config`` - Consumer configuration. Must be a dict whose + key-value pairs correspond with attributes in the + https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +* ``kafka_admin_config`` - Admin client configuration. Must be a dict whose + key-value pairs correspond with attributes in the + https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md. +""" + +from __future__ import annotations + +from queue import Empty + +from kombu.transport import virtual +from kombu.utils import cached_property +from kombu.utils.encoding import str_to_bytes +from kombu.utils.json import dumps, loads + +try: + import confluent_kafka + from confluent_kafka import Consumer, Producer, TopicPartition + from confluent_kafka.admin import AdminClient, NewTopic + + KAFKA_CONNECTION_ERRORS = () + KAFKA_CHANNEL_ERRORS = () + +except ImportError: + confluent_kafka = None + KAFKA_CONNECTION_ERRORS = KAFKA_CHANNEL_ERRORS = () + +from kombu.log import get_logger + +logger = get_logger(__name__) + +DEFAULT_PORT = 9092 + + +class NoBrokersAvailable(confluent_kafka.KafkaException): + """Kafka broker is not available exception.""" + + retriable = True + + +class Message(virtual.Message): + """Message object.""" + + def __init__(self, payload, channel=None, **kwargs): + self.topic = payload.get('topic') + super().__init__(payload, channel=channel, **kwargs) + + +class QoS(virtual.QoS): + """Quality of Service guarantees.""" + + _not_yet_acked = {} + + def can_consume(self): + """Return true if the channel can be consumed from. + + :returns: True, if this QoS object can accept a message. + :rtype: bool + """ + return not self.prefetch_count or len(self._not_yet_acked) < self \ + .prefetch_count + + def can_consume_max_estimate(self): + if self.prefetch_count: + return self.prefetch_count - len(self._not_yet_acked) + else: + return 1 + + def append(self, message, delivery_tag): + self._not_yet_acked[delivery_tag] = message + + def get(self, delivery_tag): + return self._not_yet_acked[delivery_tag] + + def ack(self, delivery_tag): + if delivery_tag not in self._not_yet_acked: + return + message = self._not_yet_acked.pop(delivery_tag) + consumer = self.channel._get_consumer(message.topic) + consumer.commit() + + def reject(self, delivery_tag, requeue=False): + """Reject a message by delivery tag. + + If requeue is True, then the last consumed message is reverted so + it'll be refetched on the next attempt. + If False, that message is consumed and ignored. + """ + if requeue: + message = self._not_yet_acked.pop(delivery_tag) + consumer = self.channel._get_consumer(message.topic) + for assignment in consumer.assignment(): + topic_partition = TopicPartition(message.topic, + assignment.partition) + [committed_offset] = consumer.committed([topic_partition]) + consumer.seek(committed_offset) + else: + self.ack(delivery_tag) + + def restore_unacked_once(self, stderr=None): + pass + + +class Channel(virtual.Channel): + """Kafka Channel.""" + + QoS = QoS + Message = Message + + default_wait_time_seconds = 5 + default_connection_wait_time_seconds = 5 + _client = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._kafka_consumers = {} + self._kafka_producers = {} + + self._client = self._open() + + def sanitize_queue_name(self, queue): + """Need to sanitize the name, celery sometimes pushes in @ signs.""" + return str(queue).replace('@', '') + + def _get_producer(self, queue): + """Create/get a producer instance for the given topic/queue.""" + queue = self.sanitize_queue_name(queue) + producer = self._kafka_producers.get(queue, None) + if producer is None: + producer = Producer({ + **self.common_config, + **(self.options.get('kafka_producer_config') or {}), + }) + self._kafka_producers[queue] = producer + + return producer + + def _get_consumer(self, queue): + """Create/get a consumer instance for the given topic/queue.""" + queue = self.sanitize_queue_name(queue) + consumer = self._kafka_consumers.get(queue, None) + if consumer is None: + consumer = Consumer({ + 'group.id': f'{queue}-consumer-group', + 'auto.offset.reset': 'earliest', + 'enable.auto.commit': False, + **self.common_config, + **(self.options.get('kafka_consumer_config') or {}), + }) + consumer.subscribe([queue]) + self._kafka_consumers[queue] = consumer + + return consumer + + def _put(self, queue, message, **kwargs): + """Put a message on the topic/queue.""" + queue = self.sanitize_queue_name(queue) + producer = self._get_producer(queue) + producer.produce(queue, str_to_bytes(dumps(message))) + producer.flush() + + def _get(self, queue, **kwargs): + """Get a message from the topic/queue.""" + queue = self.sanitize_queue_name(queue) + consumer = self._get_consumer(queue) + message = None + + try: + message = consumer.poll(self.wait_time_seconds) + except StopIteration: + pass + + if not message: + raise Empty() + + error = message.error() + if error: + logger.error(error) + raise Empty() + + return {**loads(message.value()), 'topic': message.topic()} + + def _delete(self, queue, *args, **kwargs): + """Delete a queue/topic.""" + queue = self.sanitize_queue_name(queue) + self._kafka_consumers[queue].close() + self._kafka_consumers.pop(queue) + self.client.delete_topics([queue]) + + def _size(self, queue): + """Get the number of pending messages in the topic/queue.""" + queue = self.sanitize_queue_name(queue) + + consumer = self._kafka_consumers.get(queue, None) + if consumer is None: + return 0 + + size = 0 + for assignment in consumer.assignment(): + topic_partition = TopicPartition(queue, assignment.partition) + (_, end_offset) = consumer.get_watermark_offsets(topic_partition) + [committed_offset] = consumer.committed([topic_partition]) + size += end_offset - committed_offset.offset + return size + + def _new_queue(self, queue, **kwargs): + """Create a new topic if it does not exist.""" + queue = self.sanitize_queue_name(queue) + if queue in self.client.list_topics().topics: + return + + topic = NewTopic( + queue, + num_partitions=self.options.get('num_partitions', 1), + replication_factor=self.options.get('replication_factor', 1), + config=self.options.get('topic_config', {}) + ) + self.client.create_topics(new_topics=[topic]) + + def _has_queue(self, queue, **kwargs): + """Check if a topic already exists.""" + queue = self.sanitize_queue_name(queue) + return queue in self.client.list_topics().topics + + def _open(self): + client = AdminClient({ + **self.common_config, + **(self.options.get('kafka_admin_config') or {}), + }) + + try: + # seems to be the only way to check connection + client.list_topics(timeout=self.wait_time_seconds) + except confluent_kafka.KafkaException as e: + raise NoBrokersAvailable(e) + + return client + + @property + def client(self): + if self._client is None: + self._client = self._open() + return self._client + + @property + def options(self): + return self.connection.client.transport_options + + @property + def conninfo(self): + return self.connection.client + + @cached_property + def wait_time_seconds(self): + return self.options.get( + 'wait_time_seconds', self.default_wait_time_seconds + ) + + @cached_property + def connection_wait_time_seconds(self): + return self.options.get( + 'connection_wait_time_seconds', + self.default_connection_wait_time_seconds, + ) + + @cached_property + def common_config(self): + conninfo = self.connection.client + config = { + 'bootstrap.servers': + f'{conninfo.hostname}:{int(conninfo.port) or DEFAULT_PORT}', + } + security_protocol = self.options.get('security_protocol', 'plaintext') + if security_protocol.lower() != 'plaintext': + config.update({ + 'security.protocol': security_protocol, + 'sasl.username': conninfo.userid, + 'sasl.password': conninfo.password, + 'sasl.mechanism': self.options.get('sasl_mechanism'), + }) + + config.update(self.options.get('kafka_common_config') or {}) + return config + + def close(self): + super().close() + self._kafka_producers = {} + + for consumer in self._kafka_consumers.values(): + consumer.close() + + self._kafka_consumers = {} + + +class Transport(virtual.Transport): + """Kafka Transport.""" + + def as_uri(self, uri: str, include_password=False, mask='**') -> str: + pass + + Channel = Channel + + default_port = DEFAULT_PORT + + driver_type = 'kafka' + driver_name = 'confluentkafka' + + recoverable_connection_errors = ( + NoBrokersAvailable, + ) + + def __init__(self, client, **kwargs): + if confluent_kafka is None: + raise ImportError('The confluent-kafka library is not installed') + super().__init__(client, **kwargs) + + def driver_version(self): + return confluent_kafka.__version__ + + def establish_connection(self): + return super().establish_connection() + + def close_connection(self, connection): + return super().close_connection(connection) diff --git a/kombu/transport/consul.py b/kombu/transport/consul.py index ea275c95..7ace52f6 100644 --- a/kombu/transport/consul.py +++ b/kombu/transport/consul.py @@ -27,6 +27,8 @@ Connection string has the following format: """ +from __future__ import annotations + import socket import uuid from collections import defaultdict @@ -276,24 +278,25 @@ class Transport(virtual.Transport): driver_type = 'consul' driver_name = 'consul' - def __init__(self, *args, **kwargs): - if consul is None: - raise ImportError('Missing python-consul library') - - super().__init__(*args, **kwargs) - - self.connection_errors = ( + if consul: + connection_errors = ( virtual.Transport.connection_errors + ( consul.ConsulException, consul.base.ConsulException ) ) - self.channel_errors = ( + channel_errors = ( virtual.Transport.channel_errors + ( consul.ConsulException, consul.base.ConsulException ) ) + def __init__(self, *args, **kwargs): + if consul is None: + raise ImportError('Missing python-consul library') + + super().__init__(*args, **kwargs) + def verify_connection(self, connection): port = connection.client.port or self.default_port host = connection.client.hostname or DEFAULT_HOST diff --git a/kombu/transport/etcd.py b/kombu/transport/etcd.py index 4d0b0364..2ab85841 100644 --- a/kombu/transport/etcd.py +++ b/kombu/transport/etcd.py @@ -24,6 +24,8 @@ Connection string has the following format: """ +from __future__ import annotations + import os import socket from collections import defaultdict @@ -242,6 +244,15 @@ class Transport(virtual.Transport): implements = virtual.Transport.implements.extend( exchange_type=frozenset(['direct'])) + if etcd: + connection_errors = ( + virtual.Transport.connection_errors + (etcd.EtcdException, ) + ) + + channel_errors = ( + virtual.Transport.channel_errors + (etcd.EtcdException, ) + ) + def __init__(self, *args, **kwargs): """Create a new instance of etcd.Transport.""" if etcd is None: @@ -249,14 +260,6 @@ class Transport(virtual.Transport): super().__init__(*args, **kwargs) - self.connection_errors = ( - virtual.Transport.connection_errors + (etcd.EtcdException, ) - ) - - self.channel_errors = ( - virtual.Transport.channel_errors + (etcd.EtcdException, ) - ) - def verify_connection(self, connection): """Verify the connection works.""" port = connection.client.port or self.default_port diff --git a/kombu/transport/filesystem.py b/kombu/transport/filesystem.py index d66c42d6..9d2b3581 100644 --- a/kombu/transport/filesystem.py +++ b/kombu/transport/filesystem.py @@ -65,7 +65,7 @@ Features * Type: Virtual * Supports Direct: Yes * Supports Topic: Yes -* Supports Fanout: No +* Supports Fanout: Yes * Supports Priority: No * Supports TTL: No @@ -86,22 +86,26 @@ Transport Options * ``store_processed`` - if set to True, all processed messages are backed up to ``processed_folder``. * ``processed_folder`` - directory where are backed up processed files. +* ``control_folder`` - directory where are exchange-queue table stored. """ +from __future__ import annotations + import os import shutil import tempfile import uuid +from collections import namedtuple +from pathlib import Path from queue import Empty from time import monotonic from kombu.exceptions import ChannelError +from kombu.transport import virtual from kombu.utils.encoding import bytes_to_str, str_to_bytes from kombu.utils.json import dumps, loads from kombu.utils.objects import cached_property -from . import virtual - VERSION = (1, 0, 0) __version__ = '.'.join(map(str, VERSION)) @@ -128,10 +132,11 @@ if os.name == 'nt': hfile = win32file._get_osfhandle(file.fileno()) win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped) + elif os.name == 'posix': import fcntl - from fcntl import LOCK_EX, LOCK_NB, LOCK_SH # noqa + from fcntl import LOCK_EX, LOCK_SH def lock(file, flags): """Create file lock.""" @@ -140,14 +145,66 @@ elif os.name == 'posix': def unlock(file): """Remove file lock.""" fcntl.flock(file.fileno(), fcntl.LOCK_UN) + + else: raise RuntimeError( 'Filesystem plugin only defined for NT and POSIX platforms') +exchange_queue_t = namedtuple("exchange_queue_t", + ["routing_key", "pattern", "queue"]) + + class Channel(virtual.Channel): """Filesystem Channel.""" + supports_fanout = True + + def get_table(self, exchange): + file = self.control_folder / f"{exchange}.exchange" + try: + f_obj = file.open("r") + try: + lock(f_obj, LOCK_SH) + exchange_table = loads(bytes_to_str(f_obj.read())) + return [exchange_queue_t(*q) for q in exchange_table] + finally: + unlock(f_obj) + f_obj.close() + except FileNotFoundError: + return [] + except OSError: + raise ChannelError(f"Cannot open {file}") + + def _queue_bind(self, exchange, routing_key, pattern, queue): + file = self.control_folder / f"{exchange}.exchange" + self.control_folder.mkdir(exist_ok=True) + queue_val = exchange_queue_t(routing_key or "", pattern or "", + queue or "") + try: + if file.exists(): + f_obj = file.open("rb+", buffering=0) + lock(f_obj, LOCK_EX) + exchange_table = loads(bytes_to_str(f_obj.read())) + queues = [exchange_queue_t(*q) for q in exchange_table] + if queue_val not in queues: + queues.insert(0, queue_val) + f_obj.seek(0) + f_obj.write(str_to_bytes(dumps(queues))) + else: + f_obj = file.open("wb", buffering=0) + lock(f_obj, LOCK_EX) + queues = [queue_val] + f_obj.write(str_to_bytes(dumps(queues))) + finally: + unlock(f_obj) + f_obj.close() + + def _put_fanout(self, exchange, payload, routing_key, **kwargs): + for q in self.get_table(exchange): + self._put(q.queue, payload, **kwargs) + def _put(self, queue, payload, **kwargs): """Put `message` onto `queue`.""" filename = '{}_{}.{}.msg'.format(int(round(monotonic() * 1000)), @@ -155,7 +212,7 @@ class Channel(virtual.Channel): filename = os.path.join(self.data_folder_out, filename) try: - f = open(filename, 'wb') + f = open(filename, 'wb', buffering=0) lock(f, LOCK_EX) f.write(str_to_bytes(dumps(payload))) except OSError: @@ -187,7 +244,8 @@ class Channel(virtual.Channel): shutil.move(os.path.join(self.data_folder_in, filename), processed_folder) except OSError: - pass # file could be locked, or removed in meantime so ignore + # file could be locked, or removed in meantime so ignore + continue filename = os.path.join(processed_folder, filename) try: @@ -266,10 +324,19 @@ class Channel(virtual.Channel): def processed_folder(self): return self.transport_options.get('processed_folder', 'processed') + @property + def control_folder(self): + return Path(self.transport_options.get('control_folder', 'control')) + class Transport(virtual.Transport): """Filesystem Transport.""" + implements = virtual.Transport.implements.extend( + asynchronous=False, + exchange_type=frozenset(['direct', 'topic', 'fanout']) + ) + Channel = Channel # filesystem backend state is global. global_state = virtual.BrokerState() diff --git a/kombu/transport/librabbitmq.py b/kombu/transport/librabbitmq.py index dec50ccf..37015b18 100644 --- a/kombu/transport/librabbitmq.py +++ b/kombu/transport/librabbitmq.py @@ -3,6 +3,8 @@ .. _`librabbitmq`: https://pypi.org/project/librabbitmq/ """ +from __future__ import annotations + import os import socket import warnings diff --git a/kombu/transport/memory.py b/kombu/transport/memory.py index 3073d1cf..9bfaff8d 100644 --- a/kombu/transport/memory.py +++ b/kombu/transport/memory.py @@ -22,6 +22,8 @@ Connection string is in the following format: """ +from __future__ import annotations + from collections import defaultdict from queue import Queue diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py index db758c18..b923f5f4 100644 --- a/kombu/transport/mongodb.py +++ b/kombu/transport/mongodb.py @@ -33,6 +33,8 @@ Transport Options * ``calc_queue_size``, """ +from __future__ import annotations + import datetime from queue import Empty @@ -63,11 +65,10 @@ class BroadcastCursor: def __init__(self, cursor): self._cursor = cursor - self.purge(rewind=False) def get_size(self): - return self._cursor.count() - self._offset + return self._cursor.collection.count_documents({}) - self._offset def close(self): self._cursor.close() @@ -77,7 +78,7 @@ class BroadcastCursor: self._cursor.rewind() # Fast forward the cursor past old events - self._offset = self._cursor.count() + self._offset = self._cursor.collection.count_documents({}) self._cursor = self._cursor.skip(self._offset) def __iter__(self): @@ -149,11 +150,17 @@ class Channel(virtual.Channel): def _new_queue(self, queue, **kwargs): if self.ttl: - self.queues.update( + self.queues.update_one( {'_id': queue}, - {'_id': queue, - 'options': kwargs, - 'expire_at': self._get_expire(kwargs, 'x-expires')}, + { + '$set': { + '_id': queue, + 'options': kwargs, + 'expire_at': self._get_queue_expire( + kwargs, 'x-expires' + ), + }, + }, upsert=True) def _get(self, queue): @@ -163,10 +170,9 @@ class Channel(virtual.Channel): except StopIteration: msg = None else: - msg = self.messages.find_and_modify( - query={'queue': queue}, + msg = self.messages.find_one_and_delete( + {'queue': queue}, sort=[('priority', pymongo.ASCENDING)], - remove=True, ) if self.ttl: @@ -186,7 +192,7 @@ class Channel(virtual.Channel): if queue in self._fanout_queues: return self._get_broadcast_cursor(queue).get_size() - return self.messages.find({'queue': queue}).count() + return self.messages.count_documents({'queue': queue}) def _put(self, queue, message, **kwargs): data = { @@ -196,13 +202,18 @@ class Channel(virtual.Channel): } if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-message-ttl') + data['expire_at'] = self._get_queue_expire(queue, 'x-message-ttl') + msg_expire = self._get_message_expire(message) + if msg_expire is not None and ( + data['expire_at'] is None or msg_expire < data['expire_at'] + ): + data['expire_at'] = msg_expire - self.messages.insert(data) + self.messages.insert_one(data) def _put_fanout(self, exchange, message, routing_key, **kwargs): - self.broadcast.insert({'payload': dumps(message), - 'queue': exchange}) + self.broadcast.insert_one({'payload': dumps(message), + 'queue': exchange}) def _purge(self, queue): size = self._size(queue) @@ -241,9 +252,9 @@ class Channel(virtual.Channel): data = lookup.copy() if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-expires') + data['expire_at'] = self._get_queue_expire(queue, 'x-expires') - self.routing.update(lookup, data, upsert=True) + self.routing.update_one(lookup, {'$set': data}, upsert=True) def queue_delete(self, queue, **kwargs): self.routing.remove({'queue': queue}) @@ -346,7 +357,7 @@ class Channel(virtual.Channel): def _create_broadcast(self, database): """Create capped collection for broadcast messages.""" - if self.broadcast_collection in database.collection_names(): + if self.broadcast_collection in database.list_collection_names(): return database.create_collection(self.broadcast_collection, @@ -356,20 +367,20 @@ class Channel(virtual.Channel): def _ensure_indexes(self, database): """Ensure indexes on collections.""" messages = database[self.messages_collection] - messages.ensure_index( + messages.create_index( [('queue', 1), ('priority', 1), ('_id', 1)], background=True, ) - database[self.broadcast_collection].ensure_index([('queue', 1)]) + database[self.broadcast_collection].create_index([('queue', 1)]) routing = database[self.routing_collection] - routing.ensure_index([('queue', 1), ('exchange', 1)]) + routing.create_index([('queue', 1), ('exchange', 1)]) if self.ttl: - messages.ensure_index([('expire_at', 1)], expireAfterSeconds=0) - routing.ensure_index([('expire_at', 1)], expireAfterSeconds=0) + messages.create_index([('expire_at', 1)], expireAfterSeconds=0) + routing.create_index([('expire_at', 1)], expireAfterSeconds=0) - database[self.queues_collection].ensure_index( + database[self.queues_collection].create_index( [('expire_at', 1)], expireAfterSeconds=0) def _create_client(self): @@ -427,7 +438,12 @@ class Channel(virtual.Channel): ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor) return ret - def _get_expire(self, queue, argument): + def _get_message_expire(self, message): + value = message.get('properties', {}).get('expiration') + if value is not None: + return self.get_now() + datetime.timedelta(milliseconds=int(value)) + + def _get_queue_expire(self, queue, argument): """Get expiration header named `argument` of queue definition. Note: @@ -452,15 +468,15 @@ class Channel(virtual.Channel): def _update_queues_expire(self, queue): """Update expiration field on queues documents.""" - expire_at = self._get_expire(queue, 'x-expires') + expire_at = self._get_queue_expire(queue, 'x-expires') if not expire_at: return - self.routing.update( - {'queue': queue}, {'$set': {'expire_at': expire_at}}, multi=True) - self.queues.update( - {'_id': queue}, {'$set': {'expire_at': expire_at}}, multi=True) + self.routing.update_many( + {'queue': queue}, {'$set': {'expire_at': expire_at}}) + self.queues.update_many( + {'_id': queue}, {'$set': {'expire_at': expire_at}}) def get_now(self): """Return current time in UTC.""" diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py index f230f911..c8fd3c86 100644 --- a/kombu/transport/pyamqp.py +++ b/kombu/transport/pyamqp.py @@ -68,6 +68,8 @@ hostname from broker URL. This is usefull when failover is used to fill """ +from __future__ import annotations + import amqp from kombu.utils.amq_manager import get_manager diff --git a/kombu/transport/pyro.py b/kombu/transport/pyro.py index 833d9792..7b27cb61 100644 --- a/kombu/transport/pyro.py +++ b/kombu/transport/pyro.py @@ -32,6 +32,8 @@ Transport Options """ +from __future__ import annotations + import sys from queue import Empty, Queue diff --git a/kombu/transport/qpid.py b/kombu/transport/qpid.py index b0f8df13..cfd864d8 100644 --- a/kombu/transport/qpid.py +++ b/kombu/transport/qpid.py @@ -86,13 +86,14 @@ Celery, this can be accomplished by setting the *BROKER_TRANSPORT_OPTIONS* Celery option. """ +from __future__ import annotations + import os import select import socket import ssl import sys import uuid -from collections import OrderedDict from gettext import gettext as _ from queue import Empty from time import monotonic @@ -189,7 +190,7 @@ class QoS: def __init__(self, session, prefetch_count=1): self.session = session self.prefetch_count = 1 - self._not_yet_acked = OrderedDict() + self._not_yet_acked = {} def can_consume(self): """Return True if the :class:`Channel` can consume more messages. @@ -229,8 +230,8 @@ class QoS: """Append message to the list of un-ACKed messages. Add a message, referenced by the delivery_tag, for ACKing, - rejecting, or getting later. Messages are saved into an - :class:`collections.OrderedDict` by delivery_tag. + rejecting, or getting later. Messages are saved into a + dict by delivery_tag. :param message: A received message that has not yet been ACKed. :type message: qpid.messaging.Message diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index 103a8466..6cbfbdcf 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -51,6 +51,8 @@ Transport Options * ``priority_steps`` """ +from __future__ import annotations + import functools import numbers import socket @@ -189,6 +191,7 @@ class GlobalKeyPrefixMixin: PREFIXED_SIMPLE_COMMANDS = [ "HDEL", "HGET", + "HLEN", "HSET", "LLEN", "LPUSH", @@ -208,6 +211,7 @@ class GlobalKeyPrefixMixin: "DEL": {"args_start": 0, "args_end": None}, "BRPOP": {"args_start": 0, "args_end": -1}, "EVALSHA": {"args_start": 2, "args_end": 3}, + "WATCH": {"args_start": 0, "args_end": None}, } def _prefix_args(self, args): @@ -216,8 +220,7 @@ class GlobalKeyPrefixMixin: if command in self.PREFIXED_SIMPLE_COMMANDS: args[0] = self.global_keyprefix + str(args[0]) - - if command in self.PREFIXED_COMPLEX_COMMANDS.keys(): + elif command in self.PREFIXED_COMPLEX_COMMANDS: args_start = self.PREFIXED_COMPLEX_COMMANDS[command]["args_start"] args_end = self.PREFIXED_COMPLEX_COMMANDS[command]["args_end"] @@ -267,6 +270,13 @@ class PrefixedStrictRedis(GlobalKeyPrefixMixin, redis.Redis): self.global_keyprefix = kwargs.pop('global_keyprefix', '') redis.Redis.__init__(self, *args, **kwargs) + def pubsub(self, **kwargs): + return PrefixedRedisPubSub( + self.connection_pool, + global_keyprefix=self.global_keyprefix, + **kwargs, + ) + class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline): """Custom Redis pipeline that takes global_keyprefix into consideration. @@ -281,6 +291,58 @@ class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline): redis.client.Pipeline.__init__(self, *args, **kwargs) +class PrefixedRedisPubSub(redis.client.PubSub): + """Redis pubsub client that takes global_keyprefix into consideration.""" + + PUBSUB_COMMANDS = ( + "SUBSCRIBE", + "UNSUBSCRIBE", + "PSUBSCRIBE", + "PUNSUBSCRIBE", + ) + + def __init__(self, *args, **kwargs): + self.global_keyprefix = kwargs.pop('global_keyprefix', '') + super().__init__(*args, **kwargs) + + def _prefix_args(self, args): + args = list(args) + command = args.pop(0) + + if command in self.PUBSUB_COMMANDS: + args = [ + self.global_keyprefix + str(arg) + for arg in args + ] + + return [command, *args] + + def parse_response(self, *args, **kwargs): + """Parse a response from the Redis server. + + Method wraps ``PubSub.parse_response()`` to remove prefixes of keys + returned by redis command. + """ + ret = super().parse_response(*args, **kwargs) + if ret is None: + return ret + + # response formats + # SUBSCRIBE and UNSUBSCRIBE + # -> [message type, channel, message] + # PSUBSCRIBE and PUNSUBSCRIBE + # -> [message type, pattern, channel, message] + message_type, *channels, message = ret + return [ + message_type, + *[channel[len(self.global_keyprefix):] for channel in channels], + message, + ] + + def execute_command(self, *args, **kwargs): + return super().execute_command(*self._prefix_args(args), **kwargs) + + class QoS(virtual.QoS): """Redis Ack Emulation.""" @@ -353,13 +415,17 @@ class QoS(virtual.QoS): pass def restore_by_tag(self, tag, client=None, leftmost=False): - with self.channel.conn_or_acquire(client) as client: - with client.pipeline() as pipe: - p, _, _ = self._remove_from_indices( - tag, pipe.hget(self.unacked_key, tag)).execute() + + def restore_transaction(pipe): + p = pipe.hget(self.unacked_key, tag) + pipe.multi() + self._remove_from_indices(tag, pipe) if p: M, EX, RK = loads(bytes_to_str(p)) # json is unicode - self.channel._do_restore_message(M, EX, RK, client, leftmost) + self.channel._do_restore_message(M, EX, RK, pipe, leftmost) + + with self.channel.conn_or_acquire(client) as client: + client.transaction(restore_transaction, self.unacked_key) @cached_property def unacked_key(self): @@ -709,32 +775,35 @@ class Channel(virtual.Channel): self.connection.cycle._on_connection_disconnect(connection) def _do_restore_message(self, payload, exchange, routing_key, - client=None, leftmost=False): - with self.conn_or_acquire(client) as client: + pipe, leftmost=False): + try: try: - try: - payload['headers']['redelivered'] = True - except KeyError: - pass - for queue in self._lookup(exchange, routing_key): - (client.lpush if leftmost else client.rpush)( - queue, dumps(payload), - ) - except Exception: - crit('Could not restore message: %r', payload, exc_info=True) + payload['headers']['redelivered'] = True + payload['properties']['delivery_info']['redelivered'] = True + except KeyError: + pass + for queue in self._lookup(exchange, routing_key): + (pipe.lpush if leftmost else pipe.rpush)( + queue, dumps(payload), + ) + except Exception: + crit('Could not restore message: %r', payload, exc_info=True) def _restore(self, message, leftmost=False): if not self.ack_emulation: return super()._restore(message) tag = message.delivery_tag - with self.conn_or_acquire() as client: - with client.pipeline() as pipe: - P, _ = pipe.hget(self.unacked_key, tag) \ - .hdel(self.unacked_key, tag) \ - .execute() + + def restore_transaction(pipe): + P = pipe.hget(self.unacked_key, tag) + pipe.multi() + pipe.hdel(self.unacked_key, tag) if P: M, EX, RK = loads(bytes_to_str(P)) # json is unicode - self._do_restore_message(M, EX, RK, client, leftmost) + self._do_restore_message(M, EX, RK, pipe, leftmost) + + with self.conn_or_acquire() as client: + client.transaction(restore_transaction, self.unacked_key) def _restore_at_beginning(self, message): return self._restore(message, leftmost=True) @@ -1116,8 +1185,8 @@ class Channel(virtual.Channel): if asynchronous: class Connection(connection_cls): - def disconnect(self): - super().disconnect() + def disconnect(self, *args): + super().disconnect(*args) channel._on_connection_disconnect(self) connection_cls = Connection @@ -1208,13 +1277,14 @@ class Transport(virtual.Transport): exchange_type=frozenset(['direct', 'topic', 'fanout']) ) + if redis: + connection_errors, channel_errors = get_redis_error_classes() + def __init__(self, *args, **kwargs): if redis is None: raise ImportError('Missing redis library (pip install redis)') super().__init__(*args, **kwargs) - # Get redis-py exceptions. - self.connection_errors, self.channel_errors = self._get_errors() # All channels share the same poller. self.cycle = MultiChannelPoller() @@ -1231,6 +1301,14 @@ class Transport(virtual.Transport): def _on_disconnect(connection): if connection._sock: loop.remove(connection._sock) + + # must have started polling or this will break reconnection + if cycle.fds: + # stop polling in the event loop + try: + loop.on_tick.remove(on_poll_start) + except KeyError: + pass cycle._on_connection_disconnect = _on_disconnect def on_poll_start(): @@ -1251,10 +1329,6 @@ class Transport(virtual.Transport): """Handle AIO event for one of our file descriptors.""" self.cycle.on_readable(fileno) - def _get_errors(self): - """Utility to import redis-py's exceptions at runtime.""" - return get_redis_error_classes() - if sentinel: class SentinelManagedSSLConnection( diff --git a/kombu/transport/sqlalchemy/__init__.py b/kombu/transport/sqlalchemy/__init__.py index 91f87a86..a61c8ea8 100644 --- a/kombu/transport/sqlalchemy/__init__.py +++ b/kombu/transport/sqlalchemy/__init__.py @@ -50,15 +50,13 @@ Transport Options Moreover parameters of :func:`sqlalchemy.create_engine()` function can be passed as transport options. """ -# SQLAlchemy overrides != False to have special meaning and pep8 complains -# flake8: noqa - +from __future__ import annotations import threading from json import dumps, loads from queue import Empty -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text from sqlalchemy.exc import OperationalError from sqlalchemy.orm import sessionmaker @@ -71,6 +69,13 @@ from .models import ModelBase from .models import Queue as QueueBase from .models import class_registry, metadata +# SQLAlchemy overrides != False to have special meaning and pep8 complains +# flake8: noqa + + + + + VERSION = (1, 4, 1) __version__ = '.'.join(map(str, VERSION)) @@ -164,7 +169,7 @@ class Channel(virtual.Channel): def _get(self, queue): obj = self._get_or_create(queue) if self.session.bind.name == 'sqlite': - self.session.execute('BEGIN IMMEDIATE TRANSACTION') + self.session.execute(text('BEGIN IMMEDIATE TRANSACTION')) try: msg = self.session.query(self.message_cls) \ .with_for_update() \ diff --git a/kombu/transport/sqlalchemy/models.py b/kombu/transport/sqlalchemy/models.py index 45863852..edff572a 100644 --- a/kombu/transport/sqlalchemy/models.py +++ b/kombu/transport/sqlalchemy/models.py @@ -1,10 +1,12 @@ """Kombu transport using SQLAlchemy as the message store.""" +from __future__ import annotations + import datetime from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Index, Integer, Sequence, SmallInteger, String, Text) -from sqlalchemy.orm import relation +from sqlalchemy.orm import relationship from sqlalchemy.schema import MetaData try: @@ -35,7 +37,7 @@ class Queue: @declared_attr def messages(cls): - return relation('Message', backref='queue', lazy='noload') + return relationship('Message', backref='queue', lazy='noload') class Message: diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py index 7ab11772..54e84665 100644 --- a/kombu/transport/virtual/__init__.py +++ b/kombu/transport/virtual/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .base import (AbstractChannel, Base64, BrokerState, Channel, Empty, Management, Message, NotEquivalentError, QoS, Transport, UndeliverableWarning, binding_key_t, queue_binding_t) diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py index 95e539ac..552ebec7 100644 --- a/kombu/transport/virtual/base.py +++ b/kombu/transport/virtual/base.py @@ -3,6 +3,8 @@ Emulates the AMQ API for non-AMQ transports. """ +from __future__ import annotations + import base64 import socket import sys @@ -13,6 +15,7 @@ from itertools import count from multiprocessing.util import Finalize from queue import Empty from time import monotonic, sleep +from typing import TYPE_CHECKING from amqp.protocol import queue_declare_ok_t @@ -26,6 +29,9 @@ from kombu.utils.uuid import uuid from .exchange import STANDARD_EXCHANGE_TYPES +if TYPE_CHECKING: + from types import TracebackType + ARRAY_TYPE_H = 'H' UNDELIVERABLE_FMT = """\ @@ -177,6 +183,8 @@ class QoS: self.channel = channel self.prefetch_count = prefetch_count or 0 + # Standard Python dictionaries do not support setting attributes + # on the object, hence the use of OrderedDict self._delivered = OrderedDict() self._delivered.restored = False self._dirty = set() @@ -462,14 +470,7 @@ class Channel(AbstractChannel, base.StdChannel): typ: cls(self) for typ, cls in self.exchange_types.items() } - try: - self.channel_id = self.connection._avail_channel_ids.pop() - except IndexError: - raise ResourceError( - 'No free channel ids, current={}, channel_max={}'.format( - len(self.connection.channels), - self.connection.channel_max), (20, 10), - ) + self.channel_id = self._get_free_channel_id() topts = self.connection.client.transport_options for opt_name in self.from_transport_options: @@ -727,7 +728,8 @@ class Channel(AbstractChannel, base.StdChannel): message = message.serializable() message['redelivered'] = True for queue in self._lookup( - delivery_info['exchange'], delivery_info['routing_key']): + delivery_info['exchange'], + delivery_info['routing_key']): self._put(queue, message) def _restore_at_beginning(self, message): @@ -804,7 +806,12 @@ class Channel(AbstractChannel, base.StdChannel): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.close() @property @@ -844,6 +851,22 @@ class Channel(AbstractChannel, base.StdChannel): return (self.max_priority - priority) if reverse else priority + def _get_free_channel_id(self): + # Cast to a set for fast lookups, and keep stored as an array + # for lower memory usage. + used_channel_ids = set(self.connection._used_channel_ids) + + for channel_id in range(1, self.connection.channel_max + 1): + if channel_id not in used_channel_ids: + self.connection._used_channel_ids.append(channel_id) + return channel_id + + raise ResourceError( + 'No free channel ids, current={}, channel_max={}'.format( + len(self.connection.channels), + self.connection.channel_max), (20, 10), + ) + class Management(base.Management): """Base class for the AMQP management API.""" @@ -907,9 +930,7 @@ class Transport(base.Transport): polling_interval = client.transport_options.get('polling_interval') if polling_interval is not None: self.polling_interval = polling_interval - self._avail_channel_ids = array( - ARRAY_TYPE_H, range(self.channel_max, 0, -1), - ) + self._used_channel_ids = array(ARRAY_TYPE_H) def create_channel(self, connection): try: @@ -921,7 +942,11 @@ class Transport(base.Transport): def close_channel(self, channel): try: - self._avail_channel_ids.append(channel.channel_id) + try: + self._used_channel_ids.remove(channel.channel_id) + except ValueError: + # channel id already removed + pass try: self.channels.remove(channel) except ValueError: @@ -934,7 +959,7 @@ class Transport(base.Transport): # this channel is then used as the next requested channel. # (returned by ``create_channel``). self._avail_channels.append(self.create_channel(self)) - return self # for drain events + return self # for drain events def close_connection(self, connection): self.cycle.close() diff --git a/kombu/transport/virtual/exchange.py b/kombu/transport/virtual/exchange.py index c6b6161c..b70544cd 100644 --- a/kombu/transport/virtual/exchange.py +++ b/kombu/transport/virtual/exchange.py @@ -4,6 +4,8 @@ Implementations of the standard exchanges defined by the AMQ protocol (excluding the `headers` exchange). """ +from __future__ import annotations + import re from kombu.utils.text import escape_regex diff --git a/kombu/transport/zookeeper.py b/kombu/transport/zookeeper.py index 1a2ab63c..c72ce2f5 100644 --- a/kombu/transport/zookeeper.py +++ b/kombu/transport/zookeeper.py @@ -42,6 +42,8 @@ Transport Options """ +from __future__ import annotations + import os import socket from queue import Empty diff --git a/kombu/utils/__init__.py b/kombu/utils/__init__.py index 304e2dfa..94bb3cdf 100644 --- a/kombu/utils/__init__.py +++ b/kombu/utils/__init__.py @@ -1,5 +1,7 @@ """DEPRECATED - Import from modules below.""" +from __future__ import annotations + from .collections import EqualityDict from .compat import fileno, maybe_fileno, nested, register_after_fork from .div import emergency_dump_state diff --git a/kombu/utils/amq_manager.py b/kombu/utils/amq_manager.py index 7491bb25..f3e429fd 100644 --- a/kombu/utils/amq_manager.py +++ b/kombu/utils/amq_manager.py @@ -1,6 +1,9 @@ """AMQP Management API utilities.""" +from __future__ import annotations + + def get_manager(client, hostname=None, port=None, userid=None, password=None): """Get pyrabbit manager.""" diff --git a/kombu/utils/collections.py b/kombu/utils/collections.py index 77781047..1a0a6d0d 100644 --- a/kombu/utils/collections.py +++ b/kombu/utils/collections.py @@ -1,6 +1,9 @@ """Custom maps, sequences, etc.""" +from __future__ import annotations + + class HashedSeq(list): """Hashed Sequence. diff --git a/kombu/utils/compat.py b/kombu/utils/compat.py index ffc224c1..e1b22f66 100644 --- a/kombu/utils/compat.py +++ b/kombu/utils/compat.py @@ -1,5 +1,7 @@ """Python Compatibility Utilities.""" +from __future__ import annotations + import numbers import sys from contextlib import contextmanager @@ -77,9 +79,18 @@ def detect_environment(): def entrypoints(namespace): """Return setuptools entrypoints for namespace.""" + if sys.version_info >= (3,10): + entry_points = importlib_metadata.entry_points(group=namespace) + else: + entry_points = importlib_metadata.entry_points() + try: + entry_points = entry_points.get(namespace, []) + except AttributeError: + entry_points = entry_points.select(group=namespace) + return ( (ep, ep.load()) - for ep in importlib_metadata.entry_points().get(namespace, []) + for ep in entry_points ) diff --git a/kombu/utils/debug.py b/kombu/utils/debug.py index acc2d60b..bd20948f 100644 --- a/kombu/utils/debug.py +++ b/kombu/utils/debug.py @@ -1,5 +1,7 @@ """Debugging support.""" +from __future__ import annotations + import logging from vine.utils import wraps diff --git a/kombu/utils/div.py b/kombu/utils/div.py index 45be7f94..439b6639 100644 --- a/kombu/utils/div.py +++ b/kombu/utils/div.py @@ -1,5 +1,7 @@ """Div. Utilities.""" +from __future__ import annotations + import sys from .encoding import default_encode diff --git a/kombu/utils/encoding.py b/kombu/utils/encoding.py index 5f58f0fa..42bf2ce9 100644 --- a/kombu/utils/encoding.py +++ b/kombu/utils/encoding.py @@ -5,6 +5,8 @@ applications without crashing from the infamous :exc:`UnicodeDecodeError` exception. """ +from __future__ import annotations + import sys import traceback diff --git a/kombu/utils/eventio.py b/kombu/utils/eventio.py index 48260a48..f8d89d45 100644 --- a/kombu/utils/eventio.py +++ b/kombu/utils/eventio.py @@ -1,5 +1,7 @@ """Selector Utilities.""" +from __future__ import annotations + import errno import math import select as __select__ diff --git a/kombu/utils/functional.py b/kombu/utils/functional.py index 366a0b99..6beb17d7 100644 --- a/kombu/utils/functional.py +++ b/kombu/utils/functional.py @@ -1,5 +1,7 @@ """Functional Utilities.""" +from __future__ import annotations + import inspect import random import threading diff --git a/kombu/utils/imports.py b/kombu/utils/imports.py index fd4482a8..8752fa1a 100644 --- a/kombu/utils/imports.py +++ b/kombu/utils/imports.py @@ -1,5 +1,7 @@ """Import related utilities.""" +from __future__ import annotations + import importlib import sys diff --git a/kombu/utils/json.py b/kombu/utils/json.py index cedaa793..ec6269e2 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -1,75 +1,75 @@ """JSON Serialization Utilities.""" -import datetime -import decimal -import json as stdjson +from __future__ import annotations + +import base64 +import json import uuid +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any, Callable, TypeVar -try: - from django.utils.functional import Promise as DjangoPromise -except ImportError: # pragma: no cover - class DjangoPromise: - """Dummy object.""" +textual_types = () try: - import json - _json_extra_kwargs = {} - - class _DecodeError(Exception): - pass -except ImportError: # pragma: no cover - import simplejson as json - from simplejson.decoder import JSONDecodeError as _DecodeError - _json_extra_kwargs = { - 'use_decimal': False, - 'namedtuple_as_object': False, - } - + from django.utils.functional import Promise -_encoder_cls = type(json._default_encoder) -_default_encoder = None # ... set to JSONEncoder below. + textual_types += (Promise,) +except ImportError: + pass -class JSONEncoder(_encoder_cls): +class JSONEncoder(json.JSONEncoder): """Kombu custom json encoder.""" - def default(self, o, - dates=(datetime.datetime, datetime.date), - times=(datetime.time,), - textual=(decimal.Decimal, uuid.UUID, DjangoPromise), - isinstance=isinstance, - datetime=datetime.datetime, - text_t=str): - reducer = getattr(o, '__json__', None) + def default(self, o): + reducer = getattr(o, "__json__", None) if reducer is not None: return reducer() - else: - if isinstance(o, dates): - if not isinstance(o, datetime): - o = datetime(o.year, o.month, o.day, 0, 0, 0, 0) - r = o.isoformat() - if r.endswith("+00:00"): - r = r[:-6] + "Z" - return r - elif isinstance(o, times): - return o.isoformat() - elif isinstance(o, textual): - return text_t(o) - return super().default(o) + if isinstance(o, textual_types): + return str(o) + + for t, (marker, encoder) in _encoders.items(): + if isinstance(o, t): + return _as(marker, encoder(o)) + + # Bytes is slightly trickier, so we cannot put them directly + # into _encoders, because we use two formats: bytes, and base64. + if isinstance(o, bytes): + try: + return _as("bytes", o.decode("utf-8")) + except UnicodeDecodeError: + return _as("base64", base64.b64encode(o).decode("utf-8")) -_default_encoder = JSONEncoder + return super().default(o) -def dumps(s, _dumps=json.dumps, cls=None, default_kwargs=None, **kwargs): +def _as(t: str, v: Any): + return {"__type__": t, "__value__": v} + + +def dumps( + s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs +): """Serialize object to json string.""" - if not default_kwargs: - default_kwargs = _json_extra_kwargs - return _dumps(s, cls=cls or _default_encoder, - **dict(default_kwargs, **kwargs)) + default_kwargs = default_kwargs or {} + return _dumps(s, cls=cls, **dict(default_kwargs, **kwargs)) + + +def object_hook(o: dict): + """Hook function to perform custom deserialization.""" + if o.keys() == {"__type__", "__value__"}: + decoder = _decoders.get(o["__type__"]) + if decoder: + return decoder(o["__value__"]) + else: + raise ValueError("Unsupported type", type, o) + else: + return o -def loads(s, _loads=json.loads, decode_bytes=True): +def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): """Deserialize json from string.""" # None of the json implementations supports decoding from # a buffer/memoryview, or even reading from a stream @@ -78,14 +78,51 @@ def loads(s, _loads=json.loads, decode_bytes=True): # over. Note that pickle does support buffer/memoryview # </rant> if isinstance(s, memoryview): - s = s.tobytes().decode('utf-8') + s = s.tobytes().decode("utf-8") elif isinstance(s, bytearray): - s = s.decode('utf-8') + s = s.decode("utf-8") elif decode_bytes and isinstance(s, bytes): - s = s.decode('utf-8') - - try: - return _loads(s) - except _DecodeError: - # catch "Unpaired high surrogate" error - return stdjson.loads(s) + s = s.decode("utf-8") + + return _loads(s, object_hook=object_hook) + + +DecoderT = EncoderT = Callable[[Any], Any] +T = TypeVar("T") +EncodedT = TypeVar("EncodedT") + + +def register_type( + t: type[T], + marker: str, + encoder: Callable[[T], EncodedT], + decoder: Callable[[EncodedT], T], +): + """Add support for serializing/deserializing native python type.""" + _encoders[t] = (marker, encoder) + _decoders[marker] = decoder + + +_encoders: dict[type, tuple[str, EncoderT]] = {} +_decoders: dict[str, DecoderT] = { + "bytes": lambda o: o.encode("utf-8"), + "base64": lambda o: base64.b64decode(o.encode("utf-8")), +} + +# NOTE: datetime should be registered before date, +# because datetime is also instance of date. +register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat) +register_type( + date, + "date", + lambda o: o.isoformat(), + lambda o: datetime.fromisoformat(o).date(), +) +register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) +register_type(Decimal, "decimal", str, Decimal) +register_type( + uuid.UUID, + "uuid", + lambda o: {"hex": o.hex, "version": o.version}, + lambda o: uuid.UUID(**o), +) diff --git a/kombu/utils/limits.py b/kombu/utils/limits.py index d82884f5..36d11f1f 100644 --- a/kombu/utils/limits.py +++ b/kombu/utils/limits.py @@ -1,5 +1,7 @@ """Token bucket implementation for rate limiting.""" +from __future__ import annotations + from collections import deque from time import monotonic diff --git a/kombu/utils/objects.py b/kombu/utils/objects.py index 7fef4a2f..eb4dfc2a 100644 --- a/kombu/utils/objects.py +++ b/kombu/utils/objects.py @@ -1,5 +1,7 @@ """Object Utilities.""" +from __future__ import annotations + __all__ = ('cached_property',) try: diff --git a/kombu/utils/scheduling.py b/kombu/utils/scheduling.py index 1875fce4..94286be8 100644 --- a/kombu/utils/scheduling.py +++ b/kombu/utils/scheduling.py @@ -1,5 +1,7 @@ """Scheduling Utilities.""" +from __future__ import annotations + from itertools import count from .imports import symbol_by_name diff --git a/kombu/utils/text.py b/kombu/utils/text.py index 1d5fb9de..fea53347 100644 --- a/kombu/utils/text.py +++ b/kombu/utils/text.py @@ -2,7 +2,10 @@ # flake8: noqa +from __future__ import annotations + from difflib import SequenceMatcher +from typing import Iterable, Iterator from kombu import version_info_t @@ -16,8 +19,7 @@ def escape_regex(p, white=''): for c in p) -def fmatch_iter(needle, haystack, min_ratio=0.6): - # type: (str, Sequence[str], float) -> Iterator[Tuple[float, str]] +def fmatch_iter(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> Iterator[tuple[float, str]]: """Fuzzy match: iteratively. Yields: @@ -29,19 +31,17 @@ def fmatch_iter(needle, haystack, min_ratio=0.6): yield ratio, key -def fmatch_best(needle, haystack, min_ratio=0.6): - # type: (str, Sequence[str], float) -> str +def fmatch_best(needle: str, haystack: Iterable[str], min_ratio: float = 0.6) -> str | None: """Fuzzy match - Find best match (scalar).""" try: return sorted( fmatch_iter(needle, haystack, min_ratio), reverse=True, )[0][1] except IndexError: - pass + return None -def version_string_as_tuple(s): - # type: (str) -> version_info_t +def version_string_as_tuple(s: str) -> version_info_t: """Convert version string to version info tuple.""" v = _unpack_version(*s.split('.')) # X.Y.3a1 -> (X, Y, 3, 'a1') @@ -53,13 +53,17 @@ def version_string_as_tuple(s): return v -def _unpack_version(major, minor=0, micro=0, releaselevel='', serial=''): - # type: (int, int, int, str, str) -> version_info_t +def _unpack_version( + major: str, + minor: str | int = 0, + micro: str | int = 0, + releaselevel: str = '', + serial: str = '' +) -> version_info_t: return version_info_t(int(major), int(minor), micro, releaselevel, serial) -def _splitmicro(micro, releaselevel='', serial=''): - # type: (int, str, str) -> Tuple[int, str, str] +def _splitmicro(micro: str, releaselevel: str = '', serial: str = '') -> tuple[int, str, str]: for index, char in enumerate(micro): if not char.isdigit(): break diff --git a/kombu/utils/time.py b/kombu/utils/time.py index 863f4017..8228d2be 100644 --- a/kombu/utils/time.py +++ b/kombu/utils/time.py @@ -1,11 +1,9 @@ """Time Utilities.""" -# flake8: noqa - +from __future__ import annotations __all__ = ('maybe_s_to_ms',) -def maybe_s_to_ms(v): - # type: (Optional[Union[int, float]]) -> int +def maybe_s_to_ms(v: int | float | None) -> int | None: """Convert seconds to milliseconds, but return None for None.""" return int(float(v) * 1000.0) if v is not None else v diff --git a/kombu/utils/url.py b/kombu/utils/url.py index de3a9139..f5f47701 100644 --- a/kombu/utils/url.py +++ b/kombu/utils/url.py @@ -2,6 +2,8 @@ # flake8: noqa +from __future__ import annotations + from collections.abc import Mapping from functools import partial from typing import NamedTuple diff --git a/kombu/utils/uuid.py b/kombu/utils/uuid.py index 010b3440..9f77dad9 100644 --- a/kombu/utils/uuid.py +++ b/kombu/utils/uuid.py @@ -1,9 +1,11 @@ """UUID utilities.""" +from __future__ import annotations -from uuid import uuid4 +from typing import Callable +from uuid import UUID, uuid4 -def uuid(_uuid=uuid4): +def uuid(_uuid: Callable[[], UUID] = uuid4) -> str: """Generate unique id in UUID4 format. See Also: diff --git a/requirements/default.txt b/requirements/default.txt index 4d27a499..221a92b1 100644 --- a/requirements/default.txt +++ b/requirements/default.txt @@ -1,4 +1,5 @@ -importlib-metadata>=0.18; python_version<"3.8" +importlib-metadata>=3.6; python_version<"3.8" cached_property; python_version<"3.8" -amqp>=5.0.6,<6.0.0 +typing_extensions; python_version<"3.10" +amqp>=5.1.1,<6.0.0 vine diff --git a/requirements/dev.txt b/requirements/dev.txt index 4bce72b7..8a395c83 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,2 +1,2 @@ -https://github.com/celery/py-amqp/zipball/master +https://github.com/celery/py-amqp/zipball/main https://github.com/celery/vine/zipball/master diff --git a/requirements/extras/azureservicebus.txt b/requirements/extras/azureservicebus.txt index 35b96b35..8463b0a4 100644 --- a/requirements/extras/azureservicebus.txt +++ b/requirements/extras/azureservicebus.txt @@ -1 +1 @@ -azure-servicebus>=7.0.0 +azure-servicebus>=7.9.0b1 diff --git a/requirements/extras/azurestoragequeues.txt b/requirements/extras/azurestoragequeues.txt index 2424ee7e..ff2e584d 100644 --- a/requirements/extras/azurestoragequeues.txt +++ b/requirements/extras/azurestoragequeues.txt @@ -1 +1,2 @@ -azure-storage-queue +azure-storage-queue>=12.6.0 +azure-identity>=1.12.0 diff --git a/requirements/extras/brotli.txt b/requirements/extras/brotli.txt index 35b37b35..1bb6d8d7 100644 --- a/requirements/extras/brotli.txt +++ b/requirements/extras/brotli.txt @@ -1,2 +1,2 @@ brotlipy>=0.7.0;platform_python_implementation=="PyPy" -brotli>=1.0.0;platform_python_implementation=="CPython" +brotli>=1.0.9;platform_python_implementation=="CPython" diff --git a/requirements/extras/confluentkafka.txt b/requirements/extras/confluentkafka.txt new file mode 100644 index 00000000..678c2bfd --- /dev/null +++ b/requirements/extras/confluentkafka.txt @@ -0,0 +1 @@ +confluent-kafka~=1.9.0 diff --git a/requirements/extras/consul.txt b/requirements/extras/consul.txt index dd29fbef..7b85dde7 100644 --- a/requirements/extras/consul.txt +++ b/requirements/extras/consul.txt @@ -1 +1 @@ -python-consul>=0.6.0 +python-consul2 diff --git a/requirements/extras/librabbitmq.txt b/requirements/extras/librabbitmq.txt index 866d11bc..874e223c 100644 --- a/requirements/extras/librabbitmq.txt +++ b/requirements/extras/librabbitmq.txt @@ -1 +1 @@ -librabbitmq>=1.5.2 +librabbitmq>=2.0.0 diff --git a/requirements/extras/mongodb.txt b/requirements/extras/mongodb.txt index e635ba45..b6caa029 100644 --- a/requirements/extras/mongodb.txt +++ b/requirements/extras/mongodb.txt @@ -1 +1 @@ -pymongo>=3.3.0 +pymongo>=4.1.1 diff --git a/requirements/extras/pyro.txt b/requirements/extras/pyro.txt index d19b0db3..bb73cdd7 100644 --- a/requirements/extras/pyro.txt +++ b/requirements/extras/pyro.txt @@ -1 +1 @@ -pyro4 +pyro5 diff --git a/requirements/extras/redis.txt b/requirements/extras/redis.txt index 240ddab8..a3749394 100644 --- a/requirements/extras/redis.txt +++ b/requirements/extras/redis.txt @@ -1 +1 @@ -redis>=3.4.1 +redis>=4.2.2,<4.4.0 diff --git a/requirements/extras/sqlalchemy.txt b/requirements/extras/sqlalchemy.txt index 39fb2bef..668520d9 100644 --- a/requirements/extras/sqlalchemy.txt +++ b/requirements/extras/sqlalchemy.txt @@ -1 +1 @@ -sqlalchemy +sqlalchemy>=1.4.47,<2.1 diff --git a/requirements/extras/sqs.txt b/requirements/extras/sqs.txt index c836bc9e..91c76d46 100644 --- a/requirements/extras/sqs.txt +++ b/requirements/extras/sqs.txt @@ -1,3 +1,3 @@ -boto3>=1.4.4 -pycurl==7.43.0.2 # Latest build with wheels provided -urllib3<1.26 # Unittests are faiing with urllib3>=1.28 +boto3==1.26.104 +pycurl==7.43.0.5 # Latest build with wheels provided +urllib3==1.26.15 diff --git a/requirements/extras/zookeeper.txt b/requirements/extras/zookeeper.txt index 81893ea0..84e08132 100644 --- a/requirements/extras/zookeeper.txt +++ b/requirements/extras/zookeeper.txt @@ -1 +1 @@ -kazoo>=1.3.1 +kazoo>=2.8.0 diff --git a/requirements/pkgutils.txt b/requirements/pkgutils.txt index 0bc7d383..0d9c3150 100644 --- a/requirements/pkgutils.txt +++ b/requirements/pkgutils.txt @@ -1,7 +1,8 @@ -setuptools>=20.6.7 +setuptools>=47.0.0 wheel>=0.29.0 -flake8>=2.5.4 -tox>=2.3.1 +flake8==5.0.4 +tox>=4.4.8 sphinx2rst>=1.0 bumpversion pydocstyle==1.1.1 +mypy==1.1.1 diff --git a/requirements/test-ci-windows.txt b/requirements/test-ci-windows.txt index 264b39ca..62d881f0 100644 --- a/requirements/test-ci-windows.txt +++ b/requirements/test-ci-windows.txt @@ -1,5 +1,4 @@ pytest-cov -pytest-travis-fold codecov -r extras/redis.txt -r extras/yaml.txt diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt index 7a52f3bf..6e461a92 100644 --- a/requirements/test-ci.txt +++ b/requirements/test-ci.txt @@ -1,7 +1,7 @@ pytest-cov -pytest-travis-fold codecov -r extras/redis.txt +-r extras/mongodb.txt -r extras/yaml.txt -r extras/msgpack.txt -r extras/azureservicebus.txt @@ -11,4 +11,4 @@ codecov -r extras/zookeeper.txt -r extras/brotli.txt -r extras/zstd.txt --r extras/sqlalchemy.txt +-r extras/sqlalchemy.txt
\ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index 7566efae..e11e3025 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,6 @@ -pytz>dev -pytest~=6.2 +pytz +pytest>=7.1.1 pytest-sugar Pyro4 +hypothesis +pytest-freezer @@ -12,6 +12,41 @@ all_files = 1 # whenever it makes the code more readable. extend-ignore = W504, N806, N802, N801, N803 +[isort] +add_imports = + from __future__ import annotations + +[mypy] +warn_unused_configs = True +strict = False +follow_imports = skip +show_error_codes = True +disallow_untyped_defs = True +ignore_missing_imports = True +files = + kombu/abstract.py, + kombu/utils/time.py, + kombu/utils/uuid.py, + t/unit/utils/test_uuid.py, + kombu/utils/text.py, + kombu/exceptions.py, + t/unit/test_exceptions.py, + kombu/clocks.py, + t/unit/test_clocks.py, + kombu/__init__.py, + kombu/asynchronous/__init__.py, + kombu/asynchronous/aws/__init__.py, + kombu/asynchronous/aws/ext.py, + kombu/asynchronous/aws/sqs/__init__.py, + kombu/asynchronous/aws/sqs/ext.py, + kombu/asynchronous/http/__init__.py, + kombu/transport/__init__.py, + kombu/transport/virtual/__init__.py, + kombu/utils/__init__.py, + kombu/matcher.py, + kombu/asynchronous/semaphore.py + + [pep257] ignore = D102,D104,D203,D105,D213 @@ -1,8 +1,9 @@ #!/usr/bin/env python3 +from __future__ import annotations + import os import re import sys -from distutils.command.install import INSTALL_SCHEMES import setuptools import setuptools.command.test @@ -56,9 +57,6 @@ def fullsplit(path, result=None): return fullsplit(head, [tail] + result) -for scheme in list(INSTALL_SCHEMES.values()): - scheme['data'] = scheme['purelib'] - # if os.path.exists('README.rst'): # long_description = codecs.open('README.rst', 'r', 'utf-8').read() # else: @@ -108,9 +106,12 @@ setup( author=meta['author'], author_email=meta['contact'], url=meta['homepage'], + project_urls={ + 'Source': 'https://github.com/celery/kombu' + }, platforms=['any'], zip_safe=False, - license='BSD', + license='BSD-3-Clause', cmdclass={'test': pytest}, python_requires=">=3.7", install_requires=reqs('default.txt'), @@ -130,6 +131,7 @@ setup( 'azureservicebus': extras('azureservicebus.txt'), 'qpid': extras('qpid.txt'), 'consul': extras('consul.txt'), + 'confluentkafka': extras('confluentkafka.txt'), }, classifiers=[ 'Development Status :: 5 - Production/Stable', diff --git a/t/integration/__init__.py b/t/integration/__init__.py index 1bea4880..cfce4f85 100644 --- a/t/integration/__init__.py +++ b/t/integration/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys diff --git a/t/integration/common.py b/t/integration/common.py index 84f44dd3..82dc3f1b 100644 --- a/t/integration/common.py +++ b/t/integration/common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket from contextlib import closing from time import sleep @@ -136,14 +138,21 @@ class BaseExchangeTypes: message.delivery_info['exchange'] == '' assert message.payload == body - def _consume(self, connection, queue): + def _create_consumer(self, connection, queue): consumer = kombu.Consumer( connection, [queue], accept=['pickle'] ) consumer.register_callback(self._callback) + return consumer + + def _consume_from(self, connection, consumer): with consumer: connection.drain_events(timeout=1) + def _consume(self, connection, queue): + with self._create_consumer(connection, queue): + connection.drain_events(timeout=1) + def _publish(self, channel, exchange, queues=None, routing_key=None): producer = kombu.Producer(channel, exchange=exchange) if routing_key: @@ -213,7 +222,6 @@ class BaseExchangeTypes: channel, ex, [test_queue1, test_queue2, test_queue3], routing_key='t.1' ) - self._consume(conn, test_queue1) self._consume(conn, test_queue2) with pytest.raises(socket.timeout): @@ -398,6 +406,47 @@ class BasePriority: assert msg.payload == data +class BaseMessage: + + def test_ack(self, connection): + with connection as conn: + with closing(conn.SimpleQueue('test_ack')) as queue: + queue.put({'Hello': 'World'}, headers={'k1': 'v1'}) + message = queue.get_nowait() + message.ack() + with pytest.raises(queue.Empty): + queue.get_nowait() + + def test_reject_no_requeue(self, connection): + with connection as conn: + with closing(conn.SimpleQueue('test_reject_no_requeue')) as queue: + queue.put({'Hello': 'World'}, headers={'k1': 'v1'}) + message = queue.get_nowait() + message.reject(requeue=False) + with pytest.raises(queue.Empty): + queue.get_nowait() + + def test_reject_requeue(self, connection): + with connection as conn: + with closing(conn.SimpleQueue('test_reject_requeue')) as queue: + queue.put({'Hello': 'World'}, headers={'k1': 'v1'}) + message = queue.get_nowait() + message.reject(requeue=True) + message2 = queue.get_nowait() + assert message.body == message2.body + message2.ack() + + def test_requeue(self, connection): + with connection as conn: + with closing(conn.SimpleQueue('test_requeue')) as queue: + queue.put({'Hello': 'World'}, headers={'k1': 'v1'}) + message = queue.get_nowait() + message.requeue() + message2 = queue.get_nowait() + assert message.body == message2.body + message2.ack() + + class BaseFailover(BasicFunctionality): def test_connect(self, failover_connection): diff --git a/t/integration/test_kafka.py b/t/integration/test_kafka.py new file mode 100644 index 00000000..2303d887 --- /dev/null +++ b/t/integration/test_kafka.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import pytest + +import kombu + +from .common import (BaseExchangeTypes, BaseFailover, BaseMessage, + BasicFunctionality) + + +def get_connection(hostname, port): + return kombu.Connection( + f'confluentkafka://{hostname}:{port}', + ) + + +def get_failover_connection(hostname, port): + return kombu.Connection( + f'confluentkafka://localhost:12345;confluentkafka://{hostname}:{port}', + connect_timeout=10, + ) + + +@pytest.fixture() +def invalid_connection(): + return kombu.Connection('confluentkafka://localhost:12345') + + +@pytest.fixture() +def connection(): + return get_connection( + hostname='localhost', + port='9092' + ) + + +@pytest.fixture() +def failover_connection(): + return get_failover_connection( + hostname='localhost', + port='9092' + ) + + +@pytest.mark.env('kafka') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_KafkaBasicFunctionality(BasicFunctionality): + pass + + +@pytest.mark.env('kafka') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_KafkaBaseExchangeTypes(BaseExchangeTypes): + + @pytest.mark.skip('fanout is not implemented') + def test_fanout(self, connection): + pass + + +@pytest.mark.env('kafka') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_KafkaFailover(BaseFailover): + pass + + +@pytest.mark.env('kafka') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_KafkaMessage(BaseMessage): + pass diff --git a/t/integration/test_mongodb.py b/t/integration/test_mongodb.py new file mode 100644 index 00000000..445f1389 --- /dev/null +++ b/t/integration/test_mongodb.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import os + +import pytest + +import kombu + +from .common import (BaseExchangeTypes, BaseMessage, BasePriority, + BasicFunctionality) + + +def get_connection(hostname, port, vhost): + return kombu.Connection( + f'mongodb://{hostname}:{port}/{vhost}', + transport_options={'ttl': True}, + ) + + +@pytest.fixture() +def invalid_connection(): + return kombu.Connection('mongodb://localhost:12345?connectTimeoutMS=1') + + +@pytest.fixture() +def connection(request): + return get_connection( + hostname=os.environ.get('MONGODB_HOST', 'localhost'), + port=os.environ.get('MONGODB_27017_TCP', '27017'), + vhost=getattr( + request.config, "slaveinput", {} + ).get("slaveid", 'tests'), + ) + + +@pytest.mark.env('mongodb') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_MongoDBBasicFunctionality(BasicFunctionality): + pass + + +@pytest.mark.env('mongodb') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_MongoDBBaseExchangeTypes(BaseExchangeTypes): + + # MongoDB consumer skips old messages upon initialization. + # Ensure that it's created before test messages are published. + + def test_fanout(self, connection): + ex = kombu.Exchange('test_fanout', type='fanout') + test_queue1 = kombu.Queue('fanout1', exchange=ex) + consumer1 = self._create_consumer(connection, test_queue1) + test_queue2 = kombu.Queue('fanout2', exchange=ex) + consumer2 = self._create_consumer(connection, test_queue2) + + with connection as conn: + with conn.channel() as channel: + self._publish(channel, ex, [test_queue1, test_queue2]) + + self._consume_from(conn, consumer1) + self._consume_from(conn, consumer2) + + +@pytest.mark.env('mongodb') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_MongoDBPriority(BasePriority): + + # drain_events() consumes only one value unlike in py-amqp. + + def test_publish_consume(self, connection): + test_queue = kombu.Queue( + 'priority_test', routing_key='priority_test', max_priority=10 + ) + + received_messages = [] + + def callback(body, message): + received_messages.append(body) + message.ack() + + with connection as conn: + with conn.channel() as channel: + producer = kombu.Producer(channel) + for msg, prio in [ + [{'msg': 'first'}, 3], + [{'msg': 'second'}, 6], + [{'msg': 'third'}, 3], + ]: + producer.publish( + msg, + retry=True, + exchange=test_queue.exchange, + routing_key=test_queue.routing_key, + declare=[test_queue], + serializer='pickle', + priority=prio + ) + consumer = kombu.Consumer( + conn, [test_queue], accept=['pickle'] + ) + consumer.register_callback(callback) + with consumer: + conn.drain_events(timeout=1) + conn.drain_events(timeout=1) + conn.drain_events(timeout=1) + # Second message must be received first + assert received_messages[0] == {'msg': 'second'} + assert received_messages[1] == {'msg': 'first'} + assert received_messages[2] == {'msg': 'third'} + + +@pytest.mark.env('mongodb') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_MongoDBMessage(BaseMessage): + pass diff --git a/t/integration/test_py_amqp.py b/t/integration/test_py_amqp.py index 88ff0ac7..260f164d 100644 --- a/t/integration/test_py_amqp.py +++ b/t/integration/test_py_amqp.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import os import pytest import kombu -from .common import (BaseExchangeTypes, BaseFailover, BasePriority, - BaseTimeToLive, BasicFunctionality) +from .common import (BaseExchangeTypes, BaseFailover, BaseMessage, + BasePriority, BaseTimeToLive, BasicFunctionality) def get_connection(hostname, port, vhost): @@ -73,3 +75,9 @@ class test_PyAMQPPriority(BasePriority): @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_PyAMQPFailover(BaseFailover): pass + + +@pytest.mark.env('py-amqp') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_PyAMQPMessage(BaseMessage): + pass diff --git a/t/integration/test_redis.py b/t/integration/test_redis.py index 72ba803f..b2ae5ab8 100644 --- a/t/integration/test_redis.py +++ b/t/integration/test_redis.py @@ -1,12 +1,17 @@ +from __future__ import annotations + import os +import socket from time import sleep import pytest import redis import kombu +from kombu.transport.redis import Transport -from .common import BaseExchangeTypes, BasePriority, BasicFunctionality +from .common import (BaseExchangeTypes, BaseMessage, BasePriority, + BasicFunctionality) def get_connection( @@ -55,7 +60,11 @@ def test_failed_credentials(): @pytest.mark.env('redis') @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_RedisBasicFunctionality(BasicFunctionality): - pass + def test_failed_connection__ConnectionError(self, invalid_connection): + # method raises transport exception + with pytest.raises(redis.exceptions.ConnectionError) as ex: + invalid_connection.connection + assert ex.type in Transport.connection_errors @pytest.mark.env('redis') @@ -120,3 +129,24 @@ class test_RedisPriority(BasePriority): assert received_messages[0] == {'msg': 'second'} assert received_messages[1] == {'msg': 'first'} assert received_messages[2] == {'msg': 'third'} + + +@pytest.mark.env('redis') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_RedisMessage(BaseMessage): + pass + + +@pytest.mark.env('redis') +def test_RedisConnectTimeout(monkeypatch): + # simulate a connection timeout for a new connection + def connect_timeout(self): + raise socket.timeout + monkeypatch.setattr( + redis.connection.Connection, "_connect", connect_timeout) + + # ensure the timeout raises a TimeoutError + with pytest.raises(redis.exceptions.TimeoutError): + # note the host/port here is irrelevant because + # connect will raise a socket.timeout + kombu.Connection('redis://localhost:12345').connect() @@ -1,9 +1,16 @@ +from __future__ import annotations + +import time from itertools import count +from typing import TYPE_CHECKING from unittest.mock import Mock from kombu.transport import base from kombu.utils import json +if TYPE_CHECKING: + from types import TracebackType + class _ContextMock(Mock): """Dummy class implementing __enter__ and __exit__ @@ -13,7 +20,12 @@ class _ContextMock(Mock): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: pass @@ -191,3 +203,15 @@ class Transport(base.Transport): def close_connection(self, connection): connection.connected = False + + +class TimeoutingTransport(Transport): + recoverable_connection_errors = (TimeoutError,) + + def __init__(self, connect_timeout=1, **kwargs): + self.connect_timeout = connect_timeout + super().__init__(**kwargs) + + def establish_connection(self): + time.sleep(self.connect_timeout) + raise TimeoutError('timed out') @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import pytest diff --git a/t/unit/asynchronous/aws/case.py b/t/unit/asynchronous/aws/case.py index 56c70812..220cd700 100644 --- a/t/unit/asynchronous/aws/case.py +++ b/t/unit/asynchronous/aws/case.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest import t.skip diff --git a/t/unit/asynchronous/aws/sqs/test_connection.py b/t/unit/asynchronous/aws/sqs/test_connection.py index c3dd184b..0c5d2ac9 100644 --- a/t/unit/asynchronous/aws/sqs/test_connection.py +++ b/t/unit/asynchronous/aws/sqs/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import MagicMock, Mock from kombu.asynchronous.aws.ext import boto3 diff --git a/t/unit/asynchronous/aws/sqs/test_queue.py b/t/unit/asynchronous/aws/sqs/test_queue.py index 56812831..70f10a75 100644 --- a/t/unit/asynchronous/aws/sqs/test_queue.py +++ b/t/unit/asynchronous/aws/sqs/test_queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/asynchronous/aws/test_aws.py b/t/unit/asynchronous/aws/test_aws.py index 93d92e4b..736fdf8a 100644 --- a/t/unit/asynchronous/aws/test_aws.py +++ b/t/unit/asynchronous/aws/test_aws.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock from kombu.asynchronous.aws import connect_sqs diff --git a/t/unit/asynchronous/aws/test_connection.py b/t/unit/asynchronous/aws/test_connection.py index 68e3c746..03fc5412 100644 --- a/t/unit/asynchronous/aws/test_connection.py +++ b/t/unit/asynchronous/aws/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import contextmanager from io import StringIO from unittest.mock import Mock diff --git a/t/unit/asynchronous/http/test_curl.py b/t/unit/asynchronous/http/test_curl.py index db8f5f91..51f9128e 100644 --- a/t/unit/asynchronous/http/test_curl.py +++ b/t/unit/asynchronous/http/test_curl.py @@ -1,4 +1,7 @@ -from unittest.mock import Mock, call, patch +from __future__ import annotations + +from io import BytesIO +from unittest.mock import ANY, Mock, call, patch import pytest @@ -131,3 +134,24 @@ class test_CurlClient: x._on_event.assert_called_with(fd, _pycurl.CSELECT_IN) x.on_writable(fd, _pycurl=_pycurl) x._on_event.assert_called_with(fd, _pycurl.CSELECT_OUT) + + def test_setup_request_sets_proxy_when_specified(self): + with patch('kombu.asynchronous.http.curl.pycurl') as _pycurl: + x = self.Client() + proxy_host = 'http://www.example.com' + request = Mock( + name='request', headers={}, auth_mode=None, proxy_host=None + ) + proxied_request = Mock( + name='request', headers={}, auth_mode=None, + proxy_host=proxy_host, proxy_port=123 + ) + x._setup_request( + x.Curl, request, BytesIO(), x.Headers(), _pycurl=_pycurl + ) + with pytest.raises(AssertionError): + x.Curl.setopt.assert_any_call(_pycurl.PROXY, ANY) + x._setup_request( + x.Curl, proxied_request, BytesIO(), x.Headers(), _pycurl + ) + x.Curl.setopt.assert_any_call(_pycurl.PROXY, proxy_host) diff --git a/t/unit/asynchronous/http/test_http.py b/t/unit/asynchronous/http/test_http.py index 6e6abdcb..816bf89d 100644 --- a/t/unit/asynchronous/http/test_http.py +++ b/t/unit/asynchronous/http/test_http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from io import BytesIO from unittest.mock import Mock diff --git a/t/unit/asynchronous/test_hub.py b/t/unit/asynchronous/test_hub.py index eae25357..27b048b9 100644 --- a/t/unit/asynchronous/test_hub.py +++ b/t/unit/asynchronous/test_hub.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import errno -from unittest.mock import Mock, call, patch +from unittest.mock import ANY, Mock, call, patch import pytest from vine import promise @@ -187,6 +189,12 @@ class test_Hub: assert promise() in self.hub._ready assert ret is promise() + def test_call_soon_uses_lock(self): + callback = Mock(name='callback') + with patch.object(self.hub, '_ready_lock', autospec=True) as lock: + self.hub.call_soon(callback) + assert lock.__enter__.called_once() + def test_call_soon__promise_argument(self): callback = promise(Mock(name='callback'), (1, 2, 3)) ret = self.hub.call_soon(callback) @@ -533,3 +541,31 @@ class test_Hub: callbacks[0].assert_called_once_with() callbacks[1].assert_called_once_with() deferred.assert_not_called() + + def test_loop__no_todo_tick_delay(self): + cb = Mock(name='parent') + cb.todo, cb.tick, cb.poller = Mock(), Mock(), Mock() + cb.poller.poll.side_effect = lambda obj: () + self.hub.poller = cb.poller + self.hub.add(2, Mock(), READ) + self.hub.call_soon(cb.todo) + self.hub.on_tick = [cb.tick] + + next(self.hub.loop) + + cb.assert_has_calls([ + call.todo(), + call.tick(), + call.poller.poll(ANY), + ]) + + def test__pop_ready_pops_ready_items(self): + self.hub._ready.add(None) + ret = self.hub._pop_ready() + assert ret == {None} + assert self.hub._ready == set() + + def test__pop_ready_uses_lock(self): + with patch.object(self.hub, '_ready_lock', autospec=True) as lock: + self.hub._pop_ready() + assert lock.__enter__.called_once() diff --git a/t/unit/asynchronous/test_semaphore.py b/t/unit/asynchronous/test_semaphore.py index 8767ca91..5c41a6d8 100644 --- a/t/unit/asynchronous/test_semaphore.py +++ b/t/unit/asynchronous/test_semaphore.py @@ -1,11 +1,13 @@ +from __future__ import annotations + from kombu.asynchronous.semaphore import LaxBoundedSemaphore class test_LaxBoundedSemaphore: - def test_over_release(self): + def test_over_release(self) -> None: x = LaxBoundedSemaphore(2) - calls = [] + calls: list[int] = [] for i in range(1, 21): x.acquire(calls.append, i) x.release() diff --git a/t/unit/asynchronous/test_timer.py b/t/unit/asynchronous/test_timer.py index 20411784..531b3d2e 100644 --- a/t/unit/asynchronous/test_timer.py +++ b/t/unit/asynchronous/test_timer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from unittest.mock import Mock, patch diff --git a/t/unit/conftest.py b/t/unit/conftest.py index b798e3e5..15e31366 100644 --- a/t/unit/conftest.py +++ b/t/unit/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import atexit import builtins import io diff --git a/t/unit/test_clocks.py b/t/unit/test_clocks.py index b4392440..8f2d1340 100644 --- a/t/unit/test_clocks.py +++ b/t/unit/test_clocks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from heapq import heappush from time import time @@ -8,7 +10,7 @@ from kombu.clocks import LamportClock, timetuple class test_LamportClock: - def test_clocks(self): + def test_clocks(self) -> None: c1 = LamportClock() c2 = LamportClock() @@ -29,12 +31,12 @@ class test_LamportClock: c1.adjust(c2.value) assert c1.value == c2.value + 1 - def test_sort(self): + def test_sort(self) -> None: c = LamportClock() pid1 = 'a.example.com:312' pid2 = 'b.example.com:311' - events = [] + events: list[tuple[int, str]] = [] m1 = (c.forward(), pid1) heappush(events, m1) @@ -56,15 +58,15 @@ class test_LamportClock: class test_timetuple: - def test_repr(self): + def test_repr(self) -> None: x = timetuple(133, time(), 'id', Mock()) assert repr(x) - def test_pickleable(self): + def test_pickleable(self) -> None: x = timetuple(133, time(), 'id', 'obj') assert pickle.loads(pickle.dumps(x)) == tuple(x) - def test_order(self): + def test_order(self) -> None: t1 = time() t2 = time() + 300 # windows clock not reliable a = timetuple(133, t1, 'A', 'obj') @@ -81,5 +83,6 @@ class test_timetuple: NotImplemented) assert timetuple(134, t2, 'A', 'obj') > timetuple(133, t1, 'A', 'obj') assert timetuple(134, t1, 'B', 'obj') > timetuple(134, t1, 'A', 'obj') - assert (timetuple(None, t2, 'B', 'obj') > - timetuple(None, t1, 'A', 'obj')) + assert ( + timetuple(None, t2, 'B', 'obj') > timetuple(None, t1, 'A', 'obj') + ) diff --git a/t/unit/test_common.py b/t/unit/test_common.py index 0f669b7d..fd20243f 100644 --- a/t/unit/test_common.py +++ b/t/unit/test_common.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import socket +from typing import TYPE_CHECKING from unittest.mock import Mock, patch import pytest @@ -10,6 +13,9 @@ from kombu.common import (PREFETCH_COUNT_MAX, Broadcast, QoS, collect_replies, maybe_declare, send_reply) from t.mocks import ContextMock, MockPool +if TYPE_CHECKING: + from types import TracebackType + def test_generate_oid(): from uuid import NAMESPACE_OID @@ -338,7 +344,12 @@ class MockConsumer: self.consumers.add(self) return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.consumers.discard(self) diff --git a/t/unit/test_compat.py b/t/unit/test_compat.py index d75ce5df..837d6f22 100644 --- a/t/unit/test_compat.py +++ b/t/unit/test_compat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest @@ -115,12 +117,14 @@ class test_Publisher: pub.close() def test__enter__exit__(self): - pub = compat.Publisher(self.connection, - exchange='test_Publisher_send', - routing_key='rkey') - x = pub.__enter__() - assert x is pub - x.__exit__() + pub = compat.Publisher( + self.connection, + exchange='test_Publisher_send', + routing_key='rkey' + ) + with pub as x: + assert x is pub + assert pub._closed @@ -158,11 +162,14 @@ class test_Consumer: assert q2.exchange.auto_delete def test__enter__exit__(self, n='test__enter__exit__'): - c = compat.Consumer(self.connection, queue=n, exchange=n, - routing_key='rkey') - x = c.__enter__() - assert x is c - x.__exit__() + c = compat.Consumer( + self.connection, + queue=n, + exchange=n, + routing_key='rkey' + ) + with c as x: + assert x is c assert c._closed def test_revive(self, n='test_revive'): diff --git a/t/unit/test_compression.py b/t/unit/test_compression.py index f1f426b7..95139811 100644 --- a/t/unit/test_compression.py +++ b/t/unit/test_compression.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import pytest diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 0b184d3b..c2daee3b 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import socket from copy import copy, deepcopy @@ -9,7 +11,7 @@ from kombu import Connection, Consumer, Producer, parse_url from kombu.connection import Resource from kombu.exceptions import OperationalError from kombu.utils.functional import lazy -from t.mocks import Transport +from t.mocks import TimeoutingTransport, Transport class test_connection_utils: @@ -99,6 +101,19 @@ class test_connection_utils: # see Appendix A of http://www.rabbitmq.com/uri-spec.html self.assert_info(Connection(url), **expected) + @pytest.mark.parametrize('url,expected', [ + ('sqs://user:pass@', + {'userid': None, 'password': None, 'hostname': None, + 'port': None, 'virtual_host': '/'}), + ('sqs://', + {'userid': None, 'password': None, 'hostname': None, + 'port': None, 'virtual_host': '/'}), + ]) + def test_sqs_example_urls(self, url, expected, caplog): + pytest.importorskip('boto3') + self.assert_info(Connection('sqs://'), **expected) + assert not caplog.records + @pytest.mark.skip('TODO: urllib cannot parse ipv6 urls') def test_url_IPV6(self): self.assert_info( @@ -293,7 +308,9 @@ class test_Connection: assert not c.is_evented def test_register_with_event_loop(self): - c = Connection(transport=Mock) + transport = Mock(name='transport') + transport.connection_errors = [] + c = Connection(transport=transport) loop = Mock(name='loop') c.register_with_event_loop(loop) c.transport.register_with_event_loop.assert_called_with( @@ -383,14 +400,12 @@ class test_Connection: qsms.assert_called_with(self.conn.connection) def test__enter____exit__(self): - conn = self.conn - context = conn.__enter__() - assert context is conn - conn.connect() - assert conn.connection.connected - conn.__exit__() - assert conn.connection is None - conn.close() # again + with self.conn as context: + assert context is self.conn + self.conn.connect() + assert self.conn.connection.connected + assert self.conn.connection is None + self.conn.close() # again def test_close_survives_connerror(self): @@ -477,15 +492,52 @@ class test_Connection: def publish(): raise _ConnectionError('failed connection') - self.conn.transport.connection_errors = (_ConnectionError,) + self.conn.get_transport_cls().connection_errors = (_ConnectionError,) ensured = self.conn.ensure(self.conn, publish) with pytest.raises(OperationalError): ensured() + def test_ensure_retry_errors_is_not_looping_infinitely(self): + class _MessageNacked(Exception): + pass + + def publish(): + raise _MessageNacked('NACK') + + with pytest.raises(ValueError): + self.conn.ensure( + self.conn, + publish, + retry_errors=(_MessageNacked,) + ) + + def test_ensure_retry_errors_is_limited_by_max_retries(self): + class _MessageNacked(Exception): + pass + + tries = 0 + + def publish(): + nonlocal tries + tries += 1 + if tries <= 3: + raise _MessageNacked('NACK') + # On the 4th try, we let it pass + return 'ACK' + + ensured = self.conn.ensure( + self.conn, + publish, + max_retries=3, # 3 retries + 1 initial try = 4 tries + retry_errors=(_MessageNacked,) + ) + + assert ensured() == 'ACK' + def test_autoretry(self): myfun = Mock() - self.conn.transport.connection_errors = (KeyError,) + self.conn.get_transport_cls().connection_errors = (KeyError,) def on_call(*args, **kwargs): myfun.side_effect = None @@ -571,6 +623,18 @@ class test_Connection: conn = Connection(transport=MyTransport) assert conn.channel_errors == (KeyError, ValueError) + def test_channel_errors__exception_no_cache(self): + """Ensure the channel_errors can be retrieved without an initialized + transport. + """ + + class MyTransport(Transport): + channel_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.channel_errors == (KeyError,) + def test_connection_errors(self): class MyTransport(Transport): @@ -579,6 +643,80 @@ class test_Connection: conn = Connection(transport=MyTransport) assert conn.connection_errors == (KeyError, ValueError) + def test_connection_errors__exception_no_cache(self): + """Ensure the connection_errors can be retrieved without an + initialized transport. + """ + + class MyTransport(Transport): + connection_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.connection_errors == (KeyError,) + + def test_recoverable_connection_errors(self): + + class MyTransport(Transport): + recoverable_connection_errors = (KeyError, ValueError) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_connection_errors == (KeyError, ValueError) + + def test_recoverable_connection_errors__fallback(self): + """Ensure missing recoverable_connection_errors on the Transport does + not cause a fatal error. + """ + + class MyTransport(Transport): + connection_errors = (KeyError,) + channel_errors = (ValueError,) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_connection_errors == (KeyError, ValueError) + + def test_recoverable_connection_errors__exception_no_cache(self): + """Ensure the recoverable_connection_errors can be retrieved without + an initialized transport. + """ + + class MyTransport(Transport): + recoverable_connection_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.recoverable_connection_errors == (KeyError,) + + def test_recoverable_channel_errors(self): + + class MyTransport(Transport): + recoverable_channel_errors = (KeyError, ValueError) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_channel_errors == (KeyError, ValueError) + + def test_recoverable_channel_errors__fallback(self): + """Ensure missing recoverable_channel_errors on the Transport does not + cause a fatal error. + """ + + class MyTransport(Transport): + pass + + conn = Connection(transport=MyTransport) + assert conn.recoverable_channel_errors == () + + def test_recoverable_channel_errors__exception_no_cache(self): + """Ensure the recoverable_channel_errors can be retrieved without an + initialized transport. + """ + class MyTransport(Transport): + recoverable_channel_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.recoverable_channel_errors == (KeyError,) + def test_multiple_urls_hostname(self): conn = Connection(['example.com;amqp://example.com']) assert conn.as_uri() == 'amqp://guest:**@example.com:5672//' @@ -587,6 +725,47 @@ class test_Connection: conn = Connection('example.com;example.com;') assert conn.as_uri() == 'amqp://guest:**@example.com:5672//' + def test_connection_respect_its_timeout(self): + invalid_port = 1222 + with Connection( + f'amqp://guest:guest@localhost:{invalid_port}//', + transport_options={'max_retries': 2}, + connect_timeout=1 + ) as conn: + with pytest.raises(OperationalError): + conn.default_channel + + def test_connection_failover_without_total_timeout(self): + with Connection( + ['server1', 'server2'], + transport=TimeoutingTransport, + connect_timeout=1, + transport_options={'interval_start': 0, 'interval_step': 0}, + ) as conn: + conn._establish_connection = Mock( + side_effect=conn._establish_connection + ) + with pytest.raises(OperationalError): + conn.default_channel + # Never retried, because `retry_over_time` `timeout` is equal + # to `connect_timeout` + conn._establish_connection.assert_called_once() + + def test_connection_failover_with_total_timeout(self): + with Connection( + ['server1', 'server2'], + transport=TimeoutingTransport, + connect_timeout=1, + transport_options={'connect_retries_timeout': 2, + 'interval_start': 0, 'interval_step': 0}, + ) as conn: + conn._establish_connection = Mock( + side_effect=conn._establish_connection + ) + with pytest.raises(OperationalError): + conn.default_channel + assert conn._establish_connection.call_count == 2 + class test_Connection_with_transport_options: diff --git a/t/unit/test_entity.py b/t/unit/test_entity.py index 52c42b2b..fcb0afb9 100644 --- a/t/unit/test_entity.py +++ b/t/unit/test_entity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from unittest.mock import Mock, call @@ -10,13 +12,13 @@ from kombu.serialization import registry from t.mocks import Transport -def get_conn(): +def get_conn() -> Connection: return Connection(transport=Transport) class test_binding: - def test_constructor(self): + def test_constructor(self) -> None: x = binding( Exchange('foo'), 'rkey', arguments={'barg': 'bval'}, @@ -27,31 +29,31 @@ class test_binding: assert x.arguments == {'barg': 'bval'} assert x.unbind_arguments == {'uarg': 'uval'} - def test_declare(self): + def test_declare(self) -> None: chan = get_conn().channel() x = binding(Exchange('foo'), 'rkey') x.declare(chan) assert 'exchange_declare' in chan - def test_declare_no_exchange(self): + def test_declare_no_exchange(self) -> None: chan = get_conn().channel() x = binding() x.declare(chan) assert 'exchange_declare' not in chan - def test_bind(self): + def test_bind(self) -> None: chan = get_conn().channel() x = binding(Exchange('foo')) x.bind(Exchange('bar')(chan)) assert 'exchange_bind' in chan - def test_unbind(self): + def test_unbind(self) -> None: chan = get_conn().channel() x = binding(Exchange('foo')) x.unbind(Exchange('bar')(chan)) assert 'exchange_unbind' in chan - def test_repr(self): + def test_repr(self) -> None: b = binding(Exchange('foo'), 'rkey') assert 'foo' in repr(b) assert 'rkey' in repr(b) @@ -59,7 +61,7 @@ class test_binding: class test_Exchange: - def test_bound(self): + def test_bound(self) -> None: exchange = Exchange('foo', 'direct') assert not exchange.is_bound assert '<unbound' in repr(exchange) @@ -70,11 +72,11 @@ class test_Exchange: assert bound.channel is chan assert f'bound to chan:{chan.channel_id!r}' in repr(bound) - def test_hash(self): + def test_hash(self) -> None: assert hash(Exchange('a')) == hash(Exchange('a')) assert hash(Exchange('a')) != hash(Exchange('b')) - def test_can_cache_declaration(self): + def test_can_cache_declaration(self) -> None: assert Exchange('a', durable=True).can_cache_declaration assert Exchange('a', durable=False).can_cache_declaration assert not Exchange('a', auto_delete=True).can_cache_declaration @@ -82,12 +84,12 @@ class test_Exchange: 'a', durable=True, auto_delete=True, ).can_cache_declaration - def test_pickle(self): + def test_pickle(self) -> None: e1 = Exchange('foo', 'direct') e2 = pickle.loads(pickle.dumps(e1)) assert e1 == e2 - def test_eq(self): + def test_eq(self) -> None: e1 = Exchange('foo', 'direct') e2 = Exchange('foo', 'direct') assert e1 == e2 @@ -97,7 +99,7 @@ class test_Exchange: assert e1.__eq__(True) == NotImplemented - def test_revive(self): + def test_revive(self) -> None: exchange = Exchange('foo', 'direct') conn = get_conn() chan = conn.channel() @@ -116,7 +118,7 @@ class test_Exchange: assert bound.is_bound assert bound._channel is chan2 - def test_assert_is_bound(self): + def test_assert_is_bound(self) -> None: exchange = Exchange('foo', 'direct') with pytest.raises(NotBoundError): exchange.declare() @@ -126,80 +128,80 @@ class test_Exchange: exchange.bind(chan).declare() assert 'exchange_declare' in chan - def test_set_transient_delivery_mode(self): + def test_set_transient_delivery_mode(self) -> None: exc = Exchange('foo', 'direct', delivery_mode='transient') assert exc.delivery_mode == Exchange.TRANSIENT_DELIVERY_MODE - def test_set_passive_mode(self): + def test_set_passive_mode(self) -> None: exc = Exchange('foo', 'direct', passive=True) assert exc.passive - def test_set_persistent_delivery_mode(self): + def test_set_persistent_delivery_mode(self) -> None: exc = Exchange('foo', 'direct', delivery_mode='persistent') assert exc.delivery_mode == Exchange.PERSISTENT_DELIVERY_MODE - def test_bind_at_instantiation(self): + def test_bind_at_instantiation(self) -> None: assert Exchange('foo', channel=get_conn().channel()).is_bound - def test_create_message(self): + def test_create_message(self) -> None: chan = get_conn().channel() Exchange('foo', channel=chan).Message({'foo': 'bar'}) assert 'prepare_message' in chan - def test_publish(self): + def test_publish(self) -> None: chan = get_conn().channel() Exchange('foo', channel=chan).publish('the quick brown fox') assert 'basic_publish' in chan - def test_delete(self): + def test_delete(self) -> None: chan = get_conn().channel() Exchange('foo', channel=chan).delete() assert 'exchange_delete' in chan - def test__repr__(self): + def test__repr__(self) -> None: b = Exchange('foo', 'topic') assert 'foo(topic)' in repr(b) assert 'Exchange' in repr(b) - def test_bind_to(self): + def test_bind_to(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') bar = Exchange('bar', 'topic') foo(chan).bind_to(bar) assert 'exchange_bind' in chan - def test_bind_to_by_name(self): + def test_bind_to_by_name(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') foo(chan).bind_to('bar') assert 'exchange_bind' in chan - def test_unbind_from(self): + def test_unbind_from(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') bar = Exchange('bar', 'topic') foo(chan).unbind_from(bar) assert 'exchange_unbind' in chan - def test_unbind_from_by_name(self): + def test_unbind_from_by_name(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') foo(chan).unbind_from('bar') assert 'exchange_unbind' in chan - def test_declare__no_declare(self): + def test_declare__no_declare(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic', no_declare=True) foo(chan).declare() assert 'exchange_declare' not in chan - def test_declare__internal_exchange(self): + def test_declare__internal_exchange(self) -> None: chan = get_conn().channel() foo = Exchange('amq.rabbitmq.trace', 'topic') foo(chan).declare() assert 'exchange_declare' not in chan - def test_declare(self): + def test_declare(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic', no_declare=False) foo(chan).declare() @@ -208,33 +210,33 @@ class test_Exchange: class test_Queue: - def setup(self): + def setup(self) -> None: self.exchange = Exchange('foo', 'direct') - def test_constructor_with_actual_exchange(self): + def test_constructor_with_actual_exchange(self) -> None: exchange = Exchange('exchange_name', 'direct') queue = Queue(name='queue_name', exchange=exchange) assert queue.exchange == exchange - def test_constructor_with_string_exchange(self): + def test_constructor_with_string_exchange(self) -> None: exchange_name = 'exchange_name' queue = Queue(name='queue_name', exchange=exchange_name) assert queue.exchange == Exchange(exchange_name) - def test_constructor_with_default_exchange(self): + def test_constructor_with_default_exchange(self) -> None: queue = Queue(name='queue_name') assert queue.exchange == Exchange('') - def test_hash(self): + def test_hash(self) -> None: assert hash(Queue('a')) == hash(Queue('a')) assert hash(Queue('a')) != hash(Queue('b')) - def test_repr_with_bindings(self): + def test_repr_with_bindings(self) -> None: ex = Exchange('foo') x = Queue('foo', bindings=[ex.binding('A'), ex.binding('B')]) assert repr(x) - def test_anonymous(self): + def test_anonymous(self) -> None: chan = Mock() x = Queue(bindings=[binding(Exchange('foo'), 'rkey')]) chan.queue_declare.return_value = 'generated', 0, 0 @@ -242,7 +244,7 @@ class test_Queue: xx.declare() assert xx.name == 'generated' - def test_basic_get__accept_disallowed(self): + def test_basic_get__accept_disallowed(self) -> None: conn = Connection('memory://') q = Queue('foo', exchange=self.exchange) p = Producer(conn) @@ -257,7 +259,7 @@ class test_Queue: with pytest.raises(q.ContentDisallowed): message.decode() - def test_basic_get__accept_allowed(self): + def test_basic_get__accept_allowed(self) -> None: conn = Connection('memory://') q = Queue('foo', exchange=self.exchange) p = Producer(conn) @@ -272,12 +274,12 @@ class test_Queue: payload = message.decode() assert payload['complex'] - def test_when_bound_but_no_exchange(self): + def test_when_bound_but_no_exchange(self) -> None: q = Queue('a') q.exchange = None assert q.when_bound() is None - def test_declare_but_no_exchange(self): + def test_declare_but_no_exchange(self) -> None: q = Queue('a') q.queue_declare = Mock() q.queue_bind = Mock() @@ -287,7 +289,7 @@ class test_Queue: q.queue_declare.assert_called_with( channel=None, nowait=False, passive=False) - def test_declare__no_declare(self): + def test_declare__no_declare(self) -> None: q = Queue('a', no_declare=True) q.queue_declare = Mock() q.queue_bind = Mock() @@ -297,19 +299,19 @@ class test_Queue: q.queue_declare.assert_not_called() q.queue_bind.assert_not_called() - def test_bind_to_when_name(self): + def test_bind_to_when_name(self) -> None: chan = Mock() q = Queue('a') q(chan).bind_to('ex') chan.queue_bind.assert_called() - def test_get_when_no_m2p(self): + def test_get_when_no_m2p(self) -> None: chan = Mock() q = Queue('a')(chan) chan.message_to_python = None assert q.get() - def test_multiple_bindings(self): + def test_multiple_bindings(self) -> None: chan = Mock() q = Queue('mul', [ binding(Exchange('mul1'), 'rkey1'), @@ -327,14 +329,14 @@ class test_Queue: durable=True, ) in chan.exchange_declare.call_args_list - def test_can_cache_declaration(self): + def test_can_cache_declaration(self) -> None: assert Queue('a', durable=True).can_cache_declaration assert Queue('a', durable=False).can_cache_declaration assert not Queue( 'a', queue_arguments={'x-expires': 100} ).can_cache_declaration - def test_eq(self): + def test_eq(self) -> None: q1 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx') q2 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx') assert q1 == q2 @@ -343,14 +345,14 @@ class test_Queue: q3 = Queue('yyy', Exchange('xxx', 'direct'), 'xxx') assert q1 != q3 - def test_exclusive_implies_auto_delete(self): + def test_exclusive_implies_auto_delete(self) -> None: assert Queue('foo', self.exchange, exclusive=True).auto_delete - def test_binds_at_instantiation(self): + def test_binds_at_instantiation(self) -> None: assert Queue('foo', self.exchange, channel=get_conn().channel()).is_bound - def test_also_binds_exchange(self): + def test_also_binds_exchange(self) -> None: chan = get_conn().channel() b = Queue('foo', self.exchange) assert not b.is_bound @@ -361,7 +363,7 @@ class test_Queue: assert b.channel is b.exchange.channel assert b.exchange is not self.exchange - def test_declare(self): + def test_declare(self) -> None: chan = get_conn().channel() b = Queue('foo', self.exchange, 'foo', channel=chan) assert b.is_bound @@ -370,49 +372,49 @@ class test_Queue: assert 'queue_declare' in chan assert 'queue_bind' in chan - def test_get(self): + def test_get(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.get() assert 'basic_get' in b.channel - def test_purge(self): + def test_purge(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.purge() assert 'queue_purge' in b.channel - def test_consume(self): + def test_consume(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.consume('fifafo', None) assert 'basic_consume' in b.channel - def test_cancel(self): + def test_cancel(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.cancel('fifafo') assert 'basic_cancel' in b.channel - def test_delete(self): + def test_delete(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.delete() assert 'queue_delete' in b.channel - def test_queue_unbind(self): + def test_queue_unbind(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.queue_unbind() assert 'queue_unbind' in b.channel - def test_as_dict(self): + def test_as_dict(self) -> None: q = Queue('foo', self.exchange, 'rk') d = q.as_dict(recurse=True) assert d['exchange']['name'] == self.exchange.name - def test_queue_dump(self): + def test_queue_dump(self) -> None: b = binding(self.exchange, 'rk') q = Queue('foo', self.exchange, 'rk', bindings=[b]) d = q.as_dict(recurse=True) assert d['bindings'][0]['routing_key'] == 'rk' registry.dumps(d) - def test__repr__(self): + def test__repr__(self) -> None: b = Queue('foo', self.exchange, 'foo') assert 'foo' in repr(b) assert 'Queue' in repr(b) @@ -420,5 +422,5 @@ class test_Queue: class test_MaybeChannelBound: - def test_repr(self): + def test_repr(self) -> None: assert repr(MaybeChannelBound()) diff --git a/t/unit/test_exceptions.py b/t/unit/test_exceptions.py index bba72a83..7e67fc6f 100644 --- a/t/unit/test_exceptions.py +++ b/t/unit/test_exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock from kombu.exceptions import HttpError @@ -5,5 +7,5 @@ from kombu.exceptions import HttpError class test_HttpError: - def test_str(self): + def test_str(self) -> None: assert str(HttpError(200, 'msg', Mock(name='response'))) diff --git a/t/unit/test_log.py b/t/unit/test_log.py index 4a8cd94c..30c6796f 100644 --- a/t/unit/test_log.py +++ b/t/unit/test_log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import sys from unittest.mock import ANY, Mock, patch diff --git a/t/unit/test_matcher.py b/t/unit/test_matcher.py index 2100fa74..37ae5207 100644 --- a/t/unit/test_matcher.py +++ b/t/unit/test_matcher.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu.matcher import (MatcherNotInstalled, fnmatch, match, register, diff --git a/t/unit/test_message.py b/t/unit/test_message.py index 5b0833dd..4c53cac2 100644 --- a/t/unit/test_message.py +++ b/t/unit/test_message.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from unittest.mock import Mock, patch diff --git a/t/unit/test_messaging.py b/t/unit/test_messaging.py index f8ed437c..4bd467c2 100644 --- a/t/unit/test_messaging.py +++ b/t/unit/test_messaging.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import sys from collections import defaultdict @@ -188,9 +190,8 @@ class test_Producer: def test_enter_exit(self): p = self.connection.Producer() p.release = Mock() - - assert p.__enter__() is p - p.__exit__() + with p as x: + assert x is p p.release.assert_called_with() def test_connection_property_handles_AttributeError(self): diff --git a/t/unit/test_mixins.py b/t/unit/test_mixins.py index 04a56a6c..39b7370f 100644 --- a/t/unit/test_mixins.py +++ b/t/unit/test_mixins.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket from unittest.mock import Mock, patch diff --git a/t/unit/test_pidbox.py b/t/unit/test_pidbox.py index fac46139..cf8a748a 100644 --- a/t/unit/test_pidbox.py +++ b/t/unit/test_pidbox.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import warnings from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor diff --git a/t/unit/test_pools.py b/t/unit/test_pools.py index eb2a556e..1557da95 100644 --- a/t/unit/test_pools.py +++ b/t/unit/test_pools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest @@ -139,7 +141,7 @@ class test_PoolGroup: def test_delitem(self): g = self.MyGroup() g['foo'] - del(g['foo']) + del g['foo'] assert 'foo' not in g def test_Connections(self): diff --git a/t/unit/test_serialization.py b/t/unit/test_serialization.py index 14952e5e..d3fd5c20 100644 --- a/t/unit/test_serialization.py +++ b/t/unit/test_serialization.py @@ -1,5 +1,7 @@ #!/usr/bin/python +from __future__ import annotations + from base64 import b64decode from unittest.mock import call, patch diff --git a/t/unit/test_simple.py b/t/unit/test_simple.py index a5cd899a..50ea880b 100644 --- a/t/unit/test_simple.py +++ b/t/unit/test_simple.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest @@ -91,9 +93,8 @@ class SimpleBase: def test_enter_exit(self): q = self.Queue('test_enter_exit') q.close = Mock() - - assert q.__enter__() is q - q.__exit__() + with q as x: + assert x is q q.close.assert_called_with() def test_qsize(self): diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 944728f1..2b1219fc 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -4,6 +4,8 @@ NOTE: The SQSQueueMock and SQSConnectionMock classes originally come from http://github.com/pcsforeducation/sqs-mock-python. They have been patched slightly. """ +from __future__ import annotations + import base64 import os import random @@ -38,6 +40,11 @@ example_predefined_queues = { 'access_key_id': 'c', 'secret_access_key': 'd', }, + 'queue-3.fifo': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-3.fifo', + 'access_key_id': 'e', + 'secret_access_key': 'f', + } } @@ -151,6 +158,7 @@ class test_Channel: predefined_queues_sqs_conn_mocks = { 'queue-1': SQSClientMock(QueueName='queue-1'), 'queue-2': SQSClientMock(QueueName='queue-2'), + 'queue-3.fifo': SQSClientMock(QueueName='queue-3.fifo') } def mock_sqs(): @@ -330,13 +338,13 @@ class test_Channel: with pytest.raises(Empty): self.channel._get_bulk(self.queue_name) - def test_is_base64_encoded(self): + def test_optional_b64_decode(self): raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \ b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' # noqa b64_enc = base64.b64encode(raw) - assert self.channel._Channel__b64_encoded(b64_enc) - assert not self.channel._Channel__b64_encoded(raw) - assert not self.channel._Channel__b64_encoded(b"test123") + assert self.channel._optional_b64_decode(b64_enc) == raw + assert self.channel._optional_b64_decode(raw) == raw + assert self.channel._optional_b64_decode(b"test123") == b"test123" def test_messages_to_python(self): from kombu.asynchronous.aws.sqs.message import Message @@ -738,6 +746,77 @@ class test_Channel: QueueUrl='https://sqs.us-east-1.amazonaws.com/xxx/queue-1', ReceiptHandle='test_message_id', VisibilityTimeout=20) + def test_predefined_queues_put_to_fifo_queue(self): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': example_predefined_queues, + }) + channel = connection.channel() + + queue_name = 'queue-3.fifo' + + exchange = Exchange('test_SQS', type='direct') + p = messaging.Producer(channel, exchange, routing_key=queue_name) + + queue = Queue(queue_name, exchange, queue_name) + queue(channel).declare() + + channel.sqs = Mock() + sqs_queue_mock = Mock() + channel.sqs.return_value = sqs_queue_mock + p.publish('message') + + sqs_queue_mock.send_message.assert_called_once() + assert 'MessageGroupId' in sqs_queue_mock.send_message.call_args[1] + assert 'MessageDeduplicationId' in \ + sqs_queue_mock.send_message.call_args[1] + + def test_predefined_queues_put_to_queue(self): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': example_predefined_queues, + }) + channel = connection.channel() + + queue_name = 'queue-2' + + exchange = Exchange('test_SQS', type='direct') + p = messaging.Producer(channel, exchange, routing_key=queue_name) + + queue = Queue(queue_name, exchange, queue_name) + queue(channel).declare() + + channel.sqs = Mock() + sqs_queue_mock = Mock() + channel.sqs.return_value = sqs_queue_mock + p.publish('message', DelaySeconds=10) + + sqs_queue_mock.send_message.assert_called_once() + + assert 'DelaySeconds' in sqs_queue_mock.send_message.call_args[1] + assert sqs_queue_mock.send_message.call_args[1]['DelaySeconds'] == 10 + + @pytest.mark.parametrize('predefined_queues', ( + { + 'invalid-fifo-queue-name': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue.fifo', + 'access_key_id': 'a', + 'secret_access_key': 'b' + } + }, + { + 'standard-queue.fifo': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue', + 'access_key_id': 'a', + 'secret_access_key': 'b' + } + } + )) + def test_predefined_queues_invalid_configuration(self, predefined_queues): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': predefined_queues, + }) + with pytest.raises(SQS.InvalidQueueException): + connection.channel() + def test_sts_new_session(self): # Arrange connection = Connection(transport=SQS.Transport, transport_options={ diff --git a/t/unit/transport/test_azureservicebus.py b/t/unit/transport/test_azureservicebus.py index 97775d06..5de93c2f 100644 --- a/t/unit/transport/test_azureservicebus.py +++ b/t/unit/transport/test_azureservicebus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import random @@ -201,14 +203,35 @@ MockQueue = namedtuple( ) +@pytest.fixture(autouse=True) +def sbac_class_patch(): + with patch('kombu.transport.azureservicebus.ServiceBusAdministrationClient') as sbac: # noqa + yield sbac + + +@pytest.fixture(autouse=True) +def sbc_class_patch(): + with patch('kombu.transport.azureservicebus.ServiceBusClient') as sbc: # noqa + yield sbc + + +@pytest.fixture(autouse=True) +def mock_clients( + sbc_class_patch, + sbac_class_patch, + mock_asb, + mock_asb_management +): + sbc_class_patch.from_connection_string.return_value = mock_asb + sbac_class_patch.from_connection_string.return_value = mock_asb_management + + @pytest.fixture def mock_queue(mock_asb, mock_asb_management, random_queue) -> MockQueue: exchange = Exchange('test_servicebus', type='direct') queue = Queue(random_queue, exchange, random_queue) conn = Connection(URL_CREDS, transport=azureservicebus.Transport) channel = conn.channel() - channel._queue_service = mock_asb - channel._queue_mgmt_service = mock_asb_management queue(channel).declare() producer = messaging.Producer(channel, exchange, routing_key=random_queue) diff --git a/t/unit/transport/test_azurestoragequeues.py b/t/unit/transport/test_azurestoragequeues.py new file mode 100644 index 00000000..0c9ef32a --- /dev/null +++ b/t/unit/transport/test_azurestoragequeues.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential + +from kombu import Connection + +pytest.importorskip('azure.storage.queue') +from kombu.transport import azurestoragequeues # noqa + +URL_NOCREDS = 'azurestoragequeues://' +URL_CREDS = 'azurestoragequeues://sas/key%@https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa +AZURITE_CREDS = 'azurestoragequeues://Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==@http://localhost:10001/devstoreaccount1' # noqa +AZURITE_CREDS_DOCKER_COMPOSE = 'azurestoragequeues://Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==@http://azurite:10001/devstoreaccount1' # noqa +DEFAULT_AZURE_URL_CREDS = 'azurestoragequeues://DefaultAzureCredential@https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa +MANAGED_IDENTITY_URL_CREDS = 'azurestoragequeues://ManagedIdentityCredential@https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa + + +def test_queue_service_nocredentials(): + conn = Connection(URL_NOCREDS, transport=azurestoragequeues.Transport) + with pytest.raises( + ValueError, + match='Need a URI like azurestoragequeues://{SAS or access key}@{URL}' + ): + conn.channel() + + +def test_queue_service(): + # Test gettings queue service without credentials + conn = Connection(URL_CREDS, transport=azurestoragequeues.Transport) + with patch('kombu.transport.azurestoragequeues.QueueServiceClient'): + channel = conn.channel() + + # Check the SAS token "sas/key%" has been parsed from the url correctly + assert channel._credential == 'sas/key%' + assert channel._url == 'https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa + + +@pytest.mark.parametrize( + "creds, hostname", + [ + (AZURITE_CREDS, 'localhost'), + (AZURITE_CREDS_DOCKER_COMPOSE, 'azurite'), + ] +) +def test_queue_service_works_for_azurite(creds, hostname): + conn = Connection(creds, transport=azurestoragequeues.Transport) + with patch('kombu.transport.azurestoragequeues.QueueServiceClient'): + channel = conn.channel() + + assert channel._credential == { + 'account_name': 'devstoreaccount1', + 'account_key': 'Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==' # noqa + } + assert channel._url == f'http://{hostname}:10001/devstoreaccount1' # noqa + + +def test_queue_service_works_for_default_azure_credentials(): + conn = Connection( + DEFAULT_AZURE_URL_CREDS, transport=azurestoragequeues.Transport + ) + with patch("kombu.transport.azurestoragequeues.QueueServiceClient"): + channel = conn.channel() + + assert isinstance(channel._credential, DefaultAzureCredential) + assert ( + channel._url + == "https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/" + ) + + +def test_queue_service_works_for_managed_identity_credentials(): + conn = Connection( + MANAGED_IDENTITY_URL_CREDS, transport=azurestoragequeues.Transport + ) + with patch("kombu.transport.azurestoragequeues.QueueServiceClient"): + channel = conn.channel() + + assert isinstance(channel._credential, ManagedIdentityCredential) + assert ( + channel._url + == "https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/" + ) diff --git a/t/unit/transport/test_base.py b/t/unit/transport/test_base.py index 7df12c9e..5beae3c6 100644 --- a/t/unit/transport/test_base.py +++ b/t/unit/transport/test_base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/transport/test_consul.py b/t/unit/transport/test_consul.py index ce6c4fcb..ff110e11 100644 --- a/t/unit/transport/test_consul.py +++ b/t/unit/transport/test_consul.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from array import array from queue import Empty from unittest.mock import Mock @@ -12,6 +15,8 @@ class test_Consul: def setup(self): self.connection = Mock() + self.connection._used_channel_ids = array('H') + self.connection.channel_max = 65535 self.connection.client.transport_options = {} self.connection.client.port = 303 self.consul = self.patching('consul.Consul').return_value diff --git a/t/unit/transport/test_etcd.py b/t/unit/transport/test_etcd.py index 6c75a033..f3fad035 100644 --- a/t/unit/transport/test_etcd.py +++ b/t/unit/transport/test_etcd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from queue import Empty from unittest.mock import Mock, patch diff --git a/t/unit/transport/test_filesystem.py b/t/unit/transport/test_filesystem.py index a8d1708b..20c7f47a 100644 --- a/t/unit/transport/test_filesystem.py +++ b/t/unit/transport/test_filesystem.py @@ -1,4 +1,9 @@ +from __future__ import annotations + import tempfile +from fcntl import LOCK_EX, LOCK_SH +from queue import Empty +from unittest.mock import call, patch import pytest @@ -138,3 +143,162 @@ class test_FilesystemTransport: assert self.q2(consumer_channel).get() self.q2(consumer_channel).purge() assert self.q2(consumer_channel).get() is None + + +@t.skip.if_win32 +class test_FilesystemFanout: + def setup(self): + try: + data_folder_in = tempfile.mkdtemp() + data_folder_out = tempfile.mkdtemp() + control_folder = tempfile.mkdtemp() + except Exception: + pytest.skip("filesystem transport: cannot create tempfiles") + + self.consumer_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_in, + "data_folder_out": data_folder_out, + "control_folder": control_folder, + }, + ) + self.consume_channel = self.consumer_connection.channel() + self.produce_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_out, + "data_folder_out": data_folder_in, + "control_folder": control_folder, + }, + ) + self.producer_channel = self.produce_connection.channel() + self.exchange = Exchange("filesystem_exchange_fanout", type="fanout") + self.q1 = Queue("queue1", exchange=self.exchange) + self.q2 = Queue("queue2", exchange=self.exchange) + + def teardown(self): + # make sure we don't attempt to restore messages at shutdown. + for channel in [self.producer_channel, self.consumer_connection]: + try: + channel._qos._dirty.clear() + except AttributeError: + pass + try: + channel._qos._delivered.clear() + except AttributeError: + pass + + def test_produce_consume(self): + + producer = Producer(self.producer_channel, self.exchange) + consumer1 = Consumer(self.consume_channel, self.q1) + consumer2 = Consumer(self.consume_channel, self.q2) + self.q2(self.consume_channel).declare() + + for i in range(10): + producer.publish({"foo": i}) + + _received1 = [] + _received2 = [] + + def callback1(message_data, message): + _received1.append(message) + message.ack() + + def callback2(message_data, message): + _received2.append(message) + message.ack() + + consumer1.register_callback(callback1) + consumer2.register_callback(callback2) + + consumer1.consume() + consumer2.consume() + + while 1: + try: + self.consume_channel.drain_events() + except Empty: + break + + assert len(_received1) + len(_received2) == 20 + + # queue.delete + for i in range(10): + producer.publish({"foo": i}) + assert self.q1(self.consume_channel).get() + self.q1(self.consume_channel).delete() + self.q1(self.consume_channel).declare() + assert self.q1(self.consume_channel).get() is None + + # queue.purge + assert self.q2(self.consume_channel).get() + self.q2(self.consume_channel).purge() + assert self.q2(self.consume_channel).get() is None + + +@t.skip.if_win32 +class test_FilesystemLock: + def setup(self): + try: + data_folder_in = tempfile.mkdtemp() + data_folder_out = tempfile.mkdtemp() + control_folder = tempfile.mkdtemp() + except Exception: + pytest.skip("filesystem transport: cannot create tempfiles") + + self.consumer_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_in, + "data_folder_out": data_folder_out, + "control_folder": control_folder, + }, + ) + self.consume_channel = self.consumer_connection.channel() + self.produce_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_out, + "data_folder_out": data_folder_in, + "control_folder": control_folder, + }, + ) + self.producer_channel = self.produce_connection.channel() + self.exchange = Exchange("filesystem_exchange_lock", type="fanout") + self.q = Queue("queue1", exchange=self.exchange) + + def teardown(self): + # make sure we don't attempt to restore messages at shutdown. + for channel in [self.producer_channel, self.consumer_connection]: + try: + channel._qos._dirty.clear() + except AttributeError: + pass + try: + channel._qos._delivered.clear() + except AttributeError: + pass + + def test_lock_during_process(self): + producer = Producer(self.producer_channel, self.exchange) + + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + Consumer(self.consume_channel, self.q) + assert unlock_m.call_count == 1 + lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_EX) + + self.q(self.consume_channel).declare() + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + producer.publish({"foo": 1}) + assert unlock_m.call_count == 2 + assert lock_m.call_count == 2 + exchange_file_obj = unlock_m.call_args_list[0][0][0] + msg_file_obj = unlock_m.call_args_list[1][0][0] + assert lock_m.call_args_list == [call(exchange_file_obj, LOCK_SH), + call(msg_file_obj, LOCK_EX)] diff --git a/t/unit/transport/test_librabbitmq.py b/t/unit/transport/test_librabbitmq.py index 58ee7e1e..84f8691e 100644 --- a/t/unit/transport/test_librabbitmq.py +++ b/t/unit/transport/test_librabbitmq.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest diff --git a/t/unit/transport/test_memory.py b/t/unit/transport/test_memory.py index 2c1fe83f..c707d34c 100644 --- a/t/unit/transport/test_memory.py +++ b/t/unit/transport/test_memory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import pytest @@ -131,8 +133,8 @@ class test_MemoryTransport: with pytest.raises(socket.timeout): self.c.drain_events(timeout=0.1) - del(c1) # so pyflakes doesn't complain. - del(c2) + del c1 # so pyflakes doesn't complain. + del c2 def test_drain_events_unregistered_queue(self): c1 = self.c.channel() diff --git a/t/unit/transport/test_mongodb.py b/t/unit/transport/test_mongodb.py index 39976988..6bb5f1f9 100644 --- a/t/unit/transport/test_mongodb.py +++ b/t/unit/transport/test_mongodb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime from queue import Empty from unittest.mock import MagicMock, call, patch @@ -151,16 +153,15 @@ class test_mongodb_channel(BaseMongoDBChannelCase): def test_get(self): - self.set_operation_return_value('messages', 'find_and_modify', { + self.set_operation_return_value('messages', 'find_one_and_delete', { '_id': 'docId', 'payload': '{"some": "data"}', }) event = self.channel._get('foobar') self.assert_collection_accessed('messages') self.assert_operation_called_with( - 'messages', 'find_and_modify', - query={'queue': 'foobar'}, - remove=True, + 'messages', 'find_one_and_delete', + {'queue': 'foobar'}, sort=[ ('priority', pymongo.ASCENDING), ], @@ -168,7 +169,11 @@ class test_mongodb_channel(BaseMongoDBChannelCase): assert event == {'some': 'data'} - self.set_operation_return_value('messages', 'find_and_modify', None) + self.set_operation_return_value( + 'messages', + 'find_one_and_delete', + None, + ) with pytest.raises(Empty): self.channel._get('foobar') @@ -188,7 +193,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._put('foobar', {'some': 'data'}) self.assert_collection_accessed('messages') - self.assert_operation_called_with('messages', 'insert', { + self.assert_operation_called_with('messages', 'insert_one', { 'queue': 'foobar', 'priority': 9, 'payload': '{"some": "data"}', @@ -200,17 +205,17 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._put_fanout('foobar', {'some': 'data'}, 'foo') self.assert_collection_accessed('messages.broadcast') - self.assert_operation_called_with('broadcast', 'insert', { + self.assert_operation_called_with('broadcast', 'insert_one', { 'queue': 'foobar', 'payload': '{"some": "data"}', }) def test_size(self): - self.set_operation_return_value('messages', 'find.count', 77) + self.set_operation_return_value('messages', 'count_documents', 77) result = self.channel._size('foobar') self.assert_collection_accessed('messages') self.assert_operation_called_with( - 'messages', 'find', {'queue': 'foobar'}, + 'messages', 'count_documents', {'queue': 'foobar'}, ) assert result == 77 @@ -227,7 +232,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): assert result == 77 def test_purge(self): - self.set_operation_return_value('messages', 'find.count', 77) + self.set_operation_return_value('messages', 'count_documents', 77) result = self.channel._purge('foobar') self.assert_collection_accessed('messages') @@ -276,11 +281,11 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._queue_bind('test_exchange', 'foo', '*', 'foo') self.assert_collection_accessed('messages.routing') self.assert_operation_called_with( - 'routing', 'update', - {'queue': 'foo', 'pattern': '*', - 'routing_key': 'foo', 'exchange': 'test_exchange'}, + 'routing', 'update_one', {'queue': 'foo', 'pattern': '*', 'routing_key': 'foo', 'exchange': 'test_exchange'}, + {'$set': {'queue': 'foo', 'pattern': '*', + 'routing_key': 'foo', 'exchange': 'test_exchange'}}, upsert=True, ) @@ -317,16 +322,16 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._ensure_indexes(self.channel.client) self.assert_operation_called_with( - 'messages', 'ensure_index', + 'messages', 'create_index', [('queue', 1), ('priority', 1), ('_id', 1)], background=True, ) self.assert_operation_called_with( - 'broadcast', 'ensure_index', + 'broadcast', 'create_index', [('queue', 1)], ) self.assert_operation_called_with( - 'routing', 'ensure_index', [('queue', 1), ('exchange', 1)], + 'routing', 'create_index', [('queue', 1), ('exchange', 1)], ) def test_create_broadcast_cursor(self): @@ -381,9 +386,9 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._new_queue('foobar') self.assert_operation_called_with( - 'queues', 'update', + 'queues', 'update_one', {'_id': 'foobar'}, - {'_id': 'foobar', 'options': {}, 'expire_at': None}, + {'$set': {'_id': 'foobar', 'options': {}, 'expire_at': None}}, upsert=True, ) @@ -393,25 +398,23 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): '_id': 'docId', 'options': {'arguments': {'x-expires': 777}}, }) - self.set_operation_return_value('messages', 'find_and_modify', { + self.set_operation_return_value('messages', 'find_one_and_delete', { '_id': 'docId', 'payload': '{"some": "data"}', }) self.channel._get('foobar') self.assert_collection_accessed('messages', 'messages.queues') self.assert_operation_called_with( - 'messages', 'find_and_modify', - query={'queue': 'foobar'}, - remove=True, + 'messages', 'find_one_and_delete', + {'queue': 'foobar'}, sort=[ ('priority', pymongo.ASCENDING), ], ) self.assert_operation_called_with( - 'routing', 'update', + 'routing', 'update_many', {'queue': 'foobar'}, {'$set': {'expire_at': self.expire_at}}, - multi=True, ) def test_put(self): @@ -422,7 +425,7 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._put('foobar', {'some': 'data'}) self.assert_collection_accessed('messages') - self.assert_operation_called_with('messages', 'insert', { + self.assert_operation_called_with('messages', 'insert_one', { 'queue': 'foobar', 'priority': 9, 'payload': '{"some": "data"}', @@ -437,12 +440,14 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._queue_bind('test_exchange', 'foo', '*', 'foo') self.assert_collection_accessed('messages.routing') self.assert_operation_called_with( - 'routing', 'update', + 'routing', 'update_one', {'queue': 'foo', 'pattern': '*', 'routing_key': 'foo', 'exchange': 'test_exchange'}, - {'queue': 'foo', 'pattern': '*', - 'routing_key': 'foo', 'exchange': 'test_exchange', - 'expire_at': self.expire_at}, + {'$set': { + 'queue': 'foo', 'pattern': '*', + 'routing_key': 'foo', 'exchange': 'test_exchange', + 'expire_at': self.expire_at + }}, upsert=True, ) @@ -456,18 +461,18 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._ensure_indexes(self.channel.client) self.assert_operation_called_with( - 'messages', 'ensure_index', [('expire_at', 1)], + 'messages', 'create_index', [('expire_at', 1)], expireAfterSeconds=0) self.assert_operation_called_with( - 'routing', 'ensure_index', [('expire_at', 1)], + 'routing', 'create_index', [('expire_at', 1)], expireAfterSeconds=0) self.assert_operation_called_with( - 'queues', 'ensure_index', [('expire_at', 1)], expireAfterSeconds=0) + 'queues', 'create_index', [('expire_at', 1)], expireAfterSeconds=0) - def test_get_expire(self): - result = self.channel._get_expire( + def test_get_queue_expire(self): + result = self.channel._get_queue_expire( {'arguments': {'x-expires': 777}}, 'x-expires') self.channel.client.assert_not_called() @@ -478,9 +483,15 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): '_id': 'docId', 'options': {'arguments': {'x-expires': 777}}, }) - result = self.channel._get_expire('foobar', 'x-expires') + result = self.channel._get_queue_expire('foobar', 'x-expires') assert result == self.expire_at + def test_get_message_expire(self): + assert self.channel._get_message_expire({ + 'properties': {'expiration': 777}, + }) == self.expire_at + assert self.channel._get_message_expire({}) is None + def test_update_queues_expire(self): self.set_operation_return_value('queues', 'find_one', { '_id': 'docId', 'options': {'arguments': {'x-expires': 777}}, @@ -489,16 +500,14 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.assert_collection_accessed('messages.routing', 'messages.queues') self.assert_operation_called_with( - 'routing', 'update', + 'routing', 'update_many', {'queue': 'foobar'}, {'$set': {'expire_at': self.expire_at}}, - multi=True, ) self.assert_operation_called_with( - 'queues', 'update', + 'queues', 'update_many', {'_id': 'foobar'}, {'$set': {'expire_at': self.expire_at}}, - multi=True, ) @@ -515,7 +524,7 @@ class test_mongodb_channel_calc_queue_size(BaseMongoDBChannelCase): # Tests def test_size(self): - self.set_operation_return_value('messages', 'find.count', 77) + self.set_operation_return_value('messages', 'count_documents', 77) result = self.channel._size('foobar') diff --git a/t/unit/transport/test_pyamqp.py b/t/unit/transport/test_pyamqp.py index d5f6d7e2..bd402395 100644 --- a/t/unit/transport/test_pyamqp.py +++ b/t/unit/transport/test_pyamqp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from itertools import count from unittest.mock import MagicMock, Mock, patch diff --git a/t/unit/transport/test_pyro.py b/t/unit/transport/test_pyro.py index 325f81ce..258abc9e 100644 --- a/t/unit/transport/test_pyro.py +++ b/t/unit/transport/test_pyro.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import pytest @@ -59,8 +61,8 @@ class test_PyroTransport: with pytest.raises(socket.timeout): self.c.drain_events(timeout=0.1) - del(c1) # so pyflakes doesn't complain. - del(c2) + del c1 # so pyflakes doesn't complain. + del c2 @pytest.mark.skip("requires running Pyro nameserver and Kombu Broker") def test_drain_events_unregistered_queue(self): diff --git a/t/unit/transport/test_qpid.py b/t/unit/transport/test_qpid.py index 351a929b..0048fd38 100644 --- a/t/unit/transport/test_qpid.py +++ b/t/unit/transport/test_qpid.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import select import socket import ssl import sys import time import uuid -from collections import OrderedDict from collections.abc import Callable from itertools import count from queue import Empty @@ -33,7 +34,7 @@ class QpidException(Exception): """ def __init__(self, code=None, text=None): - super(Exception, self).__init__(self) + super().__init__(self) self.code = code self.text = text @@ -57,7 +58,7 @@ class test_QoS__init__: assert qos_limit_two.prefetch_count == 1 def test__init___not_yet_acked_is_initialized(self): - assert isinstance(self.qos._not_yet_acked, OrderedDict) + assert isinstance(self.qos._not_yet_acked, dict) @pytest.mark.skip(reason='Not supported in Python3') diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index c7ea8f67..b14408a6 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +import base64 +import copy import socket import types from collections import defaultdict from itertools import count from queue import Empty from queue import Queue as _Queue +from typing import TYPE_CHECKING from unittest.mock import ANY, Mock, call, patch import pytest @@ -13,7 +18,9 @@ from kombu.exceptions import VersionMismatch from kombu.transport import virtual from kombu.utils import eventio # patch poll from kombu.utils.json import dumps -from t.mocks import ContextMock + +if TYPE_CHECKING: + from types import TracebackType def _redis_modules(): @@ -230,7 +237,12 @@ class Pipeline: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: pass def __getattr__(self, key): @@ -270,9 +282,8 @@ class Channel(redis.Channel): class Transport(redis.Transport): Channel = Channel - - def _get_errors(self): - return ((KeyError,), (IndexError,)) + connection_errors = (KeyError,) + channel_errors = (IndexError,) class test_Channel: @@ -401,38 +412,117 @@ class test_Channel: ) crit.assert_called() - def test_restore(self): + def test_do_restore_message_celery(self): + # Payload value from real Celery project + payload = { + "body": base64.b64encode(dumps([ + [], + {}, + { + "callbacks": None, + "errbacks": None, + "chain": None, + "chord": None, + }, + ]).encode()).decode(), + "content-encoding": "utf-8", + "content-type": "application/json", + "headers": { + "lang": "py", + "task": "common.tasks.test_task", + "id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "shadow": None, + "eta": None, + "expires": None, + "group": None, + "group_index": None, + "retries": 0, + "timelimit": [None, None], + "root_id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "parent_id": None, + "argsrepr": "()", + "kwargsrepr": "{}", + "origin": "gen3437@Desktop", + "ignore_result": False, + }, + "properties": { + "correlation_id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "reply_to": "512f2489-ca40-3585-bc10-9b801a981782", + "delivery_mode": 2, + "delivery_info": { + "exchange": "", + "routing_key": "celery", + }, + "priority": 0, + "body_encoding": "base64", + "delivery_tag": "badb725e-9c3e-45be-b0a4-07e44630519f", + }, + } + result_payload = copy.deepcopy(payload) + result_payload['headers']['redelivered'] = True + result_payload['properties']['delivery_info']['redelivered'] = True + queue = 'celery' + + client = Mock(name='client') + lookup = self.channel._lookup = Mock(name='_lookup') + lookup.return_value = [queue] + + self.channel._do_restore_message( + payload, 'exchange', 'routing_key', client, + ) + + client.rpush.assert_called_with(queue, dumps(result_payload)) + + def test_restore_no_messages(self): message = Mock(name='message') + with patch('kombu.transport.redis.loads') as loads: - loads.return_value = 'M', 'EX', 'RK' + def transaction_handler(restore_transaction, unacked_key): + assert unacked_key == self.channel.unacked_key + pipe = Mock(name='pipe') + pipe.hget.return_value = None + + restore_transaction(pipe) + + pipe.multi.assert_called_once_with() + pipe.hdel.assert_called_once_with( + unacked_key, message.delivery_tag) + loads.assert_not_called() + client = self.channel._create_client = Mock(name='client') client = client() - client.pipeline = ContextMock() - restore = self.channel._do_restore_message = Mock( - name='_do_restore_message', - ) - pipe = client.pipeline.return_value - pipe_hget = Mock(name='pipe.hget') - pipe.hget.return_value = pipe_hget - pipe_hget_hdel = Mock(name='pipe.hget.hdel') - pipe_hget.hdel.return_value = pipe_hget_hdel - result = Mock(name='result') - pipe_hget_hdel.execute.return_value = None, None - + client.transaction.side_effect = transaction_handler self.channel._restore(message) - client.pipeline.assert_called_with() - unacked_key = self.channel.unacked_key - loads.assert_not_called() + client.transaction.assert_called() + + def test_restore_messages(self): + message = Mock(name='message') + + with patch('kombu.transport.redis.loads') as loads: + + def transaction_handler(restore_transaction, unacked_key): + assert unacked_key == self.channel.unacked_key + restore = self.channel._do_restore_message = Mock( + name='_do_restore_message', + ) + result = Mock(name='result') + loads.return_value = 'M', 'EX', 'RK' + pipe = Mock(name='pipe') + pipe.hget.return_value = result + + restore_transaction(pipe) - tag = message.delivery_tag - pipe.hget.assert_called_with(unacked_key, tag) - pipe_hget.hdel.assert_called_with(unacked_key, tag) - pipe_hget_hdel.execute.assert_called_with() + loads.assert_called_with(result) + pipe.multi.assert_called_once_with() + pipe.hdel.assert_called_once_with( + unacked_key, message.delivery_tag) + loads.assert_called() + restore.assert_called_with('M', 'EX', 'RK', pipe, False) - pipe_hget_hdel.execute.return_value = result, None + client = self.channel._create_client = Mock(name='client') + client = client() + client.transaction.side_effect = transaction_handler self.channel._restore(message) - loads.assert_called_with(result) - restore.assert_called_with('M', 'EX', 'RK', client, False) def test_qos_restore_visible(self): client = self.channel._create_client = Mock(name='client') @@ -837,6 +927,26 @@ class test_Channel: call(13, transport.on_readable, 13), ]) + @pytest.mark.parametrize('fds', [{12: 'LISTEN', 13: 'BRPOP'}, {}]) + def test_register_with_event_loop__on_disconnect__loop_cleanup(self, fds): + """Ensure event loop polling stops on disconnect (if started).""" + transport = self.connection.transport + self.connection._sock = None + transport.cycle = Mock(name='cycle') + transport.cycle.fds = fds + conn = Mock(name='conn') + conn.client = Mock(name='client', transport_options={}) + loop = Mock(name='loop') + loop.on_tick = set() + redis.Transport.register_with_event_loop(transport, conn, loop) + assert len(loop.on_tick) == 1 + transport.cycle._on_connection_disconnect(self.connection) + if fds: + assert len(loop.on_tick) == 0 + else: + # on_tick shouldn't be cleared when polling hasn't started + assert len(loop.on_tick) == 1 + def test_configurable_health_check(self): transport = self.connection.transport transport.cycle = Mock(name='cycle') @@ -870,15 +980,22 @@ class test_Channel: redis.Transport.on_readable(transport, 13) cycle.on_readable.assert_called_with(13) - def test_transport_get_errors(self): - assert redis.Transport._get_errors(self.connection.transport) + def test_transport_connection_errors(self): + """Ensure connection_errors are populated.""" + assert redis.Transport.connection_errors + + def test_transport_channel_errors(self): + """Ensure connection_errors are populated.""" + assert redis.Transport.channel_errors def test_transport_driver_version(self): assert redis.Transport.driver_version(self.connection.transport) - def test_transport_get_errors_when_InvalidData_used(self): + def test_transport_errors_when_InvalidData_used(self): from redis import exceptions + from kombu.transport.redis import get_redis_error_classes + class ID(Exception): pass @@ -887,7 +1004,7 @@ class test_Channel: exceptions.InvalidData = ID exceptions.DataError = None try: - errors = redis.Transport._get_errors(self.connection.transport) + errors = get_redis_error_classes() assert errors assert ID in errors[1] finally: @@ -1008,6 +1125,57 @@ class test_Channel: '\x06\x16\x06\x16queue' ) + @patch("redis.client.PubSub.execute_command") + def test_global_keyprefix_pubsub(self, mock_execute_command): + from kombu.transport.redis import PrefixedStrictRedis + + with Connection(transport=Transport) as conn: + client = PrefixedStrictRedis(global_keyprefix='foo_') + + channel = conn.channel() + channel.global_keyprefix = 'foo_' + channel._create_client = Mock() + channel._create_client.return_value = client + channel.subclient.connection = Mock() + channel.active_fanout_queues.add('a') + + channel._subscribe() + mock_execute_command.assert_called_with( + 'PSUBSCRIBE', + 'foo_/{db}.a', + ) + + @patch("redis.client.Pipeline.execute_command") + def test_global_keyprefix_transaction(self, mock_execute_command): + from kombu.transport.redis import PrefixedStrictRedis + + with Connection(transport=Transport) as conn: + def pipeline(transaction=True, shard_hint=None): + pipeline_obj = original_pipeline( + transaction=transaction, shard_hint=shard_hint + ) + mock_execute_command.side_effect = [ + None, None, pipeline_obj, pipeline_obj + ] + return pipeline_obj + + client = PrefixedStrictRedis(global_keyprefix='foo_') + original_pipeline = client.pipeline + client.pipeline = pipeline + + channel = conn.channel() + channel._create_client = Mock() + channel._create_client.return_value = client + + channel.qos.restore_by_tag('test-tag') + assert mock_execute_command is not None + assert mock_execute_command.mock_calls == [ + call('WATCH', 'foo_unacked'), + call('HGET', 'foo_unacked', 'test-tag'), + call('ZREM', 'foo_unacked_index', 'test-tag'), + call('HDEL', 'foo_unacked', 'test-tag') + ] + class test_Redis: diff --git a/t/unit/transport/test_sqlalchemy.py b/t/unit/transport/test_sqlalchemy.py index 5ddca5ac..aa0907f7 100644 --- a/t/unit/transport/test_sqlalchemy.py +++ b/t/unit/transport/test_sqlalchemy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import patch import pytest diff --git a/t/unit/transport/test_transport.py b/t/unit/transport/test_transport.py index ca84dd80..b5b5e6eb 100644 --- a/t/unit/transport/test_transport.py +++ b/t/unit/transport/test_transport.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock, patch from kombu import transport diff --git a/t/unit/transport/test_zookeeper.py b/t/unit/transport/test_zookeeper.py index 21fcac42..8b6d159c 100644 --- a/t/unit/transport/test_zookeeper.py +++ b/t/unit/transport/test_zookeeper.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu import Connection diff --git a/t/unit/transport/virtual/test_base.py b/t/unit/transport/virtual/test_base.py index 681841a0..124e19dd 100644 --- a/t/unit/transport/virtual/test_base.py +++ b/t/unit/transport/virtual/test_base.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import io import socket import warnings +from array import array from time import monotonic from unittest.mock import MagicMock, Mock, patch @@ -178,13 +181,19 @@ class test_Channel: if self.channel._qos is not None: self.channel._qos._on_collect.cancel() - def test_exceeds_channel_max(self): - c = client() - t = c.transport - avail = t._avail_channel_ids = Mock(name='_avail_channel_ids') - avail.pop.side_effect = IndexError() + def test_get_free_channel_id(self): + conn = client() + channel = conn.channel() + assert channel.channel_id == 1 + assert channel._get_free_channel_id() == 2 + + def test_get_free_channel_id__exceeds_channel_max(self): + conn = client() + conn.transport.channel_max = 2 + channel = conn.channel() + channel._get_free_channel_id() with pytest.raises(ResourceError): - virtual.Channel(t) + channel._get_free_channel_id() def test_exchange_bind_interface(self): with pytest.raises(NotImplementedError): @@ -455,9 +464,8 @@ class test_Channel: assert 'could not be delivered' in log[0].message.args[0] def test_context(self): - x = self.channel.__enter__() - assert x is self.channel - x.__exit__() + with self.channel as x: + assert x is self.channel assert x.closed def test_cycle_property(self): @@ -574,8 +582,25 @@ class test_Transport: assert len(self.transport.channels) == 2 self.transport.close_connection(self.transport) assert not self.transport.channels - del(c1) # so pyflakes doesn't complain - del(c2) + del c1 # so pyflakes doesn't complain + del c2 + + def test_create_channel(self): + """Ensure create_channel can create channels successfully.""" + assert self.transport.channels == [] + created_channel = self.transport.create_channel(self.transport) + assert self.transport.channels == [created_channel] + + def test_close_channel(self): + """Ensure close_channel actually removes the channel and updates + _used_channel_ids. + """ + assert self.transport._used_channel_ids == array('H') + created_channel = self.transport.create_channel(self.transport) + assert self.transport._used_channel_ids == array('H', (1,)) + self.transport.close_channel(created_channel) + assert self.transport.channels == [] + assert self.transport._used_channel_ids == array('H') def test_drain_channel(self): channel = self.transport.create_channel(self.transport) diff --git a/t/unit/transport/virtual/test_exchange.py b/t/unit/transport/virtual/test_exchange.py index 55741445..5e5a61d7 100644 --- a/t/unit/transport/virtual/test_exchange.py +++ b/t/unit/transport/virtual/test_exchange.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/utils/test_amq_manager.py b/t/unit/utils/test_amq_manager.py index ca6adb6e..22fb9355 100644 --- a/t/unit/utils/test_amq_manager.py +++ b/t/unit/utils/test_amq_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import patch import pytest diff --git a/t/unit/utils/test_compat.py b/t/unit/utils/test_compat.py index d3159b76..d1fa0055 100644 --- a/t/unit/utils/test_compat.py +++ b/t/unit/utils/test_compat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import sys import types @@ -14,10 +16,14 @@ def test_entrypoints(): 'kombu.utils.compat.importlib_metadata.entry_points', create=True ) as iterep: eps = [Mock(), Mock()] - iterep.return_value = {'kombu.test': eps} + iterep.return_value = ( + {'kombu.test': eps} if sys.version_info < (3, 10) else eps) assert list(entrypoints('kombu.test')) - iterep.assert_called_with() + if sys.version_info < (3, 10): + iterep.assert_called_with() + else: + iterep.assert_called_with(group='kombu.test') eps[0].load.assert_called_with() eps[1].load.assert_called_with() diff --git a/t/unit/utils/test_debug.py b/t/unit/utils/test_debug.py index 020bc849..a4955507 100644 --- a/t/unit/utils/test_debug.py +++ b/t/unit/utils/test_debug.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from unittest.mock import Mock, patch diff --git a/t/unit/utils/test_div.py b/t/unit/utils/test_div.py index b29b6119..a6e988e8 100644 --- a/t/unit/utils/test_div.py +++ b/t/unit/utils/test_div.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from io import BytesIO, StringIO diff --git a/t/unit/utils/test_encoding.py b/t/unit/utils/test_encoding.py index 26e3ef36..81358a7a 100644 --- a/t/unit/utils/test_encoding.py +++ b/t/unit/utils/test_encoding.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from contextlib import contextmanager from unittest.mock import patch diff --git a/t/unit/utils/test_functional.py b/t/unit/utils/test_functional.py index 73a98e52..26f28733 100644 --- a/t/unit/utils/test_functional.py +++ b/t/unit/utils/test_functional.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from itertools import count from unittest.mock import Mock diff --git a/t/unit/utils/test_imports.py b/t/unit/utils/test_imports.py index 8a4873df..8f515bd8 100644 --- a/t/unit/utils/test_imports.py +++ b/t/unit/utils/test_imports.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/utils/test_json.py b/t/unit/utils/test_json.py index 6af1c13b..8dcc7e32 100644 --- a/t/unit/utils/test_json.py +++ b/t/unit/utils/test_json.py @@ -1,14 +1,17 @@ +from __future__ import annotations + +import uuid from collections import namedtuple from datetime import datetime from decimal import Decimal -from unittest.mock import MagicMock, Mock -from uuid import uuid4 import pytest import pytz +from hypothesis import given, settings +from hypothesis import strategies as st from kombu.utils.encoding import str_to_bytes -from kombu.utils.json import _DecodeError, dumps, loads +from kombu.utils.json import dumps, loads class Custom: @@ -21,35 +24,54 @@ class Custom: class test_JSONEncoder: - + @pytest.mark.freeze_time("2015-10-21") def test_datetime(self): now = datetime.utcnow() now_utc = now.replace(tzinfo=pytz.utc) - stripped = datetime(*now.timetuple()[:3]) - serialized = loads(dumps({ + + original = { 'datetime': now, 'tz': now_utc, 'date': now.date(), - 'time': now.time()}, - )) + 'time': now.time(), + } + + serialized = loads(dumps(original)) + + assert serialized == original + + @given(message=st.binary()) + @settings(print_blob=True) + def test_binary(self, message): + serialized = loads(dumps({ + 'args': (message,), + })) assert serialized == { - 'datetime': now.isoformat(), - 'tz': '{}Z'.format(now_utc.isoformat().split('+', 1)[0]), - 'time': now.time().isoformat(), - 'date': stripped.isoformat(), + 'args': [message], } def test_Decimal(self): - d = Decimal('3314132.13363235235324234123213213214134') - assert loads(dumps({'d': d})), {'d': str(d)} + original = {'d': Decimal('3314132.13363235235324234123213213214134')} + serialized = loads(dumps(original)) + + assert serialized == original def test_namedtuple(self): Foo = namedtuple('Foo', ['bar']) assert loads(dumps(Foo(123))) == [123] def test_UUID(self): - id = uuid4() - assert loads(dumps({'u': id})), {'u': str(id)} + constructors = [ + uuid.uuid1, + lambda: uuid.uuid3(uuid.NAMESPACE_URL, "https://example.org"), + uuid.uuid4, + lambda: uuid.uuid5(uuid.NAMESPACE_URL, "https://example.org"), + ] + for constructor in constructors: + id = constructor() + loaded_value = loads(dumps({'u': id})) + assert loaded_value == {'u': id} + assert loaded_value["u"].version == id.version def test_default(self): with pytest.raises(TypeError): @@ -81,9 +103,3 @@ class test_dumps_loads: assert loads( str_to_bytes(dumps({'x': 'z'})), decode_bytes=True) == {'x': 'z'} - - def test_loads_DecodeError(self): - _loads = Mock(name='_loads') - _loads.side_effect = _DecodeError( - MagicMock(), MagicMock(), MagicMock()) - assert loads(dumps({'x': 'z'}), _loads=_loads) == {'x': 'z'} diff --git a/t/unit/utils/test_objects.py b/t/unit/utils/test_objects.py index 93a88b4f..b9f1484a 100644 --- a/t/unit/utils/test_objects.py +++ b/t/unit/utils/test_objects.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from kombu.utils.objects import cached_property diff --git a/t/unit/utils/test_scheduling.py b/t/unit/utils/test_scheduling.py index 44cf01a2..7bc76b96 100644 --- a/t/unit/utils/test_scheduling.py +++ b/t/unit/utils/test_scheduling.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/utils/test_time.py b/t/unit/utils/test_time.py index 660ae8ec..a8f7de0f 100644 --- a/t/unit/utils/test_time.py +++ b/t/unit/utils/test_time.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu.utils.time import maybe_s_to_ms diff --git a/t/unit/utils/test_url.py b/t/unit/utils/test_url.py index 71ea0f9b..f219002b 100644 --- a/t/unit/utils/test_url.py +++ b/t/unit/utils/test_url.py @@ -1,3 +1,5 @@ +from __future__ import annotations + try: from urllib.parse import urlencode except ImportError: diff --git a/t/unit/utils/test_utils.py b/t/unit/utils/test_utils.py index d118d46e..08f95083 100644 --- a/t/unit/utils/test_utils.py +++ b/t/unit/utils/test_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu import version_info_t diff --git a/t/unit/utils/test_uuid.py b/t/unit/utils/test_uuid.py index 05d89125..bc69474a 100644 --- a/t/unit/utils/test_uuid.py +++ b/t/unit/utils/test_uuid.py @@ -1,12 +1,14 @@ +from __future__ import annotations + from kombu.utils.uuid import uuid class test_UUID: - def test_uuid4(self): + def test_uuid4(self) -> None: assert uuid() != uuid() - def test_uuid(self): + def test_uuid(self) -> None: i1 = uuid() i2 = uuid() assert isinstance(i1, str) @@ -1,15 +1,24 @@ [tox] envlist = - {pypy3,3.7,3.8,3.9,3.10}-unit - {pypy3,3.7,3.8,3.9,3.10}-linux-integration-py-amqp - {pypy3,3.7,3.8,3.9,3.10}-linux-integration-redis + {pypy3.9,3.7,3.8,3.9,3.10,3.11}-unit + {pypy3.9,3.7,3.8,3.9,3.10,3.11}-linux-integration-py-amqp + {pypy3.9,3.7,3.8,3.9,3.10,3.11}-linux-integration-redis + {pypy3.9,3.7,3.8,3.9,3.10,3.11}-linux-integration-mongodb + {3.7,3.8,3.9,3.10,3.11}-linux-integration-kafka flake8 - flakeplus apicheck pydocstyle requires = tox-docker>=3.0 +[gh-actions] +python = + 3.7: py37 + 3.8: py38 + 3.9: py39 + 3.10: py310, mypy + 3.11: py311 + [testenv] sitepackages = False setenv = C_DEBUG_TEST = 1 @@ -17,30 +26,38 @@ passenv = DISTUTILS_USE_SDK deps= -r{toxinidir}/requirements/dev.txt - apicheck,pypy3,3.7,3.8,3.9,3.10: -r{toxinidir}/requirements/default.txt - apicheck,pypy3,3.7,3.8,3.9,3.10: -r{toxinidir}/requirements/test.txt - apicheck,pypy3,3.7-linux,3.8-linux,3.9-linux,3.10-linux: -r{toxinidir}/requirements/test-ci.txt - 3.7-windows,3.8-windows,3.9-windows,3.10-windows: -r{toxinidir}/requirements/test-ci-windows.txt + apicheck,pypy3.9,3.7,3.8,3.9,3.10,3.11: -r{toxinidir}/requirements/default.txt + apicheck,pypy3.9,3.7,3.8,3.9,3.10,3.11: -r{toxinidir}/requirements/test.txt + apicheck,pypy3.9,3.7-linux,3.8-linux,3.9-linux,3.10-linux,3.11-linux: -r{toxinidir}/requirements/test-ci.txt + apicheck,3.7-linux,3.8-linux,3.9-linux,3.10-linux,3.11-linux: -r{toxinidir}/requirements/extras/confluentkafka.txt + 3.8-windows,3.9-windows,3.10-windows,3.11-windows: -r{toxinidir}/requirements/test-ci-windows.txt apicheck,linkcheck: -r{toxinidir}/requirements/docs.txt - flake8,flakeplus,pydocstyle: -r{toxinidir}/requirements/pkgutils.txt + flake8,pydocstyle,mypy: -r{toxinidir}/requirements/pkgutils.txt commands = unit: python -bb -m pytest -rxs -xv --cov=kombu --cov-report=xml --no-cov-on-fail {posargs} - integration-py-amqp: py.test -xv -E py-amqp t/integration {posargs:-n2} - integration-redis: py.test -xv -E redis t/integration {posargs:-n2} + integration-py-amqp: pytest -xv -E py-amqp t/integration {posargs:-n2} + integration-redis: pytest -xv -E redis t/integration {posargs:-n2} + integration-mongodb: pytest -xv -E mongodb t/integration {posargs:-n2} + integration-kafka: pytest -xv -E kafka t/integration {posargs:-n2} basepython = + pypy3.9: pypy3.9 + pypy3.8: pypy3.8 3.7: python3.7 - 3.8: python3.8 + 3.8,mypy: python3.8 3.9,apicheck,pydocstyle,flake8,linkcheck,cov: python3.9 - pypy3: pypy3.7 3.10: python3.10 + 3.11: python3.11 install_command = python -m pip --disable-pip-version-check install {opts} {packages} docker = integration-py-amqp: rabbitmq integration-redis: redis + integration-mongodb: mongodb + integration-kafka: zookeeper + integration-kafka: kafka dockerenv = PYAMQP_INTEGRATION_INSTANCE=1 @@ -64,6 +81,42 @@ healthcheck_timeout = 10 healthcheck_retries = 30 healthcheck_start_period = 5 +[docker:mongodb] +image = mongo +ports = 27017:27017/tcp +healthcheck_cmd = /usr/bin/mongosh --eval 'db.runCommand("ping")' +healthcheck_interval = 10 +healthcheck_timeout = 10 +healthcheck_retries = 30 +healthcheck_start_period = 5 + +[docker:zookeeper] +image = bitnami/zookeeper:latest +ports = 2181:2181/tcp +healthcheck_interval = 10 +healthcheck_timeout = 10 +healthcheck_retries = 30 +healthcheck_start_period = 5 +environment = ALLOW_ANONYMOUS_LOGIN=yes + +[docker:kafka] +image = bitnami/kafka:latest +ports = + 9092:9092/tcp +healthcheck_cmd = /bin/bash -c 'kafka-topics.sh --list --bootstrap-server 127.0.0.1:9092' +healthcheck_interval = 10 +healthcheck_timeout = 10 +healthcheck_retries = 30 +healthcheck_start_period = 5 +links = + zookeeper:zookeeper +environment = + KAFKA_BROKER_ID=1 + KAFKA_CFG_LISTENERS=PLAINTEXT://:9092 + KAFKA_CFG_ADVERTISED_LISTENERS=PLAINTEXT://127.0.0.1:9092 + KAFKA_CFG_ZOOKEEPER_CONNECT=zookeeper:2181 + ALLOW_PLAINTEXT_LISTENER=yes + [testenv:apicheck] commands = pip install -U -r{toxinidir}/requirements/dev.txt sphinx-build -j2 -b apicheck -d {envtmpdir}/doctrees docs docs/_build/apicheck @@ -79,3 +132,6 @@ commands = [testenv:pydocstyle] commands = pydocstyle {toxinidir}/kombu + +[testenv:mypy] +commands = python -m mypy --config-file setup.cfg |