diff options
| -rw-r--r-- | doc/build/changelog/unreleased_14/7937.rst | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_concurrency_py3k.py | 43 | ||||
| -rw-r--r-- | test/base/test_concurrency_py3k.py | 38 |
3 files changed, 59 insertions, 30 deletions
diff --git a/doc/build/changelog/unreleased_14/7937.rst b/doc/build/changelog/unreleased_14/7937.rst new file mode 100644 index 000000000..96d80d6cd --- /dev/null +++ b/doc/build/changelog/unreleased_14/7937.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, asyncio + :tickets: 7937 + + Allow setting contextvar values inside async adapted event handlers. + Previously the value set to the contextvar would not be properly + propagated. + diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 6ad099eef..167c42140 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -7,26 +7,28 @@ from __future__ import annotations import asyncio -from contextvars import copy_context as _copy_context +from contextvars import Context import sys import typing from typing import Any from typing import Awaitable from typing import Callable from typing import Coroutine +from typing import Optional from typing import TypeVar from .langhelpers import memoized_property from .. import exc from ..util.typing import Protocol -_T = TypeVar("_T", bound=Any) +_T = TypeVar("_T") if typing.TYPE_CHECKING: class greenlet(Protocol): dead: bool + gr_context: Optional[Context] def __init__(self, fn: Callable[..., Any], driver: "greenlet"): ... @@ -45,15 +47,10 @@ else: from greenlet import greenlet -if not typing.TYPE_CHECKING: - try: - - # If greenlet.gr_context is present in current version of greenlet, - # it will be set with a copy of the current context on creation. - # Refs: https://github.com/python-greenlet/greenlet/pull/198 - getattr(greenlet, "gr_context") - except (ImportError, AttributeError): - _copy_context = None # noqa +# If greenlet.gr_context is present in current version of greenlet, +# it will be set with the current context on creation. +# Refs: https://github.com/python-greenlet/greenlet/pull/198 +_has_gr_context = hasattr(getcurrent(), "gr_context") def is_exit_exception(e: BaseException) -> bool: @@ -75,15 +72,15 @@ class _AsyncIoGreenlet(greenlet): # type: ignore def __init__(self, fn: Callable[..., Any], driver: greenlet): greenlet.__init__(self, fn, driver) self.driver = driver - if _copy_context is not None: - self.gr_context = _copy_context() + if _has_gr_context: + self.gr_context = driver.gr_context def await_only(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. - :func:`await_` calls cannot be nested. + :func:`await_only` calls cannot be nested. :param awaitable: The coroutine to call. @@ -92,8 +89,8 @@ def await_only(awaitable: Awaitable[_T]) -> _T: current = getcurrent() if not isinstance(current, _AsyncIoGreenlet): raise exc.MissingGreenlet( - "greenlet_spawn has not been called; can't call await_() here. " - "Was IO attempted in an unexpected place?" + "greenlet_spawn has not been called; can't call await_only() " + "here. Was IO attempted in an unexpected place?" ) # returns the control to the driver greenlet passing it @@ -107,7 +104,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. - :func:`await_` calls cannot be nested. + :func:`await_fallback` calls cannot be nested. :param awaitable: The coroutine to call. @@ -120,7 +117,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: if loop.is_running(): raise exc.MissingGreenlet( "greenlet_spawn has not been called and asyncio event " - "loop is already running; can't call await_() here. " + "loop is already running; can't call await_fallback() here. " "Was IO attempted in an unexpected place?" ) return loop.run_until_complete(awaitable) # type: ignore[no-any-return] # noqa: E501 @@ -136,7 +133,7 @@ async def greenlet_spawn( ) -> _T: """Runs a sync function ``fn`` in a new greenlet. - The sync function can then use :func:`await_` to wait for async + The sync function can then use :func:`await_only` to wait for async functions. :param fn: The sync callable to call. @@ -144,10 +141,10 @@ async def greenlet_spawn( :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable. """ - result: _T + result: Any context = _AsyncIoGreenlet(fn, getcurrent()) # runs the function synchronously in gl greenlet. If the execution - # is interrupted by await_, context is not dead and result is a + # is interrupted by await_only, context is not dead and result is a # coroutine to wait. If the context is dead the function has # returned, and its result can be returned. switch_occurred = False @@ -156,7 +153,7 @@ async def greenlet_spawn( while not context.dead: switch_occurred = True try: - # wait for a coroutine from await_ and then return its + # wait for a coroutine from await_only and then return its # result back to it. value = await result except BaseException: @@ -175,7 +172,7 @@ async def greenlet_spawn( "detected. This will usually happen when using a non compatible " "DBAPI driver. Please ensure that an async DBAPI is used." ) - return result + return result # type: ignore[no-any-return] class AsyncAdaptedLock: diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py index 79601019e..6a3098a6a 100644 --- a/test/base/test_concurrency_py3k.py +++ b/test/base/test_concurrency_py3k.py @@ -1,4 +1,6 @@ import asyncio +import contextvars +import random import threading from sqlalchemy import exc @@ -88,7 +90,8 @@ class TestAsyncioCompat(fixtures.TestBase): to_await = run1() with expect_raises_message( exc.MissingGreenlet, - r"greenlet_spawn has not been called; can't call await_\(\) here.", + "greenlet_spawn has not been called; " + r"can't call await_only\(\) here.", ): await_only(to_await) @@ -133,7 +136,8 @@ class TestAsyncioCompat(fixtures.TestBase): with expect_raises_message( exc.InvalidRequestError, - r"greenlet_spawn has not been called; can't call await_\(\) here.", + "greenlet_spawn has not been called; " + r"can't call await_only\(\) here.", ): await greenlet_spawn(go) @@ -141,24 +145,44 @@ class TestAsyncioCompat(fixtures.TestBase): @async_test async def test_contextvars(self): - import asyncio - import contextvars - var = contextvars.ContextVar("var") - concurrency = 5 + concurrency = 500 + # NOTE: sleep here is not necessary. It's used to simulate IO + # ensuring that task are not run sequentially async def async_inner(val): + await asyncio.sleep(random.uniform(0.005, 0.015)) eq_(val, var.get()) return var.get() + async def async_set(val): + await asyncio.sleep(random.uniform(0.005, 0.015)) + var.set(val) + def inner(val): retval = await_only(async_inner(val)) eq_(val, var.get()) eq_(retval, val) + + # set the value in a sync function + newval = val + concurrency + var.set(newval) + syncset = await_only(async_inner(newval)) + eq_(newval, var.get()) + eq_(syncset, newval) + + # set the value in an async function + retval = val + 2 * concurrency + await_only(async_set(retval)) + eq_(var.get(), retval) + eq_(await_only(async_inner(retval)), retval) + return retval async def task(val): + await asyncio.sleep(random.uniform(0.005, 0.015)) var.set(val) + await asyncio.sleep(random.uniform(0.005, 0.015)) return await greenlet_spawn(inner, val) values = { @@ -167,7 +191,7 @@ class TestAsyncioCompat(fixtures.TestBase): [task(i) for i in range(concurrency)] ) } - eq_(values, set(range(concurrency))) + eq_(values, set(range(concurrency * 2, concurrency * 3))) @async_test async def test_require_await(self): |
