diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-19 21:06:41 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-27 14:46:36 -0400 |
| commit | ad11c482e2233f44e8747d4d5a2b17a995fff1fa (patch) | |
| tree | 57f8ddd30928951519fd6ac0f418e9cbf8e65610 /lib/sqlalchemy/sql | |
| parent | 033d1a16e7a220555d7611a5b8cacb1bd83822ae (diff) | |
| download | sqlalchemy-ad11c482e2233f44e8747d4d5a2b17a995fff1fa.tar.gz | |
pep484 ORM / SQL result support
after some experimentation it seems mypy is more amenable
to the generic types being fully integrated rather than
having separate spin-off types. so key structures
like Result, Row, Select become generic. For DML
Insert, Update, Delete, these are spun into type-specific
subclasses ReturningInsert, ReturningUpdate, ReturningDelete,
which is fine since the "row-ness" of these constructs
doesn't happen until returning() is called in any case.
a Tuple based model is then integrated so that these
objects can carry along information about their return
types. Overloads at the .execute() level carry through
the Tuple from the invoked object to the result.
To suit the issue of AliasedClass generating attributes
that are dynamic, experimented with a custom subclass
AsAliased, but then just settled on having aliased()
lie to the type checker and return `Type[_O]`, essentially.
will need some type-related accessors for with_polymorphic()
also.
Additionally, identified an issue in Update when used
"mysql style" against a join(), it basically doesn't work
if asked to UPDATE two tables on the same column name.
added an error message to the specific condition where
it happens with a very non-specific error message that we
hit a thing we can't do right now, suggest multi-table
update as a possible cause.
Change-Id: I5eff7eefe1d6166ee74160b2785c5e6a81fa8b95
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_selectable_constructors.py | 166 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 78 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 40 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 376 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 36 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 57 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 287 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 4 |
15 files changed, 939 insertions, 204 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 84913225d..c3ebb4596 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -121,7 +121,6 @@ def __go(lcls: Any) -> None: coercions.lambdas = lambdas coercions.schema = schema coercions.selectable = selectable - coercions.traversals = traversals from .annotation import _prepare_annotations from .annotation import Annotated diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 37d44976a..f89e8f578 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -9,12 +9,16 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import overload +from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -34,6 +38,17 @@ if TYPE_CHECKING: from ._typing import _FromClauseArgument from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 + from ._typing import _T8 + from ._typing import _T9 + from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE from .selectable import HasCTE @@ -41,6 +56,9 @@ if TYPE_CHECKING: from .selectable import SelectBase +_T = TypeVar("_T", bound=Any) + + def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -89,7 +107,9 @@ def cte( ) -def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def except_( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -119,7 +139,7 @@ def except_all( def exists( __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ) -> Exists: """Construct a new :class:`_expression.Exists` construct. @@ -162,7 +182,9 @@ def exists( return Exists(__argument) -def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def intersect( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -306,7 +328,129 @@ def outerjoin( return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: _ColumnsClauseArgument) -> Select: +# START OVERLOADED FUNCTIONS select Select 1-10 + +# code within this block is **programmatically, +# statically generated** by tools/generate_tuple_map_overloads.py + + +@overload +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + +@overload +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] +) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], + __ent9: _TCCA[_T9], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: + ... + + +# END OVERLOADED FUNCTIONS select + + +@overload +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + ... + + +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: r"""Construct a new :class:`_expression.Select`. @@ -343,7 +487,11 @@ def select(*entities: _ColumnsClauseArgument) -> Select: given, as well as ORM-mapped classes. """ - + # the keyword args are a necessary element in order for the typing + # to work out w/ the varargs vs. having named "keyword" arguments that + # aren't always present. + if __kw: + raise _no_kw() return Select(*entities) @@ -425,7 +573,9 @@ def tablesample( return TableSample._factory(selectable, sampling, name=name, seed=seed) -def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -445,7 +595,9 @@ def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: return CompoundSelect._create_union(*selects) -def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 53d29b628..1df530dbd 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -5,18 +5,27 @@ from typing import Any from typing import Callable from typing import Dict from typing import Set +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import roles +from .. import exc from .. import util from ..inspection import Inspectable from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from datetime import date + from datetime import datetime + from datetime import time + from datetime import timedelta + from decimal import Decimal + from uuid import UUID + from .base import Executable from .compiler import Compiled from .compiler import DDLCompiler @@ -26,17 +35,15 @@ if TYPE_CHECKING: from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import quoted_name - from .elements import SQLCoreOperations from .elements import TextClause from .lambdas import LambdaElement from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column - from .schema import DefaultGenerator - from .schema import Sequence - from .schema import Table from .selectable import Alias + from .selectable import CTE from .selectable import FromClause from .selectable import Join from .selectable import NamedFromClause @@ -61,6 +68,30 @@ class _HasClauseElement(Protocol): ... +# match column types that are not ORM entities +_NOT_ENTITY = TypeVar( + "_NOT_ENTITY", + int, + str, + "datetime", + "date", + "time", + "timedelta", + "UUID", + float, + "Decimal", +) + +_MAYBE_ENTITY = TypeVar( + "_MAYBE_ENTITY", + roles.ColumnsClauseRole, + Literal["*", 1], + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +) + + # convention: # XYZArgument - something that the end user is passing to a public API method # XYZElement - the internal representation that we use for the thing. @@ -76,9 +107,10 @@ _TextCoercedExpressionArgument = Union[ ] _ColumnsClauseArgument = Union[ - Literal["*", 1], + roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, - Type[Any], + Literal["*", 1], + Type[_T], Inspectable[_HasClauseElement], _HasClauseElement, ] @@ -92,6 +124,24 @@ sets; select(...), insert().returning(...), etc. """ +_TypedColumnClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T] +] + +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_T0 = TypeVar("_T0", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) +_T3 = TypeVar("_T3", bound=Any) +_T4 = TypeVar("_T4", bound=Any) +_T5 = TypeVar("_T5", bound=Any) +_T6 = TypeVar("_T6", bound=Any) +_T7 = TypeVar("_T7", bound=Any) +_T8 = TypeVar("_T8", bound=Any) +_T9 = TypeVar("_T9", bound=Any) + + _ColumnExpressionArgument = Union[ "ColumnElement[_T]", _HasClauseElement, @@ -169,6 +219,7 @@ _DMLTableArgument = Union[ "TableClause", "Join", "Alias", + "CTE", Type[Any], Inspectable[_HasClauseElement], _HasClauseElement, @@ -194,6 +245,11 @@ if TYPE_CHECKING: def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: ... + def is_keyed_column_element( + c: ClauseElement, + ) -> TypeGuard[KeyedColumnElement[Any]]: + ... + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... @@ -216,7 +272,7 @@ if TYPE_CHECKING: def is_select_statement( t: Union[Executable, ReturnsRows] - ) -> TypeGuard[Select]: + ) -> TypeGuard[Select[Any]]: ... def is_table(t: FromClause) -> TypeGuard[TableClause]: @@ -234,6 +290,7 @@ else: is_ddl_compiler = operator.attrgetter("is_ddl") is_named_from_clause = operator.attrgetter("named_with_column") is_column_element = operator.attrgetter("_is_column_element") + is_keyed_column_element = operator.attrgetter("_is_keyed_column_element") is_text_clause = operator.attrgetter("_is_text_clause") is_from_clause = operator.attrgetter("_is_from_clause") is_tuple_type = operator.attrgetter("_is_tuple_type") @@ -260,3 +317,10 @@ def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: return c.is_dml and (c.is_insert or c.is_update) # type: ignore + + +def _no_kw() -> exc.ArgumentError: + return exc.ArgumentError( + "Additional keyword arguments are not accepted by this " + "function/method. The presence of **kw is for pep-484 typing purposes" + ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f81878d55..790edefc6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -62,10 +62,10 @@ if TYPE_CHECKING: from . import coercions from . import elements from . import type_api - from ._typing import _ColumnsClauseArgument from .elements import BindParameter - from .elements import ColumnClause + from .elements import ColumnClause # noqa from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations from .elements import TextClause @@ -74,7 +74,6 @@ if TYPE_CHECKING: from .selectable import FromClause from ..engine import Connection from ..engine import CursorResult - from ..engine import Result from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import _ExecuteOptions @@ -704,8 +703,11 @@ class InPlaceGenerative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @_generative decorator that mutates in place.""" + __slots__ = () + def _generate(self): skip = self._memoized_keys + # note __dict__ needs to be in __slots__ if this is used for k in skip: self.__dict__.pop(k, None) return self @@ -937,7 +939,7 @@ class ExecutableOption(HasCopyInternals): SelfExecutable = TypeVar("SelfExecutable", bound="Executable") -class Executable(roles.StatementRole, Generative): +class Executable(roles.StatementRole): """Mark a :class:`_expression.ClauseElement` as supporting execution. :class:`.Executable` is a superclass for all "statement" types @@ -994,7 +996,7 @@ class Executable(roles.StatementRole, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: ... def _execute_on_scalar( @@ -1253,7 +1255,7 @@ class SchemaVisitor(ClauseVisitor): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) -_COL = TypeVar("_COL", bound="ColumnElement[Any]") +_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") class ColumnCollection(Generic[_COLKEY, _COL_co]): @@ -1505,6 +1507,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) + self._collection[:] = cols self._colset.update(c for k, c in self._collection) self._index.update( diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 0659709ab..9b7231360 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -29,6 +29,7 @@ from typing import Union from . import operators from . import roles from . import visitors +from ._typing import is_from_clause from .base import ExecutableOption from .base import Options from .cache_key import HasCacheKey @@ -38,25 +39,18 @@ from .. import inspection from .. import util from ..util.typing import Literal -if not typing.TYPE_CHECKING: - elements = None - lambdas = None - schema = None - selectable = None - traversals = None - if typing.TYPE_CHECKING: from . import elements from . import lambdas from . import schema from . import selectable - from . import traversals from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument from ._typing import _DDLColumnArgument from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument from .dml import _DMLTableElement + from .elements import BindParameter from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -64,9 +58,7 @@ if typing.TYPE_CHECKING: from .elements import SQLCoreOperations from .schema import Column from .selectable import _ColumnsClauseElement - from .selectable import _JoinTargetElement from .selectable import _JoinTargetProtocol - from .selectable import _OnClauseElement from .selectable import FromClause from .selectable import HasCTE from .selectable import SelectBase @@ -170,6 +162,15 @@ def expect( @overload def expect( + role: Type[roles.LiteralValueRole], + element: Any, + **kw: Any, +) -> BindParameter[Any]: + ... + + +@overload +def expect( role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, @@ -272,7 +273,7 @@ def expect( @overload def expect( role: Type[roles.ColumnsClauseRole], - element: _ColumnsClauseArgument, + element: _ColumnsClauseArgument[Any], **kw: Any, ) -> _ColumnsClauseElement: ... @@ -933,7 +934,7 @@ class GroupByImpl(ByOfImpl, RoleImpl): argname: Optional[str] = None, **kw: Any, ) -> Any: - if isinstance(resolved, roles.StrictFromClauseRole): + if is_from_clause(resolved): return elements.ClauseList(*resolved.c) else: return resolved diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c524a2602..a1b25b8a6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -80,7 +80,6 @@ from ..util.typing import Protocol from ..util.typing import TypedDict if typing.TYPE_CHECKING: - from . import roles from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState @@ -95,7 +94,6 @@ if typing.TYPE_CHECKING: from .elements import ColumnElement from .elements import Label from .functions import Function - from .selectable import Alias from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState from .selectable import CTE @@ -386,7 +384,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): need_result_map_for_nested: bool need_result_map_for_compound: bool select_0: ReturnsRows - insert_from_select: Select + insert_from_select: Select[Any] class ExpandedState(NamedTuple): @@ -2834,15 +2832,31 @@ class SQLCompiler(Compiled): "unique bind parameter of the same name" % name ) elif existing._is_crud or bindparam._is_crud: - raise exc.CompileError( - "bindparam() name '%s' is reserved " - "for automatic usage in the VALUES or SET " - "clause of this " - "insert/update statement. Please use a " - "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (bindparam.key, bindparam.key) - ) + if existing._is_crud and bindparam._is_crud: + # TODO: this condition is not well understood. + # see tests in test/sql/test_update.py + raise exc.CompileError( + "Encountered unsupported case when compiling an " + "INSERT or UPDATE statement. If this is a " + "multi-table " + "UPDATE statement, please provide string-named " + "arguments to the " + "values() method with distinct names; support for " + "multi-table UPDATE statements that " + "target multiple tables for UPDATE is very " + "limited", + ) + else: + raise exc.CompileError( + f"bindparam() name '{bindparam.key}' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using " + "bindparam() " + "with insert() or update() (for example, " + f"'b_{bindparam.key}')." + ) self.binds[bindparam.key] = self.binds[name] = bindparam @@ -3881,7 +3895,7 @@ class SQLCompiler(Compiled): return text def _setup_select_hints( - self, select: Select + self, select: Select[Any] ) -> Tuple[str, _FromHintsType]: byfrom = dict( [ diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e4408cd31..29d7b45d7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -22,6 +22,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -30,8 +31,10 @@ from . import coercions from . import dml from . import elements from . import roles +from .elements import ColumnClause from .schema import default_is_clause_element from .schema import default_is_sequence +from .selectable import TableClause from .. import exc from .. import util from ..util.typing import Literal @@ -41,16 +44,9 @@ if TYPE_CHECKING: from .compiler import SQLCompiler from .dml import _DMLColumnElement from .dml import DMLState - from .dml import Insert - from .dml import Update - from .dml import UpdateDMLState from .dml import ValuesBase - from .elements import ClauseElement - from .elements import ColumnClause from .elements import ColumnElement - from .elements import TextClause from .schema import _SQLExprDefault - from .schema import Column from .selectable import TableClause REQUIRED = util.symbol( @@ -68,12 +64,20 @@ values present. ) +def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: + if not isinstance(c, ColumnClause): + raise exc.CompileError( + f"Can't create DML statement against column expression {c!r}" + ) + return c + + class _CrudParams(NamedTuple): - single_params: List[ - Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + single_params: Sequence[ + Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]] ] all_multi_params: List[ - List[ + Sequence[ Tuple[ ColumnClause[Any], str, @@ -274,7 +278,7 @@ def _get_crud_params( compiler, stmt, compile_state, - cast("List[Tuple[ColumnClause[Any], str, str]]", values), + cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values), cast("Callable[..., str]", _column_as_key), kw, ) @@ -290,7 +294,7 @@ def _get_crud_params( # insert_executemany_returning mode :) values = [ ( - stmt.table.columns[0], + _as_dml_column(stmt.table.columns[0]), compiler.preparer.format_column(stmt.table.columns[0]), "DEFAULT", ) @@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, - initial_values: List[Tuple[ColumnClause[Any], str, str]], + initial_values: Sequence[Tuple[ColumnClause[Any], str, str]], _column_as_key: Callable[..., str], kw: Dict[str, Any], -) -> List[List[Tuple[ColumnClause[Any], str, str]]]: +) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]: values_0 = initial_values values = [initial_values] diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 8307f6400..e0f162fc8 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -22,15 +22,19 @@ from typing import List from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from . import util as sql_util +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_named_from_clause from .base import _entity_namespace_key @@ -42,6 +46,7 @@ from .base import ColumnCollection from .base import CompileState from .base import DialectKWArgs from .base import Executable +from .base import Generative from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement @@ -49,12 +54,13 @@ from .elements import ColumnClause from .elements import ColumnElement from .elements import Null from .selectable import Alias +from .selectable import ExecutableReturnsRows from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import Join -from .selectable import ReturnsRows from .selectable import TableClause +from .selectable import TypedReturnsRows from .sqltypes import NullType from .visitors import InternalTraversal from .. import exc @@ -66,9 +72,19 @@ if TYPE_CHECKING: from ._typing import _ColumnsClauseArgument from ._typing import _DMLColumnArgument from ._typing import _DMLTableArgument - from ._typing import _FromClauseArgument + from ._typing import _T0 # noqa + from ._typing import _T1 # noqa + from ._typing import _T2 # noqa + from ._typing import _T3 # noqa + from ._typing import _T4 # noqa + from ._typing import _T5 # noqa + from ._typing import _T6 # noqa + from ._typing import _T7 # noqa + from ._typing import _TypedColumnClauseArgument as _TCCA # noqa from .base import ReadOnlyColumnCollection from .compiler import SQLCompiler + from .elements import ColumnElement + from .elements import KeyedColumnElement from .selectable import _ColumnsClauseElement from .selectable import _SelectIterable from .selectable import Select @@ -88,6 +104,8 @@ else: isinsert = operator.attrgetter("isinsert") +_T = TypeVar("_T", bound=Any) + _DMLColumnElement = Union[str, ColumnClause[Any]] _DMLTableElement = Union[TableClause, Alias, Join] @@ -185,6 +203,11 @@ class DMLState(CompileState): "%s construct does not support " "multiple parameter sets." % statement.__visit_name__.upper() ) + else: + assert isinstance(statement, Insert) + + # which implies... + # assert isinstance(statement.table, TableClause) for parameters in statement._multi_values: multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ @@ -291,7 +314,9 @@ class UpdateDMLState(DMLState): elif statement._multi_values: self._process_multi_values(statement) self._extra_froms = ef = self._make_extra_froms(statement) - self.is_multitable = mt = ef and self._dict_parameters + + self.is_multitable = mt = ef + self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -317,8 +342,8 @@ class UpdateBase( HasCompileState, DialectKWArgs, HasPrefixes, - ReturnsRows, - Executable, + Generative, + ExecutableReturnsRows, ClauseElement, ): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" @@ -383,8 +408,8 @@ class UpdateBase( @_generative def returning( - self: SelfUpdateBase, *cols: _ColumnsClauseArgument - ) -> SelfUpdateBase: + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> UpdateBase: r"""Add a :term:`RETURNING` or equivalent clause to this statement. e.g.: @@ -454,6 +479,8 @@ class UpdateBase( :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial` """ # noqa: E501 + if __kw: + raise _no_kw() if self._return_defaults: raise exc.InvalidRequestError( "return_defaults() is already configured on this statement" @@ -464,7 +491,7 @@ class UpdateBase( return self def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False + self, column: KeyedColumnElement[Any], require_embedded: bool = False ) -> Optional[ColumnElement[Any]]: return self.exported_columns.corresponding_column( column, require_embedded=require_embedded @@ -628,7 +655,7 @@ class ValuesBase(UpdateBase): _supports_multi_parameters = False - select: Optional[Select] = None + select: Optional[Select[Any]] = None """SELECT statement for INSERT .. FROM SELECT""" _post_values_clause: Optional[ClauseElement] = None @@ -804,11 +831,15 @@ class ValuesBase(UpdateBase): ) elif isinstance(arg, collections_abc.Sequence): - if arg and isinstance(arg[0], (list, dict, tuple)): self._multi_values += (arg,) return self + if TYPE_CHECKING: + # crud.py raises during compilation if this is not the + # case + assert isinstance(self, Insert) + # tuple values arg = {c.key: value for c, value in zip(self.table.c, arg)} @@ -1010,7 +1041,7 @@ class Insert(ValuesBase): def from_select( self: SelfInsert, names: List[str], - select: Select, + select: Select[Any], include_defaults: bool = True, ) -> SelfInsert: """Return a new :class:`_expression.Insert` construct which represents @@ -1073,6 +1104,114 @@ class Insert(ValuesBase): self.select = coercions.expect(roles.DMLSelectRole, select) return self + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningInsert[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + +class ReturningInsert(Insert, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Insert` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Insert.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") @@ -1264,6 +1403,113 @@ class Update(DMLWhereBase, ValuesBase): self._inline = True return self + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningUpdate[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + +class ReturningUpdate(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Update` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Update.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDelete = typing.TypeVar("SelfDelete", bound="Delete") @@ -1297,3 +1543,111 @@ class Delete(DMLWhereBase, UpdateBase): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) + + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningDelete[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + +class ReturningDelete(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Delete` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Delete.returning` method. + + .. versionadded:: 2.0 + + """ diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 34d5127ab..a29561291 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -54,6 +54,7 @@ from .base import _clone from .base import _generative from .base import _NoArg from .base import Executable +from .base import Generative from .base import HasMemoized from .base import Immutable from .base import NO_ARG @@ -94,10 +95,7 @@ if typing.TYPE_CHECKING: from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause - from .selectable import ReturnsRows from .selectable import Select - from .selectable import TableClause - from .sqltypes import Boolean from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -122,7 +120,9 @@ _NT = TypeVar("_NT", bound="_NUMERIC") _NMT = TypeVar("_NMT", bound="_NUMBER") -def literal(value, type_=None): +def literal( + value: Any, type_: Optional[_TypeEngineArgument[_T]] = None +) -> BindParameter[_T]: r"""Return a literal clause, bound to a bind parameter. Literal clauses are created automatically when non- @@ -144,7 +144,9 @@ def literal(value, type_=None): return coercions.expect(roles.LiteralValueRole, value, type_=type_) -def literal_column(text, type_=None): +def literal_column( + text: str, type_: Optional[_TypeEngineArgument[_T]] = None +) -> ColumnClause[_T]: r"""Produce a :class:`.ColumnClause` object that has the :paramref:`_expression.column.is_literal` flag set to True. @@ -316,6 +318,7 @@ class ClauseElement( is_selectable = False is_dml = False _is_column_element = False + _is_keyed_column_element = False _is_table = False _is_textual = False _is_from_clause = False @@ -342,7 +345,7 @@ class ClauseElement( if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any ) -> Iterable[ClauseElement]: ... @@ -455,7 +458,7 @@ class ClauseElement( connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptions, - ) -> Result: + ) -> Result[Any]: if self.supports_execution: if TYPE_CHECKING: assert isinstance(self, Executable) @@ -833,13 +836,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def in_( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... def not_in( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... @@ -1699,6 +1702,14 @@ class ColumnElement( return self._anon_label(label, add_hash=idx) +class KeyedColumnElement(ColumnElement[_T]): + """ColumnElement where ``.key`` is non-None.""" + + _is_keyed_column_element = True + + key: str + + class WrapsColumnExpression(ColumnElement[_T]): """Mixin that defines a :class:`_expression.ColumnElement` as a wrapper with special @@ -1760,7 +1771,7 @@ class WrapsColumnExpression(ColumnElement[_T]): SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]") -class BindParameter(roles.InElementRole, ColumnElement[_T]): +class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): r"""Represent a "bound expression". :class:`.BindParameter` is invoked explicitly using the @@ -2073,6 +2084,7 @@ class TextClause( roles.FromClauseRole, roles.SelectStatementRole, roles.InElementRole, + Generative, Executable, DQLDMLClauseElement, roles.BinaryElementRole[Any], @@ -4160,7 +4172,7 @@ class FunctionFilter(ColumnElement[_T]): ) -class NamedColumn(ColumnElement[_T]): +class NamedColumn(KeyedColumnElement[_T]): is_literal = False table: Optional[FromClause] = None name: str @@ -4502,7 +4514,7 @@ class ColumnClause( self.is_literal = is_literal - def get_children(self, column_tables=False, **kw): + def get_children(self, *, column_tables=False, **kw): # override base get_children() to not return the Table # or selectable that is parent to this column. Traversals # expect the columns of tables and subqueries to be leaf nodes. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 648168235..b827df3df 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -175,7 +175,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: return connection._execute_function( self, distilled_params, execution_options ) @@ -623,7 +623,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): joins_implicitly=joins_implicitly, ) - def select(self) -> "Select": + def select(self) -> Select[Any]: """Produce a :func:`_expression.select` construct against this :class:`.FunctionElement`. @@ -632,7 +632,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): s = select(function_element) """ - s = Select(self) + s: Select[Any] = Select(self) if self._execution_options: s = s.execution_options(**self._execution_options) return s @@ -846,7 +846,7 @@ class _FunctionGenerator: @overload def __call__( - self, *c: Any, type_: TypeEngine[_T], **kwargs: Any + self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any ) -> Function[_T]: ... diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 231c70a5b..09d4b35ad 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -8,8 +8,6 @@ from __future__ import annotations from typing import Any from typing import Generic -from typing import Iterable -from typing import List from typing import Optional from typing import TYPE_CHECKING from typing import TypeVar @@ -19,12 +17,7 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _PropagateAttrsType - from .base import _EntityNamespace - from .base import ColumnCollection - from .base import ReadOnlyColumnCollection - from .elements import ColumnClause from .elements import Label - from .elements import NamedColumn from .selectable import _SelectIterable from .selectable import FromClause from .selectable import Subquery @@ -108,13 +101,21 @@ class TruncatedLabelRole(StringRole, SQLRole): class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): __slots__ = () - _role_name = "Column expression or FROM clause" + _role_name = ( + "Column expression, FROM clause, or other columns clause element" + ) @property def _select_iterable(self) -> _SelectIterable: raise NotImplementedError() +class TypedColumnsClauseRole(Generic[_T], SQLRole): + """element-typed form of ColumnsClauseRole""" + + __slots__ = () + + class LimitOffsetRole(SQLRole): __slots__ = () _role_name = "LIMIT / OFFSET expression" @@ -161,7 +162,7 @@ class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" -class ExpressionElementRole(Generic[_T], SQLRole): +class ExpressionElementRole(TypedColumnsClauseRole[_T]): # note when using generics for ExpressionElementRole, # the generic type needs to be in # sqlalchemy.sql.coercions._impl_lookup mapping also. @@ -212,39 +213,11 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole): named_with_column: bool - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... - - @util.ro_non_memoized_property - def _hide_froms(self) -> Iterable[FromClause]: - ... - - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - ... - class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def description(self) -> str: - ... - class AnonymizedFromClauseRole(StrictFromClauseRole): __slots__ = () @@ -317,16 +290,6 @@ class DMLTableRole(FromClauseRole): __slots__ = () _role_name = "subject table for an INSERT, UPDATE or DELETE" - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def primary_key(self) -> Iterable[NamedColumn[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - class DMLColumnRole(SQLRole): __slots__ = () diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 52ba60a62..27456d2be 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -86,7 +86,6 @@ if typing.TYPE_CHECKING: from ._typing import _InfoType from ._typing import _TextCoercedExpressionArgument from ._typing import _TypeEngineArgument - from .base import ColumnCollection from .base import DedupeColumnCollection from .base import ReadOnlyColumnCollection from .compiler import DDLCompiler @@ -97,9 +96,7 @@ if typing.TYPE_CHECKING: from .visitors import anon_map from ..engine import Connection from ..engine import Engine - from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams - from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import ExecutionContext from ..engine.mock import MockConnection @@ -2609,8 +2606,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): :class:`_schema.Table`. """ - - return table.columns.corresponding_column(self.column) + # our column is a Column, and any subquery etc. proxying us + # would be doing so via another Column, so that's what would + # be returned here + return table.columns.corresponding_column(self.column) # type: ignore @util.memoized_property def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9d4d1d6c7..b08f13f99 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -23,6 +23,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -46,6 +47,8 @@ from . import traversals from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_select_statement from ._typing import is_subquery @@ -103,9 +106,20 @@ if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _FromClauseArgument from ._typing import _JoinTargetArgument + from ._typing import _MAYBE_ENTITY + from ._typing import _NOT_ENTITY from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypedColumnClauseArgument as _TCCA from ._typing import _TypeEngineArgument from .base import _AmbiguousTableNameMap from .base import ExecutableOption @@ -115,14 +129,13 @@ if TYPE_CHECKING: from .dml import Delete from .dml import Insert from .dml import Update + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import TextClause from .functions import Function - from .schema import Column from .schema import ForeignKey from .schema import ForeignKeyConstraint from .type_api import TypeEngine - from .util import ClauseAdapter from .visitors import _CloneCallableType @@ -245,6 +258,14 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() +class ExecutableReturnsRows(Executable, ReturnsRows): + """base for executable statements that return rows.""" + + +class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]): + """base for executable statements that return rows.""" + + SelfSelectable = TypeVar("SelfSelectable", bound="Selectable") @@ -293,8 +314,8 @@ class Selectable(ReturnsRows): ) def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False - ) -> Optional[ColumnElement[Any]]: + self, column: KeyedColumnElement[Any], require_embedded: bool = False + ) -> Optional[KeyedColumnElement[Any]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from the :attr:`_expression.Selectable.exported_columns` @@ -593,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -795,7 +816,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): ) @util.ro_non_memoized_property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`. @@ -817,7 +840,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Any]: + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -833,7 +858,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Any]: + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """ A synonym for :attr:`.FromClause.columns` @@ -1223,7 +1248,7 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ + columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ c for c in self.right.c ] @@ -1458,7 +1483,7 @@ class Join(roles.DMLTableRole, FromClause): "join explicitly." % (a.description, b.description) ) - def select(self) -> "Select": + def select(self) -> Select[Any]: r"""Create a :class:`_expression.Select` from this :class:`_expression.Join`. @@ -2764,6 +2789,7 @@ class Subquery(AliasedReturnsRows): cls, selectable: SelectBase, name: Optional[str] = None ) -> Subquery: """Return a :class:`.Subquery` object.""" + return coercions.expect( roles.SelectStatementRole, selectable ).subquery(name=name) @@ -3216,7 +3242,6 @@ class SelectBase( roles.CompoundElementRole, roles.InElementRole, HasCTE, - Executable, SupportsCloneAnnotations, Selectable, ): @@ -3239,7 +3264,9 @@ class SelectBase( self._reset_memoizations() @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set. @@ -3284,7 +3311,9 @@ class SelectBase( raise NotImplementedError() @property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`, not including @@ -3377,7 +3406,7 @@ class SelectBase( def as_scalar(self): return self.scalar_subquery() - def exists(self): + def exists(self) -> Exists: """Return an :class:`_sql.Exists` representation of this selectable, which can be used as a column expression. @@ -3394,7 +3423,7 @@ class SelectBase( """ return Exists(self) - def scalar_subquery(self): + def scalar_subquery(self) -> ScalarSelect[Any]: """Return a 'scalar' representation of this selectable, which can be used as a column expression. @@ -3607,7 +3636,7 @@ SelfGenerativeSelect = typing.TypeVar( ) -class GenerativeSelect(SelectBase): +class GenerativeSelect(SelectBase, Generative): """Base class for SELECT statements where additional elements can be added. @@ -4128,7 +4157,7 @@ class _CompoundSelectKeyword(Enum): INTERSECT_ALL = "INTERSECT ALL" -class CompoundSelect(HasCompileState, GenerativeSelect): +class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -4293,7 +4322,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self.selects[0]._all_selected_columns @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -4343,7 +4374,10 @@ class SelectState(util.MemoizedSlots, CompileState): ... def __init__( - self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any + self, + statement: Select[Any], + compiler: Optional[SQLCompiler], + **kw: Any, ): self.statement = statement self.from_clauses = statement._from_obj @@ -4369,7 +4403,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def get_column_descriptions( - cls, statement: Select + cls, statement: Select[Any] ) -> List[Dict[str, Any]]: return [ { @@ -4384,12 +4418,14 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def from_statement( - cls, statement: Select, from_statement: ReturnsRows - ) -> Any: + cls, statement: Select[Any], from_statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: cls._plugin_not_implemented() @classmethod - def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]: + def get_columns_clause_froms( + cls, statement: Select[Any] + ) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( element._from_objects for element in statement._raw_columns @@ -4439,7 +4475,7 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement: Select) -> List[FromClause]: + def _get_froms(self, statement: Select[Any]) -> List[FromClause]: ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} @@ -4467,7 +4503,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _normalize_froms( cls, iterable_of_froms: Iterable[FromClause], - check_statement: Optional[Select] = None, + check_statement: Optional[Select[Any]] = None, ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what @@ -4615,7 +4651,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def determine_last_joined_entity( - cls, stmt: Select + cls, stmt: Select[Any] ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] @@ -4623,7 +4659,7 @@ class SelectState(util.MemoizedSlots, CompileState): return None @classmethod - def all_selected_columns(cls, statement: Select) -> _SelectIterable: + def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] def _setup_joins( @@ -4876,7 +4912,7 @@ class _MemoizedSelectEntities( return c # type: ignore @classmethod - def _generate_for_statement(cls, select_stmt: Select) -> None: + def _generate_for_statement(cls, select_stmt: Select[Any]) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -4888,7 +4924,7 @@ class _MemoizedSelectEntities( select_stmt._setup_joins = select_stmt._with_options = () -SelfSelect = typing.TypeVar("SelfSelect", bound="Select") +SelfSelect = typing.TypeVar("SelfSelect", bound="Select[Any]") class Select( @@ -4898,6 +4934,7 @@ class Select( HasCompileState, _SelectFromElements, GenerativeSelect, + TypedReturnsRows[_TP], ): """Represents a ``SELECT`` statement. @@ -4973,7 +5010,7 @@ class Select( _compile_state_factory: Type[SelectState] @classmethod - def _create_raw_select(cls, **kw: Any) -> Select: + def _create_raw_select(cls, **kw: Any) -> Select[Any]: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -4985,7 +5022,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: _ColumnsClauseArgument): + def __init__(self, *entities: _ColumnsClauseArgument[Any]): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the @@ -5013,7 +5050,9 @@ class Select( cols = list(elem._select_iterable) return cols[0].type - def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect: + def filter( + self: SelfSelect, *criteria: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """A synonym for the :meth:`_future.Select.where` method.""" return self.where(*criteria) @@ -5032,7 +5071,28 @@ class Select( return self._raw_columns[0] - def filter_by(self, **kwargs): + if TYPE_CHECKING: + + @overload + def scalar_subquery( + self: Select[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: + ... + + @overload + def scalar_subquery( + self: Select[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def filter_by(self: SelfSelect, **kwargs: Any) -> SelfSelect: r"""apply the given filtering criterion as a WHERE clause to this select. @@ -5046,7 +5106,7 @@ class Select( return self.filter(*clauses) @property - def column_descriptions(self): + def column_descriptions(self) -> Any: """Return a :term:`plugin-enabled` 'column descriptions' structure referring to the columns which are SELECTed by this statement. @@ -5089,7 +5149,9 @@ class Select( meth = SelectState.get_plugin_class(self).get_column_descriptions return meth(self) - def from_statement(self, statement): + def from_statement( + self, statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: """Apply the columns which this :class:`.Select` would select onto another statement. @@ -5410,7 +5472,7 @@ class Select( ) @property - def inner_columns(self): + def inner_columns(self) -> _SelectIterable: """An iterator of all :class:`_expression.ColumnElement` expressions which would be rendered into the columns clause of the resulting SELECT statement. @@ -5487,18 +5549,19 @@ class Select( self._reset_memoizations() - def get_children(self, **kwargs): + def get_children(self, **kw: Any) -> Iterable[ClauseElement]: return itertools.chain( super(Select, self).get_children( - omit_attrs=("_from_obj", "_correlate", "_correlate_except") + omit_attrs=("_from_obj", "_correlate", "_correlate_except"), + **kw, ), self._iterate_from_elements(), ) @_generative def add_columns( - self: SelfSelect, *columns: _ColumnsClauseArgument - ) -> SelfSelect: + self, *columns: _ColumnsClauseArgument[Any] + ) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expressions added to its columns clause. @@ -5523,7 +5586,7 @@ class Select( return self def _set_entities( - self, entities: Iterable[_ColumnsClauseArgument] + self, entities: Iterable[_ColumnsClauseArgument[Any]] ) -> None: self._raw_columns = [ coercions.expect( @@ -5538,7 +5601,7 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect: + def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -5555,9 +5618,7 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns( - self: SelfSelect, only_synonyms: bool = True - ) -> SelfSelect: + def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -5580,20 +5641,115 @@ class Select( all columns that are equivalent to another are removed. """ - return self.with_only_columns( + woc: Select[Any] + woc = self.with_only_columns( *util.preloaded.sql_util.reduce_columns( self._all_selected_columns, only_synonyms=only_synonyms, *(self._where_criteria + self._from_obj), ) ) + return woc + + # START OVERLOADED FUNCTIONS self.with_only_columns Select 8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_sel_v1_overloads.py + + @overload + def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> Select[Tuple[_T0, _T1]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.with_only_columns + + @overload + def with_only_columns( + self, + *columns: _ColumnsClauseArgument[Any], + maintain_column_froms: bool = False, + **__kw: Any, + ) -> Select[Any]: + ... @_generative def with_only_columns( - self: SelfSelect, - *columns: _ColumnsClauseArgument, + self, + *columns: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, - ) -> SelfSelect: + **__kw: Any, + ) -> Select[Any]: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given columns. @@ -5647,6 +5803,9 @@ class Select( """ # noqa: E501 + if __kw: + raise _no_kw() + # memoizations should be cleared here as of # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this # is the case for now. @@ -5915,7 +6074,9 @@ class Select( return self @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -6215,7 +6376,7 @@ class ScalarSelect( by this :class:`_expression.ScalarSelect`. """ - self.element = cast(Select, self.element).where(crit) + self.element = cast("Select[Any]", self.element).where(crit) return self @overload @@ -6269,7 +6430,9 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate(*fromclauses) + self.element = cast("Select[Any]", self.element).correlate( + *fromclauses + ) return self @_generative @@ -6307,7 +6470,7 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate_except( + self.element = cast("Select[Any]", self.element).correlate_except( *fromclauses ) return self @@ -6331,12 +6494,18 @@ class Exists(UnaryExpression[bool]): def __init__( self, __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ): + s: ScalarSelect[Any] + + # TODO: this seems like we should be using coercions for this if __argument is None: s = Select(literal_column("*")).scalar_subquery() - elif isinstance(__argument, (SelectBase, ScalarSelect)): + elif isinstance(__argument, SelectBase): + s = __argument.scalar_subquery() + s._propagate_attrs = __argument._propagate_attrs + elif isinstance(__argument, ScalarSelect): s = __argument else: s = Select(__argument).scalar_subquery() @@ -6358,7 +6527,7 @@ class Exists(UnaryExpression[bool]): element = fn(element) return element.self_group(against=operators.exists) - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: @@ -6452,7 +6621,7 @@ class Exists(UnaryExpression[bool]): SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect") -class TextualSelect(SelectBase): +class TextualSelect(SelectBase, Executable, Generative): """Wrap a :class:`_expression.TextClause` construct within a :class:`_expression.SelectBase` interface. @@ -6503,7 +6672,9 @@ class TextualSelect(SelectBase): self.positional = positional @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d08fef60a..8c45ba410 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -50,6 +50,7 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import Grouping +from .elements import KeyedColumnElement from .elements import Label from .elements import Null from .elements import UnaryExpression @@ -72,9 +73,7 @@ if typing.TYPE_CHECKING: from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument from .elements import TextClause - from .roles import FromClauseRole from .selectable import _JoinTargetElement - from .selectable import _OnClauseElement from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType @@ -569,7 +568,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: "Row", max_chars: int = 300): + def __init__(self, row: "Row[Any]", max_chars: int = 300): self.row = row self.max_chars = max_chars @@ -1068,7 +1067,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): col = col._annotations["adapt_column"] if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.adapt_from_selectables and col not in self.equivalents: for adp in self.adapt_from_selectables: @@ -1078,7 +1077,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return None if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.include_fn and not self.include_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index e0a66fbcf..88586d834 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -450,7 +450,7 @@ class HasTraverseInternals: @util.preload_module("sqlalchemy.sql.traversals") def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[HasTraverseInternals]: r"""Return immediate child :class:`.visitors.HasTraverseInternals` elements of this :class:`.visitors.HasTraverseInternals`. @@ -594,7 +594,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[ExternallyTraversible]: ... |
