diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2023-05-10 15:11:06 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-05-10 15:11:06 +0000 |
| commit | bce61160a9aec321ea0af4a59d4b83ff93a0429f (patch) | |
| tree | 3754696c2d8f97d92507309ef4c13037b7372bf3 /lib/sqlalchemy/util/_concurrency_py3k.py | |
| parent | 987285fb4b13c39bcc6b8922e618d9e830577dda (diff) | |
| parent | 60b31198311eedfa3814e7098c94d3aa29338fdd (diff) | |
| download | sqlalchemy-bce61160a9aec321ea0af4a59d4b83ff93a0429f.tar.gz | |
Merge "fix test suite warnings" into main
Diffstat (limited to 'lib/sqlalchemy/util/_concurrency_py3k.py')
| -rw-r--r-- | lib/sqlalchemy/util/_concurrency_py3k.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index b29cc2119..0e26425b2 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -17,11 +17,13 @@ from typing import Awaitable from typing import Callable from typing import Coroutine from typing import Optional +from typing import TYPE_CHECKING from typing import TypeVar from .langhelpers import memoized_property from .. import exc from ..util.typing import Protocol +from ..util.typing import TypeGuard _T = TypeVar("_T") @@ -78,6 +80,26 @@ class _AsyncIoGreenlet(greenlet): # type: ignore self.gr_context = driver.gr_context +_T_co = TypeVar("_T_co", covariant=True) + +if TYPE_CHECKING: + + def iscoroutine( + awaitable: Awaitable[_T_co], + ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: + ... + +else: + iscoroutine = asyncio.iscoroutine + + +def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None: + # https://docs.python.org/3/reference/datamodel.html#coroutine.close + + if iscoroutine(awaitable): + awaitable.close() + + def await_only(awaitable: Awaitable[_T]) -> _T: """Awaits an async function in a sync method. @@ -90,6 +112,8 @@ def await_only(awaitable: Awaitable[_T]) -> _T: # this is called in the context greenlet while running fn current = getcurrent() if not isinstance(current, _AsyncIoGreenlet): + _safe_cancel_awaitable(awaitable) + raise exc.MissingGreenlet( "greenlet_spawn has not been called; can't call await_only() " "here. Was IO attempted in an unexpected place?" @@ -117,6 +141,9 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: if not isinstance(current, _AsyncIoGreenlet): loop = get_event_loop() if loop.is_running(): + + _safe_cancel_awaitable(awaitable) + raise exc.MissingGreenlet( "greenlet_spawn has not been called and asyncio event " "loop is already running; can't call await_fallback() here. " |
