summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/util.py')
-rw-r--r--lib/sqlalchemy/engine/util.py55
1 files changed, 39 insertions, 16 deletions
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py
index f9ee65bef..213485cc9 100644
--- a/lib/sqlalchemy/engine/util.py
+++ b/lib/sqlalchemy/engine/util.py
@@ -7,18 +7,30 @@
from __future__ import annotations
+import typing
+from typing import Any
+from typing import Callable
+from typing import TypeVar
+
from .. import exc
from .. import util
+from ..util._has_cy import HAS_CYEXTENSION
+
+if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
+ from ._py_util import _distill_params_20 as _distill_params_20
+ from ._py_util import _distill_raw_params as _distill_raw_params
+else:
+ from sqlalchemy.cyextension.util import (
+ _distill_params_20 as _distill_params_20,
+ )
+ from sqlalchemy.cyextension.util import (
+ _distill_raw_params as _distill_raw_params,
+ )
-try:
- from sqlalchemy.cyextension.util import _distill_params_20 # noqa
- from sqlalchemy.cyextension.util import _distill_raw_params # noqa
-except ImportError:
- from ._py_util import _distill_params_20 # noqa
- from ._py_util import _distill_raw_params # noqa
+_C = TypeVar("_C", bound=Callable[[], Any])
-def connection_memoize(key):
+def connection_memoize(key: str) -> Callable[[_C], _C]:
"""Decorator, memoize a function in a connection.info stash.
Only applicable to functions which take no arguments other than a
@@ -26,7 +38,7 @@ def connection_memoize(key):
"""
@util.decorator
- def decorated(fn, self, connection):
+ def decorated(fn, self, connection): # type: ignore
connection = connection.connect()
try:
return connection.info[key]
@@ -34,7 +46,7 @@ def connection_memoize(key):
connection.info[key] = val = fn(self, connection)
return val
- return decorated
+ return decorated # type: ignore[return-value]
class TransactionalContext:
@@ -47,13 +59,13 @@ class TransactionalContext:
__slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__")
- def _transaction_is_active(self):
+ def _transaction_is_active(self) -> bool:
raise NotImplementedError()
- def _transaction_is_closed(self):
+ def _transaction_is_closed(self) -> bool:
raise NotImplementedError()
- def _rollback_can_be_called(self):
+ def _rollback_can_be_called(self) -> bool:
"""indicates the object is in a state that is known to be acceptable
for rollback() to be called.
@@ -70,11 +82,20 @@ class TransactionalContext:
"""
raise NotImplementedError()
- def _get_subject(self):
+ def _get_subject(self) -> Any:
+ raise NotImplementedError()
+
+ def commit(self) -> None:
+ raise NotImplementedError()
+
+ def rollback(self) -> None:
+ raise NotImplementedError()
+
+ def close(self) -> None:
raise NotImplementedError()
@classmethod
- def _trans_ctx_check(cls, subject):
+ def _trans_ctx_check(cls, subject: Any) -> None:
trans_context = subject._trans_context_manager
if trans_context:
if not trans_context._transaction_is_active():
@@ -84,7 +105,7 @@ class TransactionalContext:
"before emitting further commands."
)
- def __enter__(self):
+ def __enter__(self) -> TransactionalContext:
subject = self._get_subject()
# none for outer transaction, may be non-None for nested
@@ -96,7 +117,7 @@ class TransactionalContext:
subject._trans_context_manager = self
return self
- def __exit__(self, type_, value, traceback):
+ def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
subject = getattr(self, "_trans_subject", None)
# simplistically we could assume that
@@ -119,6 +140,7 @@ class TransactionalContext:
self.rollback()
finally:
if not out_of_band_exit:
+ assert subject is not None
subject._trans_context_manager = self._outer_trans_ctx
self._trans_subject = self._outer_trans_ctx = None
else:
@@ -131,5 +153,6 @@ class TransactionalContext:
self.rollback()
finally:
if not out_of_band_exit:
+ assert subject is not None
subject._trans_context_manager = self._outer_trans_ctx
self._trans_subject = self._outer_trans_ctx = None