diff options
Diffstat (limited to 'lib/sqlalchemy/engine/util.py')
| -rw-r--r-- | lib/sqlalchemy/engine/util.py | 55 |
1 files changed, 39 insertions, 16 deletions
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index f9ee65bef..213485cc9 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -7,18 +7,30 @@ from __future__ import annotations +import typing +from typing import Any +from typing import Callable +from typing import TypeVar + from .. import exc from .. import util +from ..util._has_cy import HAS_CYEXTENSION + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_util import _distill_params_20 as _distill_params_20 + from ._py_util import _distill_raw_params as _distill_raw_params +else: + from sqlalchemy.cyextension.util import ( + _distill_params_20 as _distill_params_20, + ) + from sqlalchemy.cyextension.util import ( + _distill_raw_params as _distill_raw_params, + ) -try: - from sqlalchemy.cyextension.util import _distill_params_20 # noqa - from sqlalchemy.cyextension.util import _distill_raw_params # noqa -except ImportError: - from ._py_util import _distill_params_20 # noqa - from ._py_util import _distill_raw_params # noqa +_C = TypeVar("_C", bound=Callable[[], Any]) -def connection_memoize(key): +def connection_memoize(key: str) -> Callable[[_C], _C]: """Decorator, memoize a function in a connection.info stash. Only applicable to functions which take no arguments other than a @@ -26,7 +38,7 @@ def connection_memoize(key): """ @util.decorator - def decorated(fn, self, connection): + def decorated(fn, self, connection): # type: ignore connection = connection.connect() try: return connection.info[key] @@ -34,7 +46,7 @@ def connection_memoize(key): connection.info[key] = val = fn(self, connection) return val - return decorated + return decorated # type: ignore[return-value] class TransactionalContext: @@ -47,13 +59,13 @@ class TransactionalContext: __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__") - def _transaction_is_active(self): + def _transaction_is_active(self) -> bool: raise NotImplementedError() - def _transaction_is_closed(self): + def _transaction_is_closed(self) -> bool: raise NotImplementedError() - def _rollback_can_be_called(self): + def _rollback_can_be_called(self) -> bool: """indicates the object is in a state that is known to be acceptable for rollback() to be called. @@ -70,11 +82,20 @@ class TransactionalContext: """ raise NotImplementedError() - def _get_subject(self): + def _get_subject(self) -> Any: + raise NotImplementedError() + + def commit(self) -> None: + raise NotImplementedError() + + def rollback(self) -> None: + raise NotImplementedError() + + def close(self) -> None: raise NotImplementedError() @classmethod - def _trans_ctx_check(cls, subject): + def _trans_ctx_check(cls, subject: Any) -> None: trans_context = subject._trans_context_manager if trans_context: if not trans_context._transaction_is_active(): @@ -84,7 +105,7 @@ class TransactionalContext: "before emitting further commands." ) - def __enter__(self): + def __enter__(self) -> TransactionalContext: subject = self._get_subject() # none for outer transaction, may be non-None for nested @@ -96,7 +117,7 @@ class TransactionalContext: subject._trans_context_manager = self return self - def __exit__(self, type_, value, traceback): + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: subject = getattr(self, "_trans_subject", None) # simplistically we could assume that @@ -119,6 +140,7 @@ class TransactionalContext: self.rollback() finally: if not out_of_band_exit: + assert subject is not None subject._trans_context_manager = self._outer_trans_ctx self._trans_subject = self._outer_trans_ctx = None else: @@ -131,5 +153,6 @@ class TransactionalContext: self.rollback() finally: if not out_of_band_exit: + assert subject is not None subject._trans_context_manager = self._outer_trans_ctx self._trans_subject = self._outer_trans_ctx = None |
