summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py7
-rw-r--r--lib/sqlalchemy/sql/_typing.py21
-rw-r--r--lib/sqlalchemy/sql/base.py17
-rw-r--r--lib/sqlalchemy/sql/coercions.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--lib/sqlalchemy/sql/ddl.py3
-rw-r--r--lib/sqlalchemy/sql/elements.py37
-rw-r--r--lib/sqlalchemy/sql/lambdas.py4
-rw-r--r--lib/sqlalchemy/sql/roles.py11
-rw-r--r--lib/sqlalchemy/sql/schema.py135
-rw-r--r--lib/sqlalchemy/sql/selectable.py4
-rw-r--r--lib/sqlalchemy/sql/util.py98
-rw-r--r--lib/sqlalchemy/sql/visitors.py8
13 files changed, 172 insertions, 176 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index ea21e01c6..605f75ec4 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -389,7 +389,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
def bindparam(
- key: str,
+ key: Optional[str],
value: Any = _NoArg.NO_ARG,
type_: Optional[TypeEngine[_T]] = None,
unique: bool = False,
@@ -521,6 +521,11 @@ def bindparam(
key, or if its length is too long and truncation is
required.
+ If omitted, an "anonymous" name is generated for the bound parameter;
+ when given a value to bind, the end result is equivalent to calling upon
+ the :func:`.literal` function with a value to bind, particularly
+ if the :paramref:`.bindparam.unique` parameter is also provided.
+
:param value:
Initial value for this bind param. Will be used at statement
execution time as the value for this parameter passed to the
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index b0a717a1a..53d29b628 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -2,13 +2,14 @@ from __future__ import annotations
import operator
from typing import Any
+from typing import Callable
from typing import Dict
+from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from sqlalchemy.sql.base import Executable
from . import roles
from .. import util
from ..inspection import Inspectable
@@ -16,6 +17,7 @@ from ..util.typing import Literal
from ..util.typing import Protocol
if TYPE_CHECKING:
+ from .base import Executable
from .compiler import Compiled
from .compiler import DDLCompiler
from .compiler import SQLCompiler
@@ -27,17 +29,20 @@ if TYPE_CHECKING:
from .elements import quoted_name
from .elements import SQLCoreOperations
from .elements import TextClause
+ from .lambdas import LambdaElement
from .roles import ColumnsClauseRole
from .roles import FromClauseRole
from .schema import Column
from .schema import DefaultGenerator
from .schema import Sequence
+ from .schema import Table
from .selectable import Alias
from .selectable import FromClause
from .selectable import Join
from .selectable import NamedFromClause
from .selectable import ReturnsRows
from .selectable import Select
+ from .selectable import Selectable
from .selectable import SelectBase
from .selectable import Subquery
from .selectable import TableClause
@@ -46,7 +51,6 @@ if TYPE_CHECKING:
from .type_api import TypeEngine
from ..util.typing import TypeGuard
-
_T = TypeVar("_T", bound=Any)
@@ -89,7 +93,11 @@ sets; select(...), insert().returning(...), etc.
"""
_ColumnExpressionArgument = Union[
- "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
+ "ColumnElement[_T]",
+ _HasClauseElement,
+ roles.ExpressionElementRole[_T],
+ Callable[[], "ColumnElement[_T]"],
+ "LambdaElement",
]
"""narrower "column expression" argument.
@@ -103,6 +111,7 @@ overall which brings in the TextClause object also.
"""
+
_InfoType = Dict[Any, Any]
"""the .info dictionary accepted and used throughout Core /ORM"""
@@ -169,6 +178,8 @@ _PropagateAttrsType = util.immutabledict[str, Any]
_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
+_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]]
+
if TYPE_CHECKING:
def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]:
@@ -195,6 +206,9 @@ if TYPE_CHECKING:
def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
...
+ def is_selectable(t: Any) -> TypeGuard[Selectable]:
+ ...
+
def is_select_base(
t: Union[Executable, ReturnsRows]
) -> TypeGuard[SelectBase]:
@@ -224,6 +238,7 @@ else:
is_from_clause = operator.attrgetter("_is_from_clause")
is_tuple_type = operator.attrgetter("_is_tuple_type")
is_table_value_type = operator.attrgetter("_is_table_value")
+ is_selectable = operator.attrgetter("is_selectable")
is_select_base = operator.attrgetter("_is_select_base")
is_select_statement = operator.attrgetter("_is_select_statement")
is_table = operator.attrgetter("_is_table")
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index f7692dbc2..f81878d55 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -218,7 +218,7 @@ def _generative(fn: _Fn) -> _Fn:
"""
- @util.decorator
+ @util.decorator # type: ignore
def _generative(
fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
) -> _SelfGenerativeType:
@@ -244,7 +244,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
for name in names
]
- @util.decorator
+ @util.decorator # type: ignore
def check(fn, *args, **kw):
# make pylance happy by not including "self" in the argument
# list
@@ -260,7 +260,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
raise exc.InvalidRequestError(msg)
return fn(self, *args, **kw)
- return check
+ return check # type: ignore
def _clone(element, **kw):
@@ -1750,15 +1750,14 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
self._collection.append((k, col))
self._colset.update(c for (k, c) in self._collection)
- # https://github.com/python/mypy/issues/12610
self._index.update(
- (idx, c) for idx, (k, c) in enumerate(self._collection) # type: ignore # noqa: E501
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
)
for col in replace_col:
self.replace(col)
def extend(self, iter_: Iterable[_NAMEDCOL]) -> None:
- self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501
+ self._populate_separate_keys((col.key, col) for col in iter_)
def remove(self, column: _NAMEDCOL) -> None:
if column not in self._colset:
@@ -1772,9 +1771,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
(k, c) for (k, c) in self._collection if c is not column
]
- # https://github.com/python/mypy/issues/12610
self._index.update(
- {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
)
# delete higher index
del self._index[len(self._collection)]
@@ -1827,9 +1825,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
self._index.clear()
- # https://github.com/python/mypy/issues/12610
self._index.update(
- {idx: col for idx, (k, col) in enumerate(self._collection)} # type: ignore # noqa: E501
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
)
self._index.update(self._collection)
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 4bf45da9c..0659709ab 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -214,6 +214,7 @@ def expect(
Type[roles.ExpressionElementRole[Any]],
Type[roles.LimitOffsetRole],
Type[roles.WhereHavingRole],
+ Type[roles.OnClauseRole],
],
element: Any,
**kw: Any,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 938be0f81..c524a2602 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1078,7 +1078,7 @@ class SQLCompiler(Compiled):
return list(self.insert_prefetch) + list(self.update_prefetch)
@util.memoized_property
- def _global_attributes(self):
+ def _global_attributes(self) -> Dict[Any, Any]:
return {}
@util.memoized_instancemethod
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 6ac7c2448..052af6ac9 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -14,6 +14,7 @@ from __future__ import annotations
import typing
from typing import Any
from typing import Callable
+from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence as typing_Sequence
@@ -1143,7 +1144,7 @@ class SchemaDropper(InvokeDDLBase):
def sort_tables(
- tables: typing_Sequence["Table"],
+ tables: Iterable["Table"],
skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None,
extra_dependencies: Optional[
typing_Sequence[Tuple["Table", "Table"]]
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index ea0fa7996..34d5127ab 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -293,11 +293,18 @@ class ClauseElement(
__visit_name__ = "clause"
- _propagate_attrs: _PropagateAttrsType = util.immutabledict()
- """like annotations, however these propagate outwards liberally
- as SQL constructs are built, and are set up at construction time.
+ if TYPE_CHECKING:
- """
+ @util.memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ """like annotations, however these propagate outwards liberally
+ as SQL constructs are built, and are set up at construction time.
+
+ """
+ ...
+
+ else:
+ _propagate_attrs = util.EMPTY_DICT
@util.ro_memoized_property
def description(self) -> Optional[str]:
@@ -343,7 +350,9 @@ class ClauseElement(
def _from_objects(self) -> List[FromClause]:
return []
- def _set_propagate_attrs(self, values):
+ def _set_propagate_attrs(
+ self: SelfClauseElement, values: Mapping[str, Any]
+ ) -> SelfClauseElement:
# usually, self._propagate_attrs is empty here. one case where it's
# not is a subquery against ORM select, that is then pulled as a
# property of an aliased class. should all be good
@@ -526,13 +535,10 @@ class ClauseElement(
if unique:
bind._convert_to_unique()
- return cast(
- SelfClauseElement,
- cloned_traverse(
- self,
- {"maintain_key": True, "detect_subquery_cols": True},
- {"bindparam": visit_bindparam},
- ),
+ return cloned_traverse(
+ self,
+ {"maintain_key": True, "detect_subquery_cols": True},
+ {"bindparam": visit_bindparam},
)
def compare(self, other, **kw):
@@ -730,7 +736,9 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
# redefined with the specific types returned by ColumnElement hierarchies
if typing.TYPE_CHECKING:
- _propagate_attrs: _PropagateAttrsType
+ @util.non_memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ ...
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
@@ -2064,10 +2072,11 @@ class TextClause(
roles.OrderByRole,
roles.FromClauseRole,
roles.SelectStatementRole,
- roles.BinaryElementRole[Any],
roles.InElementRole,
Executable,
DQLDMLClauseElement,
+ roles.BinaryElementRole[Any],
+ inspection.Inspectable["TextClause"],
):
"""Represent a literal SQL text fragment.
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index da15c305f..4b220188f 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -444,7 +444,7 @@ class DeferredLambdaElement(LambdaElement):
def _invoke_user_fn(self, fn, *arg):
return fn(*self.lambda_args)
- def _resolve_with_args(self, *lambda_args):
+ def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:
assert isinstance(self._rec, AnalyzedFunction)
tracker_fn = self._rec.tracker_instrumented_fn
expr = tracker_fn(*lambda_args)
@@ -478,7 +478,7 @@ class DeferredLambdaElement(LambdaElement):
for deferred_copy_internals in self._transforms:
expr = deferred_copy_internals(expr)
- return expr
+ return expr # type: ignore
def _copy_internals(
self, clone=_clone, deferred_copy_internals=None, **kw
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 577d868fd..231c70a5b 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -22,9 +22,7 @@ if TYPE_CHECKING:
from .base import _EntityNamespace
from .base import ColumnCollection
from .base import ReadOnlyColumnCollection
- from .elements import ClauseElement
from .elements import ColumnClause
- from .elements import ColumnElement
from .elements import Label
from .elements import NamedColumn
from .selectable import _SelectIterable
@@ -271,7 +269,14 @@ class StatementRole(SQLRole):
__slots__ = ()
_role_name = "Executable SQL or text() construct"
- _propagate_attrs: _PropagateAttrsType = util.immutabledict()
+ if TYPE_CHECKING:
+
+ @util.memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ ...
+
+ else:
+ _propagate_attrs = util.EMPTY_DICT
class SelectStatementRole(StatementRole, ReturnsRowsRole):
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 92b9cc62c..52ba60a62 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -144,9 +144,9 @@ class SchemaConst(Enum):
NULL_UNSPECIFIED = 3
"""Symbol indicating the "nullable" keyword was not passed to a Column.
- Normally we would expect None to be acceptable for this but some backends
- such as that of SQL Server place special signficance on a "nullability"
- value of None.
+ This is used to distinguish between the use case of passing
+ ``nullable=None`` to a :class:`.Column`, which has special meaning
+ on some backends such as SQL Server.
"""
@@ -308,7 +308,9 @@ class HasSchemaAttr(SchemaItem):
schema: Optional[str]
-class Table(DialectKWArgs, HasSchemaAttr, TableClause):
+class Table(
+ DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"]
+):
r"""Represent a table in a database.
e.g.::
@@ -1318,117 +1320,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
inherit_cache = True
key: str
- @overload
- def __init__(
- self,
- *args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- @overload
- def __init__(
- self,
- __name: str,
- *args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- @overload
def __init__(
self,
- __type: _TypeEngineArgument[_T],
- *args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- @overload
- def __init__(
- self,
- __name: str,
- __type: _TypeEngineArgument[_T],
+ __name_pos: Optional[
+ Union[str, _TypeEngineArgument[_T], SchemaEventTarget]
+ ] = None,
+ __type_pos: Optional[
+ Union[_TypeEngineArgument[_T], SchemaEventTarget]
+ ] = None,
*args: SchemaEventTarget,
- autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
- doc: Optional[str] = None,
- key: Optional[str] = None,
- index: Optional[bool] = None,
- unique: Optional[bool] = None,
- info: Optional[_InfoType] = None,
- nullable: Optional[
- Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
- onupdate: Optional[Any] = None,
- primary_key: bool = False,
- server_default: Optional[_ServerDefaultType] = None,
- server_onupdate: Optional[FetchedValue] = None,
- quote: Optional[bool] = None,
- system: bool = False,
- comment: Optional[str] = None,
- _proxies: Optional[Any] = None,
- **dialect_kwargs: Any,
- ):
- ...
-
- def __init__(
- self,
- *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget],
name: Optional[str] = None,
type_: Optional[_TypeEngineArgument[_T]] = None,
autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
@@ -1440,7 +1340,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
info: Optional[_InfoType] = None,
nullable: Optional[
Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
- ] = NULL_UNSPECIFIED,
+ ] = SchemaConst.NULL_UNSPECIFIED,
onupdate: Optional[Any] = None,
primary_key: bool = False,
server_default: Optional[_ServerDefaultType] = None,
@@ -1953,7 +1853,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
""" # noqa: E501, RST201, RST202
- l_args = list(args)
+ l_args = [__name_pos, __type_pos] + list(args)
del args
if l_args:
@@ -1963,6 +1863,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
"May not pass name positionally and as a keyword."
)
name = l_args.pop(0) # type: ignore
+ elif l_args[0] is None:
+ l_args.pop(0)
if l_args:
coltype = l_args[0]
@@ -1972,6 +1874,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
"May not pass type_ positionally and as a keyword."
)
type_ = l_args.pop(0) # type: ignore
+ elif l_args[0] is None:
+ l_args.pop(0)
if name is not None:
name = quoted_name(name, quote)
@@ -1989,7 +1893,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
self.primary_key = primary_key
self._user_defined_nullable = udn = nullable
-
if udn is not NULL_UNSPECIFIED:
self.nullable = udn
else:
@@ -5128,7 +5031,7 @@ class MetaData(HasSchemaAttr):
def clear(self) -> None:
"""Clear all Table objects from this MetaData."""
- dict.clear(self.tables)
+ dict.clear(self.tables) # type: ignore
self._schemas.clear()
self._fk_memos.clear()
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index aab3c678c..9d4d1d6c7 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -1223,7 +1223,9 @@ class Join(roles.DMLTableRole, FromClause):
@util.preload_module("sqlalchemy.sql.util")
def _populate_column_collection(self):
sqlutil = util.preloaded.sql_util
- columns = [c for c in self.left.c] + [c for c in self.right.c]
+ columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [
+ c for c in self.right.c
+ ]
self.primary_key.extend( # type: ignore
sqlutil.reduce_columns(
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 284343154..d08fef60a 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -17,7 +17,9 @@ from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
+from typing import Collection
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
@@ -32,15 +34,15 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
+from ._typing import is_text_clause
from .annotation import _deep_annotate as _deep_annotate
from .annotation import _deep_deannotate as _deep_deannotate
from .annotation import _shallow_annotate as _shallow_annotate
from .base import _expand_cloned
from .base import _from_objects
-from .base import ColumnSet
-from .cache_key import HasCacheKey # noqa
-from .ddl import sort_tables # noqa
-from .elements import _find_columns
+from .cache_key import HasCacheKey as HasCacheKey
+from .ddl import sort_tables as sort_tables
+from .elements import _find_columns as _find_columns
from .elements import _label_reference
from .elements import _textual_label_reference
from .elements import BindParameter
@@ -67,10 +69,13 @@ from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
+ from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import TextClause
from .roles import FromClauseRole
from .selectable import _JoinTargetElement
from .selectable import _OnClauseElement
+ from .selectable import _SelectIterable
from .selectable import Selectable
from .visitors import _TraverseCallableType
from .visitors import ExternallyTraversible
@@ -752,7 +757,29 @@ def splice_joins(
return ret
-def reduce_columns(columns, *clauses, **kw):
+@overload
+def reduce_columns(
+ columns: Iterable[ColumnElement[Any]],
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[ColumnElement[Any]]:
+ ...
+
+
+@overload
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Sequence[Union[ColumnElement[Any], TextClause]]:
+ ...
+
+
+def reduce_columns(
+ columns: _SelectIterable,
+ *clauses: Optional[ClauseElement],
+ **kw: bool,
+) -> Collection[Union[ColumnElement[Any], TextClause]]:
r"""given a list of columns, return a 'reduced' set based on natural
equivalents.
@@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw):
ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
only_synonyms = kw.pop("only_synonyms", False)
- columns = util.ordered_column_set(columns)
+ column_set = util.OrderedSet(columns)
+ cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
+ c for c in column_set if is_text_clause(c) # type: ignore
+ )
omit = util.column_set()
- for col in columns:
+ for col in cset_no_text:
for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
- for c in columns:
+ for c in cset_no_text:
if c is col:
continue
try:
@@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw):
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)])
+ chain(
+ *[c.proxy_set for c in cset_no_text.difference(omit)]
+ )
)
if binary.left in cols and binary.right in cols:
- for c in reversed(columns):
+ for c in reversed(cset_no_text):
if c.shares_lineage(binary.right) and (
not only_synonyms or c.name == binary.left.name
):
@@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw):
if clause is not None:
visitors.traverse(clause, {}, {"binary": visit_binary})
- return ColumnSet(columns.difference(omit))
+ return column_set.difference(omit)
def criterion_as_pairs(
@@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
adapt_on_names: bool = False,
@@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
class _ColumnLookup(Protocol):
- def __getitem__(
- self, key: ColumnElement[Any]
- ) -> Optional[ColumnElement[Any]]:
+ @overload
+ def __getitem__(self, key: None) -> None:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]:
+ ...
+
+ @overload
+ def __getitem__(self, key: _ET) -> _ET:
+ ...
+
+ def __getitem__(self, key: Any) -> Any:
...
@@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter):
def __init__(
self,
selectable: Selectable,
- equivalents: Optional[
- Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
- ] = None,
+ equivalents: Optional[_EquivalentColumnMap] = None,
adapt_required: bool = False,
include_fn: Optional[Callable[[ClauseElement], bool]] = None,
exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
@@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter):
return ac
- def traverse(self, obj):
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
return self.columns[obj]
def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
@@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter):
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
- def adapt_check_present(self, col):
+ def adapt_check_present(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
newcol = self.columns[col]
if newcol is col and self._corresponding_column(col, True) is None:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 7363f9ddc..e0a66fbcf 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -961,12 +961,16 @@ def cloned_traverse(
...
+# a bit of controversy here, as the clone of the lead element
+# *could* in theory replace with an entirely different kind of element.
+# however this is really not how cloned_traverse is ever used internally
+# at least.
@overload
def cloned_traverse(
- obj: ExternallyTraversible,
+ obj: _ET,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
-) -> ExternallyTraversible:
+) -> _ET:
...