diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 73 |
1 files changed, 48 insertions, 25 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index fa3bae835..c0de1902f 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -8,14 +8,20 @@ """High level utilities which build upon other modules here. """ - from collections import deque from itertools import chain +import typing +from typing import Any +from typing import cast +from typing import Optional from . import coercions from . import operators from . import roles from . import visitors +from ._typing import _ExecuteParams +from ._typing import _MultiExecuteParams +from ._typing import _SingleExecuteParams from .annotation import _deep_annotate # noqa from .annotation import _deep_deannotate # noqa from .annotation import _shallow_annotate # noqa @@ -45,6 +51,9 @@ from .selectable import TableClause from .. import exc from .. import util +if typing.TYPE_CHECKING: + from ..engine.row import Row + def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None): """Create a join condition between two tables or selectables. @@ -488,13 +497,13 @@ def _quote_ddl_expr(element): class _repr_base: - _LIST = 0 - _TUPLE = 1 - _DICT = 2 + _LIST: int = 0 + _TUPLE: int = 1 + _DICT: int = 2 __slots__ = ("max_chars",) - def trunc(self, value): + def trunc(self, value: Any) -> str: rep = repr(value) lenrep = len(rep) if lenrep > self.max_chars: @@ -515,11 +524,11 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row, max_chars=300): + def __init__(self, row: "Row", max_chars: int = 300): self.row = row self.max_chars = max_chars - def __repr__(self): + def __repr__(self) -> str: trunc = self.trunc return "(%s%s)" % ( ", ".join(trunc(value) for value in self.row), @@ -537,13 +546,19 @@ class _repr_params(_repr_base): __slots__ = "params", "batches", "ismulti" - def __init__(self, params, batches, max_chars=300, ismulti=None): - self.params = params + def __init__( + self, + params: _ExecuteParams, + batches: int, + max_chars: int = 300, + ismulti: Optional[bool] = None, + ): + self.params: _ExecuteParams = params self.ismulti = ismulti self.batches = batches self.max_chars = max_chars - def __repr__(self): + def __repr__(self) -> str: if self.ismulti is None: return self.trunc(self.params) @@ -557,23 +572,31 @@ class _repr_params(_repr_base): else: return self.trunc(self.params) - if self.ismulti and len(self.params) > self.batches: - msg = " ... displaying %i of %i total bound parameter sets ... " - return " ".join( - ( - self._repr_multi(self.params[: self.batches - 2], typ)[ - 0:-1 - ], - msg % (self.batches, len(self.params)), - self._repr_multi(self.params[-2:], typ)[1:], + if self.ismulti: + multi_params = cast(_MultiExecuteParams, self.params) + + if len(self.params) > self.batches: + msg = ( + " ... displaying %i of %i total bound parameter sets ... " ) - ) - elif self.ismulti: - return self._repr_multi(self.params, typ) + return " ".join( + ( + self._repr_multi( + multi_params[: self.batches - 2], + typ, + )[0:-1], + msg % (self.batches, len(self.params)), + self._repr_multi(multi_params[-2:], typ)[1:], + ) + ) + else: + return self._repr_multi(multi_params, typ) else: - return self._repr_params(self.params, typ) + return self._repr_params( + cast(_SingleExecuteParams, self.params), typ + ) - def _repr_multi(self, multi_params, typ): + def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str: if multi_params: if isinstance(multi_params[0], list): elem_type = self._LIST @@ -597,7 +620,7 @@ class _repr_params(_repr_base): else: return "(%s)" % elements - def _repr_params(self, params, typ): + def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str: trunc = self.trunc if typ is self._DICT: return "{%s}" % ( |
