diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_selectable_constructors.py | 166 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 78 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 40 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 376 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 36 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 57 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 287 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 4 |
15 files changed, 939 insertions, 204 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 84913225d..c3ebb4596 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -121,7 +121,6 @@ def __go(lcls: Any) -> None: coercions.lambdas = lambdas coercions.schema = schema coercions.selectable = selectable - coercions.traversals = traversals from .annotation import _prepare_annotations from .annotation import Annotated diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 37d44976a..f89e8f578 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -9,12 +9,16 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import overload +from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -34,6 +38,17 @@ if TYPE_CHECKING: from ._typing import _FromClauseArgument from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 + from ._typing import _T8 + from ._typing import _T9 + from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE from .selectable import HasCTE @@ -41,6 +56,9 @@ if TYPE_CHECKING: from .selectable import SelectBase +_T = TypeVar("_T", bound=Any) + + def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -89,7 +107,9 @@ def cte( ) -def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def except_( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -119,7 +139,7 @@ def except_all( def exists( __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ) -> Exists: """Construct a new :class:`_expression.Exists` construct. @@ -162,7 +182,9 @@ def exists( return Exists(__argument) -def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def intersect( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -306,7 +328,129 @@ def outerjoin( return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: _ColumnsClauseArgument) -> Select: +# START OVERLOADED FUNCTIONS select Select 1-10 + +# code within this block is **programmatically, +# statically generated** by tools/generate_tuple_map_overloads.py + + +@overload +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + +@overload +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] +) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], + __ent9: _TCCA[_T9], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: + ... + + +# END OVERLOADED FUNCTIONS select + + +@overload +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + ... + + +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: r"""Construct a new :class:`_expression.Select`. @@ -343,7 +487,11 @@ def select(*entities: _ColumnsClauseArgument) -> Select: given, as well as ORM-mapped classes. """ - + # the keyword args are a necessary element in order for the typing + # to work out w/ the varargs vs. having named "keyword" arguments that + # aren't always present. + if __kw: + raise _no_kw() return Select(*entities) @@ -425,7 +573,9 @@ def tablesample( return TableSample._factory(selectable, sampling, name=name, seed=seed) -def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -445,7 +595,9 @@ def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: return CompoundSelect._create_union(*selects) -def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 53d29b628..1df530dbd 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -5,18 +5,27 @@ from typing import Any from typing import Callable from typing import Dict from typing import Set +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import roles +from .. import exc from .. import util from ..inspection import Inspectable from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from datetime import date + from datetime import datetime + from datetime import time + from datetime import timedelta + from decimal import Decimal + from uuid import UUID + from .base import Executable from .compiler import Compiled from .compiler import DDLCompiler @@ -26,17 +35,15 @@ if TYPE_CHECKING: from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import quoted_name - from .elements import SQLCoreOperations from .elements import TextClause from .lambdas import LambdaElement from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column - from .schema import DefaultGenerator - from .schema import Sequence - from .schema import Table from .selectable import Alias + from .selectable import CTE from .selectable import FromClause from .selectable import Join from .selectable import NamedFromClause @@ -61,6 +68,30 @@ class _HasClauseElement(Protocol): ... +# match column types that are not ORM entities +_NOT_ENTITY = TypeVar( + "_NOT_ENTITY", + int, + str, + "datetime", + "date", + "time", + "timedelta", + "UUID", + float, + "Decimal", +) + +_MAYBE_ENTITY = TypeVar( + "_MAYBE_ENTITY", + roles.ColumnsClauseRole, + Literal["*", 1], + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +) + + # convention: # XYZArgument - something that the end user is passing to a public API method # XYZElement - the internal representation that we use for the thing. @@ -76,9 +107,10 @@ _TextCoercedExpressionArgument = Union[ ] _ColumnsClauseArgument = Union[ - Literal["*", 1], + roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, - Type[Any], + Literal["*", 1], + Type[_T], Inspectable[_HasClauseElement], _HasClauseElement, ] @@ -92,6 +124,24 @@ sets; select(...), insert().returning(...), etc. """ +_TypedColumnClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T] +] + +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_T0 = TypeVar("_T0", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) +_T3 = TypeVar("_T3", bound=Any) +_T4 = TypeVar("_T4", bound=Any) +_T5 = TypeVar("_T5", bound=Any) +_T6 = TypeVar("_T6", bound=Any) +_T7 = TypeVar("_T7", bound=Any) +_T8 = TypeVar("_T8", bound=Any) +_T9 = TypeVar("_T9", bound=Any) + + _ColumnExpressionArgument = Union[ "ColumnElement[_T]", _HasClauseElement, @@ -169,6 +219,7 @@ _DMLTableArgument = Union[ "TableClause", "Join", "Alias", + "CTE", Type[Any], Inspectable[_HasClauseElement], _HasClauseElement, @@ -194,6 +245,11 @@ if TYPE_CHECKING: def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: ... + def is_keyed_column_element( + c: ClauseElement, + ) -> TypeGuard[KeyedColumnElement[Any]]: + ... + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... @@ -216,7 +272,7 @@ if TYPE_CHECKING: def is_select_statement( t: Union[Executable, ReturnsRows] - ) -> TypeGuard[Select]: + ) -> TypeGuard[Select[Any]]: ... def is_table(t: FromClause) -> TypeGuard[TableClause]: @@ -234,6 +290,7 @@ else: is_ddl_compiler = operator.attrgetter("is_ddl") is_named_from_clause = operator.attrgetter("named_with_column") is_column_element = operator.attrgetter("_is_column_element") + is_keyed_column_element = operator.attrgetter("_is_keyed_column_element") is_text_clause = operator.attrgetter("_is_text_clause") is_from_clause = operator.attrgetter("_is_from_clause") is_tuple_type = operator.attrgetter("_is_tuple_type") @@ -260,3 +317,10 @@ def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: return c.is_dml and (c.is_insert or c.is_update) # type: ignore + + +def _no_kw() -> exc.ArgumentError: + return exc.ArgumentError( + "Additional keyword arguments are not accepted by this " + "function/method. The presence of **kw is for pep-484 typing purposes" + ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f81878d55..790edefc6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -62,10 +62,10 @@ if TYPE_CHECKING: from . import coercions from . import elements from . import type_api - from ._typing import _ColumnsClauseArgument from .elements import BindParameter - from .elements import ColumnClause + from .elements import ColumnClause # noqa from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations from .elements import TextClause @@ -74,7 +74,6 @@ if TYPE_CHECKING: from .selectable import FromClause from ..engine import Connection from ..engine import CursorResult - from ..engine import Result from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import _ExecuteOptions @@ -704,8 +703,11 @@ class InPlaceGenerative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @_generative decorator that mutates in place.""" + __slots__ = () + def _generate(self): skip = self._memoized_keys + # note __dict__ needs to be in __slots__ if this is used for k in skip: self.__dict__.pop(k, None) return self @@ -937,7 +939,7 @@ class ExecutableOption(HasCopyInternals): SelfExecutable = TypeVar("SelfExecutable", bound="Executable") -class Executable(roles.StatementRole, Generative): +class Executable(roles.StatementRole): """Mark a :class:`_expression.ClauseElement` as supporting execution. :class:`.Executable` is a superclass for all "statement" types @@ -994,7 +996,7 @@ class Executable(roles.StatementRole, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: ... def _execute_on_scalar( @@ -1253,7 +1255,7 @@ class SchemaVisitor(ClauseVisitor): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) -_COL = TypeVar("_COL", bound="ColumnElement[Any]") +_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") class ColumnCollection(Generic[_COLKEY, _COL_co]): @@ -1505,6 +1507,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) + self._collection[:] = cols self._colset.update(c for k, c in self._collection) self._index.update( diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 0659709ab..9b7231360 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -29,6 +29,7 @@ from typing import Union from . import operators from . import roles from . import visitors +from ._typing import is_from_clause from .base import ExecutableOption from .base import Options from .cache_key import HasCacheKey @@ -38,25 +39,18 @@ from .. import inspection from .. import util from ..util.typing import Literal -if not typing.TYPE_CHECKING: - elements = None - lambdas = None - schema = None - selectable = None - traversals = None - if typing.TYPE_CHECKING: from . import elements from . import lambdas from . import schema from . import selectable - from . import traversals from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument from ._typing import _DDLColumnArgument from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument from .dml import _DMLTableElement + from .elements import BindParameter from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -64,9 +58,7 @@ if typing.TYPE_CHECKING: from .elements import SQLCoreOperations from .schema import Column from .selectable import _ColumnsClauseElement - from .selectable import _JoinTargetElement from .selectable import _JoinTargetProtocol - from .selectable import _OnClauseElement from .selectable import FromClause from .selectable import HasCTE from .selectable import SelectBase @@ -170,6 +162,15 @@ def expect( @overload def expect( + role: Type[roles.LiteralValueRole], + element: Any, + **kw: Any, +) -> BindParameter[Any]: + ... + + +@overload +def expect( role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, @@ -272,7 +273,7 @@ def expect( @overload def expect( role: Type[roles.ColumnsClauseRole], - element: _ColumnsClauseArgument, + element: _ColumnsClauseArgument[Any], **kw: Any, ) -> _ColumnsClauseElement: ... @@ -933,7 +934,7 @@ class GroupByImpl(ByOfImpl, RoleImpl): argname: Optional[str] = None, **kw: Any, ) -> Any: - if isinstance(resolved, roles.StrictFromClauseRole): + if is_from_clause(resolved): return elements.ClauseList(*resolved.c) else: return resolved diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c524a2602..a1b25b8a6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -80,7 +80,6 @@ from ..util.typing import Protocol from ..util.typing import TypedDict if typing.TYPE_CHECKING: - from . import roles from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState @@ -95,7 +94,6 @@ if typing.TYPE_CHECKING: from .elements import ColumnElement from .elements import Label from .functions import Function - from .selectable import Alias from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState from .selectable import CTE @@ -386,7 +384,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): need_result_map_for_nested: bool need_result_map_for_compound: bool select_0: ReturnsRows - insert_from_select: Select + insert_from_select: Select[Any] class ExpandedState(NamedTuple): @@ -2834,15 +2832,31 @@ class SQLCompiler(Compiled): "unique bind parameter of the same name" % name ) elif existing._is_crud or bindparam._is_crud: - raise exc.CompileError( - "bindparam() name '%s' is reserved " - "for automatic usage in the VALUES or SET " - "clause of this " - "insert/update statement. Please use a " - "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (bindparam.key, bindparam.key) - ) + if existing._is_crud and bindparam._is_crud: + # TODO: this condition is not well understood. + # see tests in test/sql/test_update.py + raise exc.CompileError( + "Encountered unsupported case when compiling an " + "INSERT or UPDATE statement. If this is a " + "multi-table " + "UPDATE statement, please provide string-named " + "arguments to the " + "values() method with distinct names; support for " + "multi-table UPDATE statements that " + "target multiple tables for UPDATE is very " + "limited", + ) + else: + raise exc.CompileError( + f"bindparam() name '{bindparam.key}' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using " + "bindparam() " + "with insert() or update() (for example, " + f"'b_{bindparam.key}')." + ) self.binds[bindparam.key] = self.binds[name] = bindparam @@ -3881,7 +3895,7 @@ class SQLCompiler(Compiled): return text def _setup_select_hints( - self, select: Select + self, select: Select[Any] ) -> Tuple[str, _FromHintsType]: byfrom = dict( [ diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e4408cd31..29d7b45d7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -22,6 +22,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -30,8 +31,10 @@ from . import coercions from . import dml from . import elements from . import roles +from .elements import ColumnClause from .schema import default_is_clause_element from .schema import default_is_sequence +from .selectable import TableClause from .. import exc from .. import util from ..util.typing import Literal @@ -41,16 +44,9 @@ if TYPE_CHECKING: from .compiler import SQLCompiler from .dml import _DMLColumnElement from .dml import DMLState - from .dml import Insert - from .dml import Update - from .dml import UpdateDMLState from .dml import ValuesBase - from .elements import ClauseElement - from .elements import ColumnClause from .elements import ColumnElement - from .elements import TextClause from .schema import _SQLExprDefault - from .schema import Column from .selectable import TableClause REQUIRED = util.symbol( @@ -68,12 +64,20 @@ values present. ) +def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: + if not isinstance(c, ColumnClause): + raise exc.CompileError( + f"Can't create DML statement against column expression {c!r}" + ) + return c + + class _CrudParams(NamedTuple): - single_params: List[ - Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + single_params: Sequence[ + Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]] ] all_multi_params: List[ - List[ + Sequence[ Tuple[ ColumnClause[Any], str, @@ -274,7 +278,7 @@ def _get_crud_params( compiler, stmt, compile_state, - cast("List[Tuple[ColumnClause[Any], str, str]]", values), + cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values), cast("Callable[..., str]", _column_as_key), kw, ) @@ -290,7 +294,7 @@ def _get_crud_params( # insert_executemany_returning mode :) values = [ ( - stmt.table.columns[0], + _as_dml_column(stmt.table.columns[0]), compiler.preparer.format_column(stmt.table.columns[0]), "DEFAULT", ) @@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, - initial_values: List[Tuple[ColumnClause[Any], str, str]], + initial_values: Sequence[Tuple[ColumnClause[Any], str, str]], _column_as_key: Callable[..., str], kw: Dict[str, Any], -) -> List[List[Tuple[ColumnClause[Any], str, str]]]: +) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]: values_0 = initial_values values = [initial_values] diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 8307f6400..e0f162fc8 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -22,15 +22,19 @@ from typing import List from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from . import util as sql_util +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_named_from_clause from .base import _entity_namespace_key @@ -42,6 +46,7 @@ from .base import ColumnCollection from .base import CompileState from .base import DialectKWArgs from .base import Executable +from .base import Generative from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement @@ -49,12 +54,13 @@ from .elements import ColumnClause from .elements import ColumnElement from .elements import Null from .selectable import Alias +from .selectable import ExecutableReturnsRows from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import Join -from .selectable import ReturnsRows from .selectable import TableClause +from .selectable import TypedReturnsRows from .sqltypes import NullType from .visitors import InternalTraversal from .. import exc @@ -66,9 +72,19 @@ if TYPE_CHECKING: from ._typing import _ColumnsClauseArgument from ._typing import _DMLColumnArgument from ._typing import _DMLTableArgument - from ._typing import _FromClauseArgument + from ._typing import _T0 # noqa + from ._typing import _T1 # noqa + from ._typing import _T2 # noqa + from ._typing import _T3 # noqa + from ._typing import _T4 # noqa + from ._typing import _T5 # noqa + from ._typing import _T6 # noqa + from ._typing import _T7 # noqa + from ._typing import _TypedColumnClauseArgument as _TCCA # noqa from .base import ReadOnlyColumnCollection from .compiler import SQLCompiler + from .elements import ColumnElement + from .elements import KeyedColumnElement from .selectable import _ColumnsClauseElement from .selectable import _SelectIterable from .selectable import Select @@ -88,6 +104,8 @@ else: isinsert = operator.attrgetter("isinsert") +_T = TypeVar("_T", bound=Any) + _DMLColumnElement = Union[str, ColumnClause[Any]] _DMLTableElement = Union[TableClause, Alias, Join] @@ -185,6 +203,11 @@ class DMLState(CompileState): "%s construct does not support " "multiple parameter sets." % statement.__visit_name__.upper() ) + else: + assert isinstance(statement, Insert) + + # which implies... + # assert isinstance(statement.table, TableClause) for parameters in statement._multi_values: multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ @@ -291,7 +314,9 @@ class UpdateDMLState(DMLState): elif statement._multi_values: self._process_multi_values(statement) self._extra_froms = ef = self._make_extra_froms(statement) - self.is_multitable = mt = ef and self._dict_parameters + + self.is_multitable = mt = ef + self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -317,8 +342,8 @@ class UpdateBase( HasCompileState, DialectKWArgs, HasPrefixes, - ReturnsRows, - Executable, + Generative, + ExecutableReturnsRows, ClauseElement, ): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" @@ -383,8 +408,8 @@ class UpdateBase( @_generative def returning( - self: SelfUpdateBase, *cols: _ColumnsClauseArgument - ) -> SelfUpdateBase: + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> UpdateBase: r"""Add a :term:`RETURNING` or equivalent clause to this statement. e.g.: @@ -454,6 +479,8 @@ class UpdateBase( :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial` """ # noqa: E501 + if __kw: + raise _no_kw() if self._return_defaults: raise exc.InvalidRequestError( "return_defaults() is already configured on this statement" @@ -464,7 +491,7 @@ class UpdateBase( return self def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False + self, column: KeyedColumnElement[Any], require_embedded: bool = False ) -> Optional[ColumnElement[Any]]: return self.exported_columns.corresponding_column( column, require_embedded=require_embedded @@ -628,7 +655,7 @@ class ValuesBase(UpdateBase): _supports_multi_parameters = False - select: Optional[Select] = None + select: Optional[Select[Any]] = None """SELECT statement for INSERT .. FROM SELECT""" _post_values_clause: Optional[ClauseElement] = None @@ -804,11 +831,15 @@ class ValuesBase(UpdateBase): ) elif isinstance(arg, collections_abc.Sequence): - if arg and isinstance(arg[0], (list, dict, tuple)): self._multi_values += (arg,) return self + if TYPE_CHECKING: + # crud.py raises during compilation if this is not the + # case + assert isinstance(self, Insert) + # tuple values arg = {c.key: value for c, value in zip(self.table.c, arg)} @@ -1010,7 +1041,7 @@ class Insert(ValuesBase): def from_select( self: SelfInsert, names: List[str], - select: Select, + select: Select[Any], include_defaults: bool = True, ) -> SelfInsert: """Return a new :class:`_expression.Insert` construct which represents @@ -1073,6 +1104,114 @@ class Insert(ValuesBase): self.select = coercions.expect(roles.DMLSelectRole, select) return self + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningInsert[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + +class ReturningInsert(Insert, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Insert` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Insert.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") @@ -1264,6 +1403,113 @@ class Update(DMLWhereBase, ValuesBase): self._inline = True return self + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningUpdate[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + +class ReturningUpdate(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Update` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Update.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDelete = typing.TypeVar("SelfDelete", bound="Delete") @@ -1297,3 +1543,111 @@ class Delete(DMLWhereBase, UpdateBase): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) + + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningDelete[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + +class ReturningDelete(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Delete` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Delete.returning` method. + + .. versionadded:: 2.0 + + """ diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 34d5127ab..a29561291 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -54,6 +54,7 @@ from .base import _clone from .base import _generative from .base import _NoArg from .base import Executable +from .base import Generative from .base import HasMemoized from .base import Immutable from .base import NO_ARG @@ -94,10 +95,7 @@ if typing.TYPE_CHECKING: from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause - from .selectable import ReturnsRows from .selectable import Select - from .selectable import TableClause - from .sqltypes import Boolean from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -122,7 +120,9 @@ _NT = TypeVar("_NT", bound="_NUMERIC") _NMT = TypeVar("_NMT", bound="_NUMBER") -def literal(value, type_=None): +def literal( + value: Any, type_: Optional[_TypeEngineArgument[_T]] = None +) -> BindParameter[_T]: r"""Return a literal clause, bound to a bind parameter. Literal clauses are created automatically when non- @@ -144,7 +144,9 @@ def literal(value, type_=None): return coercions.expect(roles.LiteralValueRole, value, type_=type_) -def literal_column(text, type_=None): +def literal_column( + text: str, type_: Optional[_TypeEngineArgument[_T]] = None +) -> ColumnClause[_T]: r"""Produce a :class:`.ColumnClause` object that has the :paramref:`_expression.column.is_literal` flag set to True. @@ -316,6 +318,7 @@ class ClauseElement( is_selectable = False is_dml = False _is_column_element = False + _is_keyed_column_element = False _is_table = False _is_textual = False _is_from_clause = False @@ -342,7 +345,7 @@ class ClauseElement( if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any ) -> Iterable[ClauseElement]: ... @@ -455,7 +458,7 @@ class ClauseElement( connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptions, - ) -> Result: + ) -> Result[Any]: if self.supports_execution: if TYPE_CHECKING: assert isinstance(self, Executable) @@ -833,13 +836,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def in_( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... def not_in( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... @@ -1699,6 +1702,14 @@ class ColumnElement( return self._anon_label(label, add_hash=idx) +class KeyedColumnElement(ColumnElement[_T]): + """ColumnElement where ``.key`` is non-None.""" + + _is_keyed_column_element = True + + key: str + + class WrapsColumnExpression(ColumnElement[_T]): """Mixin that defines a :class:`_expression.ColumnElement` as a wrapper with special @@ -1760,7 +1771,7 @@ class WrapsColumnExpression(ColumnElement[_T]): SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]") -class BindParameter(roles.InElementRole, ColumnElement[_T]): +class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): r"""Represent a "bound expression". :class:`.BindParameter` is invoked explicitly using the @@ -2073,6 +2084,7 @@ class TextClause( roles.FromClauseRole, roles.SelectStatementRole, roles.InElementRole, + Generative, Executable, DQLDMLClauseElement, roles.BinaryElementRole[Any], @@ -4160,7 +4172,7 @@ class FunctionFilter(ColumnElement[_T]): ) -class NamedColumn(ColumnElement[_T]): +class NamedColumn(KeyedColumnElement[_T]): is_literal = False table: Optional[FromClause] = None name: str @@ -4502,7 +4514,7 @@ class ColumnClause( self.is_literal = is_literal - def get_children(self, column_tables=False, **kw): + def get_children(self, *, column_tables=False, **kw): # override base get_children() to not return the Table # or selectable that is parent to this column. Traversals # expect the columns of tables and subqueries to be leaf nodes. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 648168235..b827df3df 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -175,7 +175,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: return connection._execute_function( self, distilled_params, execution_options ) @@ -623,7 +623,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): joins_implicitly=joins_implicitly, ) - def select(self) -> "Select": + def select(self) -> Select[Any]: """Produce a :func:`_expression.select` construct against this :class:`.FunctionElement`. @@ -632,7 +632,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): s = select(function_element) """ - s = Select(self) + s: Select[Any] = Select(self) if self._execution_options: s = s.execution_options(**self._execution_options) return s @@ -846,7 +846,7 @@ class _FunctionGenerator: @overload def __call__( - self, *c: Any, type_: TypeEngine[_T], **kwargs: Any + self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any ) -> Function[_T]: ... diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 231c70a5b..09d4b35ad 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -8,8 +8,6 @@ from __future__ import annotations from typing import Any from typing import Generic -from typing import Iterable -from typing import List from typing import Optional from typing import TYPE_CHECKING from typing import TypeVar @@ -19,12 +17,7 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _PropagateAttrsType - from .base import _EntityNamespace - from .base import ColumnCollection - from .base import ReadOnlyColumnCollection - from .elements import ColumnClause from .elements import Label - from .elements import NamedColumn from .selectable import _SelectIterable from .selectable import FromClause from .selectable import Subquery @@ -108,13 +101,21 @@ class TruncatedLabelRole(StringRole, SQLRole): class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): __slots__ = () - _role_name = "Column expression or FROM clause" + _role_name = ( + "Column expression, FROM clause, or other columns clause element" + ) @property def _select_iterable(self) -> _SelectIterable: raise NotImplementedError() +class TypedColumnsClauseRole(Generic[_T], SQLRole): + """element-typed form of ColumnsClauseRole""" + + __slots__ = () + + class LimitOffsetRole(SQLRole): __slots__ = () _role_name = "LIMIT / OFFSET expression" @@ -161,7 +162,7 @@ class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" -class ExpressionElementRole(Generic[_T], SQLRole): +class ExpressionElementRole(TypedColumnsClauseRole[_T]): # note when using generics for ExpressionElementRole, # the generic type needs to be in # sqlalchemy.sql.coercions._impl_lookup mapping also. @@ -212,39 +213,11 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole): named_with_column: bool - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... - - @util.ro_non_memoized_property - def _hide_froms(self) -> Iterable[FromClause]: - ... - - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - ... - class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def description(self) -> str: - ... - class AnonymizedFromClauseRole(StrictFromClauseRole): __slots__ = () @@ -317,16 +290,6 @@ class DMLTableRole(FromClauseRole): __slots__ = () _role_name = "subject table for an INSERT, UPDATE or DELETE" - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def primary_key(self) -> Iterable[NamedColumn[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - class DMLColumnRole(SQLRole): __slots__ = () diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 52ba60a62..27456d2be 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -86,7 +86,6 @@ if typing.TYPE_CHECKING: from ._typing import _InfoType from ._typing import _TextCoercedExpressionArgument from ._typing import _TypeEngineArgument - from .base import ColumnCollection from .base import DedupeColumnCollection from .base import ReadOnlyColumnCollection from .compiler import DDLCompiler @@ -97,9 +96,7 @@ if typing.TYPE_CHECKING: from .visitors import anon_map from ..engine import Connection from ..engine import Engine - from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams - from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import ExecutionContext from ..engine.mock import MockConnection @@ -2609,8 +2606,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): :class:`_schema.Table`. """ - - return table.columns.corresponding_column(self.column) + # our column is a Column, and any subquery etc. proxying us + # would be doing so via another Column, so that's what would + # be returned here + return table.columns.corresponding_column(self.column) # type: ignore @util.memoized_property def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9d4d1d6c7..b08f13f99 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -23,6 +23,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -46,6 +47,8 @@ from . import traversals from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_select_statement from ._typing import is_subquery @@ -103,9 +106,20 @@ if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _FromClauseArgument from ._typing import _JoinTargetArgument + from ._typing import _MAYBE_ENTITY + from ._typing import _NOT_ENTITY from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypedColumnClauseArgument as _TCCA from ._typing import _TypeEngineArgument from .base import _AmbiguousTableNameMap from .base import ExecutableOption @@ -115,14 +129,13 @@ if TYPE_CHECKING: from .dml import Delete from .dml import Insert from .dml import Update + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import TextClause from .functions import Function - from .schema import Column from .schema import ForeignKey from .schema import ForeignKeyConstraint from .type_api import TypeEngine - from .util import ClauseAdapter from .visitors import _CloneCallableType @@ -245,6 +258,14 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() +class ExecutableReturnsRows(Executable, ReturnsRows): + """base for executable statements that return rows.""" + + +class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]): + """base for executable statements that return rows.""" + + SelfSelectable = TypeVar("SelfSelectable", bound="Selectable") @@ -293,8 +314,8 @@ class Selectable(ReturnsRows): ) def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False - ) -> Optional[ColumnElement[Any]]: + self, column: KeyedColumnElement[Any], require_embedded: bool = False + ) -> Optional[KeyedColumnElement[Any]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from the :attr:`_expression.Selectable.exported_columns` @@ -593,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -795,7 +816,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): ) @util.ro_non_memoized_property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`. @@ -817,7 +840,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Any]: + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -833,7 +858,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Any]: + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """ A synonym for :attr:`.FromClause.columns` @@ -1223,7 +1248,7 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ + columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ c for c in self.right.c ] @@ -1458,7 +1483,7 @@ class Join(roles.DMLTableRole, FromClause): "join explicitly." % (a.description, b.description) ) - def select(self) -> "Select": + def select(self) -> Select[Any]: r"""Create a :class:`_expression.Select` from this :class:`_expression.Join`. @@ -2764,6 +2789,7 @@ class Subquery(AliasedReturnsRows): cls, selectable: SelectBase, name: Optional[str] = None ) -> Subquery: """Return a :class:`.Subquery` object.""" + return coercions.expect( roles.SelectStatementRole, selectable ).subquery(name=name) @@ -3216,7 +3242,6 @@ class SelectBase( roles.CompoundElementRole, roles.InElementRole, HasCTE, - Executable, SupportsCloneAnnotations, Selectable, ): @@ -3239,7 +3264,9 @@ class SelectBase( self._reset_memoizations() @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set. @@ -3284,7 +3311,9 @@ class SelectBase( raise NotImplementedError() @property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`, not including @@ -3377,7 +3406,7 @@ class SelectBase( def as_scalar(self): return self.scalar_subquery() - def exists(self): + def exists(self) -> Exists: """Return an :class:`_sql.Exists` representation of this selectable, which can be used as a column expression. @@ -3394,7 +3423,7 @@ class SelectBase( """ return Exists(self) - def scalar_subquery(self): + def scalar_subquery(self) -> ScalarSelect[Any]: """Return a 'scalar' representation of this selectable, which can be used as a column expression. @@ -3607,7 +3636,7 @@ SelfGenerativeSelect = typing.TypeVar( ) -class GenerativeSelect(SelectBase): +class GenerativeSelect(SelectBase, Generative): """Base class for SELECT statements where additional elements can be added. @@ -4128,7 +4157,7 @@ class _CompoundSelectKeyword(Enum): INTERSECT_ALL = "INTERSECT ALL" -class CompoundSelect(HasCompileState, GenerativeSelect): +class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -4293,7 +4322,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self.selects[0]._all_selected_columns @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -4343,7 +4374,10 @@ class SelectState(util.MemoizedSlots, CompileState): ... def __init__( - self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any + self, + statement: Select[Any], + compiler: Optional[SQLCompiler], + **kw: Any, ): self.statement = statement self.from_clauses = statement._from_obj @@ -4369,7 +4403,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def get_column_descriptions( - cls, statement: Select + cls, statement: Select[Any] ) -> List[Dict[str, Any]]: return [ { @@ -4384,12 +4418,14 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def from_statement( - cls, statement: Select, from_statement: ReturnsRows - ) -> Any: + cls, statement: Select[Any], from_statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: cls._plugin_not_implemented() @classmethod - def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]: + def get_columns_clause_froms( + cls, statement: Select[Any] + ) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( element._from_objects for element in statement._raw_columns @@ -4439,7 +4475,7 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement: Select) -> List[FromClause]: + def _get_froms(self, statement: Select[Any]) -> List[FromClause]: ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} @@ -4467,7 +4503,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _normalize_froms( cls, iterable_of_froms: Iterable[FromClause], - check_statement: Optional[Select] = None, + check_statement: Optional[Select[Any]] = None, ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what @@ -4615,7 +4651,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def determine_last_joined_entity( - cls, stmt: Select + cls, stmt: Select[Any] ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] @@ -4623,7 +4659,7 @@ class SelectState(util.MemoizedSlots, CompileState): return None @classmethod - def all_selected_columns(cls, statement: Select) -> _SelectIterable: + def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] def _setup_joins( @@ -4876,7 +4912,7 @@ class _MemoizedSelectEntities( return c # type: ignore @classmethod - def _generate_for_statement(cls, select_stmt: Select) -> None: + def _generate_for_statement(cls, select_stmt: Select[Any]) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -4888,7 +4924,7 @@ class _MemoizedSelectEntities( select_stmt._setup_joins = select_stmt._with_options = () -SelfSelect = typing.TypeVar("SelfSelect", bound="Select") +SelfSelect = typing.TypeVar("SelfSelect", bound="Select[Any]") class Select( @@ -4898,6 +4934,7 @@ class Select( HasCompileState, _SelectFromElements, GenerativeSelect, + TypedReturnsRows[_TP], ): """Represents a ``SELECT`` statement. @@ -4973,7 +5010,7 @@ class Select( _compile_state_factory: Type[SelectState] @classmethod - def _create_raw_select(cls, **kw: Any) -> Select: + def _create_raw_select(cls, **kw: Any) -> Select[Any]: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -4985,7 +5022,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: _ColumnsClauseArgument): + def __init__(self, *entities: _ColumnsClauseArgument[Any]): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the @@ -5013,7 +5050,9 @@ class Select( cols = list(elem._select_iterable) return cols[0].type - def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect: + def filter( + self: SelfSelect, *criteria: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """A synonym for the :meth:`_future.Select.where` method.""" return self.where(*criteria) @@ -5032,7 +5071,28 @@ class Select( return self._raw_columns[0] - def filter_by(self, **kwargs): + if TYPE_CHECKING: + + @overload + def scalar_subquery( + self: Select[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: + ... + + @overload + def scalar_subquery( + self: Select[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def filter_by(self: SelfSelect, **kwargs: Any) -> SelfSelect: r"""apply the given filtering criterion as a WHERE clause to this select. @@ -5046,7 +5106,7 @@ class Select( return self.filter(*clauses) @property - def column_descriptions(self): + def column_descriptions(self) -> Any: """Return a :term:`plugin-enabled` 'column descriptions' structure referring to the columns which are SELECTed by this statement. @@ -5089,7 +5149,9 @@ class Select( meth = SelectState.get_plugin_class(self).get_column_descriptions return meth(self) - def from_statement(self, statement): + def from_statement( + self, statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: """Apply the columns which this :class:`.Select` would select onto another statement. @@ -5410,7 +5472,7 @@ class Select( ) @property - def inner_columns(self): + def inner_columns(self) -> _SelectIterable: """An iterator of all :class:`_expression.ColumnElement` expressions which would be rendered into the columns clause of the resulting SELECT statement. @@ -5487,18 +5549,19 @@ class Select( self._reset_memoizations() - def get_children(self, **kwargs): + def get_children(self, **kw: Any) -> Iterable[ClauseElement]: return itertools.chain( super(Select, self).get_children( - omit_attrs=("_from_obj", "_correlate", "_correlate_except") + omit_attrs=("_from_obj", "_correlate", "_correlate_except"), + **kw, ), self._iterate_from_elements(), ) @_generative def add_columns( - self: SelfSelect, *columns: _ColumnsClauseArgument - ) -> SelfSelect: + self, *columns: _ColumnsClauseArgument[Any] + ) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expressions added to its columns clause. @@ -5523,7 +5586,7 @@ class Select( return self def _set_entities( - self, entities: Iterable[_ColumnsClauseArgument] + self, entities: Iterable[_ColumnsClauseArgument[Any]] ) -> None: self._raw_columns = [ coercions.expect( @@ -5538,7 +5601,7 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect: + def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -5555,9 +5618,7 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns( - self: SelfSelect, only_synonyms: bool = True - ) -> SelfSelect: + def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -5580,20 +5641,115 @@ class Select( all columns that are equivalent to another are removed. """ - return self.with_only_columns( + woc: Select[Any] + woc = self.with_only_columns( *util.preloaded.sql_util.reduce_columns( self._all_selected_columns, only_synonyms=only_synonyms, *(self._where_criteria + self._from_obj), ) ) + return woc + + # START OVERLOADED FUNCTIONS self.with_only_columns Select 8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_sel_v1_overloads.py + + @overload + def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> Select[Tuple[_T0, _T1]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.with_only_columns + + @overload + def with_only_columns( + self, + *columns: _ColumnsClauseArgument[Any], + maintain_column_froms: bool = False, + **__kw: Any, + ) -> Select[Any]: + ... @_generative def with_only_columns( - self: SelfSelect, - *columns: _ColumnsClauseArgument, + self, + *columns: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, - ) -> SelfSelect: + **__kw: Any, + ) -> Select[Any]: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given columns. @@ -5647,6 +5803,9 @@ class Select( """ # noqa: E501 + if __kw: + raise _no_kw() + # memoizations should be cleared here as of # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this # is the case for now. @@ -5915,7 +6074,9 @@ class Select( return self @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -6215,7 +6376,7 @@ class ScalarSelect( by this :class:`_expression.ScalarSelect`. """ - self.element = cast(Select, self.element).where(crit) + self.element = cast("Select[Any]", self.element).where(crit) return self @overload @@ -6269,7 +6430,9 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate(*fromclauses) + self.element = cast("Select[Any]", self.element).correlate( + *fromclauses + ) return self @_generative @@ -6307,7 +6470,7 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate_except( + self.element = cast("Select[Any]", self.element).correlate_except( *fromclauses ) return self @@ -6331,12 +6494,18 @@ class Exists(UnaryExpression[bool]): def __init__( self, __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ): + s: ScalarSelect[Any] + + # TODO: this seems like we should be using coercions for this if __argument is None: s = Select(literal_column("*")).scalar_subquery() - elif isinstance(__argument, (SelectBase, ScalarSelect)): + elif isinstance(__argument, SelectBase): + s = __argument.scalar_subquery() + s._propagate_attrs = __argument._propagate_attrs + elif isinstance(__argument, ScalarSelect): s = __argument else: s = Select(__argument).scalar_subquery() @@ -6358,7 +6527,7 @@ class Exists(UnaryExpression[bool]): element = fn(element) return element.self_group(against=operators.exists) - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: @@ -6452,7 +6621,7 @@ class Exists(UnaryExpression[bool]): SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect") -class TextualSelect(SelectBase): +class TextualSelect(SelectBase, Executable, Generative): """Wrap a :class:`_expression.TextClause` construct within a :class:`_expression.SelectBase` interface. @@ -6503,7 +6672,9 @@ class TextualSelect(SelectBase): self.positional = positional @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d08fef60a..8c45ba410 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -50,6 +50,7 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import Grouping +from .elements import KeyedColumnElement from .elements import Label from .elements import Null from .elements import UnaryExpression @@ -72,9 +73,7 @@ if typing.TYPE_CHECKING: from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument from .elements import TextClause - from .roles import FromClauseRole from .selectable import _JoinTargetElement - from .selectable import _OnClauseElement from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType @@ -569,7 +568,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: "Row", max_chars: int = 300): + def __init__(self, row: "Row[Any]", max_chars: int = 300): self.row = row self.max_chars = max_chars @@ -1068,7 +1067,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): col = col._annotations["adapt_column"] if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.adapt_from_selectables and col not in self.equivalents: for adp in self.adapt_from_selectables: @@ -1078,7 +1077,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return None if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.include_fn and not self.include_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index e0a66fbcf..88586d834 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -450,7 +450,7 @@ class HasTraverseInternals: @util.preload_module("sqlalchemy.sql.traversals") def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[HasTraverseInternals]: r"""Return immediate child :class:`.visitors.HasTraverseInternals` elements of this :class:`.visitors.HasTraverseInternals`. @@ -594,7 +594,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[ExternallyTraversible]: ... |
