diff options
author | Federico Caselli <cfederico87@gmail.com> | 2022-04-15 00:29:01 +0200 |
---|---|---|
committer | Federico Caselli <cfederico87@gmail.com> | 2022-04-17 10:44:42 +0200 |
commit | 73f2b5bcdcc431ec77c62ffd49bdcbf4718fbfc1 (patch) | |
tree | c0cddb5695e80213bbfc19e2d246c9be93379094 /lib/sqlalchemy/util/_concurrency_py3k.py | |
parent | c538f810bce57472c8960a0a6c4c61024b00f3ed (diff) | |
download | sqlalchemy-73f2b5bcdcc431ec77c62ffd49bdcbf4718fbfc1.tar.gz |
Allow contextvars to be set in events when using asyncio
Allow setting contextvar values inside async adapted event handlers.
Previously the value set to the contextvar would not be properly
propagated.
Fixes: #7937
Change-Id: I787aa869f8d057579e13e32c749f05f184ffd02a
Diffstat (limited to 'lib/sqlalchemy/util/_concurrency_py3k.py')
-rw-r--r-- | lib/sqlalchemy/util/_concurrency_py3k.py | 43 |
1 files changed, 20 insertions, 23 deletions
diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 6ad099eef..167c42140 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -7,26 +7,28 @@ from __future__ import annotations import asyncio -from contextvars import copy_context as _copy_context +from contextvars import Context import sys import typing from typing import Any from typing import Awaitable from typing import Callable from typing import Coroutine +from typing import Optional from typing import TypeVar from .langhelpers import memoized_property from .. import exc from ..util.typing import Protocol -_T = TypeVar("_T", bound=Any) +_T = TypeVar("_T") if typing.TYPE_CHECKING: class greenlet(Protocol): dead: bool + gr_context: Optional[Context] def __init__(self, fn: Callable[..., Any], driver: "greenlet"): ... @@ -45,15 +47,10 @@ else: from greenlet import greenlet -if not typing.TYPE_CHECKING: - try: - - # If greenlet.gr_context is present in current version of greenlet, - # it will be set with a copy of the current context on creation. - # Refs: https://github.com/python-greenlet/greenlet/pull/198 - getattr(greenlet, "gr_context") - except (ImportError, AttributeError): - _copy_context = None # noqa +# If greenlet.gr_context is present in current version of greenlet, +# it will be set with the current context on creation. +# Refs: https://github.com/python-greenlet/greenlet/pull/198 +_has_gr_context = hasattr(getcurrent(), "gr_context") def is_exit_exception(e: BaseException) -> bool: @@ -75,15 +72,15 @@ class _AsyncIoGreenlet(greenlet): # type: ignore def __init__(self, fn: Callable[..., Any], driver: greenlet): greenlet.__init__(self, fn, driver) self.driver = driver - if _copy_context is not None: - self.gr_context = _copy_context() + if _has_gr_context: + self.gr_context = driver.gr_context def await_only(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. - :func:`await_` calls cannot be nested. + :func:`await_only` calls cannot be nested. :param awaitable: The coroutine to call. @@ -92,8 +89,8 @@ def await_only(awaitable: Awaitable[_T]) -> _T: current = getcurrent() if not isinstance(current, _AsyncIoGreenlet): raise exc.MissingGreenlet( - "greenlet_spawn has not been called; can't call await_() here. " - "Was IO attempted in an unexpected place?" + "greenlet_spawn has not been called; can't call await_only() " + "here. Was IO attempted in an unexpected place?" ) # returns the control to the driver greenlet passing it @@ -107,7 +104,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. The sync method must be inside a :func:`greenlet_spawn` context. - :func:`await_` calls cannot be nested. + :func:`await_fallback` calls cannot be nested. :param awaitable: The coroutine to call. @@ -120,7 +117,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: if loop.is_running(): raise exc.MissingGreenlet( "greenlet_spawn has not been called and asyncio event " - "loop is already running; can't call await_() here. " + "loop is already running; can't call await_fallback() here. " "Was IO attempted in an unexpected place?" ) return loop.run_until_complete(awaitable) # type: ignore[no-any-return] # noqa: E501 @@ -136,7 +133,7 @@ async def greenlet_spawn( ) -> _T: """Runs a sync function ``fn`` in a new greenlet. - The sync function can then use :func:`await_` to wait for async + The sync function can then use :func:`await_only` to wait for async functions. :param fn: The sync callable to call. @@ -144,10 +141,10 @@ async def greenlet_spawn( :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable. """ - result: _T + result: Any context = _AsyncIoGreenlet(fn, getcurrent()) # runs the function synchronously in gl greenlet. If the execution - # is interrupted by await_, context is not dead and result is a + # is interrupted by await_only, context is not dead and result is a # coroutine to wait. If the context is dead the function has # returned, and its result can be returned. switch_occurred = False @@ -156,7 +153,7 @@ async def greenlet_spawn( while not context.dead: switch_occurred = True try: - # wait for a coroutine from await_ and then return its + # wait for a coroutine from await_only and then return its # result back to it. value = await result except BaseException: @@ -175,7 +172,7 @@ async def greenlet_spawn( "detected. This will usually happen when using a non compatible " "DBAPI driver. Please ensure that an async DBAPI is used." ) - return result + return result # type: ignore[no-any-return] class AsyncAdaptedLock: |