diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 98 |
1 files changed, 76 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 284343154..d08fef60a 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -17,7 +17,9 @@ from typing import AbstractSet from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import Optional @@ -32,15 +34,15 @@ from . import coercions from . import operators from . import roles from . import visitors +from ._typing import is_text_clause from .annotation import _deep_annotate as _deep_annotate from .annotation import _deep_deannotate as _deep_deannotate from .annotation import _shallow_annotate as _shallow_annotate from .base import _expand_cloned from .base import _from_objects -from .base import ColumnSet -from .cache_key import HasCacheKey # noqa -from .ddl import sort_tables # noqa -from .elements import _find_columns +from .cache_key import HasCacheKey as HasCacheKey +from .ddl import sort_tables as sort_tables +from .elements import _find_columns as _find_columns from .elements import _label_reference from .elements import _textual_label_reference from .elements import BindParameter @@ -67,10 +69,13 @@ from ..util.typing import Protocol if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument + from .elements import TextClause from .roles import FromClauseRole from .selectable import _JoinTargetElement from .selectable import _OnClauseElement + from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType from .visitors import ExternallyTraversible @@ -752,7 +757,29 @@ def splice_joins( return ret -def reduce_columns(columns, *clauses, **kw): +@overload +def reduce_columns( + columns: Iterable[ColumnElement[Any]], + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[ColumnElement[Any]]: + ... + + +@overload +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Sequence[Union[ColumnElement[Any], TextClause]]: + ... + + +def reduce_columns( + columns: _SelectIterable, + *clauses: Optional[ClauseElement], + **kw: bool, +) -> Collection[Union[ColumnElement[Any], TextClause]]: r"""given a list of columns, return a 'reduced' set based on natural equivalents. @@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw): ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) only_synonyms = kw.pop("only_synonyms", False) - columns = util.ordered_column_set(columns) + column_set = util.OrderedSet(columns) + cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference( + c for c in column_set if is_text_clause(c) # type: ignore + ) omit = util.column_set() - for col in columns: + for col in cset_no_text: for fk in chain(*[c.foreign_keys for c in col.proxy_set]): - for c in columns: + for c in cset_no_text: if c is col: continue try: @@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw): def visit_binary(binary): if binary.operator == operators.eq: cols = util.column_set( - chain(*[c.proxy_set for c in columns.difference(omit)]) + chain( + *[c.proxy_set for c in cset_no_text.difference(omit)] + ) ) if binary.left in cols and binary.right in cols: - for c in reversed(columns): + for c in reversed(cset_no_text): if c.shares_lineage(binary.right) and ( not only_synonyms or c.name == binary.left.name ): @@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw): if clause is not None: visitors.traverse(clause, {}, {"binary": visit_binary}) - return ColumnSet(columns.difference(omit)) + return column_set.difference(omit) def criterion_as_pairs( @@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, adapt_on_names: bool = False, @@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): class _ColumnLookup(Protocol): - def __getitem__( - self, key: ColumnElement[Any] - ) -> Optional[ColumnElement[Any]]: + @overload + def __getitem__(self, key: None) -> None: + ... + + @overload + def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: + ... + + @overload + def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: + ... + + @overload + def __getitem__(self, key: _ET) -> _ET: + ... + + def __getitem__(self, key: Any) -> Any: ... @@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter): def __init__( self, selectable: Selectable, - equivalents: Optional[ - Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] - ] = None, + equivalents: Optional[_EquivalentColumnMap] = None, adapt_required: bool = False, include_fn: Optional[Callable[[ClauseElement], bool]] = None, exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, @@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter): return ac - def traverse(self, obj): + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: return self.columns[obj] def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: @@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter): adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process - def adapt_check_present(self, col): + def adapt_check_present( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: newcol = self.columns[col] if newcol is col and self._corresponding_column(col, True) is None: |
