summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py73
1 files changed, 48 insertions, 25 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index fa3bae835..c0de1902f 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -8,14 +8,20 @@
"""High level utilities which build upon other modules here.
"""
-
from collections import deque
from itertools import chain
+import typing
+from typing import Any
+from typing import cast
+from typing import Optional
from . import coercions
from . import operators
from . import roles
from . import visitors
+from ._typing import _ExecuteParams
+from ._typing import _MultiExecuteParams
+from ._typing import _SingleExecuteParams
from .annotation import _deep_annotate # noqa
from .annotation import _deep_deannotate # noqa
from .annotation import _shallow_annotate # noqa
@@ -45,6 +51,9 @@ from .selectable import TableClause
from .. import exc
from .. import util
+if typing.TYPE_CHECKING:
+ from ..engine.row import Row
+
def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
"""Create a join condition between two tables or selectables.
@@ -488,13 +497,13 @@ def _quote_ddl_expr(element):
class _repr_base:
- _LIST = 0
- _TUPLE = 1
- _DICT = 2
+ _LIST: int = 0
+ _TUPLE: int = 1
+ _DICT: int = 2
__slots__ = ("max_chars",)
- def trunc(self, value):
+ def trunc(self, value: Any) -> str:
rep = repr(value)
lenrep = len(rep)
if lenrep > self.max_chars:
@@ -515,11 +524,11 @@ class _repr_row(_repr_base):
__slots__ = ("row",)
- def __init__(self, row, max_chars=300):
+ def __init__(self, row: "Row", max_chars: int = 300):
self.row = row
self.max_chars = max_chars
- def __repr__(self):
+ def __repr__(self) -> str:
trunc = self.trunc
return "(%s%s)" % (
", ".join(trunc(value) for value in self.row),
@@ -537,13 +546,19 @@ class _repr_params(_repr_base):
__slots__ = "params", "batches", "ismulti"
- def __init__(self, params, batches, max_chars=300, ismulti=None):
- self.params = params
+ def __init__(
+ self,
+ params: _ExecuteParams,
+ batches: int,
+ max_chars: int = 300,
+ ismulti: Optional[bool] = None,
+ ):
+ self.params: _ExecuteParams = params
self.ismulti = ismulti
self.batches = batches
self.max_chars = max_chars
- def __repr__(self):
+ def __repr__(self) -> str:
if self.ismulti is None:
return self.trunc(self.params)
@@ -557,23 +572,31 @@ class _repr_params(_repr_base):
else:
return self.trunc(self.params)
- if self.ismulti and len(self.params) > self.batches:
- msg = " ... displaying %i of %i total bound parameter sets ... "
- return " ".join(
- (
- self._repr_multi(self.params[: self.batches - 2], typ)[
- 0:-1
- ],
- msg % (self.batches, len(self.params)),
- self._repr_multi(self.params[-2:], typ)[1:],
+ if self.ismulti:
+ multi_params = cast(_MultiExecuteParams, self.params)
+
+ if len(self.params) > self.batches:
+ msg = (
+ " ... displaying %i of %i total bound parameter sets ... "
)
- )
- elif self.ismulti:
- return self._repr_multi(self.params, typ)
+ return " ".join(
+ (
+ self._repr_multi(
+ multi_params[: self.batches - 2],
+ typ,
+ )[0:-1],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(multi_params[-2:], typ)[1:],
+ )
+ )
+ else:
+ return self._repr_multi(multi_params, typ)
else:
- return self._repr_params(self.params, typ)
+ return self._repr_params(
+ cast(_SingleExecuteParams, self.params), typ
+ )
- def _repr_multi(self, multi_params, typ):
+ def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str:
if multi_params:
if isinstance(multi_params[0], list):
elem_type = self._LIST
@@ -597,7 +620,7 @@ class _repr_params(_repr_base):
else:
return "(%s)" % elements
- def _repr_params(self, params, typ):
+ def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str:
trunc = self.trunc
if typ is self._DICT:
return "{%s}" % (