diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/_elements_constructors.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 135 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 98 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 8 | 
13 files changed, 172 insertions, 176 deletions
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index ea21e01c6..605f75ec4 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -389,7 +389,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:  def bindparam( -    key: str, +    key: Optional[str],      value: Any = _NoArg.NO_ARG,      type_: Optional[TypeEngine[_T]] = None,      unique: bool = False, @@ -521,6 +521,11 @@ def bindparam(        key, or if its length is too long and truncation is        required. +      If omitted, an "anonymous" name is generated for the bound parameter; +      when given a value to bind, the end result is equivalent to calling upon +      the :func:`.literal` function with a value to bind, particularly +      if the :paramref:`.bindparam.unique` parameter is also provided. +      :param value:        Initial value for this bind param.  Will be used at statement        execution time as the value for this parameter passed to the diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b0a717a1a..53d29b628 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -2,13 +2,14 @@ from __future__ import annotations  import operator  from typing import Any +from typing import Callable  from typing import Dict +from typing import Set  from typing import Type  from typing import TYPE_CHECKING  from typing import TypeVar  from typing import Union -from sqlalchemy.sql.base import Executable  from . import roles  from .. import util  from ..inspection import Inspectable @@ -16,6 +17,7 @@ from ..util.typing import Literal  from ..util.typing import Protocol  if TYPE_CHECKING: +    from .base import Executable      from .compiler import Compiled      from .compiler import DDLCompiler      from .compiler import SQLCompiler @@ -27,17 +29,20 @@ if TYPE_CHECKING:      from .elements import quoted_name      from .elements import SQLCoreOperations      from .elements import TextClause +    from .lambdas import LambdaElement      from .roles import ColumnsClauseRole      from .roles import FromClauseRole      from .schema import Column      from .schema import DefaultGenerator      from .schema import Sequence +    from .schema import Table      from .selectable import Alias      from .selectable import FromClause      from .selectable import Join      from .selectable import NamedFromClause      from .selectable import ReturnsRows      from .selectable import Select +    from .selectable import Selectable      from .selectable import SelectBase      from .selectable import Subquery      from .selectable import TableClause @@ -46,7 +51,6 @@ if TYPE_CHECKING:      from .type_api import TypeEngine      from ..util.typing import TypeGuard -  _T = TypeVar("_T", bound=Any) @@ -89,7 +93,11 @@ sets; select(...), insert().returning(...), etc.  """  _ColumnExpressionArgument = Union[ -    "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] +    "ColumnElement[_T]", +    _HasClauseElement, +    roles.ExpressionElementRole[_T], +    Callable[[], "ColumnElement[_T]"], +    "LambdaElement",  ]  """narrower "column expression" argument. @@ -103,6 +111,7 @@ overall which brings in the TextClause object also.  """ +  _InfoType = Dict[Any, Any]  """the .info dictionary accepted and used throughout Core /ORM""" @@ -169,6 +178,8 @@ _PropagateAttrsType = util.immutabledict[str, Any]  _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] +_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]] +  if TYPE_CHECKING:      def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: @@ -195,6 +206,9 @@ if TYPE_CHECKING:      def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:          ... +    def is_selectable(t: Any) -> TypeGuard[Selectable]: +        ... +      def is_select_base(          t: Union[Executable, ReturnsRows]      ) -> TypeGuard[SelectBase]: @@ -224,6 +238,7 @@ else:      is_from_clause = operator.attrgetter("_is_from_clause")      is_tuple_type = operator.attrgetter("_is_tuple_type")      is_table_value_type = operator.attrgetter("_is_table_value") +    is_selectable = operator.attrgetter("is_selectable")      is_select_base = operator.attrgetter("_is_select_base")      is_select_statement = operator.attrgetter("_is_select_statement")      is_table = operator.attrgetter("_is_table") diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f7692dbc2..f81878d55 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -218,7 +218,7 @@ def _generative(fn: _Fn) -> _Fn:      """ -    @util.decorator +    @util.decorator  # type: ignore      def _generative(          fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any      ) -> _SelfGenerativeType: @@ -244,7 +244,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:          for name in names      ] -    @util.decorator +    @util.decorator  # type: ignore      def check(fn, *args, **kw):          # make pylance happy by not including "self" in the argument          # list @@ -260,7 +260,7 @@ def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:                  raise exc.InvalidRequestError(msg)          return fn(self, *args, **kw) -    return check +    return check  # type: ignore  def _clone(element, **kw): @@ -1750,15 +1750,14 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):                  self._collection.append((k, col))          self._colset.update(c for (k, c) in self._collection) -        # https://github.com/python/mypy/issues/12610          self._index.update( -            (idx, c) for idx, (k, c) in enumerate(self._collection)  # type: ignore  # noqa: E501 +            (idx, c) for idx, (k, c) in enumerate(self._collection)          )          for col in replace_col:              self.replace(col)      def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: -        self._populate_separate_keys((col.key, col) for col in iter_)  # type: ignore  # noqa: E501 +        self._populate_separate_keys((col.key, col) for col in iter_)      def remove(self, column: _NAMEDCOL) -> None:          if column not in self._colset: @@ -1772,9 +1771,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):              (k, c) for (k, c) in self._collection if c is not column          ] -        # https://github.com/python/mypy/issues/12610          self._index.update( -            {idx: col for idx, (k, col) in enumerate(self._collection)}  # type: ignore  # noqa: E501 +            {idx: col for idx, (k, col) in enumerate(self._collection)}          )          # delete higher index          del self._index[len(self._collection)] @@ -1827,9 +1825,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):          self._index.clear() -        # https://github.com/python/mypy/issues/12610          self._index.update( -            {idx: col for idx, (k, col) in enumerate(self._collection)}  # type: ignore  # noqa: E501 +            {idx: col for idx, (k, col) in enumerate(self._collection)}          )          self._index.update(self._collection) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 4bf45da9c..0659709ab 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -214,6 +214,7 @@ def expect(          Type[roles.ExpressionElementRole[Any]],          Type[roles.LimitOffsetRole],          Type[roles.WhereHavingRole], +        Type[roles.OnClauseRole],      ],      element: Any,      **kw: Any, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 938be0f81..c524a2602 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1078,7 +1078,7 @@ class SQLCompiler(Compiled):          return list(self.insert_prefetch) + list(self.update_prefetch)      @util.memoized_property -    def _global_attributes(self): +    def _global_attributes(self) -> Dict[Any, Any]:          return {}      @util.memoized_instancemethod diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 6ac7c2448..052af6ac9 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -14,6 +14,7 @@ from __future__ import annotations  import typing  from typing import Any  from typing import Callable +from typing import Iterable  from typing import List  from typing import Optional  from typing import Sequence as typing_Sequence @@ -1143,7 +1144,7 @@ class SchemaDropper(InvokeDDLBase):  def sort_tables( -    tables: typing_Sequence["Table"], +    tables: Iterable["Table"],      skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None,      extra_dependencies: Optional[          typing_Sequence[Tuple["Table", "Table"]] diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ea0fa7996..34d5127ab 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -293,11 +293,18 @@ class ClauseElement(      __visit_name__ = "clause" -    _propagate_attrs: _PropagateAttrsType = util.immutabledict() -    """like annotations, however these propagate outwards liberally -    as SQL constructs are built, and are set up at construction time. +    if TYPE_CHECKING: -    """ +        @util.memoized_property +        def _propagate_attrs(self) -> _PropagateAttrsType: +            """like annotations, however these propagate outwards liberally +            as SQL constructs are built, and are set up at construction time. + +            """ +            ... + +    else: +        _propagate_attrs = util.EMPTY_DICT      @util.ro_memoized_property      def description(self) -> Optional[str]: @@ -343,7 +350,9 @@ class ClauseElement(      def _from_objects(self) -> List[FromClause]:          return [] -    def _set_propagate_attrs(self, values): +    def _set_propagate_attrs( +        self: SelfClauseElement, values: Mapping[str, Any] +    ) -> SelfClauseElement:          # usually, self._propagate_attrs is empty here.  one case where it's          # not is a subquery against ORM select, that is then pulled as a          # property of an aliased class.   should all be good @@ -526,13 +535,10 @@ class ClauseElement(              if unique:                  bind._convert_to_unique() -        return cast( -            SelfClauseElement, -            cloned_traverse( -                self, -                {"maintain_key": True, "detect_subquery_cols": True}, -                {"bindparam": visit_bindparam}, -            ), +        return cloned_traverse( +            self, +            {"maintain_key": True, "detect_subquery_cols": True}, +            {"bindparam": visit_bindparam},          )      def compare(self, other, **kw): @@ -730,7 +736,9 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):      # redefined with the specific types returned by ColumnElement hierarchies      if typing.TYPE_CHECKING: -        _propagate_attrs: _PropagateAttrsType +        @util.non_memoized_property +        def _propagate_attrs(self) -> _PropagateAttrsType: +            ...          def operate(              self, op: OperatorType, *other: Any, **kwargs: Any @@ -2064,10 +2072,11 @@ class TextClause(      roles.OrderByRole,      roles.FromClauseRole,      roles.SelectStatementRole, -    roles.BinaryElementRole[Any],      roles.InElementRole,      Executable,      DQLDMLClauseElement, +    roles.BinaryElementRole[Any], +    inspection.Inspectable["TextClause"],  ):      """Represent a literal SQL text fragment. diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index da15c305f..4b220188f 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -444,7 +444,7 @@ class DeferredLambdaElement(LambdaElement):      def _invoke_user_fn(self, fn, *arg):          return fn(*self.lambda_args) -    def _resolve_with_args(self, *lambda_args): +    def _resolve_with_args(self, *lambda_args: Any) -> ClauseElement:          assert isinstance(self._rec, AnalyzedFunction)          tracker_fn = self._rec.tracker_instrumented_fn          expr = tracker_fn(*lambda_args) @@ -478,7 +478,7 @@ class DeferredLambdaElement(LambdaElement):          for deferred_copy_internals in self._transforms:              expr = deferred_copy_internals(expr) -        return expr +        return expr  # type: ignore      def _copy_internals(          self, clone=_clone, deferred_copy_internals=None, **kw diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 577d868fd..231c70a5b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -22,9 +22,7 @@ if TYPE_CHECKING:      from .base import _EntityNamespace      from .base import ColumnCollection      from .base import ReadOnlyColumnCollection -    from .elements import ClauseElement      from .elements import ColumnClause -    from .elements import ColumnElement      from .elements import Label      from .elements import NamedColumn      from .selectable import _SelectIterable @@ -271,7 +269,14 @@ class StatementRole(SQLRole):      __slots__ = ()      _role_name = "Executable SQL or text() construct" -    _propagate_attrs: _PropagateAttrsType = util.immutabledict() +    if TYPE_CHECKING: + +        @util.memoized_property +        def _propagate_attrs(self) -> _PropagateAttrsType: +            ... + +    else: +        _propagate_attrs = util.EMPTY_DICT  class SelectStatementRole(StatementRole, ReturnsRowsRole): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 92b9cc62c..52ba60a62 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -144,9 +144,9 @@ class SchemaConst(Enum):      NULL_UNSPECIFIED = 3      """Symbol indicating the "nullable" keyword was not passed to a Column. -    Normally we would expect None to be acceptable for this but some backends -    such as that of SQL Server place special signficance on a "nullability" -    value of None. +    This is used to distinguish between the use case of passing +    ``nullable=None`` to a :class:`.Column`, which has special meaning +    on some backends such as SQL Server.      """ @@ -308,7 +308,9 @@ class HasSchemaAttr(SchemaItem):      schema: Optional[str] -class Table(DialectKWArgs, HasSchemaAttr, TableClause): +class Table( +    DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"] +):      r"""Represent a table in a database.      e.g.:: @@ -1318,117 +1320,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):      inherit_cache = True      key: str -    @overload -    def __init__( -        self, -        *args: SchemaEventTarget, -        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", -        default: Optional[Any] = None, -        doc: Optional[str] = None, -        key: Optional[str] = None, -        index: Optional[bool] = None, -        unique: Optional[bool] = None, -        info: Optional[_InfoType] = None, -        nullable: Optional[ -            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] -        ] = NULL_UNSPECIFIED, -        onupdate: Optional[Any] = None, -        primary_key: bool = False, -        server_default: Optional[_ServerDefaultType] = None, -        server_onupdate: Optional[FetchedValue] = None, -        quote: Optional[bool] = None, -        system: bool = False, -        comment: Optional[str] = None, -        _proxies: Optional[Any] = None, -        **dialect_kwargs: Any, -    ): -        ... - -    @overload -    def __init__( -        self, -        __name: str, -        *args: SchemaEventTarget, -        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", -        default: Optional[Any] = None, -        doc: Optional[str] = None, -        key: Optional[str] = None, -        index: Optional[bool] = None, -        unique: Optional[bool] = None, -        info: Optional[_InfoType] = None, -        nullable: Optional[ -            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] -        ] = NULL_UNSPECIFIED, -        onupdate: Optional[Any] = None, -        primary_key: bool = False, -        server_default: Optional[_ServerDefaultType] = None, -        server_onupdate: Optional[FetchedValue] = None, -        quote: Optional[bool] = None, -        system: bool = False, -        comment: Optional[str] = None, -        _proxies: Optional[Any] = None, -        **dialect_kwargs: Any, -    ): -        ... - -    @overload      def __init__(          self, -        __type: _TypeEngineArgument[_T], -        *args: SchemaEventTarget, -        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", -        default: Optional[Any] = None, -        doc: Optional[str] = None, -        key: Optional[str] = None, -        index: Optional[bool] = None, -        unique: Optional[bool] = None, -        info: Optional[_InfoType] = None, -        nullable: Optional[ -            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] -        ] = NULL_UNSPECIFIED, -        onupdate: Optional[Any] = None, -        primary_key: bool = False, -        server_default: Optional[_ServerDefaultType] = None, -        server_onupdate: Optional[FetchedValue] = None, -        quote: Optional[bool] = None, -        system: bool = False, -        comment: Optional[str] = None, -        _proxies: Optional[Any] = None, -        **dialect_kwargs: Any, -    ): -        ... - -    @overload -    def __init__( -        self, -        __name: str, -        __type: _TypeEngineArgument[_T], +        __name_pos: Optional[ +            Union[str, _TypeEngineArgument[_T], SchemaEventTarget] +        ] = None, +        __type_pos: Optional[ +            Union[_TypeEngineArgument[_T], SchemaEventTarget] +        ] = None,          *args: SchemaEventTarget, -        autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", -        default: Optional[Any] = None, -        doc: Optional[str] = None, -        key: Optional[str] = None, -        index: Optional[bool] = None, -        unique: Optional[bool] = None, -        info: Optional[_InfoType] = None, -        nullable: Optional[ -            Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] -        ] = NULL_UNSPECIFIED, -        onupdate: Optional[Any] = None, -        primary_key: bool = False, -        server_default: Optional[_ServerDefaultType] = None, -        server_onupdate: Optional[FetchedValue] = None, -        quote: Optional[bool] = None, -        system: bool = False, -        comment: Optional[str] = None, -        _proxies: Optional[Any] = None, -        **dialect_kwargs: Any, -    ): -        ... - -    def __init__( -        self, -        *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget],          name: Optional[str] = None,          type_: Optional[_TypeEngineArgument[_T]] = None,          autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", @@ -1440,7 +1340,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):          info: Optional[_InfoType] = None,          nullable: Optional[              Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] -        ] = NULL_UNSPECIFIED, +        ] = SchemaConst.NULL_UNSPECIFIED,          onupdate: Optional[Any] = None,          primary_key: bool = False,          server_default: Optional[_ServerDefaultType] = None, @@ -1953,7 +1853,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):          """  # noqa: E501, RST201, RST202 -        l_args = list(args) +        l_args = [__name_pos, __type_pos] + list(args)          del args          if l_args: @@ -1963,6 +1863,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):                          "May not pass name positionally and as a keyword."                      )                  name = l_args.pop(0)  # type: ignore +            elif l_args[0] is None: +                l_args.pop(0)          if l_args:              coltype = l_args[0] @@ -1972,6 +1874,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):                          "May not pass type_ positionally and as a keyword."                      )                  type_ = l_args.pop(0)  # type: ignore +            elif l_args[0] is None: +                l_args.pop(0)          if name is not None:              name = quoted_name(name, quote) @@ -1989,7 +1893,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):          self.primary_key = primary_key          self._user_defined_nullable = udn = nullable -          if udn is not NULL_UNSPECIFIED:              self.nullable = udn          else: @@ -5128,7 +5031,7 @@ class MetaData(HasSchemaAttr):      def clear(self) -> None:          """Clear all Table objects from this MetaData.""" -        dict.clear(self.tables) +        dict.clear(self.tables)  # type: ignore          self._schemas.clear()          self._fk_memos.clear() diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index aab3c678c..9d4d1d6c7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1223,7 +1223,9 @@ class Join(roles.DMLTableRole, FromClause):      @util.preload_module("sqlalchemy.sql.util")      def _populate_column_collection(self):          sqlutil = util.preloaded.sql_util -        columns = [c for c in self.left.c] + [c for c in self.right.c] +        columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ +            c for c in self.right.c +        ]          self.primary_key.extend(  # type: ignore              sqlutil.reduce_columns( diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 284343154..d08fef60a 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -17,7 +17,9 @@ from typing import AbstractSet  from typing import Any  from typing import Callable  from typing import cast +from typing import Collection  from typing import Dict +from typing import Iterable  from typing import Iterator  from typing import List  from typing import Optional @@ -32,15 +34,15 @@ from . import coercions  from . import operators  from . import roles  from . import visitors +from ._typing import is_text_clause  from .annotation import _deep_annotate as _deep_annotate  from .annotation import _deep_deannotate as _deep_deannotate  from .annotation import _shallow_annotate as _shallow_annotate  from .base import _expand_cloned  from .base import _from_objects -from .base import ColumnSet -from .cache_key import HasCacheKey  # noqa -from .ddl import sort_tables  # noqa -from .elements import _find_columns +from .cache_key import HasCacheKey as HasCacheKey +from .ddl import sort_tables as sort_tables +from .elements import _find_columns as _find_columns  from .elements import _label_reference  from .elements import _textual_label_reference  from .elements import BindParameter @@ -67,10 +69,13 @@ from ..util.typing import Protocol  if typing.TYPE_CHECKING:      from ._typing import _ColumnExpressionArgument +    from ._typing import _EquivalentColumnMap      from ._typing import _TypeEngineArgument +    from .elements import TextClause      from .roles import FromClauseRole      from .selectable import _JoinTargetElement      from .selectable import _OnClauseElement +    from .selectable import _SelectIterable      from .selectable import Selectable      from .visitors import _TraverseCallableType      from .visitors import ExternallyTraversible @@ -752,7 +757,29 @@ def splice_joins(      return ret -def reduce_columns(columns, *clauses, **kw): +@overload +def reduce_columns( +    columns: Iterable[ColumnElement[Any]], +    *clauses: Optional[ClauseElement], +    **kw: bool, +) -> Sequence[ColumnElement[Any]]: +    ... + + +@overload +def reduce_columns( +    columns: _SelectIterable, +    *clauses: Optional[ClauseElement], +    **kw: bool, +) -> Sequence[Union[ColumnElement[Any], TextClause]]: +    ... + + +def reduce_columns( +    columns: _SelectIterable, +    *clauses: Optional[ClauseElement], +    **kw: bool, +) -> Collection[Union[ColumnElement[Any], TextClause]]:      r"""given a list of columns, return a 'reduced' set based on natural      equivalents. @@ -775,12 +802,15 @@ def reduce_columns(columns, *clauses, **kw):      ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)      only_synonyms = kw.pop("only_synonyms", False) -    columns = util.ordered_column_set(columns) +    column_set = util.OrderedSet(columns) +    cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference( +        c for c in column_set if is_text_clause(c)  # type: ignore +    )      omit = util.column_set() -    for col in columns: +    for col in cset_no_text:          for fk in chain(*[c.foreign_keys for c in col.proxy_set]): -            for c in columns: +            for c in cset_no_text:                  if c is col:                      continue                  try: @@ -810,10 +840,12 @@ def reduce_columns(columns, *clauses, **kw):          def visit_binary(binary):              if binary.operator == operators.eq:                  cols = util.column_set( -                    chain(*[c.proxy_set for c in columns.difference(omit)]) +                    chain( +                        *[c.proxy_set for c in cset_no_text.difference(omit)] +                    )                  )                  if binary.left in cols and binary.right in cols: -                    for c in reversed(columns): +                    for c in reversed(cset_no_text):                          if c.shares_lineage(binary.right) and (                              not only_synonyms or c.name == binary.left.name                          ): @@ -824,7 +856,7 @@ def reduce_columns(columns, *clauses, **kw):              if clause is not None:                  visitors.traverse(clause, {}, {"binary": visit_binary}) -    return ColumnSet(columns.difference(omit)) +    return column_set.difference(omit)  def criterion_as_pairs( @@ -923,9 +955,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):      def __init__(          self,          selectable: Selectable, -        equivalents: Optional[ -            Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] -        ] = None, +        equivalents: Optional[_EquivalentColumnMap] = None,          include_fn: Optional[Callable[[ClauseElement], bool]] = None,          exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,          adapt_on_names: bool = False, @@ -1059,9 +1089,23 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):  class _ColumnLookup(Protocol): -    def __getitem__( -        self, key: ColumnElement[Any] -    ) -> Optional[ColumnElement[Any]]: +    @overload +    def __getitem__(self, key: None) -> None: +        ... + +    @overload +    def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: +        ... + +    @overload +    def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: +        ... + +    @overload +    def __getitem__(self, key: _ET) -> _ET: +        ... + +    def __getitem__(self, key: Any) -> Any:          ... @@ -1101,9 +1145,7 @@ class ColumnAdapter(ClauseAdapter):      def __init__(          self,          selectable: Selectable, -        equivalents: Optional[ -            Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] -        ] = None, +        equivalents: Optional[_EquivalentColumnMap] = None,          adapt_required: bool = False,          include_fn: Optional[Callable[[ClauseElement], bool]] = None,          exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, @@ -1155,7 +1197,17 @@ class ColumnAdapter(ClauseAdapter):          return ac -    def traverse(self, obj): +    @overload +    def traverse(self, obj: Literal[None]) -> None: +        ... + +    @overload +    def traverse(self, obj: _ET) -> _ET: +        ... + +    def traverse( +        self, obj: Optional[ExternallyTraversible] +    ) -> Optional[ExternallyTraversible]:          return self.columns[obj]      def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: @@ -1172,7 +1224,9 @@ class ColumnAdapter(ClauseAdapter):      adapt_clause = traverse      adapt_list = ClauseAdapter.copy_and_process -    def adapt_check_present(self, col): +    def adapt_check_present( +        self, col: ColumnElement[Any] +    ) -> Optional[ColumnElement[Any]]:          newcol = self.columns[col]          if newcol is col and self._corresponding_column(col, True) is None: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7363f9ddc..e0a66fbcf 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -961,12 +961,16 @@ def cloned_traverse(      ... +# a bit of controversy here, as the clone of the lead element +# *could* in theory replace with an entirely different kind of element. +# however this is really not how cloned_traverse is ever used internally +# at least.  @overload  def cloned_traverse( -    obj: ExternallyTraversible, +    obj: _ET,      opts: Mapping[str, Any],      visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: +) -> _ET:      ...  | 
