summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py98
1 files changed, 76 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 284343154..d08fef60a 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -17,7 +17,9 @@ from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
+from typing import Collection
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
@@ -32,15 +34,15 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
+from ._typing import is_text_clause
from .annotation import _deep_annotate as _deep_annotate
from .annotation import _deep_deannotate as _deep_deannotate
from .annotation import _shallow_annotate as _shallow_annotate
from .base import _expand_cloned
from .base import _from_objects
-from .base import ColumnSet
-from .cache_key import HasCacheKey # noqa
-from .ddl import sort_tables # noqa
-from .elements import _find_columns
+from .cache_key import HasCacheKey as HasCacheKey
+from .ddl import sort_tables as sort_tables
+from .elements import _find_columns as _find_columns
from .elements import _label_reference
from .elements import _textual_label_reference
from .elements import BindParameter
@@ -67,10 +69,13 @@ from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
+ from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import TextClause
from .roles import FromClauseRole
from .selectable import _JoinTargetElement
from .selectable import _OnClauseElement
+ from .selectable import _SelectIterable
from .selectable import Selectable
from .visitors import _TraverseCallableType
from .visitors import ExternallyTraversible
@@ -752,7 +757,29 @@ def splice_joins(
return ret
-def reduce_columns(columns, *clauses, **kw):
+@overload
+def reduce_columns(
+ columns: Iterable[ColumnElement[Any]],
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[ColumnElement[Any]]:
+ ...
+
+
+@overload
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[Union[ColumnElement[Any], TextClause]]:
+ ...
+
+
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Collection[Union[ColumnElement[Any], TextClause]]:
r"""given a list of columns, return a 'reduced' set based on natural
equivalents.
@@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw):
ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
only_synonyms = kw.pop("only_synonyms", False)
- columns = util.ordered_column_set(columns)
+ column_set = util.OrderedSet(columns)
+ cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
+ c for c in column_set if is_text_clause(c) # type: ignore
+ )
omit = util.column_set()
- for col in columns:
+ for col in cset_no_text:
for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
- for c in columns:
+ for c in cset_no_text:
if c is col:
continue
try:
@@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw):
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)])
+ chain(
+ *[c.proxy_set for c in cset_no_text.difference(omit)]
+ )
)
if binary.left in cols and binary.right in cols:
- for c in reversed(columns):
+ for c in reversed(cset_no_text):
if c.shares_lineage(binary.right) and (
not only_synonyms or c.name == binary.left.name
):
@@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw):
if clause is not None:
visitors.traverse(clause, {}, {"binary": visit_binary})
- return ColumnSet(columns.difference(omit))
+ return column_set.difference(omit)
def criterion_as_pairs(
@@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
adapt_on_names: bool = False,
@@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
class _ColumnLookup(Protocol):
- def __getitem__(
- self, key: ColumnElement[Any]
- ) -> Optional[ColumnElement[Any]]:
+ @overload
+ def __getitem__(self, key: None) -> None:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: _ET) -> _ET:
+ ...
+
+ def __getitem__(self, key: Any) -> Any:
...
@@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
adapt_required: bool = False,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
@@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter):
return ac
- def traverse(self, obj):
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
return self.columns[obj]
def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
@@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter):
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
- def adapt_check_present(self, col):
+ def adapt_check_present(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
newcol = self.columns[col]
if newcol is col and self._corresponding_column(col, True) is None: