diff options
37 files changed, 2568 insertions, 1119 deletions
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 8f4b963eb..9db6f3f52 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -23,8 +23,8 @@ from typing import Tuple from typing import Type from typing import Union -from .util import _preloaded from .util import compat +from .util import preloaded as _preloaded if typing.TYPE_CHECKING: from .engine.interfaces import _AnyExecuteParams @@ -345,6 +345,8 @@ class MultipleResultsFound(InvalidRequestError): class NoReferenceError(InvalidRequestError): """Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" + table_name: str + class AwaitRequired(InvalidRequestError): """Error raised by the async greenlet spawn if no async operation @@ -501,10 +503,7 @@ class StatementError(SQLAlchemyError): @_preloaded.preload_module("sqlalchemy.sql.util") def _sql_message(self) -> str: - if typing.TYPE_CHECKING: - from .sql import util - else: - util = _preloaded.preloaded.sql_util + util = _preloaded.sql_util details = [self._message()] if self.statement: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 18a14012f..b9ced44d5 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -28,6 +28,8 @@ from typing import Generic from typing import Iterable from typing import List from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING from typing import TypeVar from . import exc as orm_exc @@ -78,6 +80,8 @@ from ..sql.selectable import SelectBase from ..sql.selectable import SelectStatementGrouping from ..sql.visitors import InternalTraversal +if TYPE_CHECKING: + from ..sql.selectable import _SetupJoinsElement __all__ = ["Query", "QueryContext"] @@ -134,7 +138,8 @@ class Query( _correlate = () _auto_correlate = True _from_obj = () - _setup_joins = () + _setup_joins: Tuple[_SetupJoinsElement, ...] = () + _label_style = LABEL_STYLE_LEGACY_ORM _memoized_select_entities = () diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py index 835819bac..926e5257b 100644 --- a/lib/sqlalchemy/sql/_dml_constructors.py +++ b/lib/sqlalchemy/sql/_dml_constructors.py @@ -7,12 +7,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from .dml import Delete from .dml import Insert from .dml import Update +if TYPE_CHECKING: + from ._typing import _DMLTableArgument + -def insert(table): +def insert(table: _DMLTableArgument) -> Insert: """Construct an :class:`_expression.Insert` object. E.g.:: @@ -82,7 +87,7 @@ def insert(table): return Insert(table) -def update(table): +def update(table: _DMLTableArgument) -> Update: r"""Construct an :class:`_expression.Update` object. E.g.:: @@ -122,7 +127,7 @@ def update(table): return Update(table) -def delete(table): +def delete(table: _DMLTableArgument) -> Delete: r"""Construct :class:`_expression.Delete` object. E.g.:: diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index fc925a8b3..ea21e01c6 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -345,8 +345,8 @@ def between( :meth:`_expression.ColumnElement.between` """ - expr = coercions.expect(roles.ExpressionElementRole, expr) - return expr.between(lower_bound, upper_bound, symmetric=symmetric) + col_expr = coercions.expect(roles.ExpressionElementRole, expr) + return col_expr.between(lower_bound, upper_bound, symmetric=symmetric) def outparam( diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index a17ee4ce8..7896c02c2 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -9,6 +9,8 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from . import coercions from . import roles @@ -17,64 +19,65 @@ from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect from .selectable import Exists +from .selectable import FromClause from .selectable import Join from .selectable import Lateral +from .selectable import LateralFromClause +from .selectable import NamedFromClause from .selectable import Select from .selectable import TableClause from .selectable import TableSample from .selectable import Values +if TYPE_CHECKING: + from ._typing import _ColumnsClauseArgument + from ._typing import _FromClauseArgument + from ._typing import _OnClauseArgument + from ._typing import _SelectStatementForCompoundArgument + from .functions import Function + from .selectable import CTE + from .selectable import HasCTE + from .selectable import ScalarSelect + from .selectable import SelectBase -def alias(selectable, name=None, flat=False): - """Return an :class:`_expression.Alias` object. - An :class:`_expression.Alias` represents any - :class:`_expression.FromClause` - with an alternate name assigned within SQL, typically using the ``AS`` - clause when generated, e.g. ``SELECT * FROM table AS aliasname``. +def alias( + selectable: FromClause, name: Optional[str] = None, flat: bool = False +) -> NamedFromClause: + """Return a named alias of the given :class:`.FromClause`. + + For :class:`.Table` and :class:`.Join` objects, the return type is the + :class:`_expression.Alias` object. Other kinds of :class:`.NamedFromClause` + objects may be returned for other kinds of :class:`.FromClause` objects. + + The named alias represents any :class:`_expression.FromClause` with an + alternate name assigned within SQL, typically using the ``AS`` clause when + generated, e.g. ``SELECT * FROM table AS aliasname``. - Similar functionality is available via the + Equivalent functionality is available via the :meth:`_expression.FromClause.alias` - method available on all :class:`_expression.FromClause` subclasses. - In terms of - a SELECT object as generated from the :func:`_expression.select` - function, the :meth:`_expression.SelectBase.alias` method returns an - :class:`_expression.Alias` or similar object which represents a named, - parenthesized subquery. - - When an :class:`_expression.Alias` is created from a - :class:`_schema.Table` object, - this has the effect of the table being rendered - as ``tablename AS aliasname`` in a SELECT statement. - - For :func:`_expression.select` objects, the effect is that of - creating a named subquery, i.e. ``(select ...) AS aliasname``. - - The ``name`` parameter is optional, and provides the name - to use in the rendered SQL. If blank, an "anonymous" name - will be deterministically generated at compile time. - Deterministic means the name is guaranteed to be unique against - other constructs used in the same statement, and will also be the - same name for each successive compilation of the same statement - object. + method available on all :class:`_expression.FromClause` objects. :param selectable: any :class:`_expression.FromClause` subclass, such as a table, select statement, etc. :param name: string name to be assigned as the alias. - If ``None``, a name will be deterministically generated - at compile time. + If ``None``, a name will be deterministically generated at compile + time. Deterministic means the name is guaranteed to be unique against + other constructs used in the same statement, and will also be the same + name for each successive compilation of the same statement object. :param flat: Will be passed through to if the given selectable is an instance of :class:`_expression.Join` - see - :meth:`_expression.Join.alias` - for details. + :meth:`_expression.Join.alias` for details. """ return Alias._factory(selectable, name=name, flat=flat) -def cte(selectable, name=None, recursive=False): +def cte( + selectable: HasCTE, name: Optional[str] = None, recursive: bool = False +) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -86,7 +89,7 @@ def cte(selectable, name=None, recursive=False): ) -def except_(*selects): +def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -99,7 +102,9 @@ def except_(*selects): return CompoundSelect._create_except(*selects) -def except_all(*selects): +def except_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT ALL`` of multiple selectables. The returned object is an instance of @@ -112,7 +117,11 @@ def except_all(*selects): return CompoundSelect._create_except_all(*selects) -def exists(__argument=None): +def exists( + __argument: Optional[ + Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + ] = None, +) -> Exists: """Construct a new :class:`_expression.Exists` construct. The :func:`_sql.exists` can be invoked by itself to produce an @@ -153,7 +162,7 @@ def exists(__argument=None): return Exists(__argument) -def intersect(*selects): +def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -166,7 +175,9 @@ def intersect(*selects): return CompoundSelect._create_intersect(*selects) -def intersect_all(*selects): +def intersect_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT ALL`` of multiple selectables. The returned object is an instance of @@ -180,7 +191,13 @@ def intersect_all(*selects): return CompoundSelect._create_intersect_all(*selects) -def join(left, right, onclause=None, isouter=False, full=False): +def join( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, +) -> Join: """Produce a :class:`_expression.Join` object, given two :class:`_expression.FromClause` expressions. @@ -232,7 +249,10 @@ def join(left, right, onclause=None, isouter=False, full=False): return Join(left, right, onclause, isouter, full) -def lateral(selectable, name=None): +def lateral( + selectable: Union[SelectBase, _FromClauseArgument], + name: Optional[str] = None, +) -> LateralFromClause: """Return a :class:`_expression.Lateral` object. :class:`_expression.Lateral` is an :class:`_expression.Alias` @@ -255,7 +275,12 @@ def lateral(selectable, name=None): return Lateral._factory(selectable, name=name) -def outerjoin(left, right, onclause=None, full=False): +def outerjoin( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, +) -> Join: """Return an ``OUTER JOIN`` clause element. The returned object is an instance of :class:`_expression.Join`. @@ -349,7 +374,12 @@ def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause: return TableClause(name, *columns, **kw) -def tablesample(selectable, sampling, name=None, seed=None): +def tablesample( + selectable: _FromClauseArgument, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, +) -> TableSample: """Return a :class:`_expression.TableSample` object. :class:`_expression.TableSample` is an :class:`_expression.Alias` @@ -395,7 +425,7 @@ def tablesample(selectable, sampling, name=None, seed=None): return TableSample._factory(selectable, sampling, name=name, seed=seed) -def union(*selects, **kwargs): +def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -412,10 +442,10 @@ def union(*selects, **kwargs): :func:`select`. """ - return CompoundSelect._create_union(*selects, **kwargs) + return CompoundSelect._create_union(*selects) -def union_all(*selects): +def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index a5da87802..0a72a93c5 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,7 +1,7 @@ from __future__ import annotations +import operator from typing import Any -from typing import Iterable from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -24,9 +24,16 @@ if TYPE_CHECKING: from .roles import FromClauseRole from .schema import DefaultGenerator from .schema import Sequence + from .selectable import Alias from .selectable import FromClause + from .selectable import Join from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectBase + from .selectable import Subquery from .selectable import TableClause + from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine from ..util.typing import TypeGuard @@ -47,6 +54,14 @@ class _HasClauseElement(Protocol): # the coercions system is responsible for converting from XYZArgument to # XYZElement. +_TextCoercedExpressionArgument = Union[ + str, + "TextClause", + "ColumnElement[_T]", + _HasClauseElement, + roles.ExpressionElementRole[_T], +] + _ColumnsClauseArgument = Union[ Literal["*", 1], roles.ColumnsClauseRole, @@ -54,8 +69,31 @@ _ColumnsClauseArgument = Union[ Inspectable[_HasClauseElement], _HasClauseElement, ] +"""open-ended SELECT columns clause argument. + +Includes column expressions, tables, ORM mapped entities, a few literal values. + +This type is used for lists of columns / entities to be returned in result +sets; select(...), insert().returning(...), etc. + + +""" + +_ColumnExpressionArgument = Union[ + "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] +] +"""narrower "column expression" argument. + +This type is used for all the other "column" kinds of expressions that +typically represent a single SQL column expression, not a set of columns the +way a table or ORM entity does. + +This includes ColumnElement, or ORM-mapped attributes that will have a +`__clause_element__()` method, it also has the ExpressionElementRole +overall which brings in the TextClause object also. + +""" -_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]] _FromClauseArgument = Union[ roles.FromClauseRole, @@ -63,28 +101,99 @@ _FromClauseArgument = Union[ Inspectable[_HasClauseElement], _HasClauseElement, ] +"""A FROM clause, like we would send to select().select_from(). -_ColumnExpressionArgument = Union[ - "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] +Also accommodates ORM entities and related constructs. + +""" + +_JoinTargetArgument = Union[_FromClauseArgument, roles.JoinTargetRole] +"""target for join() builds on _FromClauseArgument to include additional +join target roles such as those which come from the ORM. + +""" + +_OnClauseArgument = Union[_ColumnExpressionArgument[Any], roles.OnClauseRole] +"""target for an ON clause, includes additional roles such as those which +come from the ORM. + +""" + +_SelectStatementForCompoundArgument = Union[ + "SelectBase", roles.CompoundElementRole +] +"""SELECT statement acceptable by ``union()`` and other SQL set operations""" + +_DMLColumnArgument = Union[ + str, "ColumnClause[Any]", _HasClauseElement, roles.DMLColumnRole ] +"""A DML column expression. This is a "key" inside of insert().values(), +update().values(), and related. + +These are usually strings or SQL table columns. + +There's also edge cases like JSON expression assignment, which we would want +the DMLColumnRole to be able to accommodate. -_DMLColumnArgument = Union[str, "ColumnClause[Any]", _HasClauseElement] +""" + + +_DMLTableArgument = Union[ + "TableClause", + "Join", + "Alias", + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +] _PropagateAttrsType = util.immutabledict[str, Any] _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] +if TYPE_CHECKING: -def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: - return t.named_with_column + def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: + ... + def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: + ... -def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: - return c._is_column_element + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: + ... + def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: + ... -def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: - return c._is_text_clause + def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: + ... + + def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: + ... + + def is_select_base(t: ReturnsRows) -> TypeGuard[SelectBase]: + ... + + def is_select_statement(t: ReturnsRows) -> TypeGuard[Select]: + ... + + def is_table(t: FromClause) -> TypeGuard[TableClause]: + ... + + def is_subquery(t: FromClause) -> TypeGuard[Subquery]: + ... + +else: + is_named_from_clause = operator.attrgetter("named_with_column") + is_column_element = operator.attrgetter("_is_column_element") + is_text_clause = operator.attrgetter("_is_text_clause") + is_from_clause = operator.attrgetter("_is_from_clause") + is_tuple_type = operator.attrgetter("_is_tuple_type") + is_table_value_type = operator.attrgetter("_is_table_value") + is_select_base = operator.attrgetter("_is_select_base") + is_select_statement = operator.attrgetter("_is_select_statement") + is_table = operator.attrgetter("_is_table") + is_subquery = operator.attrgetter("_is_subquery") def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]: @@ -95,9 +204,5 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: return hasattr(s, "quote") -def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: - return t._is_tuple_type - - def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: return hasattr(s, "__clause_element__") diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index f1919d1d3..fa36c09fc 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -59,7 +59,9 @@ class SupportsAnnotations(ExternallyTraversible): _is_immutable: bool - def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations: + def _annotate( + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: raise NotImplementedError() @overload @@ -105,11 +107,6 @@ class SupportsAnnotations(ExternallyTraversible): ) -SelfSupportsCloneAnnotations = TypeVar( - "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations" -) - - class SupportsCloneAnnotations(SupportsAnnotations): if not typing.TYPE_CHECKING: __slots__ = () @@ -119,8 +116,8 @@ class SupportsCloneAnnotations(SupportsAnnotations): ] def _annotate( - self: SelfSupportsCloneAnnotations, values: _AnnotationDict - ) -> SelfSupportsCloneAnnotations: + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -132,8 +129,8 @@ class SupportsCloneAnnotations(SupportsAnnotations): return new def _with_annotations( - self: SelfSupportsCloneAnnotations, values: _AnnotationDict - ) -> SelfSupportsCloneAnnotations: + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations replaced by the given dictionary. @@ -184,11 +181,6 @@ class SupportsCloneAnnotations(SupportsAnnotations): return self -SelfSupportsWrappingAnnotations = TypeVar( - "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations" -) - - class SupportsWrappingAnnotations(SupportsAnnotations): __slots__ = () @@ -200,19 +192,23 @@ class SupportsWrappingAnnotations(SupportsAnnotations): def entity_namespace(self) -> _EntityNamespace: ... - def _annotate(self, values: _AnnotationDict) -> Annotated: + def _annotate( + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations updated by the given dictionary. """ - return Annotated._as_annotated_instance(self, values) + return Annotated._as_annotated_instance(self, values) # type: ignore - def _with_annotations(self, values: _AnnotationDict) -> Annotated: + def _with_annotations( + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations replaced by the given dictionary. """ - return Annotated._as_annotated_instance(self, values) + return Annotated._as_annotated_instance(self, values) # type: ignore @overload def _deannotate( @@ -306,16 +302,17 @@ class Annotated(SupportsAnnotations): self: SelfAnnotated, values: _AnnotationDict ) -> SelfAnnotated: _values = self._annotations.union(values) - return self._with_annotations(_values) + new: SelfAnnotated = self._with_annotations(_values) # type: ignore + return new def _with_annotations( - self: SelfAnnotated, values: util.immutabledict[str, Any] - ) -> SelfAnnotated: + self: SelfAnnotated, values: _AnnotationDict + ) -> SupportsAnnotations: clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() clone.__dict__.pop("_annotations_cache_key", None) clone.__dict__.pop("_generate_cache_key", None) - clone._annotations = values + clone._annotations = util.immutabledict(values) return clone @overload diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 19e4c13d2..6b25d8fcd 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -63,12 +63,14 @@ if TYPE_CHECKING: from . import elements from . import type_api from ._typing import _ColumnsClauseArgument - from ._typing import _SelectIterable from .elements import BindParameter from .elements import ColumnClause from .elements import ColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations + from .elements import TextClause + from .selectable import _JoinTargetElement + from .selectable import _SelectIterable from .selectable import FromClause from ..engine import Connection from ..engine import Result @@ -167,7 +169,11 @@ class SingletonConstant(Immutable): cls._singleton = obj -def _from_objects(*elements: ColumnElement[Any]) -> Iterator[FromClause]: +def _from_objects( + *elements: Union[ + ColumnElement[Any], FromClause, TextClause, _JoinTargetElement + ] +) -> Iterator[FromClause]: return itertools.chain.from_iterable( [element._from_objects for element in elements] ) @@ -255,6 +261,11 @@ def _expand_cloned(elements): predecessors. """ + # TODO: cython candidate + # and/or change approach: in + # https://gerrit.sqlalchemy.org/c/sqlalchemy/sqlalchemy/+/3712 we propose + # getting rid of _cloned_set. + # turning this into chain.from_iterable adds all kinds of callcount return itertools.chain(*[x._cloned_set for x in elements]) @@ -1559,6 +1570,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]): was moved onto the :class:`_expression.ColumnCollection` itself. """ + # TODO: cython candidate + + # don't dig around if the column is locally present + if column in self._colset: + return column def embedded(expanded_proxy_set, target_set): for t in target_set.difference(expanded_proxy_set): @@ -1568,9 +1584,6 @@ class ColumnCollection(Generic[_COLKEY, _COL]): return False return True - # don't dig around if the column is locally present - if column in self._colset: - return column col, intersect = None, None target_set = column.proxy_set cols = [c for (k, c) in self._collection] diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 19a232c56..1f8b9c19e 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -54,6 +54,11 @@ class CacheConst(enum.Enum): NO_CACHE = CacheConst.NO_CACHE +_CacheKeyTraversalType = Union[ + "_TraverseInternalsType", Literal[CacheConst.NO_CACHE], Literal[None] +] + + class CacheTraverseTarget(enum.Enum): CACHE_IN_PLACE = 0 CALL_GEN_CACHE_KEY = 1 @@ -89,9 +94,7 @@ class HasCacheKey: __slots__ = () - _cache_key_traversal: Union[ - _TraverseInternalsType, Literal[CacheConst.NO_CACHE], Literal[None] - ] = NO_CACHE + _cache_key_traversal: _CacheKeyTraversalType = NO_CACHE _is_has_cache_key = True diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index ccc8fba8d..4c71ca38b 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -12,7 +12,6 @@ import numbers import re import typing from typing import Any -from typing import Any as TODO_Any from typing import Callable from typing import Dict from typing import List @@ -20,7 +19,9 @@ from typing import NoReturn from typing import Optional from typing import overload from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import operators from . import roles @@ -32,6 +33,7 @@ from .visitors import Visitable from .. import exc from .. import inspection from .. import util +from ..util.typing import Literal if not typing.TYPE_CHECKING: elements = None @@ -46,12 +48,26 @@ if typing.TYPE_CHECKING: from . import schema from . import selectable from . import traversals + from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument + from ._typing import _DMLTableArgument + from ._typing import _FromClauseArgument + from .dml import _DMLTableElement from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import DQLDMLClauseElement from .elements import SQLCoreOperations - + from .schema import Column + from .selectable import _ColumnsClauseElement + from .selectable import _JoinTargetElement + from .selectable import _JoinTargetProtocol + from .selectable import _OnClauseElement + from .selectable import FromClause + from .selectable import HasCTE + from .selectable import SelectBase + from .selectable import Subquery + from .visitors import _TraverseCallableType _SR = TypeVar("_SR", bound=roles.SQLRole) _F = TypeVar("_F", bound=Callable[..., Any]) @@ -143,10 +159,6 @@ def _expression_collection_was_a_list(attrname, fnname, args): def expect( role: Type[roles.TruncatedLabelRole], element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, **kw: Any, ) -> str: ... @@ -154,12 +166,30 @@ def expect( @overload def expect( - role: Type[roles.ExpressionElementRole[_T]], + role: Type[roles.StatementOptionRole], element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, + **kw: Any, +) -> DQLDMLClauseElement: + ... + + +@overload +def expect( + role: Type[roles.DDLReferredColumnRole], + element: Any, + **kw: Any, +) -> Column[Any]: + ... + + +@overload +def expect( + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + ], + element: _ColumnExpressionArgument[_T], **kw: Any, ) -> ColumnElement[_T]: ... @@ -167,40 +197,89 @@ def expect( @overload def expect( - role: Type[roles.DMLTableRole], + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + ], element: Any, + **kw: Any, +) -> ColumnElement[Any]: + ... + + +@overload +def expect( + role: Type[roles.DMLTableRole], + element: _DMLTableArgument, + **kw: Any, +) -> _DMLTableElement: + ... + + +@overload +def expect( + role: Type[roles.HasCTERole], + element: HasCTE, + **kw: Any, +) -> HasCTE: + ... + + +@overload +def expect( + role: Type[roles.SelectStatementRole], + element: SelectBase, + **kw: Any, +) -> SelectBase: + ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: _FromClauseArgument, + **kw: Any, +) -> FromClause: + ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: SelectBase, *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, + explicit_subquery: Literal[True] = ..., **kw: Any, -) -> roles.DMLTableRole: +) -> Subquery: ... @overload def expect( role: Type[roles.ColumnsClauseRole], - element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, + element: _ColumnsClauseArgument, + **kw: Any, +) -> _ColumnsClauseElement: + ... + + +@overload +def expect( + role: Union[Type[roles.JoinTargetRole], Type[roles.OnClauseRole]], + element: _JoinTargetProtocol, **kw: Any, -) -> roles.ColumnsClauseRole: +) -> _JoinTargetProtocol: ... +# catchall for not-yet-implemented overloads @overload def expect( role: Type[_SR], element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, **kw: Any, -) -> TODO_Any: +) -> Any: ... @@ -212,7 +291,7 @@ def expect( argname: Optional[str] = None, post_inspect: bool = False, **kw: Any, -) -> TODO_Any: +) -> Any: if ( role.allows_lambda # note callable() will not invoke a __getattr__() method, whereas @@ -329,7 +408,8 @@ def expect_col_expression_collection(role, expressions): strname = resolved = expr else: cols: List[ColumnClause[Any]] = [] - visitors.traverse(resolved, {}, {"column": cols.append}) + col_append: _TraverseCallableType[ColumnClause[Any]] = cols.append + visitors.traverse(resolved, {}, {"column": col_append}) if cols: column = cols[0] add_element = column if column is not None else strname @@ -432,7 +512,7 @@ class _ColumnCoercions(RoleImpl): original_element = element if not getattr(resolved, "is_clause_element", False): self._raise_for_expected(original_element, argname, resolved) - elif resolved._is_select_statement: + elif resolved._is_select_base: self._warn_for_scalar_subquery_coercion() return resolved.scalar_subquery() elif resolved._is_from_clause and isinstance( @@ -670,7 +750,7 @@ class InElementImpl(RoleImpl): if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) - and resolved.element._is_select_statement + and resolved.element._is_select_base ): self._warn_for_implicit_coercion(resolved) return self._post_coercion(resolved.element, **kw) @@ -722,7 +802,7 @@ class InElementImpl(RoleImpl): self._raise_for_expected(element, **kw) def _post_coercion(self, element, expr, operator, **kw): - if element._is_select_statement: + if element._is_select_base: # for IN, we are doing scalar_subquery() coercion without # a warning return element.scalar_subquery() @@ -1085,7 +1165,7 @@ class JoinTargetImpl(RoleImpl): # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match # were set to False. return element - elif legacy and resolved._is_select_statement: + elif legacy and resolved._is_select_base: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT " "constructs into FROM clauses is deprecated; please call " @@ -1114,7 +1194,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): allow_select: bool = True, **kw: Any, ) -> Any: - if resolved._is_select_statement: + if resolved._is_select_base: if explicit_subquery: return resolved.subquery() elif allow_select: @@ -1150,7 +1230,7 @@ class StrictFromClauseImpl(FromClauseImpl): allow_select: bool = False, **kw: Any, ) -> Any: - if resolved._is_select_statement and allow_select: + if resolved._is_select_base and allow_select: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT constructs " "into FROM clauses is deprecated; please call .subquery() " @@ -1195,7 +1275,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl): if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) - and resolved.element._is_select_statement + and resolved.element._is_select_base ): return resolved.element else: @@ -1235,3 +1315,9 @@ for name in dir(roles): if name in globals(): impl = globals()[name](cls) _impl_lookup[cls] = impl + +if not TYPE_CHECKING: + ee_impl = _impl_lookup[roles.ExpressionElementRole] + + for py_type in (int, bool, str, float): + _impl_lookup[roles.ExpressionElementRole[py_type]] = ee_impl diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b7f6d11f6..6ecfbf986 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -313,12 +313,12 @@ EXTRACT_MAP = { } COMPOUND_KEYWORDS = { - selectable.CompoundSelect.UNION: "UNION", - selectable.CompoundSelect.UNION_ALL: "UNION ALL", - selectable.CompoundSelect.EXCEPT: "EXCEPT", - selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", - selectable.CompoundSelect.INTERSECT: "INTERSECT", - selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", + selectable._CompoundSelectKeyword.UNION: "UNION", + selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL", + selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT", + selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL", + selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT", + selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL", } @@ -1468,6 +1468,10 @@ class SQLCompiler(Compiled): self.post_compile_params = frozenset() for key in expanded_state.parameter_expansion: bind = self.binds.pop(key) + + if TYPE_CHECKING: + assert bind.value is not None + self.bind_names.pop(bind) for value, expanded_key in zip( bind.value, expanded_state.parameter_expansion[key] @@ -3089,12 +3093,7 @@ class SQLCompiler(Compiled): self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) if cte.recursive: - if isinstance(cte.element, selectable.Select): - col_source = cte.element - elif isinstance(cte.element, selectable.CompoundSelect): - col_source = cte.element.selects[0] - else: - assert False, "cte should only be against SelectBase" + col_source = cte.element # TODO: can we get at the .columns_plus_names collection # that is already (or will be?) generated for the SELECT @@ -3315,7 +3314,9 @@ class SQLCompiler(Compiled): for elem in chunk ) - if isinstance(element.name, elements._truncated_label): + if element._unnamed: + name = None + elif isinstance(element.name, elements._truncated_label): name = self._truncated_identifier("values", element.name) else: name = element.name @@ -3980,7 +3981,7 @@ class SQLCompiler(Compiled): clause = " ".join( prefix._compiler_dispatch(self, **kw) for prefix, dialect_name in prefixes - if dialect_name is None or dialect_name == self.dialect.name + if dialect_name in (None, "*") or dialect_name == self.dialect.name ) if clause: clause += " " diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 0c9056aee..8a3a1b38f 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -45,11 +45,14 @@ from .base import Executable from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement +from .elements import ColumnClause from .elements import ColumnElement from .elements import Null +from .selectable import Alias from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes +from .selectable import Join from .selectable import ReturnsRows from .selectable import TableClause from .sqltypes import NullType @@ -59,15 +62,15 @@ from .. import util from ..util.typing import TypeGuard if TYPE_CHECKING: - + from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument from ._typing import _DMLColumnArgument + from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument - from ._typing import _HasClauseElement - from ._typing import _SelectIterable from .base import ReadOnlyColumnCollection from .compiler import SQLCompiler - from .elements import ColumnClause + from .selectable import _ColumnsClauseElement + from .selectable import _SelectIterable from .selectable import Select def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: @@ -85,7 +88,8 @@ else: isinsert = operator.attrgetter("isinsert") -_DMLColumnElement = Union[str, "ColumnClause[Any]"] +_DMLColumnElement = Union[str, ColumnClause[Any]] +_DMLTableElement = Union[TableClause, Alias, Join] class DMLState(CompileState): @@ -132,7 +136,7 @@ class DMLState(CompileState): ] @property - def dml_table(self) -> roles.DMLTableRole: + def dml_table(self) -> _DMLTableElement: return self.statement.table if TYPE_CHECKING: @@ -322,17 +326,17 @@ class UpdateBase( __visit_name__ = "update_base" _hints: util.immutabledict[ - Tuple[roles.DMLTableRole, str], str + Tuple[_DMLTableElement, str], str ] = util.EMPTY_DICT named_with_column = False - table: roles.DMLTableRole + table: _DMLTableElement _return_defaults = False _return_defaults_columns: Optional[ - Tuple[roles.ColumnsClauseRole, ...] + Tuple[_ColumnsClauseElement, ...] ] = None - _returning: Tuple[roles.ColumnsClauseRole, ...] = () + _returning: Tuple[_ColumnsClauseElement, ...] = () is_dml = True @@ -483,7 +487,7 @@ class UpdateBase( def with_hint( self: SelfUpdateBase, text: str, - selectable: Optional[roles.DMLTableRole] = None, + selectable: Optional[_DMLTableArgument] = None, dialect_name: str = "*", ) -> SelfUpdateBase: """Add a table hint for a single table to this @@ -517,7 +521,8 @@ class UpdateBase( """ if selectable is None: selectable = self.table - + else: + selectable = coercions.expect(roles.DMLTableRole, selectable) self._hints = self._hints.union({(selectable, dialect_name): text}) return self @@ -636,9 +641,9 @@ class ValuesBase(UpdateBase): _select_names: Optional[List[str]] = None _inline: bool = False - _returning: Tuple[roles.ColumnsClauseRole, ...] = () + _returning: Tuple[_ColumnsClauseElement, ...] = () - def __init__(self, table: _FromClauseArgument): + def __init__(self, table: _DMLTableArgument): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) @@ -970,7 +975,7 @@ class Insert(ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table: roles.FromClauseRole): + def __init__(self, table: _DMLTableArgument): super(Insert, self).__init__(table) @_generative @@ -1066,12 +1071,12 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") class DMLWhereBase: - table: roles.DMLTableRole + table: _DMLTableElement _where_criteria: Tuple[ColumnElement[Any], ...] = () @_generative def where( - self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any] + self: SelfDMLWhereBase, *whereclause: _ColumnExpressionArgument[bool] ) -> SelfDMLWhereBase: """Return a new construct with the given expression(s) added to its WHERE clause, joined to the existing clause via AND, if any. @@ -1104,7 +1109,9 @@ class DMLWhereBase: """ for criterion in whereclause: - where_criteria = coercions.expect(roles.WhereHavingRole, criterion) + where_criteria: ColumnElement[Any] = coercions.expect( + roles.WhereHavingRole, criterion + ) self._where_criteria += (where_criteria,) return self @@ -1119,7 +1126,7 @@ class DMLWhereBase: return self.where(*criteria) - def _filter_by_zero(self) -> roles.DMLTableRole: + def _filter_by_zero(self) -> _DMLTableElement: return self.table def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase: @@ -1189,7 +1196,7 @@ class Update(DMLWhereBase, ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table: roles.FromClauseRole): + def __init__(self, table: _DMLTableArgument): super(Update, self).__init__(table) @_generative @@ -1279,7 +1286,7 @@ class Delete(DMLWhereBase, UpdateBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table: roles.FromClauseRole): + def __init__(self, table: _DMLTableArgument): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index c735085f8..aec29d1b2 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -26,6 +26,7 @@ from typing import Dict from typing import FrozenSet from typing import Generic from typing import Iterable +from typing import Iterator from typing import List from typing import Mapping from typing import Optional @@ -77,8 +78,8 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _PropagateAttrsType - from ._typing import _SelectIterable from ._typing import _TypeEngineArgument + from .cache_key import _CacheKeyTraversalType from .cache_key import CacheKey from .compiler import Compiled from .compiler import SQLCompiler @@ -88,6 +89,7 @@ if typing.TYPE_CHECKING: from .schema import DefaultGenerator from .schema import FetchedValue from .schema import ForeignKey + from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause from .selectable import ReturnsRows @@ -96,6 +98,7 @@ if typing.TYPE_CHECKING: from .sqltypes import Boolean from .sqltypes import TupleType from .type_api import TypeEngine + from .visitors import _CloneCallableType from .visitors import _TraverseInternalsType from ..engine import Connection from ..engine import Dialect @@ -310,6 +313,7 @@ class ClauseElement( _is_text_clause = False _is_from_container = False _is_select_container = False + _is_select_base = False _is_select_statement = False _is_bind_parameter = False _is_clause_list = False @@ -321,7 +325,7 @@ class ClauseElement( def _order_by_label_element(self) -> Optional[Label[Any]]: return None - _cache_key_traversal = None + _cache_key_traversal: _CacheKeyTraversalType = None negation_clause: ColumnElement[bool] @@ -528,7 +532,7 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Any: """Apply a 'grouping' to this :class:`_expression.ClauseElement`. This method is overridden by subclasses to return a "grouping" @@ -637,9 +641,9 @@ class ClauseElement( return self._negate() def _negate(self) -> ClauseElement: - return UnaryExpression( - self.self_group(against=operators.inv), operator=operators.inv - ) + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression(grouped, operator=operators.inv) def __bool__(self): raise TypeError("Boolean value of this clause is not defined") @@ -1290,12 +1294,6 @@ class ColumnElement( @overload def self_group( - self: ColumnElement[bool], against: Optional[OperatorType] = None - ) -> ColumnElement[bool]: - ... - - @overload - def self_group( self: ColumnElement[Any], against: Optional[OperatorType] = None ) -> ColumnElement[Any]: ... @@ -1764,6 +1762,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): key: str type: TypeEngine[_T] + value: Optional[_T] _is_crud = False _is_bind_parameter = True @@ -1883,7 +1882,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): return cloned @property - def effective_value(self): + def effective_value(self) -> Optional[_T]: """Return the value of this bound parameter, taking into account if the ``callable`` parameter was set. @@ -1893,11 +1892,12 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): """ if self.callable: - return self.callable() + # TODO: set up protocol for bind parameter callable + return self.callable() # type: ignore else: return self.value - def render_literal_execute(self): + def render_literal_execute(self) -> BindParameter[_T]: """Produce a copy of this bound parameter that will enable the :paramref:`_sql.BindParameter.literal_execute` flag. @@ -2513,8 +2513,10 @@ class ClauseList( self.operator = operator self.group = group self.group_contents = group_contents + clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses if _flatten_sub_clauses: - clauses = util.flatten_iterator(clauses) + clauses_iterator = util.flatten_iterator(clauses_iterator) + self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role text_converter_role: Type[roles.SQLRole] = _literal_as_text_role @@ -2523,31 +2525,35 @@ class ClauseList( coercions.expect( text_converter_role, clause, apply_propagate_attrs=self ).self_group(against=self.operator) - for clause in clauses + for clause in clauses_iterator ] else: self.clauses = [ coercions.expect( text_converter_role, clause, apply_propagate_attrs=self ) - for clause in clauses + for clause in clauses_iterator ] self._is_implicitly_boolean = operators.is_boolean(self.operator) @classmethod - def _construct_raw(cls, operator, clauses=None): + def _construct_raw( + cls, + operator: OperatorType, + clauses: Optional[Sequence[ColumnElement[Any]]] = None, + ) -> ClauseList: self = cls.__new__(cls) - self.clauses = clauses if clauses else [] + self.clauses = list(clauses) if clauses else [] self.group = True self.operator = operator self.group_contents = True self._is_implicitly_boolean = False return self - def __iter__(self): + def __iter__(self) -> Iterator[ColumnElement[Any]]: return iter(self.clauses) - def __len__(self): + def __len__(self) -> int: return len(self.clauses) @property @@ -2708,10 +2714,10 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): def _construct_raw( cls, operator: OperatorType, - clauses: Optional[List[ColumnElement[Any]]] = None, + clauses: Optional[Sequence[ColumnElement[Any]]] = None, ) -> BooleanClauseList: self = cls.__new__(cls) - self.clauses = clauses if clauses else [] + self.clauses = list(clauses) if clauses else [] self.group = True self.operator = operator self.group_contents = True @@ -2781,7 +2787,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): sqltypes = util.preloaded.sql_sqltypes if types is None: - init_clauses = [ + init_clauses: List[ColumnElement[Any]] = [ coercions.expect(roles.ExpressionElementRole, c) for c in clauses ] @@ -2908,7 +2914,7 @@ class Case(ColumnElement[_T]): ] if whenlist: - type_ = list(whenlist[-1])[-1].type + type_ = whenlist[-1][-1].type else: type_ = None @@ -3098,6 +3104,8 @@ class _label_reference(ColumnElement[_T]): ("element", InternalTraversal.dp_clauseelement) ] + element: ColumnElement[_T] + def __init__(self, element: ColumnElement[_T]): self.element = element @@ -3212,7 +3220,9 @@ class UnaryExpression(ColumnElement[_T]): cls, expr: _ColumnExpressionArgument[_T], ) -> UnaryExpression[_T]: - col_expr = coercions.expect(roles.ExpressionElementRole, expr) + col_expr: ColumnElement[_T] = coercions.expect( + roles.ExpressionElementRole, expr + ) return UnaryExpression( col_expr, operator=operators.distinct_op, @@ -3265,7 +3275,7 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_any( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: - col_expr = coercions.expect( + col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, ) @@ -3281,7 +3291,7 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_all( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: - col_expr = coercions.expect( + col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, ) @@ -3374,6 +3384,9 @@ class BinaryExpression(ColumnElement[_T]): modifiers: Optional[Mapping[str, Any]] + left: ColumnElement[Any] + right: Union[ColumnElement[Any], ClauseList] + def __init__( self, left: ColumnElement[Any], @@ -4147,7 +4160,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): def foreign_keys(self): return self.element.foreign_keys - def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): + def _copy_internals( + self, + *, + clone: _CloneCallableType = _clone, + anonymize_labels: bool = False, + **kw: Any, + ) -> None: self._reset_memoizations() self._element = clone(self._element, **kw) if anonymize_labels: @@ -4447,7 +4466,9 @@ class TableValuedColumn(NamedColumn[_T]): self.key = self.name = scalar_alias.name self.type = type_ - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: self.scalar_alias = clone(self.scalar_alias, **kw) self.key = self.name = self.scalar_alias.name @@ -4467,7 +4488,7 @@ class CollationClause(ColumnElement[str]): def _create_collation_expression( cls, expression: _ColumnExpressionArgument[str], collation: str ) -> BinaryExpression[str]: - expr = coercions.expect(roles.ExpressionElementRole, expression) + expr = coercions.expect(roles.ExpressionElementRole[str], expression) return BinaryExpression( expr, CollationClause(collation), diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 3bca8b502..db4bb5837 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -11,10 +11,15 @@ from __future__ import annotations +import datetime from typing import Any +from typing import cast +from typing import Dict +from typing import Mapping from typing import Optional from typing import overload -from typing import Sequence +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -24,7 +29,9 @@ from . import operators from . import roles from . import schema from . import sqltypes +from . import type_api from . import util as sqlutil +from ._typing import is_table_value_type from .base import _entity_namespace from .base import ColumnCollection from .base import Executable @@ -46,16 +53,21 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias +from .sqltypes import _N +from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util + if TYPE_CHECKING: from ._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) -_registry = util.defaultdict(dict) +_registry: util.defaultdict[ + str, Dict[str, Type[Function[Any]]] +] = util.defaultdict(dict) def register_function(identifier, fn, package="_default"): @@ -103,11 +115,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ("_table_value_type", InternalTraversal.dp_has_cache_key), ] - packagenames = () + packagenames: Tuple[str, ...] = () _has_args = False _with_ordinality = False - _table_value_type = None + _table_value_type: Optional[TableValueType] = None + + # some attributes that are defined between both ColumnElement and + # FromClause are set to Any here to avoid typing errors + primary_key: Any + _is_clone_of: Any + + clause_expr: Grouping[Any] def __init__(self, *clauses: Any): r"""Construct a :class:`.FunctionElement`. @@ -135,9 +154,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): for c in clauses ] self._has_args = self._has_args or bool(args) - self.clause_expr = ClauseList( - operator=operators.comma_op, group_contents=True, *args - ).self_group() + self.clause_expr = Grouping( + ClauseList(operator=operators.comma_op, group_contents=True, *args) + ) _non_anon_label = None @@ -263,9 +282,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): expr += (with_ordinality,) new_func._with_ordinality = True - new_func.type = new_func._table_value_type = sqltypes.TableValueType( - *expr - ) + new_func.type = new_func._table_value_type = TableValueType(*expr) return new_func.alias(name=name, joins_implicitly=joins_implicitly) @@ -332,7 +349,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): @property def _all_selected_columns(self): - if self.type._is_table_value: + if is_table_value_type(self.type): cols = self.type._elements else: cols = [self.label(None)] @@ -344,12 +361,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return self.columns @HasMemoized.memoized_attribute - def clauses(self): + def clauses(self) -> ClauseList: """Return the underlying :class:`.ClauseList` which contains the arguments for this :class:`.FunctionElement`. """ - return self.clause_expr.element + return cast(ClauseList, self.clause_expr.element) def over(self, partition_by=None, order_by=None, rows=None, range_=None): """Produce an OVER clause against this function. @@ -647,7 +664,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return _entity_namespace(self.clause_expr) -class FunctionAsBinary(BinaryExpression): +class FunctionAsBinary(BinaryExpression[Any]): _traverse_internals = [ ("sql_function", InternalTraversal.dp_clauseelement), ("left_index", InternalTraversal.dp_plain_obj), @@ -655,10 +672,16 @@ class FunctionAsBinary(BinaryExpression): ("modifiers", InternalTraversal.dp_plain_dict), ] + sql_function: FunctionElement[Any] + left_index: int + right_index: int + def _gen_cache_key(self, anon_map, bindparams): return ColumnElement._gen_cache_key(self, anon_map, bindparams) - def __init__(self, fn, left_index, right_index): + def __init__( + self, fn: FunctionElement[Any], left_index: int, right_index: int + ): self.sql_function = fn self.left_index = left_index self.right_index = right_index @@ -670,23 +693,30 @@ class FunctionAsBinary(BinaryExpression): self.modifiers = {} @property - def left(self): + def left_expr(self) -> ColumnElement[Any]: return self.sql_function.clauses.clauses[self.left_index - 1] - @left.setter - def left(self, value): + @left_expr.setter + def left_expr(self, value: ColumnElement[Any]) -> None: self.sql_function.clauses.clauses[self.left_index - 1] = value @property - def right(self): + def right_expr(self) -> ColumnElement[Any]: return self.sql_function.clauses.clauses[self.right_index - 1] - @right.setter - def right(self, value): + @right_expr.setter + def right_expr(self, value: ColumnElement[Any]) -> None: self.sql_function.clauses.clauses[self.right_index - 1] = value + if not TYPE_CHECKING: + # mypy can't accommodate @property to replace an instance + # variable + + left = left_expr + right = right_expr + -class ScalarFunctionColumn(NamedColumn): +class ScalarFunctionColumn(NamedColumn[_T]): __visit_name__ = "scalar_function_column" _traverse_internals = [ @@ -698,10 +728,18 @@ class ScalarFunctionColumn(NamedColumn): is_literal = False table = None - def __init__(self, fn, name, type_=None): + def __init__( + self, + fn: FunctionElement[_T], + name: str, + type_: Optional[_TypeEngineArgument[_T]] = None, + ): self.fn = fn self.name = name - self.type = sqltypes.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore class _FunctionGenerator: @@ -789,7 +827,7 @@ class _FunctionGenerator: # passthru __ attributes; fixes pydoc if name.startswith("__"): try: - return self.__dict__[name] + return self.__dict__[name] # type: ignore except KeyError: raise AttributeError(name) @@ -883,8 +921,6 @@ class Function(FunctionElement[_T]): identifier: str - packagenames: Sequence[str] - type: TypeEngine[_T] """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -907,7 +943,7 @@ class Function(FunctionElement[_T]): name: str, *clauses: Any, type_: Optional[_TypeEngineArgument[_T]] = None, - packagenames: Optional[Sequence[str]] = None, + packagenames: Optional[Tuple[str, ...]] = None, ): """Construct a :class:`.Function`. @@ -918,7 +954,9 @@ class Function(FunctionElement[_T]): self.packagenames = packagenames or () self.name = name - self.type = sqltypes.to_instance(type_) + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore FunctionElement.__init__(self, *clauses) @@ -934,7 +972,7 @@ class Function(FunctionElement[_T]): ) -class GenericFunction(Function): +class GenericFunction(Function[_T]): """Define a 'generic' function. A generic function is a pre-established :class:`.Function` @@ -957,7 +995,7 @@ class GenericFunction(Function): from sqlalchemy.types import DateTime class as_utc(GenericFunction): - type = DateTime + type = DateTime() inherit_cache = True print(select(func.as_utc())) @@ -971,7 +1009,7 @@ class GenericFunction(Function): "time":: class as_utc(GenericFunction): - type = DateTime + type = DateTime() package = "time" inherit_cache = True @@ -987,7 +1025,7 @@ class GenericFunction(Function): the usage of ``name`` as the rendered name:: class GeoBuffer(GenericFunction): - type = Geometry + type = Geometry() package = "geo" name = "ST_Buffer" identifier = "buffer" @@ -1006,7 +1044,7 @@ class GenericFunction(Function): from sqlalchemy.sql import quoted_name class GeoBuffer(GenericFunction): - type = Geometry + type = Geometry() package = "geo" name = quoted_name("ST_Buffer", True) identifier = "buffer" @@ -1028,6 +1066,8 @@ class GenericFunction(Function): coerce_arguments = True inherit_cache = True + _register: bool + name = "GenericFunction" def __init_subclass__(cls) -> None: @@ -1036,7 +1076,9 @@ class GenericFunction(Function): super().__init_subclass__() @classmethod - def _register_generic_function(cls, clsname, clsdict): + def _register_generic_function( + cls, clsname: str, clsdict: Mapping[str, Any] + ) -> None: cls.name = name = clsdict.get("name", clsname) cls.identifier = identifier = clsdict.get("identifier", name) package = clsdict.get("package", "_default") @@ -1068,11 +1110,14 @@ class GenericFunction(Function): ] self._has_args = self._has_args or bool(parsed_args) self.packagenames = () - self.clause_expr = ClauseList( - operator=operators.comma_op, group_contents=True, *parsed_args - ).self_group() - self.type = sqltypes.to_instance( + self.clause_expr = Grouping( + ClauseList( + operator=operators.comma_op, group_contents=True, *parsed_args + ) + ) + + self.type = type_api.to_instance( # type: ignore kwargs.pop("type_", None) or getattr(self, "type", None) ) @@ -1081,7 +1126,7 @@ register_function("cast", Cast) register_function("extract", Extract) -class next_value(GenericFunction): +class next_value(GenericFunction[int]): """Represent the 'next value', given a :class:`.Sequence` as its single argument. @@ -1103,7 +1148,7 @@ class next_value(GenericFunction): seq, schema.Sequence ), "next_value() accepts a Sequence object as input." self.sequence = seq - self.type = sqltypes.to_instance( + self.type = sqltypes.to_instance( # type: ignore seq.data_type or getattr(self, "type", None) ) @@ -1118,7 +1163,7 @@ class next_value(GenericFunction): return [] -class AnsiFunction(GenericFunction): +class AnsiFunction(GenericFunction[_T]): """Define a function in "ansi" format, which doesn't render parenthesis.""" inherit_cache = True @@ -1127,13 +1172,13 @@ class AnsiFunction(GenericFunction): GenericFunction.__init__(self, *args, **kwargs) -class ReturnTypeFromArgs(GenericFunction): +class ReturnTypeFromArgs(GenericFunction[_T]): """Define a function whose return type is the same as its arguments.""" inherit_cache = True def __init__(self, *args, **kwargs): - args = [ + fn_args = [ coercions.expect( roles.ExpressionElementRole, c, @@ -1142,35 +1187,35 @@ class ReturnTypeFromArgs(GenericFunction): ) for c in args ] - kwargs.setdefault("type_", _type_from_args(args)) - kwargs["_parsed_args"] = args - super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) + kwargs.setdefault("type_", _type_from_args(fn_args)) + kwargs["_parsed_args"] = fn_args + super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs) -class coalesce(ReturnTypeFromArgs): +class coalesce(ReturnTypeFromArgs[_T]): _has_args = True inherit_cache = True -class max(ReturnTypeFromArgs): # noqa A001 +class max(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL MAX() aggregate function.""" inherit_cache = True -class min(ReturnTypeFromArgs): # noqa A001 +class min(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL MIN() aggregate function.""" inherit_cache = True -class sum(ReturnTypeFromArgs): # noqa A001 +class sum(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL SUM() aggregate function.""" inherit_cache = True -class now(GenericFunction): +class now(GenericFunction[datetime.datetime]): """The SQL now() datetime function. SQLAlchemy dialects will usually render this particular function @@ -1178,11 +1223,11 @@ class now(GenericFunction): """ - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class concat(GenericFunction): +class concat(GenericFunction[str]): """The SQL CONCAT() function, which concatenates strings. E.g.:: @@ -1200,28 +1245,30 @@ class concat(GenericFunction): """ - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class char_length(GenericFunction): +class char_length(GenericFunction[int]): """The CHAR_LENGTH() SQL function.""" - type = sqltypes.Integer + type = sqltypes.Integer() inherit_cache = True - def __init__(self, arg, **kwargs): - GenericFunction.__init__(self, arg, **kwargs) + def __init__(self, arg, **kw): + # slight hack to limit to just one positional argument + # not sure why this one function has this special treatment + super().__init__(arg, **kw) -class random(GenericFunction): +class random(GenericFunction[float]): """The RANDOM() SQL function.""" _has_args = True inherit_cache = True -class count(GenericFunction): +class count(GenericFunction[int]): r"""The ANSI COUNT aggregate function. With no arguments, emits COUNT \*. @@ -1242,7 +1289,7 @@ class count(GenericFunction): """ - type = sqltypes.Integer + type = sqltypes.Integer() inherit_cache = True def __init__(self, expression=None, **kwargs): @@ -1251,70 +1298,70 @@ class count(GenericFunction): super(count, self).__init__(expression, **kwargs) -class current_date(AnsiFunction): +class current_date(AnsiFunction[datetime.date]): """The CURRENT_DATE() SQL function.""" - type = sqltypes.Date + type = sqltypes.Date() inherit_cache = True -class current_time(AnsiFunction): +class current_time(AnsiFunction[datetime.time]): """The CURRENT_TIME() SQL function.""" - type = sqltypes.Time + type = sqltypes.Time() inherit_cache = True -class current_timestamp(AnsiFunction): +class current_timestamp(AnsiFunction[datetime.datetime]): """The CURRENT_TIMESTAMP() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class current_user(AnsiFunction): +class current_user(AnsiFunction[str]): """The CURRENT_USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class localtime(AnsiFunction): +class localtime(AnsiFunction[datetime.datetime]): """The localtime() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class localtimestamp(AnsiFunction): +class localtimestamp(AnsiFunction[datetime.datetime]): """The localtimestamp() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class session_user(AnsiFunction): +class session_user(AnsiFunction[str]): """The SESSION_USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class sysdate(AnsiFunction): +class sysdate(AnsiFunction[datetime.datetime]): """The SYSDATE() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class user(AnsiFunction): +class user(AnsiFunction[str]): """The USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class array_agg(GenericFunction): +class array_agg(GenericFunction[_T]): """Support for the ARRAY_AGG function. The ``func.array_agg(expr)`` construct returns an expression of @@ -1334,11 +1381,10 @@ class array_agg(GenericFunction): """ - type = sqltypes.ARRAY inherit_cache = True def __init__(self, *args, **kwargs): - args = [ + fn_args = [ coercions.expect( roles.ExpressionElementRole, c, apply_propagate_attrs=self ) @@ -1348,16 +1394,16 @@ class array_agg(GenericFunction): default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: - type_from_args = _type_from_args(args) + type_from_args = _type_from_args(fn_args) if isinstance(type_from_args, sqltypes.ARRAY): kwargs["type_"] = type_from_args else: kwargs["type_"] = default_array_type(type_from_args) - kwargs["_parsed_args"] = args - super(array_agg, self).__init__(*args, **kwargs) + kwargs["_parsed_args"] = fn_args + super(array_agg, self).__init__(*fn_args, **kwargs) -class OrderedSetAgg(GenericFunction): +class OrderedSetAgg(GenericFunction[_T]): """Define a function where the return type is based on the sort expression type as defined by the expression passed to the :meth:`.FunctionElement.within_group` method.""" @@ -1366,7 +1412,7 @@ class OrderedSetAgg(GenericFunction): inherit_cache = True def within_group_type(self, within_group): - func_clauses = self.clause_expr.element + func_clauses = cast(ClauseList, self.clause_expr.element) order_by = sqlutil.unwrap_order_by(within_group.order_by) if self.array_for_multi_clause and len(func_clauses.clauses) > 1: return sqltypes.ARRAY(order_by[0].type) @@ -1374,7 +1420,7 @@ class OrderedSetAgg(GenericFunction): return order_by[0].type -class mode(OrderedSetAgg): +class mode(OrderedSetAgg[_T]): """Implement the ``mode`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1389,7 +1435,7 @@ class mode(OrderedSetAgg): inherit_cache = True -class percentile_cont(OrderedSetAgg): +class percentile_cont(OrderedSetAgg[_T]): """Implement the ``percentile_cont`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1407,7 +1453,7 @@ class percentile_cont(OrderedSetAgg): inherit_cache = True -class percentile_disc(OrderedSetAgg): +class percentile_disc(OrderedSetAgg[_T]): """Implement the ``percentile_disc`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1425,7 +1471,7 @@ class percentile_disc(OrderedSetAgg): inherit_cache = True -class rank(GenericFunction): +class rank(GenericFunction[int]): """Implement the ``rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1441,7 +1487,7 @@ class rank(GenericFunction): inherit_cache = True -class dense_rank(GenericFunction): +class dense_rank(GenericFunction[int]): """Implement the ``dense_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1457,7 +1503,7 @@ class dense_rank(GenericFunction): inherit_cache = True -class percent_rank(GenericFunction): +class percent_rank(GenericFunction[_N]): """Implement the ``percent_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1469,11 +1515,11 @@ class percent_rank(GenericFunction): """ - type = sqltypes.Numeric() + type: sqltypes.Numeric[_N] = sqltypes.Numeric() inherit_cache = True -class cume_dist(GenericFunction): +class cume_dist(GenericFunction[_N]): """Implement the ``cume_dist`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1485,11 +1531,11 @@ class cume_dist(GenericFunction): """ - type = sqltypes.Numeric() + type: sqltypes.Numeric[_N] = sqltypes.Numeric() inherit_cache = True -class cube(GenericFunction): +class cube(GenericFunction[_T]): r"""Implement the ``CUBE`` grouping operation. This function is used as part of the GROUP BY of a statement, @@ -1506,7 +1552,7 @@ class cube(GenericFunction): inherit_cache = True -class rollup(GenericFunction): +class rollup(GenericFunction[_T]): r"""Implement the ``ROLLUP`` grouping operation. This function is used as part of the GROUP BY of a statement, @@ -1523,7 +1569,7 @@ class rollup(GenericFunction): inherit_cache = True -class grouping_sets(GenericFunction): +class grouping_sets(GenericFunction[_T]): r"""Implement the ``GROUPING SETS`` grouping operation. This function is used as part of the GROUP BY of a statement, diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 9d011ef53..da15c305f 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -12,6 +12,18 @@ import inspect import itertools import operator import types +from types import CodeType +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterable +from typing import List +from typing import MutableMapping +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union import weakref from . import cache_key as _cache_key @@ -19,37 +31,62 @@ from . import coercions from . import elements from . import roles from . import schema -from . import traversals from . import type_api from . import visitors from .base import _clone +from .base import Executable from .base import Options +from .cache_key import CacheConst from .operators import ColumnOperators from .. import exc from .. import inspection from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self -_closure_per_cache_key = util.LRUCache(1000) +if TYPE_CHECKING: + from .cache_key import CacheConst + from .cache_key import NO_CACHE + from .elements import BindParameter + from .elements import ClauseElement + from .roles import SQLRole + from .visitors import _CloneCallableType + +_LambdaCacheType = MutableMapping[ + Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"] +] +_BoundParameterGetter = Callable[..., Any] + +_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000) + + +class _LambdaType(Protocol): + __code__: CodeType + __closure__: Iterable[Tuple[Any, Any]] + + def __call__(self, *arg: Any, **kw: Any) -> ClauseElement: + ... class LambdaOptions(Options): enable_tracking = True track_closure_variables = True - track_on = None + track_on: Optional[object] = None global_track_bound_values = True track_bound_values = True - lambda_cache = None + lambda_cache: Optional[_LambdaCacheType] = None def lambda_stmt( - lmb, - enable_tracking=True, - track_closure_variables=True, - track_on=None, - global_track_bound_values=True, - track_bound_values=True, - lambda_cache=None, -): + lmb: _LambdaType, + enable_tracking: bool = True, + track_closure_variables: bool = True, + track_on: Optional[object] = None, + global_track_bound_values: bool = True, + track_bound_values: bool = True, + lambda_cache: Optional[_LambdaCacheType] = None, +) -> StatementLambdaElement: """Produce a SQL statement that is cached as a lambda. The Python code object within the lambda is scanned for both Python @@ -142,15 +179,28 @@ class LambdaElement(elements.ClauseElement): ("_resolved", visitors.InternalTraversal.dp_clauseelement) ] - _transforms = () + _transforms: Tuple[_CloneCallableType, ...] = () - parent_lambda = None + _resolved_bindparams: List[BindParameter[Any]] + parent_lambda: Optional[StatementLambdaElement] = None + closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] + role: Type[SQLRole] + _rec: Union[AnalyzedFunction, NonAnalyzedFunction] + fn: _LambdaType + tracker_key: Tuple[CodeType, ...] def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) + return "%s(%r)" % ( + self.__class__.__name__, + self.fn.__code__, + ) def __init__( - self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None + self, + fn: _LambdaType, + role: Type[SQLRole], + opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, + apply_propagate_attrs: Optional[ClauseElement] = None, ): self.fn = fn self.role = role @@ -182,6 +232,7 @@ class LambdaElement(elements.ClauseElement): opts, ) + bindparams: List[BindParameter[Any]] self._resolved_bindparams = bindparams = [] if self.parent_lambda is not None: @@ -189,8 +240,10 @@ class LambdaElement(elements.ClauseElement): else: parent_closure_cache_key = () + cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] + if parent_closure_cache_key is not _cache_key.NO_CACHE: - anon_map = traversals.anon_map() + anon_map = visitors.anon_map() cache_key = tuple( [ getter(closure, opts, anon_map, bindparams) @@ -241,7 +294,7 @@ class LambdaElement(elements.ClauseElement): if self.parent_lambda is not None: bindparams[:0] = self.parent_lambda._resolved_bindparams - lambda_element = self + lambda_element: Optional[LambdaElement] = self while lambda_element is not None: rec = lambda_element._rec if rec.bindparam_trackers: @@ -289,17 +342,21 @@ class LambdaElement(elements.ClauseElement): def _setup_binds_for_tracked_expr(self, expr): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} - def replace(thing): - if isinstance(thing, elements.BindParameter): + def replace( + element: Optional[visitors.ExternallyTraversible], **kw: Any + ) -> Optional[visitors.ExternallyTraversible]: + if isinstance(element, elements.BindParameter): - if thing.key in bindparam_lookup: - bind = bindparam_lookup[thing.key] - if thing.expanding: + if element.key in bindparam_lookup: + bind = bindparam_lookup[element.key] + if element.expanding: bind.expanding = True - bind.expand_op = thing.expand_op - bind.type = thing.type + bind.expand_op = element.expand_op + bind.type = element.type return bind + return None + if self._rec.is_sequence: expr = [ visitors.replacement_traverse(sub_expr, {}, replace) @@ -311,8 +368,11 @@ class LambdaElement(elements.ClauseElement): return expr def _copy_internals( - self, clone=_clone, deferred_copy_internals=None, **kw - ): + self: Self, + clone: _CloneCallableType = _clone, + deferred_copy_internals: Optional[_CloneCallableType] = None, + **kw: Any, + ) -> None: # TODO: this needs A LOT of tests self._resolved = clone( self._resolved, @@ -340,9 +400,15 @@ class LambdaElement(elements.ClauseElement): ) + self.closure_cache_key parent = self.parent_lambda + while parent is not None: + assert parent.closure_cache_key is not CacheConst.NO_CACHE + parent_closure_cache_key: Tuple[ + Any, ... + ] = parent.closure_cache_key + cache_key = ( - (parent.fn.__code__,) + parent.closure_cache_key + cache_key + (parent.fn.__code__,) + parent_closure_cache_key + cache_key ) parent = parent.parent_lambda @@ -351,7 +417,7 @@ class LambdaElement(elements.ClauseElement): bindparams.extend(self._resolved_bindparams) return cache_key - def _invoke_user_fn(self, fn, *arg): + def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement: return fn() @@ -365,7 +431,13 @@ class DeferredLambdaElement(LambdaElement): """ - def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): + def __init__( + self, + fn: _LambdaType, + role: Type[roles.SQLRole], + opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, + lambda_args: Tuple[Any, ...] = (), + ): self.lambda_args = lambda_args super(DeferredLambdaElement, self).__init__(fn, role, opts) @@ -373,6 +445,7 @@ class DeferredLambdaElement(LambdaElement): return fn(*self.lambda_args) def _resolve_with_args(self, *lambda_args): + assert isinstance(self._rec, AnalyzedFunction) tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) @@ -506,6 +579,8 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): def _execute_on_connection( self, connection, distilled_params, execution_options ): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, ClauseElement) if self._rec.expected_expr.supports_execution: return connection._execute_clauseelement( self, distilled_params, execution_options @@ -515,14 +590,20 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @property def _with_options(self): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, Executable) return self._rec.expected_expr._with_options @property def _effective_plugin_target(self): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, Executable) return self._rec.expected_expr._effective_plugin_target @property def _execution_options(self): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, Executable) return self._rec.expected_expr._execution_options def spoil(self): @@ -583,9 +664,14 @@ class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): class LinkedLambdaElement(StatementLambdaElement): """Represent subsequent links of a :class:`.StatementLambdaElement`.""" - role = None + parent_lambda: StatementLambdaElement - def __init__(self, fn, parent_lambda, opts): + def __init__( + self, + fn: _LambdaType, + parent_lambda: StatementLambdaElement, + opts: Union[Type[LambdaOptions], LambdaOptions], + ): self.opts = opts self.fn = fn self.parent_lambda = parent_lambda @@ -606,7 +692,9 @@ class AnalyzedCode: "closure_trackers", "build_py_wrappers", ) - _fns = weakref.WeakKeyDictionary() + _fns: weakref.WeakKeyDictionary[ + CodeType, AnalyzedCode + ] = weakref.WeakKeyDictionary() @classmethod def get(cls, fn, lambda_element, lambda_kw, **kw): @@ -615,6 +703,8 @@ class AnalyzedCode: return cls._fns[fn.__code__] except KeyError: pass + + analyzed: AnalyzedCode cls._fns[fn.__code__] = analyzed = AnalyzedCode( fn, lambda_element, lambda_kw, **kw ) @@ -947,14 +1037,18 @@ class AnalyzedCode: class NonAnalyzedFunction: __slots__ = ("expr",) - closure_bindparams = None - bindparam_trackers = None + closure_bindparams: Optional[List[BindParameter[Any]]] = None + bindparam_trackers: Optional[List[_BoundParameterGetter]] = None + + is_sequence = False + + expr: ClauseElement - def __init__(self, expr): + def __init__(self, expr: ClauseElement): self.expr = expr @property - def expected_expr(self): + def expected_expr(self) -> ClauseElement: return self.expr @@ -972,6 +1066,10 @@ class AnalyzedFunction: "closure_bindparams", ) + closure_bindparams: Optional[List[BindParameter[Any]]] + expected_expr: Union[ClauseElement, List[ClauseElement]] + bindparam_trackers: Optional[List[_BoundParameterGetter]] + def __init__( self, analyzed_code, @@ -1071,19 +1169,25 @@ class AnalyzedFunction: if parent_lambda is None: if isinstance(expr, collections_abc.Sequence): self.expected_expr = [ - coercions.expect( - lambda_element.role, - sub_expr, - apply_propagate_attrs=apply_propagate_attrs, + cast( + "ClauseElement", + coercions.expect( + lambda_element.role, + sub_expr, + apply_propagate_attrs=apply_propagate_attrs, + ), ) for sub_expr in expr ] self.is_sequence = True else: - self.expected_expr = coercions.expect( - lambda_element.role, - expr, - apply_propagate_attrs=apply_propagate_attrs, + self.expected_expr = cast( + "ClauseElement", + coercions.expect( + lambda_element.role, + expr, + apply_propagate_attrs=apply_propagate_attrs, + ), ) self.is_sequence = False else: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 86725f86f..577d868fd 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -19,7 +19,6 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _PropagateAttrsType - from ._typing import _SelectIterable from .base import _EntityNamespace from .base import ColumnCollection from .base import ReadOnlyColumnCollection @@ -28,6 +27,7 @@ if TYPE_CHECKING: from .elements import ColumnElement from .elements import Label from .elements import NamedColumn + from .selectable import _SelectIterable from .selectable import FromClause from .selectable import Subquery @@ -164,6 +164,12 @@ class WhereHavingRole(OnClauseRole): class ExpressionElementRole(Generic[_T], SQLRole): + # note when using generics for ExpressionElementRole, + # the generic type needs to be in + # sqlalchemy.sql.coercions._impl_lookup mapping also. + # these are set up for basic types like int, bool, str, float + # right now + __slots__ = () _role_name = "SQL expression element" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index cbd0c77f4..883439ca5 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -36,6 +36,7 @@ import operator import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict from typing import Iterator from typing import List @@ -68,6 +69,7 @@ from .elements import SQLCoreOperations from .elements import TextClause from .selectable import TableClause from .type_api import to_instance +from .visitors import ExternallyTraversible from .visitors import InternalTraversal from .. import event from .. import exc @@ -131,21 +133,33 @@ def _get_table_key(name: str, schema: Optional[str]) -> str: # this should really be in sql/util.py but we'd have to # break an import cycle -def _copy_expression(expression, source_table, target_table): +def _copy_expression( + expression: ColumnElement[Any], + source_table: Optional[Table], + target_table: Optional[Table], +) -> ColumnElement[Any]: if source_table is None or target_table is None: return expression - def replace(col): + fixed_source_table = source_table + fixed_target_table = target_table + + def replace( + element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: if ( - isinstance(col, Column) - and col.table is source_table - and col.key in source_table.c + isinstance(element, Column) + and element.table is fixed_source_table + and element.key in fixed_source_table.c ): - return target_table.c[col.key] + return fixed_target_table.c[element.key] else: return None - return visitors.replacement_traverse(expression, {}, replace) + return cast( + ColumnElement[Any], + visitors.replacement_traverse(expression, {}, replace), + ) @inspection._self_inspects @@ -911,8 +925,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): def _reset_exported(self): pass - @property - def _autoincrement_column(self): + @util.ro_non_memoized_property + def _autoincrement_column(self) -> Optional[Column[Any]]: return self.primary_key._autoincrement_column @property @@ -2308,6 +2322,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): parent: Column[Any] + _table_column: Optional[Column[Any]] + def __init__( self, column: Union[str, Column[Any], SQLCoreOperations[Any]], @@ -4290,11 +4306,11 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): self._columns.extend(columns) - PrimaryKeyConstraint._autoincrement_column._reset(self) + PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore self._set_parent_with_dispatch(self.table) def _replace(self, col): - PrimaryKeyConstraint._autoincrement_column._reset(self) + PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore self._columns.replace(col) self.dispatch._sa_event_column_added_to_pk_constraint(self, col) @@ -4308,8 +4324,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): else: return list(self._columns) - @util.memoized_property - def _autoincrement_column(self): + @util.ro_memoized_property + def _autoincrement_column(self) -> Optional[Column[Any]]: def _validate_autoinc(col, autoinc_true): if col.type._type_affinity is None or not issubclass( col.type._type_affinity, type_api.INTEGERTYPE._type_affinity @@ -4350,6 +4366,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): "ignore_fk", ) and _validate_autoinc(col, False): return col + else: + return None else: autoinc = None diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 4f6e3795e..6504449f1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -17,16 +17,26 @@ import collections from enum import Enum import itertools import typing +from typing import AbstractSet from typing import Any as TODO_Any from typing import Any +from typing import Callable +from typing import cast +from typing import Dict from typing import Iterable +from typing import Iterator from typing import List from typing import NamedTuple +from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import cache_key from . import coercions @@ -37,6 +47,9 @@ from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument from ._typing import is_column_element +from ._typing import is_select_statement +from ._typing import is_subquery +from ._typing import is_table from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -68,32 +81,80 @@ from .elements import ColumnClause from .elements import ColumnElement from .elements import DQLDMLClauseElement from .elements import GroupedElement -from .elements import Grouping from .elements import literal_column from .elements import TableValuedColumn from .elements import UnaryExpression +from .operators import OperatorType +from .visitors import _TraverseInternalsType from .visitors import InternalTraversal from .visitors import prefix_anon_map from .. import exc from .. import util +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self and_ = BooleanClauseList.and_ _T = TypeVar("_T", bound=Any) if TYPE_CHECKING: - from ._typing import _SelectIterable + from ._typing import _ColumnExpressionArgument + from ._typing import _FromClauseArgument + from ._typing import _JoinTargetArgument + from ._typing import _OnClauseArgument + from ._typing import _SelectStatementForCompoundArgument + from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypeEngineArgument + from .base import _AmbiguousTableNameMap + from .base import ExecutableOption from .base import ReadOnlyColumnCollection + from .cache_key import _CacheKeyTraversalType + from .compiler import SQLCompiler + from .dml import Delete + from .dml import Insert + from .dml import Update from .elements import NamedColumn + from .elements import TextClause + from .functions import Function + from .schema import Column from .schema import ForeignKey - from .schema import PrimaryKeyConstraint + from .schema import ForeignKeyConstraint + from .type_api import TypeEngine + from .util import ClauseAdapter + from .visitors import _CloneCallableType -class _OffsetLimitParam(BindParameter): +_ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"] + + +class _JoinTargetProtocol(Protocol): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + ... + + +_JoinTargetElement = Union["FromClause", _JoinTargetProtocol] +_OnClauseElement = Union["ColumnElement[bool]", _JoinTargetProtocol] + + +_SetupJoinsElement = Tuple[ + _JoinTargetElement, + Optional[_OnClauseElement], + Optional["FromClause"], + Dict[str, Any], +] + + +_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]] + + +class _OffsetLimitParam(BindParameter[int]): inherit_cache = True @property - def _limit_offset_value(self): + def _limit_offset_value(self) -> Optional[int]: return self.effective_value @@ -114,11 +175,12 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): # sub-elements of returns_rows _is_from_clause = False + _is_select_base = False _is_select_statement = False _is_lateral = False @property - def selectable(self): + def selectable(self) -> ReturnsRows: return self @util.non_memoized_property @@ -133,8 +195,28 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): """ raise NotImplementedError() + def is_derived_from(self, fromclause: FromClause) -> bool: + """Return ``True`` if this :class:`.ReturnsRows` is + 'derived' from the given :class:`.FromClause`. + + An example would be an Alias of a Table is derived from that Table. + + """ + raise NotImplementedError() + + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: + """Populate columns into an :class:`.AliasedReturnsRows` object.""" + + raise NotImplementedError() + + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: + """reset internal collections for an incoming column being added.""" + raise NotImplementedError() + @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[Any, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.ReturnsRows`. @@ -160,6 +242,9 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() +SelfSelectable = TypeVar("SelfSelectable", bound="Selectable") + + class Selectable(ReturnsRows): """Mark a class as being selectable.""" @@ -167,10 +252,10 @@ class Selectable(ReturnsRows): is_selectable = True - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: raise NotImplementedError() - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a LATERAL alias of this :class:`_expression.Selectable`. The return value is the :class:`_expression.Lateral` construct also @@ -192,15 +277,21 @@ class Selectable(ReturnsRows): "functionality is available via the sqlalchemy.sql.visitors module.", ) @util.preload_module("sqlalchemy.sql.util") - def replace_selectable(self, old, alias): + def replace_selectable( + self: SelfSelectable, old: FromClause, alias: Alias + ) -> SelfSelectable: """Replace all occurrences of :class:`_expression.FromClause` 'old' with the given :class:`_expression.Alias` object, returning a copy of this :class:`_expression.FromClause`. """ - return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self) + return util.preloaded.sql_util.ClauseAdapter(alias).traverse( # type: ignore # noqa E501 + self + ) - def corresponding_column(self, column, require_embedded=False): + def corresponding_column( + self, column: ColumnElement[Any], require_embedded: bool = False + ) -> Optional[ColumnElement[Any]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from the :attr:`_expression.Selectable.exported_columns` @@ -242,19 +333,23 @@ SelfHasPrefixes = typing.TypeVar("SelfHasPrefixes", bound="HasPrefixes") class HasPrefixes: - _prefixes = () + _prefixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = () - _has_prefixes_traverse_internals = [ + _has_prefixes_traverse_internals: _TraverseInternalsType = [ ("_prefixes", InternalTraversal.dp_prefix_sequence) ] @_generative @_document_text_coercion( - "expr", + "prefixes", ":meth:`_expression.HasPrefixes.prefix_with`", - ":paramref:`.HasPrefixes.prefix_with.*expr`", + ":paramref:`.HasPrefixes.prefix_with.*prefixes`", ) - def prefix_with(self: SelfHasPrefixes, *expr, **kw) -> SelfHasPrefixes: + def prefix_with( + self: SelfHasPrefixes, + *prefixes: _TextCoercedExpressionArgument[Any], + dialect: str = "*", + ) -> SelfHasPrefixes: r"""Add one or more expressions following the statement keyword, i.e. SELECT, INSERT, UPDATE, or DELETE. Generative. @@ -272,49 +367,44 @@ class HasPrefixes: Multiple prefixes can be specified by multiple calls to :meth:`_expression.HasPrefixes.prefix_with`. - :param \*expr: textual or :class:`_expression.ClauseElement` + :param \*prefixes: textual or :class:`_expression.ClauseElement` construct which will be rendered following the INSERT, UPDATE, or DELETE keyword. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will + :param dialect: optional string dialect name which will limit rendering of this prefix to only that dialect. """ - dialect = kw.pop("dialect", None) - if kw: - raise exc.ArgumentError( - "Unsupported argument(s): %s" % ",".join(kw) - ) - self._setup_prefixes(expr, dialect) - return self - - def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( [ (coercions.expect(roles.StatementOptionRole, p), dialect) for p in prefixes ] ) + return self SelfHasSuffixes = typing.TypeVar("SelfHasSuffixes", bound="HasSuffixes") class HasSuffixes: - _suffixes = () + _suffixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = () - _has_suffixes_traverse_internals = [ + _has_suffixes_traverse_internals: _TraverseInternalsType = [ ("_suffixes", InternalTraversal.dp_prefix_sequence) ] @_generative @_document_text_coercion( - "expr", + "suffixes", ":meth:`_expression.HasSuffixes.suffix_with`", - ":paramref:`.HasSuffixes.suffix_with.*expr`", + ":paramref:`.HasSuffixes.suffix_with.*suffixes`", ) - def suffix_with(self: SelfHasSuffixes, *expr, **kw) -> SelfHasSuffixes: + def suffix_with( + self: SelfHasSuffixes, + *suffixes: _TextCoercedExpressionArgument[Any], + dialect: str = "*", + ) -> SelfHasSuffixes: r"""Add one or more expressions following the statement as a whole. This is used to support backend-specific suffix keywords on @@ -328,44 +418,39 @@ class HasSuffixes: Multiple suffixes can be specified by multiple calls to :meth:`_expression.HasSuffixes.suffix_with`. - :param \*expr: textual or :class:`_expression.ClauseElement` + :param \*suffixes: textual or :class:`_expression.ClauseElement` construct which will be rendered following the target clause. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will + :param dialect: Optional string dialect name which will limit rendering of this suffix to only that dialect. """ - dialect = kw.pop("dialect", None) - if kw: - raise exc.ArgumentError( - "Unsupported argument(s): %s" % ",".join(kw) - ) - self._setup_suffixes(expr, dialect) - return self - - def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( [ (coercions.expect(roles.StatementOptionRole, p), dialect) for p in suffixes ] ) + return self SelfHasHints = typing.TypeVar("SelfHasHints", bound="HasHints") class HasHints: - _hints = util.immutabledict() - _statement_hints = () + _hints: util.immutabledict[ + Tuple[FromClause, str], str + ] = util.immutabledict() + _statement_hints: Tuple[Tuple[str, str], ...] = () - _has_hints_traverse_internals = [ + _has_hints_traverse_internals: _TraverseInternalsType = [ ("_statement_hints", InternalTraversal.dp_statement_hint_list), ("_hints", InternalTraversal.dp_table_hint_list), ] - def with_statement_hint(self, text, dialect_name="*"): + def with_statement_hint( + self: SelfHasHints, text: str, dialect_name: str = "*" + ) -> SelfHasHints: """Add a statement hint to this :class:`_expression.Select` or other selectable object. @@ -389,11 +474,14 @@ class HasHints: MySQL optimizer hints """ - return self.with_hint(None, text, dialect_name) + return self._with_hint(None, text, dialect_name) @_generative def with_hint( - self: SelfHasHints, selectable, text, dialect_name="*" + self: SelfHasHints, + selectable: _FromClauseArgument, + text: str, + dialect_name: str = "*", ) -> SelfHasHints: r"""Add an indexing or other executional context hint for the given selectable to this :class:`_expression.Select` or other selectable @@ -429,6 +517,15 @@ class HasHints: :meth:`_expression.Select.with_statement_hint` """ + + return self._with_hint(selectable, text, dialect_name) + + def _with_hint( + self: SelfHasHints, + selectable: Optional[_FromClauseArgument], + text: str, + dialect_name: str, + ) -> SelfHasHints: if selectable is None: self._statement_hints += ((dialect_name, text),) else: @@ -443,6 +540,9 @@ class HasHints: return self +SelfFromClause = TypeVar("SelfFromClause", bound="FromClause") + + class FromClause(roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -473,6 +573,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _is_clone_of: Optional[FromClause] + _columns: ColumnCollection[Any, Any] + schema: Optional[str] = None """Define the 'schema' attribute for this :class:`_expression.FromClause`. @@ -488,7 +590,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> "Select": + def select(self) -> Select: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -504,7 +606,13 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return Select(self) - def join(self, right, onclause=None, isouter=False, full=False): + def join( + self, + right: _FromClauseArgument, + onclause: Optional[_ColumnExpressionArgument[bool]] = None, + isouter: bool = False, + full: bool = False, + ) -> Join: """Return a :class:`_expression.Join` from this :class:`_expression.FromClause` to another :class:`FromClause`. @@ -550,7 +658,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Join(self, right, onclause, isouter, full) - def outerjoin(self, right, onclause=None, full=False): + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_ColumnExpressionArgument[bool]] = None, + full: bool = False, + ) -> Join: """Return a :class:`_expression.Join` from this :class:`_expression.FromClause` to another :class:`FromClause`, with the "isouter" flag set to @@ -596,7 +709,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Join(self, right, onclause, True, full) - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: """Return an alias of this :class:`_expression.FromClause`. E.g.:: @@ -617,35 +732,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Alias._construct(self, name) - @util.preload_module("sqlalchemy.sql.sqltypes") - def table_valued(self): - """Return a :class:`_sql.TableValuedColumn` object for this - :class:`_expression.FromClause`. - - A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that - represents a complete row in a table. Support for this construct is - backend dependent, and is supported in various forms by backends - such as PostgreSQL, Oracle and SQL Server. - - E.g.:: - - >>> from sqlalchemy import select, column, func, table - >>> a = table("a", column("id"), column("x"), column("y")) - >>> stmt = select(func.row_to_json(a.table_valued())) - >>> print(stmt) - SELECT row_to_json(a) AS row_to_json_1 - FROM a - - .. versionadded:: 1.4.0b2 - - .. seealso:: - - :ref:`tutorial_functions` - in the :ref:`unified_tutorial` - - """ - return TableValuedColumn(self, type_api.TABLEVALUE) - - def tablesample(self, sampling, name=None, seed=None): + def tablesample( + self, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> TableSample: """Return a TABLESAMPLE alias of this :class:`_expression.FromClause`. The return value is the :class:`_expression.TableSample` @@ -661,7 +753,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return TableSample._construct(self, sampling, name, seed) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: """Return ``True`` if this :class:`_expression.FromClause` is 'derived' from the given ``FromClause``. @@ -673,7 +765,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): # contained elements. return fromclause in self._cloned_set - def _is_lexical_equivalent(self, other): + def _is_lexical_equivalent(self, other: FromClause) -> bool: """Return ``True`` if this :class:`_expression.FromClause` and the other represent the same lexical identity. @@ -681,9 +773,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): if they are the same via annotation identity. """ - return self._cloned_set.intersection(other._cloned_set) + return bool(self._cloned_set.intersection(other._cloned_set)) - @util.non_memoized_property + @util.ro_non_memoized_property def description(self) -> str: """A brief description of this :class:`_expression.FromClause`. @@ -692,13 +784,15 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return getattr(self, "name", self.__class__.__name__ + " object") - def _generate_fromclause_column_proxies(self, fromclause): + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: fromclause._columns._populate_separate_keys( col._make_proxy(fromclause) for col in self.c ) @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`. @@ -796,7 +890,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._populate_column_collection() return self.foreign_keys - def _reset_column_collection(self): + def _reset_column_collection(self) -> None: """Reset the attributes linked to the ``FromClause.c`` attribute. This collection is separate from all the other memoized things @@ -817,7 +911,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): def _select_iterable(self) -> _SelectIterable: return self.c - def _init_collections(self): + def _init_collections(self) -> None: assert "_columns" not in self.__dict__ assert "primary_key" not in self.__dict__ assert "foreign_keys" not in self.__dict__ @@ -827,10 +921,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self.foreign_keys = set() # type: ignore @property - def _cols_populated(self): + def _cols_populated(self) -> bool: return "_columns" in self.__dict__ - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: """Called on subclasses to establish the .c collection. Each implementation has a different way of establishing @@ -838,7 +932,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: """Given a column added to the .c collection of an underlying selectable, produce the local version of that column, assuming this selectable ultimately should proxy this column. @@ -865,15 +959,60 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ self._reset_column_collection() - def _anonymous_fromclause(self, name=None, flat=False): + def _anonymous_fromclause( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: return self.alias(name=name) + if TYPE_CHECKING: + + def self_group( + self: Self, against: Optional[OperatorType] = None + ) -> Union[FromGrouping, Self]: + ... + class NamedFromClause(FromClause): + """A :class:`.FromClause` that has a name. + + Examples include tables, subqueries, CTEs, aliased tables. + + .. versionadded:: 2.0 + + """ + named_with_column = True name: str + @util.preload_module("sqlalchemy.sql.sqltypes") + def table_valued(self) -> TableValuedColumn[Any]: + """Return a :class:`_sql.TableValuedColumn` object for this + :class:`_expression.FromClause`. + + A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that + represents a complete row in a table. Support for this construct is + backend dependent, and is supported in various forms by backends + such as PostgreSQL, Oracle and SQL Server. + + E.g.:: + + >>> from sqlalchemy import select, column, func, table + >>> a = table("a", column("id"), column("x"), column("y")) + >>> stmt = select(func.row_to_json(a.table_valued())) + >>> print(stmt) + SELECT row_to_json(a) AS row_to_json_1 + FROM a + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :ref:`tutorial_functions` - in the :ref:`unified_tutorial` + + """ + return TableValuedColumn(self, type_api.TABLEVALUE) + class SelectLabelStyle(Enum): """Label style constants that may be passed to @@ -992,7 +1131,7 @@ class Join(roles.DMLTableRole, FromClause): __visit_name__ = "join" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("left", InternalTraversal.dp_clauseelement), ("right", InternalTraversal.dp_clauseelement), ("onclause", InternalTraversal.dp_clauseelement), @@ -1002,7 +1141,20 @@ class Join(roles.DMLTableRole, FromClause): _is_join = True - def __init__(self, left, right, onclause=None, isouter=False, full=False): + left: FromClause + right: FromClause + onclause: Optional[ColumnElement[bool]] + isouter: bool + full: bool + + def __init__( + self, + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ): """Construct a new :class:`_expression.Join`. The usual entrypoint here is the :func:`_expression.join` @@ -1010,11 +1162,23 @@ class Join(roles.DMLTableRole, FromClause): :class:`_expression.FromClause` object. """ + + # when deannotate was removed here, callcounts went up for ORM + # compilation of eager joins, since there were more comparisons of + # annotated objects. test_orm.py -> test_fetch_results + # was therefore changed to show a more real-world use case, where the + # compilation is cached; there's no change in post-cache callcounts. + # callcounts for a single compilation in that particular test + # that includes about eight joins about 1100 extra fn calls, from + # 29200 -> 30373 + self.left = coercions.expect( - roles.FromClauseRole, left, deannotate=True + roles.FromClauseRole, + left, ) self.right = coercions.expect( - roles.FromClauseRole, right, deannotate=True + roles.FromClauseRole, + right, ).self_group() if onclause is None: @@ -1029,7 +1193,7 @@ class Join(roles.DMLTableRole, FromClause): self.isouter = isouter self.full = full - @property + @util.ro_non_memoized_property def description(self) -> str: return "Join object on %s(%d) and %s(%d)" % ( self.left.description, @@ -1038,7 +1202,7 @@ class Join(roles.DMLTableRole, FromClause): id(self.right), ) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: return ( # use hash() to ensure direct comparison to annotated works # as well @@ -1047,7 +1211,10 @@ class Join(roles.DMLTableRole, FromClause): or self.right.is_derived_from(fromclause) ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> FromGrouping: + ... return FromGrouping(self) @util.preload_module("sqlalchemy.sql.util") @@ -1055,7 +1222,7 @@ class Join(roles.DMLTableRole, FromClause): sqlutil = util.preloaded.sql_util columns = [c for c in self.left.c] + [c for c in self.right.c] - self.primary_key.extend( + self.primary_key.extend( # type: ignore sqlutil.reduce_columns( (c for c in columns if c.primary_key), self.onclause ) @@ -1063,11 +1230,13 @@ class Join(roles.DMLTableRole, FromClause): self._columns._populate_separate_keys( (col._tq_key_label, col) for col in columns ) - self.foreign_keys.update( + self.foreign_keys.update( # type: ignore itertools.chain(*[col.foreign_keys for col in columns]) ) - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: # see Select._copy_internals() for similar concept # here we pre-clone "left" and "right" so that we can @@ -1100,12 +1269,14 @@ class Join(roles.DMLTableRole, FromClause): self._reset_memoizations() - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super(Join, self)._refresh_for_new_column(column) self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) - def _match_primaries(self, left, right): + def _match_primaries( + self, left: FromClause, right: FromClause + ) -> ColumnElement[bool]: if isinstance(left, Join): left_right = left.right else: @@ -1114,8 +1285,15 @@ class Join(roles.DMLTableRole, FromClause): @classmethod def _join_condition( - cls, a, b, a_subset=None, consider_as_foreign_keys=None - ): + cls, + a: FromClause, + b: FromClause, + *, + a_subset: Optional[FromClause] = None, + consider_as_foreign_keys: Optional[ + AbstractSet[ColumnClause[Any]] + ] = None, + ) -> ColumnElement[bool]: """Create a join condition between two tables or selectables. See sqlalchemy.sql.util.join_condition() for full docs. @@ -1151,7 +1329,15 @@ class Join(roles.DMLTableRole, FromClause): return and_(*crit) @classmethod - def _can_join(cls, left, right, consider_as_foreign_keys=None): + def _can_join( + cls, + left: FromClause, + right: FromClause, + *, + consider_as_foreign_keys: Optional[ + AbstractSet[ColumnClause[Any]] + ] = None, + ) -> bool: if isinstance(left, Join): left_right = left.right else: @@ -1169,20 +1355,31 @@ class Join(roles.DMLTableRole, FromClause): @classmethod @util.preload_module("sqlalchemy.sql.util") def _joincond_scan_left_right( - cls, a, a_subset, b, consider_as_foreign_keys - ): + cls, + a: FromClause, + a_subset: Optional[FromClause], + b: FromClause, + consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]], + ) -> collections.defaultdict[ + Optional[ForeignKeyConstraint], + List[Tuple[ColumnClause[Any], ColumnClause[Any]]], + ]: sql_util = util.preloaded.sql_util a = coercions.expect(roles.FromClauseRole, a) b = coercions.expect(roles.FromClauseRole, b) - constraints = collections.defaultdict(list) + constraints: collections.defaultdict[ + Optional[ForeignKeyConstraint], + List[Tuple[ColumnClause[Any], ColumnClause[Any]]], + ] = collections.defaultdict(list) for left in (a_subset, a): if left is None: continue for fk in sorted( - b.foreign_keys, key=lambda fk: fk.parent._creation_order + b.foreign_keys, + key=lambda fk: fk.parent._creation_order, # type: ignore ): if ( consider_as_foreign_keys is not None @@ -1202,7 +1399,8 @@ class Join(roles.DMLTableRole, FromClause): constraints[fk.constraint].append((col, fk.parent)) if left is not b: for fk in sorted( - left.foreign_keys, key=lambda fk: fk.parent._creation_order + left.foreign_keys, + key=lambda fk: fk.parent._creation_order, # type: ignore ): if ( consider_as_foreign_keys is not None @@ -1309,7 +1507,8 @@ class Join(roles.DMLTableRole, FromClause): @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: - return [self] + self.left._from_objects + self.right._from_objects + self_list: List[FromClause] = [self] + return self_list + self.left._from_objects + self.right._from_objects class NoInit: @@ -1327,6 +1526,14 @@ class NoInit: ) +class LateralFromClause(NamedFromClause): + """mark a FROM clause as being able to render directly as LATERAL""" + + +_SelfAliasedReturnsRows = TypeVar( + "_SelfAliasedReturnsRows", bound="AliasedReturnsRows" +) + # FromClause -> # AliasedReturnsRows # -> Alias only for FromClause @@ -1335,6 +1542,8 @@ class NoInit: # -> Lateral -> FromClause, but we accept SelectBase # w/ non-deprecated coercion # -> TableSample -> only for FromClause + + class AliasedReturnsRows(NoInit, NamedFromClause): """Base class of aliases against tables, subqueries, and other selectables.""" @@ -1343,24 +1552,21 @@ class AliasedReturnsRows(NoInit, NamedFromClause): _supports_derived_columns = False - element: ClauseElement + element: ReturnsRows - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), ] @classmethod - def _construct(cls, *arg, **kw): + def _construct( + cls: Type[_SelfAliasedReturnsRows], *arg: Any, **kw: Any + ) -> _SelfAliasedReturnsRows: obj = cls.__new__(cls) obj._init(*arg, **kw) return obj - @classmethod - def _factory(cls, returnsrows, name=None): - """Base factory method. Subclasses need to provide this.""" - raise NotImplementedError() - def _init(self, selectable, name=None): self.element = coercions.expect( roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self @@ -1378,11 +1584,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause): name = _anonymous_label.safe_construct(id(self), name or "anon") self.name = name - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super(AliasedReturnsRows, self)._refresh_for_new_column(column) self.element._refresh_for_new_column(column) - @property + def _populate_column_collection(self): + self.element._generate_fromclause_column_proxies(self) + + @util.ro_non_memoized_property def description(self) -> str: name = self.name if isinstance(name, _anonymous_label): @@ -1395,15 +1604,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause): """Legacy for dialects that are referring to Alias.original.""" return self.element - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: if fromclause in self._cloned_set: return True return self.element.is_derived_from(fromclause) - def _populate_column_collection(self): - self.element._generate_fromclause_column_proxies(self) - - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: existing_element = self.element super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) @@ -1420,7 +1628,11 @@ class AliasedReturnsRows(NoInit, NamedFromClause): return [self] -class Alias(roles.DMLTableRole, AliasedReturnsRows): +class FromClauseAlias(AliasedReturnsRows): + element: FromClause + + +class Alias(roles.DMLTableRole, FromClauseAlias): """Represents an table or selectable alias (AS). Represents an alias, as typically applied to any table or @@ -1445,13 +1657,18 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows): element: FromClause @classmethod - def _factory(cls, selectable, name=None, flat=False): + def _factory( + cls, + selectable: FromClause, + name: Optional[str] = None, + flat: bool = False, + ) -> NamedFromClause: return coercions.expect( roles.FromClauseRole, selectable, allow_select=True ).alias(name=name, flat=flat) -class TableValuedAlias(Alias): +class TableValuedAlias(LateralFromClause, Alias): """An alias against a "table valued" SQL function. This construct provides for a SQL function that returns columns @@ -1480,7 +1697,7 @@ class TableValuedAlias(Alias): _render_derived_w_types = False joins_implicitly = False - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), ("_tableval_type", InternalTraversal.dp_type), @@ -1526,7 +1743,9 @@ class TableValuedAlias(Alias): return TableValuedColumn(self, self._tableval_type) - def alias(self, name=None): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> TableValuedAlias: """Return a new alias of this :class:`_sql.TableValuedAlias`. This creates a distinct FROM object that will be distinguished @@ -1547,7 +1766,7 @@ class TableValuedAlias(Alias): return tva - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a new :class:`_sql.TableValuedAlias` with the lateral flag set, so that it renders as LATERAL. @@ -1619,7 +1838,7 @@ class TableValuedAlias(Alias): return new_alias -class Lateral(AliasedReturnsRows): +class Lateral(FromClauseAlias, LateralFromClause): """Represent a LATERAL subquery. This object is constructed from the :func:`_expression.lateral` module @@ -1644,13 +1863,17 @@ class Lateral(AliasedReturnsRows): inherit_cache = True @classmethod - def _factory(cls, selectable, name=None): + def _factory( + cls, + selectable: Union[SelectBase, _FromClauseArgument], + name: Optional[str] = None, + ) -> LateralFromClause: return coercions.expect( roles.FromClauseRole, selectable, explicit_subquery=True ).lateral(name=name) -class TableSample(AliasedReturnsRows): +class TableSample(FromClauseAlias): """Represent a TABLESAMPLE clause. This object is constructed from the :func:`_expression.tablesample` module @@ -1668,13 +1891,22 @@ class TableSample(AliasedReturnsRows): __visit_name__ = "tablesample" - _traverse_internals = AliasedReturnsRows._traverse_internals + [ - ("sampling", InternalTraversal.dp_clauseelement), - ("seed", InternalTraversal.dp_clauseelement), - ] + _traverse_internals: _TraverseInternalsType = ( + AliasedReturnsRows._traverse_internals + + [ + ("sampling", InternalTraversal.dp_clauseelement), + ("seed", InternalTraversal.dp_clauseelement), + ] + ) @classmethod - def _factory(cls, selectable, sampling, name=None, seed=None): + def _factory( + cls, + selectable: _FromClauseArgument, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> TableSample: return coercions.expect(roles.FromClauseRole, selectable).tablesample( sampling, name=name, seed=seed ) @@ -1721,7 +1953,7 @@ class CTE( __visit_name__ = "cte" - _traverse_internals = ( + _traverse_internals: _TraverseInternalsType = ( AliasedReturnsRows._traverse_internals + [ ("_cte_alias", InternalTraversal.dp_clauseelement), @@ -1736,7 +1968,12 @@ class CTE( element: HasCTE @classmethod - def _factory(cls, selectable, name=None, recursive=False): + def _factory( + cls, + selectable: HasCTE, + name: Optional[str] = None, + recursive: bool = False, + ) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -1775,7 +2012,9 @@ class CTE( else: self.element._generate_fromclause_column_proxies(self) - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: """Return an :class:`_expression.Alias` of this :class:`_expression.CTE`. @@ -1814,6 +2053,10 @@ class CTE( :meth:`_sql.HasCTE.cte` - examples of calling styles """ + assert is_select_statement( + self.element + ), f"CTE element f{self.element} does not support union()" + return CTE._construct( self.element.union(*other), name=self.name, @@ -1839,6 +2082,11 @@ class CTE( :meth:`_sql.HasCTE.cte` - examples of calling styles """ + + assert is_select_statement( + self.element + ), f"CTE element f{self.element} does not support union_all()" + return CTE._construct( self.element.union_all(*other), name=self.name, @@ -1865,23 +2113,229 @@ class _CTEOpts(NamedTuple): nesting: bool -class HasCTE(roles.HasCTERole, ClauseElement): +class _ColumnsPlusNames(NamedTuple): + required_label_name: Optional[str] + """ + string label name, if non-None, must be rendered as a + label, i.e. "AS <name>" + """ + + proxy_key: Optional[str] + """ + proxy_key that is to be part of the result map for this + col. this is also the key in a fromclause.c or + select.selected_columns collection + """ + + fallback_label_name: Optional[str] + """ + name that can be used to render an "AS <name>" when + we have to render a label even though + required_label_name was not given + """ + + column: Union[ColumnElement[Any], TextClause] + """ + the ColumnElement itself + """ + + repeated: bool + """ + True if this is a duplicate of a previous column + in the list of columns + """ + + +class SelectsRows(ReturnsRows): + """Sub-base of ReturnsRows for elements that deliver rows + directly, namely SELECT and INSERT/UPDATE/DELETE..RETURNING""" + + _label_style: SelectLabelStyle = LABEL_STYLE_NONE + + def _generate_columns_plus_names( + self, anon_for_dupe_key: bool + ) -> List[_ColumnsPlusNames]: + """Generate column names as rendered in a SELECT statement by + the compiler. + + This is distinct from the _column_naming_convention generator that's + intended for population of .c collections and similar, which has + different rules. the collection returned here calls upon the + _column_naming_convention as well. + + """ + cols = self._all_selected_columns + + key_naming_convention = SelectState._column_naming_convention( + self._label_style + ) + + names = {} + + result: List[_ColumnsPlusNames] = [] + result_append = result.append + + table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + label_style_none = self._label_style is LABEL_STYLE_NONE + + # a counter used for "dedupe" labels, which have double underscores + # in them and are never referred by name; they only act + # as positional placeholders. they need only be unique within + # the single columns clause they're rendered within (required by + # some dbs such as mysql). So their anon identity is tracked against + # a fixed counter rather than hash() identity. + dedupe_hash = 1 + + for c in cols: + repeated = False + + if not c._render_label_in_columns_clause: + effective_name = ( + required_label_name + ) = fallback_label_name = None + elif label_style_none: + if TYPE_CHECKING: + assert is_column_element(c) + + effective_name = required_label_name = None + fallback_label_name = c._non_anon_label or c._anon_name_label + else: + if TYPE_CHECKING: + assert is_column_element(c) + + if table_qualified: + required_label_name = ( + effective_name + ) = fallback_label_name = c._tq_label + else: + effective_name = fallback_label_name = c._non_anon_label + required_label_name = None + + if effective_name is None: + # it seems like this could be _proxy_key and we would + # not need _expression_label but it isn't + # giving us a clue when to use anon_label instead + expr_label = c._expression_label + if expr_label is None: + repeated = c._anon_name_label in names + names[c._anon_name_label] = c + effective_name = required_label_name = None + + if repeated: + # here, "required_label_name" is sent as + # "None" and "fallback_label_name" is sent. + if table_qualified: + fallback_label_name = ( + c._dedupe_anon_tq_label_idx(dedupe_hash) + ) + dedupe_hash += 1 + else: + fallback_label_name = c._dedupe_anon_label_idx( + dedupe_hash + ) + dedupe_hash += 1 + else: + fallback_label_name = c._anon_name_label + else: + required_label_name = ( + effective_name + ) = fallback_label_name = expr_label + + if effective_name is not None: + if TYPE_CHECKING: + assert is_column_element(c) + + if effective_name in names: + # when looking to see if names[name] is the same column as + # c, use hash(), so that an annotated version of the column + # is seen as the same as the non-annotated + if hash(names[effective_name]) != hash(c): + + # different column under the same name. apply + # disambiguating label + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._anon_tq_label + else: + required_label_name = ( + fallback_label_name + ) = c._anon_name_label + + if anon_for_dupe_key and required_label_name in names: + # here, c._anon_tq_label is definitely unique to + # that column identity (or annotated version), so + # this should always be true. + # this is also an infrequent codepath because + # you need two levels of duplication to be here + assert hash(names[required_label_name]) == hash(c) + + # the column under the disambiguating label is + # already present. apply the "dedupe" label to + # subsequent occurrences of the column so that the + # original stays non-ambiguous + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + dedupe_hash += 1 + else: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_label_idx(dedupe_hash) + dedupe_hash += 1 + repeated = True + else: + names[required_label_name] = c + elif anon_for_dupe_key: + # same column under the same name. apply the "dedupe" + # label so that the original stays non-ambiguous + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + dedupe_hash += 1 + else: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_label_idx(dedupe_hash) + dedupe_hash += 1 + repeated = True + else: + names[effective_name] = c + + result_append( + _ColumnsPlusNames( + required_label_name, + key_naming_convention(c), + fallback_label_name, + c, + repeated, + ) + ) + + return result + + +class HasCTE(roles.HasCTERole, SelectsRows): """Mixin that declares a class to include CTE support. .. versionadded:: 1.1 """ - _has_ctes_traverse_internals = [ + _has_ctes_traverse_internals: _TraverseInternalsType = [ ("_independent_ctes", InternalTraversal.dp_clauseelement_list), ("_independent_ctes_opts", InternalTraversal.dp_plain_obj), ] - _independent_ctes = () - _independent_ctes_opts = () + _independent_ctes: Tuple[CTE, ...] = () + _independent_ctes_opts: Tuple[_CTEOpts, ...] = () @_generative - def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE: + def add_cte( + self: SelfHasCTE, *ctes: CTE, nest_here: bool = False + ) -> SelfHasCTE: r"""Add one or more :class:`_sql.CTE` constructs to this statement. This method will associate the given :class:`_sql.CTE` constructs with @@ -1985,7 +2439,12 @@ class HasCTE(roles.HasCTERole, ClauseElement): self._independent_ctes_opts += (opt,) return self - def cte(self, name=None, recursive=False, nesting=False): + def cte( + self, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + ) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -2293,10 +2752,12 @@ class Subquery(AliasedReturnsRows): inherit_cache = True - element: Select + element: SelectBase @classmethod - def _factory(cls, selectable, name=None): + def _factory( + cls, selectable: SelectBase, name: Optional[str] = None + ) -> Subquery: """Return a :class:`.Subquery` object.""" return coercions.expect( roles.SelectStatementRole, selectable @@ -2335,11 +2796,13 @@ class Subquery(AliasedReturnsRows): class FromGrouping(GroupedElement, FromClause): """Represent a grouping of a FROM clause""" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] element: FromClause - def __init__(self, element): + def __init__(self, element: FromClause): self.element = coercions.expect(roles.FromClauseRole, element) def _init_collections(self): @@ -2361,11 +2824,13 @@ class FromGrouping(GroupedElement, FromClause): def foreign_keys(self): return self.element.foreign_keys - def is_derived_from(self, element): - return self.element.is_derived_from(element) + def is_derived_from(self, fromclause: FromClause) -> bool: + return self.element.is_derived_from(fromclause) - def alias(self, **kw): - return FromGrouping(self.element.alias(**kw)) + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromGrouping: + return NamedFromGrouping(self.element.alias(name=name, flat=flat)) def _anonymous_fromclause(self, **kw): return FromGrouping(self.element._anonymous_fromclause(**kw)) @@ -2385,6 +2850,16 @@ class FromGrouping(GroupedElement, FromClause): self.element = state["element"] +class NamedFromGrouping(FromGrouping, NamedFromClause): + """represent a grouping of a named FROM clause + + .. versionadded:: 2.0 + + """ + + inherit_cache = True + + class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """Represents a minimal "table" construct. @@ -2417,7 +2892,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): __visit_name__ = "table" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ( "columns", InternalTraversal.dp_fromclause_canonical_column_collection, @@ -2434,15 +2909,17 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): doesn't support having a primary key or column -level defaults, so implicit returning doesn't apply.""" - _autoincrement_column = None - """No PK or default support so no autoincrement column.""" + @util.ro_memoized_property + def _autoincrement_column(self) -> Optional[ColumnClause[Any]]: + """No PK or default support so no autoincrement column.""" + return None - def __init__(self, name, *columns, **kw): + def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any): super(TableClause, self).__init__() self.name = name self._columns = DedupeColumnCollection() - self.primary_key = ColumnSet() - self.foreign_keys = set() + self.primary_key = ColumnSet() # type: ignore + self.foreign_keys = set() # type: ignore for c in columns: self.append_column(c) @@ -2466,23 +2943,23 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... - def __str__(self): + def __str__(self) -> str: if self.schema is not None: return self.schema + "." + self.name else: return self.name - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: pass - def _init_collections(self): + def _init_collections(self) -> None: pass - @util.memoized_property + @util.ro_memoized_property def description(self) -> str: return self.name - def append_column(self, c, **kw): + def append_column(self, c: ColumnClause[Any]) -> None: existing = c.table if existing is not None and existing is not self: raise exc.ArgumentError( @@ -2494,7 +2971,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): c.table = self @util.preload_module("sqlalchemy.sql.dml") - def insert(self): + def insert(self) -> Insert: """Generate an :func:`_expression.insert` construct against this :class:`_expression.TableClause`. @@ -2505,10 +2982,11 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): See :func:`_expression.insert` for argument and usage information. """ + return util.preloaded.sql_dml.Insert(self) @util.preload_module("sqlalchemy.sql.dml") - def update(self): + def update(self) -> Update: """Generate an :func:`_expression.update` construct against this :class:`_expression.TableClause`. @@ -2524,7 +3002,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): ) @util.preload_module("sqlalchemy.sql.dml") - def delete(self): + def delete(self) -> Delete: """Generate a :func:`_expression.delete` construct against this :class:`_expression.TableClause`. @@ -2543,13 +3021,18 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): class ForUpdateArg(ClauseElement): - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("of", InternalTraversal.dp_clauseelement_list), ("nowait", InternalTraversal.dp_boolean), ("read", InternalTraversal.dp_boolean), ("skip_locked", InternalTraversal.dp_boolean), ] + of: Optional[Sequence[ClauseElement]] + nowait: bool + read: bool + skip_locked: bool + @classmethod def _from_argument(cls, with_for_update): if isinstance(with_for_update, ForUpdateArg): @@ -2606,7 +3089,7 @@ class ForUpdateArg(ClauseElement): SelfValues = typing.TypeVar("SelfValues", bound="Values") -class Values(Generative, NamedFromClause): +class Values(Generative, LateralFromClause): """Represent a ``VALUES`` construct that can be used as a FROM element in a statement. @@ -2619,28 +3102,42 @@ class Values(Generative, NamedFromClause): __visit_name__ = "values" - _data = () + _data: Tuple[List[Tuple[Any, ...]], ...] = () - _traverse_internals = [ + _unnamed: bool + _traverse_internals: _TraverseInternalsType = [ ("_column_args", InternalTraversal.dp_clauseelement_list), ("_data", InternalTraversal.dp_dml_multi_values), ("name", InternalTraversal.dp_string), ("literal_binds", InternalTraversal.dp_boolean), ] - def __init__(self, *columns, name=None, literal_binds=False): + def __init__( + self, + *columns: ColumnClause[Any], + name: Optional[str] = None, + literal_binds: bool = False, + ): super(Values, self).__init__() self._column_args = columns - self.name = name + if name is None: + self._unnamed = True + self.name = _anonymous_label.safe_construct(id(self), "anon") + else: + self._unnamed = False + self.name = name self.literal_binds = literal_binds - self.named_with_column = self.name is not None + self.named_with_column = not self._unnamed @property def _column_types(self): return [col.type for col in self._column_args] @_generative - def alias(self: SelfValues, name, **kw) -> SelfValues: + def alias( + self: SelfValues, name: Optional[str] = None, flat: bool = False + ) -> SelfValues: + """Return a new :class:`_expression.Values` construct that is a copy of this one with the given name. @@ -2655,12 +3152,20 @@ class Values(Generative, NamedFromClause): :func:`_expression.alias` """ - self.name = name - self.named_with_column = self.name is not None + non_none_name: str + + if name is None: + non_none_name = _anonymous_label.safe_construct(id(self), "anon") + else: + non_none_name = name + + self.name = non_none_name + self.named_with_column = True + self._unnamed = False return self @_generative - def lateral(self: SelfValues, name=None) -> SelfValues: + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a new :class:`_expression.Values` with the lateral flag set, so that it renders as LATERAL. @@ -2670,13 +3175,20 @@ class Values(Generative, NamedFromClause): :func:`_expression.lateral` """ + non_none_name: str + + if name is None: + non_none_name = self.name + else: + non_none_name = name + self._is_lateral = True - if name is not None: - self.name = name + self.name = non_none_name + self._unnamed = False return self @_generative - def data(self: SelfValues, values) -> SelfValues: + def data(self: SelfValues, values: List[Tuple[Any, ...]]) -> SelfValues: """Return a new :class:`_expression.Values` construct, adding the given data to the data list. @@ -2694,7 +3206,7 @@ class Values(Generative, NamedFromClause): self._data += (values,) return self - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: for c in self._column_args: self._columns.add(c) c.table = self @@ -2727,32 +3239,16 @@ class SelectBase( """ - _is_select_statement = True + _is_select_base = True is_select = True - def _generate_fromclause_column_proxies( - self, fromclause: FromClause - ) -> None: - raise NotImplementedError() + _label_style: SelectLabelStyle = LABEL_STYLE_NONE def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: self._reset_memoizations() - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - raise NotImplementedError() - - def set_label_style( - self: SelfSelectBase, label_style: SelectLabelStyle - ) -> SelfSelectBase: - raise NotImplementedError() - - def get_label_style(self) -> SelectLabelStyle: - raise NotImplementedError() - - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set. @@ -2797,7 +3293,7 @@ class SelectBase( raise NotImplementedError() @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`, not including @@ -2819,7 +3315,7 @@ class SelectBase( """ - return self.selected_columns + return self.selected_columns.as_readonly() @util.deprecated_property( "1.4", @@ -2841,6 +3337,26 @@ class SelectBase( def columns(self): return self.c + def get_label_style(self) -> SelectLabelStyle: + """ + Retrieve the current label style. + + Implemented by subclasses. + + """ + raise NotImplementedError() + + def set_label_style( + self: SelfSelectBase, style: SelectLabelStyle + ) -> SelfSelectBase: + """Return a new selectable with the specified label style. + + Implemented by subclasses. + + """ + + raise NotImplementedError() + @util.deprecated( "1.4", "The :meth:`_expression.SelectBase.select` method is deprecated " @@ -2857,6 +3373,9 @@ class SelectBase( def _implicit_subquery(self): return self.subquery() + def _scalar_type(self) -> TypeEngine[Any]: + raise NotImplementedError() + @util.deprecated( "1.4", "The :meth:`_expression.SelectBase.as_scalar` " @@ -2926,7 +3445,7 @@ class SelectBase( """ return self.scalar_subquery().label(name) - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a LATERAL alias of this :class:`_expression.Selectable`. The return value is the :class:`_expression.Lateral` construct also @@ -2941,11 +3460,7 @@ class SelectBase( """ return Lateral._factory(self, name) - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - return [self] - - def subquery(self, name=None): + def subquery(self, name: Optional[str] = None) -> Subquery: """Return a subquery of this :class:`_expression.SelectBase`. A subquery is from a SQL perspective a parenthesized, named @@ -2995,7 +3510,9 @@ class SelectBase( raise NotImplementedError() - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> Subquery: """Return a named subquery against this :class:`_expression.SelectBase`. @@ -3023,7 +3540,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ __visit_name__ = "select_statement_grouping" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] _is_select_container = True @@ -3053,13 +3572,14 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def select_statement(self): return self.element - def self_group(self, against=None): + def self_group(self: Self, against: Optional[OperatorType] = None) -> Self: + ... return self - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - return self.element._generate_columns_plus_names(anon_for_dupe_key) + # def _generate_columns_plus_names( + # self, anon_for_dupe_key: bool + # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: + # return self.element._generate_columns_plus_names(anon_for_dupe_key) def _generate_fromclause_column_proxies( self, subquery: FromClause @@ -3070,8 +3590,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def _all_selected_columns(self) -> _SelectIterable: return self.element._all_selected_columns - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that the embedded SELECT statement returns in its result set, not including @@ -3112,25 +3632,30 @@ class GenerativeSelect(SelectBase): """ - _order_by_clauses = () - _group_by_clauses = () - _limit_clause = None - _offset_clause = None - _fetch_clause = None - _fetch_clause_options = None - _for_update_arg = None + _order_by_clauses: Tuple[ColumnElement[Any], ...] = () + _group_by_clauses: Tuple[ColumnElement[Any], ...] = () + _limit_clause: Optional[ColumnElement[Any]] = None + _offset_clause: Optional[ColumnElement[Any]] = None + _fetch_clause: Optional[ColumnElement[Any]] = None + _fetch_clause_options: Optional[Dict[str, bool]] = None + _for_update_arg: Optional[ForUpdateArg] = None - def __init__(self, _label_style=LABEL_STYLE_DEFAULT): + def __init__(self, _label_style: SelectLabelStyle = LABEL_STYLE_DEFAULT): self._label_style = _label_style @_generative def with_for_update( self: SelfGenerativeSelect, - nowait=False, - read=False, - of=None, - skip_locked=False, - key_share=False, + nowait: bool = False, + read: bool = False, + of: Optional[ + Union[ + _ColumnExpressionArgument[Any], + Sequence[_ColumnExpressionArgument[Any]], + ] + ] = None, + skip_locked: bool = False, + key_share: bool = False, ) -> SelfGenerativeSelect: """Specify a ``FOR UPDATE`` clause for this :class:`_expression.GenerativeSelect`. @@ -3241,20 +3766,25 @@ class GenerativeSelect(SelectBase): return self @property - def _group_by_clause(self): + def _group_by_clause(self) -> ClauseList: """ClauseList access to group_by_clauses for legacy dialects""" return ClauseList._construct_raw( operators.comma_op, self._group_by_clauses ) @property - def _order_by_clause(self): + def _order_by_clause(self) -> ClauseList: """ClauseList access to order_by_clauses for legacy dialects""" return ClauseList._construct_raw( operators.comma_op, self._order_by_clauses ) - def _offset_or_limit_clause(self, element, name=None, type_=None): + def _offset_or_limit_clause( + self, + element: Union[int, _ColumnExpressionArgument[Any]], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[int]] = None, + ) -> ColumnElement[Any]: """Convert the given value to an "offset or limit" clause. This handles incoming integers and converts to an expression; if @@ -3265,7 +3795,21 @@ class GenerativeSelect(SelectBase): roles.LimitOffsetRole, element, name=name, type_=type_ ) - def _offset_or_limit_clause_asint(self, clause, attrname): + @overload + def _offset_or_limit_clause_asint( + self, clause: ColumnElement[Any], attrname: str + ) -> NoReturn: + ... + + @overload + def _offset_or_limit_clause_asint( + self, clause: Optional[_OffsetLimitParam], attrname: str + ) -> Optional[int]: + ... + + def _offset_or_limit_clause_asint( + self, clause: Optional[ColumnElement[Any]], attrname: str + ) -> Union[NoReturn, Optional[int]]: """Convert the "offset or limit" clause of a select construct to an integer. @@ -3286,7 +3830,7 @@ class GenerativeSelect(SelectBase): return util.asint(value) @property - def _limit(self): + def _limit(self) -> Optional[int]: """Get an integer value for the limit. This should only be used by code that cannot support a limit as a BindParameter or other custom clause as it will throw an exception if the limit @@ -3295,14 +3839,14 @@ class GenerativeSelect(SelectBase): """ return self._offset_or_limit_clause_asint(self._limit_clause, "limit") - def _simple_int_clause(self, clause): + def _simple_int_clause(self, clause: ClauseElement) -> bool: """True if the clause is a simple integer, False if it is not present or is a SQL expression. """ return isinstance(clause, _OffsetLimitParam) @property - def _offset(self): + def _offset(self) -> Optional[int]: """Get an integer value for the offset. This should only be used by code that cannot support an offset as a BindParameter or other custom clause as it will throw an exception if the @@ -3314,7 +3858,7 @@ class GenerativeSelect(SelectBase): ) @property - def _has_row_limiting_clause(self): + def _has_row_limiting_clause(self) -> bool: return ( self._limit_clause is not None or self._offset_clause is not None @@ -3322,7 +3866,10 @@ class GenerativeSelect(SelectBase): ) @_generative - def limit(self: SelfGenerativeSelect, limit) -> SelfGenerativeSelect: + def limit( + self: SelfGenerativeSelect, + limit: Union[int, _ColumnExpressionArgument[int]], + ) -> SelfGenerativeSelect: """Return a new selectable with the given LIMIT criterion applied. @@ -3356,7 +3903,10 @@ class GenerativeSelect(SelectBase): @_generative def fetch( - self: SelfGenerativeSelect, count, with_ties=False, percent=False + self: SelfGenerativeSelect, + count: Union[int, _ColumnExpressionArgument[int]], + with_ties: bool = False, + percent: bool = False, ) -> SelfGenerativeSelect: """Return a new selectable with the given FETCH FIRST criterion applied. @@ -3408,7 +3958,10 @@ class GenerativeSelect(SelectBase): return self @_generative - def offset(self: SelfGenerativeSelect, offset) -> SelfGenerativeSelect: + def offset( + self: SelfGenerativeSelect, + offset: Union[int, _ColumnExpressionArgument[int]], + ) -> SelfGenerativeSelect: """Return a new selectable with the given OFFSET criterion applied. @@ -3438,7 +3991,11 @@ class GenerativeSelect(SelectBase): @_generative @util.preload_module("sqlalchemy.sql.util") - def slice(self: SelfGenerativeSelect, start, stop) -> SelfGenerativeSelect: + def slice( + self: SelfGenerativeSelect, + start: int, + stop: int, + ) -> SelfGenerativeSelect: """Apply LIMIT / OFFSET to this statement based on a slice. The start and stop indices behave like the argument to Python's @@ -3485,7 +4042,9 @@ class GenerativeSelect(SelectBase): return self @_generative - def order_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect: + def order_by( + self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of ORDER BY criteria applied. @@ -3522,7 +4081,9 @@ class GenerativeSelect(SelectBase): return self @_generative - def group_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect: + def group_by( + self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of GROUP BY criterion applied. @@ -3567,6 +4128,15 @@ class CompoundSelectState(CompileState): return d, d, d +class _CompoundSelectKeyword(Enum): + UNION = "UNION" + UNION_ALL = "UNION ALL" + EXCEPT = "EXCEPT" + EXCEPT_ALL = "EXCEPT ALL" + INTERSECT = "INTERSECT" + INTERSECT_ALL = "INTERSECT ALL" + + class CompoundSelect(HasCompileState, GenerativeSelect): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -3590,7 +4160,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect): __visit_name__ = "compound_select" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("selects", InternalTraversal.dp_clauseelement_list), ("_limit_clause", InternalTraversal.dp_clauseelement), ("_offset_clause", InternalTraversal.dp_clauseelement), @@ -3602,17 +4172,16 @@ class CompoundSelect(HasCompileState, GenerativeSelect): ("keyword", InternalTraversal.dp_string), ] + SupportsCloneAnnotations._clone_annotations_traverse_internals - UNION = util.symbol("UNION") - UNION_ALL = util.symbol("UNION ALL") - EXCEPT = util.symbol("EXCEPT") - EXCEPT_ALL = util.symbol("EXCEPT ALL") - INTERSECT = util.symbol("INTERSECT") - INTERSECT_ALL = util.symbol("INTERSECT ALL") + selects: List[SelectBase] _is_from_container = True _auto_correlate = False - def __init__(self, keyword, *selects): + def __init__( + self, + keyword: _CompoundSelectKeyword, + *selects: _SelectStatementForCompoundArgument, + ): self.keyword = keyword self.selects = [ coercions.expect(roles.CompoundElementRole, s).self_group( @@ -3624,36 +4193,50 @@ class CompoundSelect(HasCompileState, GenerativeSelect): GenerativeSelect.__init__(self) @classmethod - def _create_union(cls, *selects, **kwargs): - return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs) + def _create_union( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.UNION, *selects) @classmethod - def _create_union_all(cls, *selects): - return CompoundSelect(CompoundSelect.UNION_ALL, *selects) + def _create_union_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects) @classmethod - def _create_except(cls, *selects): - return CompoundSelect(CompoundSelect.EXCEPT, *selects) + def _create_except( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects) @classmethod - def _create_except_all(cls, *selects): - return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects) + def _create_except_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects) @classmethod - def _create_intersect(cls, *selects): - return CompoundSelect(CompoundSelect.INTERSECT, *selects) + def _create_intersect( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects) @classmethod - def _create_intersect_all(cls, *selects): - return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects) + def _create_intersect_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects) - def _scalar_type(self): + def _scalar_type(self) -> TypeEngine[Any]: return self.selects[0]._scalar_type() - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> GroupedElement: return SelectStatementGrouping(self) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: for s in self.selects: if s.is_derived_from(fromclause): return True @@ -3675,7 +4258,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self - def _generate_fromclause_column_proxies(self, subquery): + def _generate_fromclause_column_proxies( + self, subquery: FromClause + ) -> None: # this is a slightly hacky thing - the union exports a # column that resembles just that of the *first* selectable. @@ -3716,8 +4301,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect): def _all_selected_columns(self) -> _SelectIterable: return self.selects[0]._all_selected_columns - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -3739,6 +4324,11 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self.selects[0].selected_columns +# backwards compat +for elem in _CompoundSelectKeyword: + setattr(CompoundSelect, elem.name, elem) + + @CompileState.plugin_for("default", "select") class SelectState(util.MemoizedSlots, CompileState): __slots__ = ( @@ -3758,10 +4348,12 @@ class SelectState(util.MemoizedSlots, CompileState): if TYPE_CHECKING: @classmethod - def get_plugin_class(cls, statement: Select) -> SelectState: + def get_plugin_class(cls, statement: Executable) -> Type[SelectState]: ... - def __init__(self, statement, compiler, **kw): + def __init__( + self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any + ): self.statement = statement self.from_clauses = statement._from_obj @@ -3778,14 +4370,16 @@ class SelectState(util.MemoizedSlots, CompileState): self.columns_plus_names = statement._generate_columns_plus_names(True) @classmethod - def _plugin_not_implemented(cls): + def _plugin_not_implemented(cls) -> NoReturn: raise NotImplementedError( "The default SELECT construct without plugins does not " "implement this method." ) @classmethod - def get_column_descriptions(cls, statement): + def get_column_descriptions( + cls, statement: Select + ) -> List[Dict[str, Any]]: return [ { "name": name, @@ -3798,11 +4392,13 @@ class SelectState(util.MemoizedSlots, CompileState): ] @classmethod - def from_statement(cls, statement, from_statement): + def from_statement( + cls, statement: Select, from_statement: ReturnsRows + ) -> Any: cls._plugin_not_implemented() @classmethod - def get_columns_clause_froms(cls, statement): + def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( element._from_objects for element in statement._raw_columns @@ -3810,7 +4406,9 @@ class SelectState(util.MemoizedSlots, CompileState): ) @classmethod - def _column_naming_convention(cls, label_style): + def _column_naming_convention( + cls, label_style: SelectLabelStyle + ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]: table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL dedupe = label_style is not LABEL_STYLE_NONE @@ -3850,7 +4448,8 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement): + def _get_froms(self, statement: Select) -> List[FromClause]: + ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} return self._normalize_froms( @@ -3876,10 +4475,10 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def _normalize_froms( cls, - iterable_of_froms, - check_statement=None, - ambiguous_table_name_map=None, - ): + iterable_of_froms: Iterable[FromClause], + check_statement: Optional[Select] = None, + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what would actually render in the FROM clause of a SELECT. @@ -3888,12 +4487,12 @@ class SelectState(util.MemoizedSlots, CompileState): etc. """ - seen = set() - froms = [] + seen: Set[FromClause] = set() + froms: List[FromClause] = [] for item in iterable_of_froms: - if item._is_subquery and item.element is check_statement: + if is_subquery(item) and item.element is check_statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" ) @@ -3923,7 +4522,7 @@ class SelectState(util.MemoizedSlots, CompileState): ) for item in froms for fr in item._from_objects - if fr._is_table + if is_table(fr) and fr.schema and fr.name not in ambiguous_table_name_map ) @@ -3931,8 +4530,10 @@ class SelectState(util.MemoizedSlots, CompileState): return froms def _get_display_froms( - self, explicit_correlate_froms=None, implicit_correlate_froms=None - ): + self, + explicit_correlate_froms: Optional[Sequence[FromClause]] = None, + implicit_correlate_froms: Optional[Sequence[FromClause]] = None, + ) -> List[FromClause]: """Return the full list of 'from' clauses to be displayed. Takes into account a set of existing froms which may be @@ -3998,25 +4599,33 @@ class SelectState(util.MemoizedSlots, CompileState): return froms - def _memoized_attr__label_resolve_dict(self): - with_cols = dict( - (c._tq_label or c.key, c) + def _memoized_attr__label_resolve_dict( + self, + ) -> Tuple[ + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + ]: + with_cols: Dict[str, ColumnElement[Any]] = dict( + (c._tq_label or c.key, c) # type: ignore for c in self.statement._all_selected_columns if c._allow_label_resolve ) - only_froms = dict( - (c.key, c) + only_froms: Dict[str, ColumnElement[Any]] = dict( + (c.key, c) # type: ignore for c in _select_iterables(self.froms) if c._allow_label_resolve ) - only_cols = with_cols.copy() + only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) return with_cols, only_froms, only_cols @classmethod - def determine_last_joined_entity(cls, stmt): + def determine_last_joined_entity( + cls, stmt: Select + ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] else: @@ -4026,8 +4635,16 @@ class SelectState(util.MemoizedSlots, CompileState): def all_selected_columns(cls, statement: Select) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] - def _setup_joins(self, args, raw_columns): + def _setup_joins( + self, + args: Tuple[_SetupJoinsElement, ...], + raw_columns: List[_ColumnsClauseElement], + ) -> None: for (right, onclause, left, flags) in args: + if TYPE_CHECKING: + if onclause is not None: + assert isinstance(onclause, ColumnElement) + isouter = flags["isouter"] full = flags["full"] @@ -4043,6 +4660,16 @@ class SelectState(util.MemoizedSlots, CompileState): left ) + # these assertions can be made here, as if the right/onclause + # contained ORM elements, the select() statement would have been + # upgraded to an ORM select, and this method would not be called; + # orm.context.ORMSelectCompileState._join() would be + # used instead. + if TYPE_CHECKING: + assert isinstance(right, FromClause) + if onclause is not None: + assert isinstance(onclause, ColumnElement) + if replace_from_obj_index is not None: # splice into an existing element in the # self._from_obj list @@ -4062,15 +4689,19 @@ class SelectState(util.MemoizedSlots, CompileState): + self.from_clauses[replace_from_obj_index + 1 :] ) else: - + assert left is not None self.from_clauses = self.from_clauses + ( Join(left, right, onclause, isouter=isouter, full=full), ) @util.preload_module("sqlalchemy.sql.util") def _join_determine_implicit_left_side( - self, raw_columns, left, right, onclause - ): + self, + raw_columns: List[_ColumnsClauseElement], + left: Optional[FromClause], + right: _JoinTargetElement, + onclause: Optional[ColumnElement[Any]], + ) -> Tuple[Optional[FromClause], Optional[int]]: """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -4079,13 +4710,13 @@ class SelectState(util.MemoizedSlots, CompileState): sql_util = util.preloaded.sql_util - replace_from_obj_index = None + replace_from_obj_index: Optional[int] = None from_clauses = self.from_clauses if from_clauses: - indexes = sql_util.find_left_clause_to_join_from( + indexes: List[int] = sql_util.find_left_clause_to_join_from( from_clauses, right, onclause ) @@ -4138,15 +4769,17 @@ class SelectState(util.MemoizedSlots, CompileState): return left, replace_from_obj_index @util.preload_module("sqlalchemy.sql.util") - def _join_place_explicit_left_side(self, left): - replace_from_obj_index = None + def _join_place_explicit_left_side( + self, left: FromClause + ) -> Optional[int]: + replace_from_obj_index: Optional[int] = None sql_util = util.preloaded.sql_util from_clauses = list(self.statement._iterate_from_elements()) if from_clauses: - indexes = sql_util.find_left_clause_that_matches_given( + indexes: List[int] = sql_util.find_left_clause_that_matches_given( self.from_clauses, left ) else: @@ -4171,7 +4804,13 @@ class SelectState(util.MemoizedSlots, CompileState): class _SelectFromElements: - def _iterate_from_elements(self): + __slots__ = () + + _raw_columns: List[_ColumnsClauseElement] + _where_criteria: Tuple[ColumnElement[Any], ...] + _from_obj: Tuple[FromClause, ...] + + def _iterate_from_elements(self) -> Iterator[FromClause]: # note this does not include elements # in _setup_joins @@ -4195,28 +4834,58 @@ class _SelectFromElements: yield element +Self_MemoizedSelectEntities = TypeVar("Self_MemoizedSelectEntities", bound=Any) + + class _MemoizedSelectEntities( cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible ): + """represents partial state from a Select object, for the case + where Select.columns() has redefined the set of columns/entities the + statement will be SELECTing from. This object represents + the entities from the SELECT before that transformation was applied, + so that transformations that were made in terms of the SELECT at that + time, such as join() as well as options(), can access the correct context. + + In previous SQLAlchemy versions, this wasn't needed because these + constructs calculated everything up front, like when you called join() + or options(), it did everything to figure out how that would translate + into specific SQL constructs that would be ready to send directly to the + SQL compiler when needed. But as of + 1.4, all of that stuff is done in the compilation phase, during the + "compile state" portion of the process, so that the work can all be + cached. So it needs to be able to resolve joins/options2 based on what + the list of entities was when those methods were called. + + + """ + __visit_name__ = "memoized_select_entities" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ("_setup_joins", InternalTraversal.dp_setup_join_tuple), ("_with_options", InternalTraversal.dp_executable_options), ] + _is_clone_of: Optional[ClauseElement] + _raw_columns: List[_ColumnsClauseElement] + _setup_joins: Tuple[_SetupJoinsElement, ...] + _with_options: Tuple[ExecutableOption, ...] + _annotations = util.EMPTY_DICT - def _clone(self, **kw): + def _clone( + self: Self_MemoizedSelectEntities, **kw: Any + ) -> Self_MemoizedSelectEntities: c = self.__class__.__new__(self.__class__) c.__dict__ = {k: v for k, v in self.__dict__.items()} c._is_clone_of = self.__dict__.get("_is_clone_of", self) - return c + return c # type: ignore @classmethod - def _generate_for_statement(cls, select_stmt): + def _generate_for_statement(cls, select_stmt: Select) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -4224,12 +4893,10 @@ class _MemoizedSelectEntities( self._with_options = select_stmt._with_options select_stmt._memoized_select_entities += (self,) - select_stmt._raw_columns = ( - select_stmt._setup_joins - ) = select_stmt._with_options = () + select_stmt._raw_columns = [] + select_stmt._setup_joins = select_stmt._with_options = () -# TODO: use pep-673 when feasible SelfSelect = typing.TypeVar("SelfSelect", bound="Select") @@ -4258,9 +4925,11 @@ class Select( __visit_name__ = "select" - _setup_joins: Tuple[TODO_Any, ...] = () + _setup_joins: Tuple[_SetupJoinsElement, ...] = () _memoized_select_entities: Tuple[TODO_Any, ...] = () + _raw_columns: List[_ColumnsClauseElement] + _distinct = False _distinct_on: Tuple[ColumnElement[Any], ...] = () _correlate: Tuple[FromClause, ...] = () @@ -4269,12 +4938,12 @@ class Select( _having_criteria: Tuple[ColumnElement[Any], ...] = () _from_obj: Tuple[FromClause, ...] = () _auto_correlate = True - + _is_select_statement = True _compile_options: CacheableOptions = ( SelectState.default_select_compile_options ) - _traverse_internals = ( + _traverse_internals: _TraverseInternalsType = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ( @@ -4306,12 +4975,14 @@ class Select( + Executable._executable_traverse_internals ) - _cache_key_traversal = _traverse_internals + [ + _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [ ("_compile_options", InternalTraversal.dp_has_cache_key) ] + _compile_state_factory: Type[SelectState] + @classmethod - def _create_raw_select(cls, **kw) -> "Select": + def _create_raw_select(cls, **kw: Any) -> Select: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -4330,6 +5001,12 @@ class Select( :func:`_sql.select` function. """ + things = [ + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) + for ent in entities + ] self._raw_columns = [ coercions.expect( @@ -4340,7 +5017,7 @@ class Select( GenerativeSelect.__init__(self) - def _scalar_type(self): + def _scalar_type(self) -> TypeEngine[Any]: elem = self._raw_columns[0] cols = list(elem._select_iterable) return cols[0].type @@ -4446,7 +5123,12 @@ class Select( @_generative def join( - self: SelfSelect, target, onclause=None, *, isouter=False, full=False + self: SelfSelect, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + isouter: bool = False, + full: bool = False, ) -> SelfSelect: r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion @@ -4505,17 +5187,32 @@ class Select( :meth:`_expression.Select.outerjoin` """ # noqa: E501 - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self ) if onclause is not None: - onclause = coercions.expect(roles.OnClauseRole, onclause) + onclause_element = coercions.expect(roles.OnClauseRole, onclause) + else: + onclause_element = None + self._setup_joins += ( - (target, onclause, None, {"isouter": isouter, "full": full}), + ( + join_target, + onclause_element, + None, + {"isouter": isouter, "full": full}, + ), ) return self - def outerjoin_from(self, from_, target, onclause=None, *, full=False): + def outerjoin_from( + self: SelfSelect, + from_: _FromClauseArgument, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfSelect: r"""Create a SQL LEFT OUTER JOIN against this :class:`_expression.Select` object's criterion and apply generatively, returning the newly resulting @@ -4531,12 +5228,12 @@ class Select( @_generative def join_from( self: SelfSelect, - from_, - target, - onclause=None, + from_: _FromClauseArgument, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, *, - isouter=False, - full=False, + isouter: bool = False, + full: bool = False, ) -> SelfSelect: r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion @@ -4586,18 +5283,31 @@ class Select( from_ = coercions.expect( roles.FromClauseRole, from_, apply_propagate_attrs=self ) - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self ) if onclause is not None: - onclause = coercions.expect(roles.OnClauseRole, onclause) + onclause_element = coercions.expect(roles.OnClauseRole, onclause) + else: + onclause_element = None self._setup_joins += ( - (target, onclause, from_, {"isouter": isouter, "full": full}), + ( + join_target, + onclause_element, + from_, + {"isouter": isouter, "full": full}, + ), ) return self - def outerjoin(self, target, onclause=None, *, full=False): + def outerjoin( + self: SelfSelect, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfSelect: """Create a left outer join. Parameters are the same as that of :meth:`_expression.Select.join`. @@ -4634,7 +5344,7 @@ class Select( """ return self.join(target, onclause=onclause, isouter=True, full=full) - def get_final_froms(self): + def get_final_froms(self) -> Sequence[FromClause]: """Compute the final displayed list of :class:`_expression.FromClause` elements. @@ -4671,6 +5381,7 @@ class Select( :attr:`_sql.Select.columns_clause_froms` """ + return self._compile_state_factory(self, None)._get_display_froms() @util.deprecated_property( @@ -4678,7 +5389,7 @@ class Select( "The :attr:`_expression.Select.froms` attribute is moved to " "the :meth:`_expression.Select.get_final_froms` method.", ) - def froms(self): + def froms(self) -> Sequence[FromClause]: """Return the displayed list of :class:`_expression.FromClause` elements. @@ -4687,7 +5398,7 @@ class Select( return self.get_final_froms() @property - def columns_clause_froms(self): + def columns_clause_froms(self) -> List[FromClause]: """Return the set of :class:`_expression.FromClause` objects implied by the columns clause of this SELECT statement. @@ -4720,7 +5431,7 @@ class Select( return iter(self._all_selected_columns) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: if self in fromclause._cloned_set: return True @@ -4729,7 +5440,9 @@ class Select( return True return False - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: # Select() object has been cloned and probably adapted by the # given clone function. Apply the cloning function to internal # objects @@ -4786,13 +5499,15 @@ class Select( def get_children(self, **kwargs): return itertools.chain( super(Select, self).get_children( - omit_attrs=["_from_obj", "_correlate", "_correlate_except"] + omit_attrs=("_from_obj", "_correlate", "_correlate_except") ), self._iterate_from_elements(), ) @_generative - def add_columns(self: SelfSelect, *columns) -> SelfSelect: + def add_columns( + self: SelfSelect, *columns: _ColumnsClauseArgument + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given column expressions added to its columns clause. @@ -4816,7 +5531,9 @@ class Select( ] return self - def _set_entities(self, entities): + def _set_entities( + self, entities: Iterable[_ColumnsClauseArgument] + ) -> None: self._raw_columns = [ coercions.expect( roles.ColumnsClauseRole, ent, apply_propagate_attrs=self @@ -4830,7 +5547,7 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self, column): + def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -4847,7 +5564,9 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns(self, only_synonyms=True): + def reduce_columns( + self: SelfSelect, only_synonyms: bool = True + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -4880,7 +5599,9 @@ class Select( @_generative def with_only_columns( - self: SelfSelect, *columns, maintain_column_froms=False + self: SelfSelect, + *columns: _ColumnsClauseArgument, + maintain_column_froms: bool = False, ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given columns. @@ -4941,7 +5662,9 @@ class Select( self._assert_no_memoizations() if maintain_column_froms: - self.select_from.non_generative(self, *self.columns_clause_froms) + self.select_from.non_generative( # type: ignore + self, *self.columns_clause_froms + ) # then memoize the FROMs etc. _MemoizedSelectEntities._generate_for_statement(self) @@ -4974,7 +5697,9 @@ class Select( _whereclause = whereclause @_generative - def where(self: SelfSelect, *whereclause) -> SelfSelect: + def where( + self: SelfSelect, *whereclause: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any. @@ -4984,24 +5709,33 @@ class Select( assert isinstance(self._where_criteria, tuple) for criterion in whereclause: - where_criteria = coercions.expect(roles.WhereHavingRole, criterion) + where_criteria: ColumnElement[Any] = coercions.expect( + roles.WhereHavingRole, criterion + ) self._where_criteria += (where_criteria,) return self @_generative - def having(self: SelfSelect, having) -> SelfSelect: + def having( + self: SelfSelect, *having: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given expression added to its HAVING clause, joined to the existing clause via AND, if any. """ - self._having_criteria += ( - coercions.expect(roles.WhereHavingRole, having), - ) + + for criterion in having: + having_criteria = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._having_criteria += (having_criteria,) return self @_generative - def distinct(self: SelfSelect, *expr) -> SelfSelect: + def distinct( + self: SelfSelect, *expr: _ColumnExpressionArgument[Any] + ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct which will apply DISTINCT to its columns clause. @@ -5023,7 +5757,9 @@ class Select( return self @_generative - def select_from(self: SelfSelect, *froms) -> SelfSelect: + def select_from( + self: SelfSelect, *froms: _FromClauseArgument + ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct with the given FROM expression(s) merged into its list of FROM objects. @@ -5067,7 +5803,10 @@ class Select( return self @_generative - def correlate(self: SelfSelect, *fromclauses) -> SelfSelect: + def correlate( + self: SelfSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfSelect: r"""Return a new :class:`_expression.Select` which will correlate the given FROM clauses to that of an enclosing :class:`_expression.Select`. @@ -5106,10 +5845,10 @@ class Select( none of its FROM entries, and all will render unconditionally in the local FROM clause. - :param \*fromclauses: a list of one or more - :class:`_expression.FromClause` - constructs, or other compatible constructs (i.e. ORM-mapped - classes) to become part of the correlate collection. + :param \*fromclauses: one or more :class:`.FromClause` or other + FROM-compatible construct such as an ORM mapped entity to become part + of the correlate collection; alternatively pass a single value + ``None`` to remove all existing correlations. .. seealso:: @@ -5119,8 +5858,16 @@ class Select( """ + # tests failing when we try to change how these + # arguments are passed + self._auto_correlate = False - if fromclauses and fromclauses[0] in {None, False}: + if not fromclauses or fromclauses[0] in {None, False}: + if len(fromclauses) > 1: + raise exc.ArgumentError( + "additional FROM objects not accepted when " + "passing None/False to correlate()" + ) self._correlate = () else: self._correlate = self._correlate + tuple( @@ -5129,7 +5876,10 @@ class Select( return self @_generative - def correlate_except(self: SelfSelect, *fromclauses) -> SelfSelect: + def correlate_except( + self: SelfSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfSelect: r"""Return a new :class:`_expression.Select` which will omit the given FROM clauses from the auto-correlation process. @@ -5141,9 +5891,9 @@ class Select( all other FROM elements remain subject to normal auto-correlation behaviors. - If ``None`` is passed, the :class:`_expression.Select` - object will correlate - all of its FROM entries. + If ``None`` is passed, or no arguments are passed, + the :class:`_expression.Select` object will correlate all of its + FROM entries. :param \*fromclauses: a list of one or more :class:`_expression.FromClause` @@ -5159,16 +5909,22 @@ class Select( """ self._auto_correlate = False - if fromclauses and fromclauses[0] in {None, False}: + if not fromclauses or fromclauses[0] in {None, False}: + if len(fromclauses) > 1: + raise exc.ArgumentError( + "additional FROM objects not accepted when " + "passing None/False to correlate_except()" + ) self._correlate_except = () else: self._correlate_except = (self._correlate_except or ()) + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) + return self - @HasMemoized.memoized_attribute - def selected_columns(self): + @HasMemoized_ro_memoized_attribute + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -5214,18 +5970,22 @@ class Select( # generates the actual names used in the SELECT string. that # method is more complex because it also renders columns that are # fully ambiguous, e.g. same column more than once. - conv = SelectState._column_naming_convention(self._label_style) + conv = cast( + "Callable[[Any], str]", + SelectState._column_naming_convention(self._label_style), + ) - return ColumnCollection( + cc: ColumnCollection[str, ColumnElement[Any]] = ColumnCollection( [ (conv(c), c) for c in self._all_selected_columns if is_column_element(c) ] - ).as_readonly() + ) + return cc.as_readonly() @HasMemoized.memoized_attribute - def _all_selected_columns(self) -> Sequence[ColumnElement[Any]]: + def _all_selected_columns(self) -> _SelectIterable: meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) @@ -5234,173 +5994,9 @@ class Select( self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - """Generate column names as rendered in a SELECT statement by - the compiler. - - This is distinct from the _column_naming_convention generator that's - intended for population of .c collections and similar, which has - different rules. the collection returned here calls upon the - _column_naming_convention as well. - - """ - cols = self._all_selected_columns - - key_naming_convention = SelectState._column_naming_convention( - self._label_style - ) - - names = {} - - result = [] - result_append = result.append - - table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL - label_style_none = self._label_style is LABEL_STYLE_NONE - - # a counter used for "dedupe" labels, which have double underscores - # in them and are never referred by name; they only act - # as positional placeholders. they need only be unique within - # the single columns clause they're rendered within (required by - # some dbs such as mysql). So their anon identity is tracked against - # a fixed counter rather than hash() identity. - dedupe_hash = 1 - - for c in cols: - repeated = False - - if not c._render_label_in_columns_clause: - effective_name = ( - required_label_name - ) = fallback_label_name = None - elif label_style_none: - effective_name = required_label_name = None - fallback_label_name = c._non_anon_label or c._anon_name_label - else: - if table_qualified: - required_label_name = ( - effective_name - ) = fallback_label_name = c._tq_label - else: - effective_name = fallback_label_name = c._non_anon_label - required_label_name = None - - if effective_name is None: - # it seems like this could be _proxy_key and we would - # not need _expression_label but it isn't - # giving us a clue when to use anon_label instead - expr_label = c._expression_label - if expr_label is None: - repeated = c._anon_name_label in names - names[c._anon_name_label] = c - effective_name = required_label_name = None - - if repeated: - # here, "required_label_name" is sent as - # "None" and "fallback_label_name" is sent. - if table_qualified: - fallback_label_name = ( - c._dedupe_anon_tq_label_idx(dedupe_hash) - ) - dedupe_hash += 1 - else: - fallback_label_name = c._dedupe_anon_label_idx( - dedupe_hash - ) - dedupe_hash += 1 - else: - fallback_label_name = c._anon_name_label - else: - required_label_name = ( - effective_name - ) = fallback_label_name = expr_label - - if effective_name is not None: - if effective_name in names: - # when looking to see if names[name] is the same column as - # c, use hash(), so that an annotated version of the column - # is seen as the same as the non-annotated - if hash(names[effective_name]) != hash(c): - - # different column under the same name. apply - # disambiguating label - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._anon_tq_label - else: - required_label_name = ( - fallback_label_name - ) = c._anon_name_label - - if anon_for_dupe_key and required_label_name in names: - # here, c._anon_tq_label is definitely unique to - # that column identity (or annotated version), so - # this should always be true. - # this is also an infrequent codepath because - # you need two levels of duplication to be here - assert hash(names[required_label_name]) == hash(c) - - # the column under the disambiguating label is - # already present. apply the "dedupe" label to - # subsequent occurrences of the column so that the - # original stays non-ambiguous - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) - dedupe_hash += 1 - else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) - dedupe_hash += 1 - repeated = True - else: - names[required_label_name] = c - elif anon_for_dupe_key: - # same column under the same name. apply the "dedupe" - # label so that the original stays non-ambiguous - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) - dedupe_hash += 1 - else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) - dedupe_hash += 1 - repeated = True - else: - names[effective_name] = c - - result_append( - ( - # string label name, if non-None, must be rendered as a - # label, i.e. "AS <name>" - required_label_name, - # proxy_key that is to be part of the result map for this - # col. this is also the key in a fromclause.c or - # select.selected_columns collection - key_naming_convention(c), - # name that can be used to render an "AS <name>" when - # we have to render a label even though - # required_label_name was not given - fallback_label_name, - # the ColumnElement itself - c, - # True if this is a duplicate of a previous column - # in the list of columns - repeated, - ) - ) - - return result - - def _generate_fromclause_column_proxies(self, subquery): + def _generate_fromclause_column_proxies( + self, subquery: FromClause + ) -> None: """Generate column proxies to place in the exported ``.c`` collection of a subquery.""" @@ -5418,7 +6014,7 @@ class Select( c, repeated, ) in (self._generate_columns_plus_names(False)) - if not c._is_text_clause + if is_column_element(c) ] subquery._columns._populate_separate_keys(prox) @@ -5428,7 +6024,10 @@ class Select( self._order_by_clause.clauses ) - def self_group(self, against=None): + def self_group( + self: Self, against: Optional[OperatorType] = None + ) -> Union[SelectStatementGrouping, Self]: + ... """Return a 'grouping' construct as per the :class:`_expression.ClauseElement` specification. @@ -5445,7 +6044,9 @@ class Select( else: return SelectStatementGrouping(self) - def union(self, *other, **kwargs): + def union( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``UNION`` of this select() construct against the given selectables provided as positional arguments. @@ -5460,9 +6061,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union(self, *other, **kwargs) + return CompoundSelect._create_union(self, *other) - def union_all(self, *other, **kwargs): + def union_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``UNION ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5477,9 +6080,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union_all(self, *other, **kwargs) + return CompoundSelect._create_union_all(self, *other) - def except_(self, *other, **kwargs): + def except_( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``EXCEPT`` of this select() construct against the given selectable provided as positional arguments. @@ -5490,13 +6095,12 @@ class Select( multiple elements are now accepted. - :param \**kwargs: keyword arguments are forwarded to the constructor - for the newly created :class:`_sql.CompoundSelect` object. - """ - return CompoundSelect._create_except(self, *other, **kwargs) + return CompoundSelect._create_except(self, *other) - def except_all(self, *other, **kwargs): + def except_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``EXCEPT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5507,13 +6111,12 @@ class Select( multiple elements are now accepted. - :param \**kwargs: keyword arguments are forwarded to the constructor - for the newly created :class:`_sql.CompoundSelect` object. - """ - return CompoundSelect._create_except_all(self, *other, **kwargs) + return CompoundSelect._create_except_all(self, *other) - def intersect(self, *other, **kwargs): + def intersect( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``INTERSECT`` of this select() construct against the given selectables provided as positional arguments. @@ -5528,9 +6131,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect(self, *other, **kwargs) + return CompoundSelect._create_intersect(self, *other) - def intersect_all(self, *other, **kwargs): + def intersect_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``INTERSECT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5545,13 +6150,17 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect_all(self, *other, **kwargs) + return CompoundSelect._create_intersect_all(self, *other) -SelfScalarSelect = typing.TypeVar("SelfScalarSelect", bound="ScalarSelect") +SelfScalarSelect = typing.TypeVar( + "SelfScalarSelect", bound="ScalarSelect[Any]" +) -class ScalarSelect(roles.InElementRole, Generative, Grouping): +class ScalarSelect( + roles.InElementRole, Generative, GroupedElement, ColumnElement[_T] +): """Represent a scalar subquery. @@ -5570,15 +6179,33 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - _from_objects = [] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + + _from_objects: List[FromClause] = [] _is_from_container = True - _is_implicitly_boolean = False + if not TYPE_CHECKING: + _is_implicitly_boolean = False inherit_cache = True + element: SelectBase + def __init__(self, element): self.element = element self.type = element._scalar_type() + def __getattr__(self, attr): + return getattr(self.element, attr) + + def __getstate__(self): + return {"element": self.element, "type": self.type} + + def __setstate__(self, state): + self.element = state["element"] + self.type = state["type"] + @property def columns(self): raise exc.InvalidRequestError( @@ -5590,19 +6217,39 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): c = columns @_generative - def where(self: SelfScalarSelect, crit) -> SelfScalarSelect: + def where( + self: SelfScalarSelect, crit: _ColumnExpressionArgument[bool] + ) -> SelfScalarSelect: """Apply a WHERE clause to the SELECT statement referred to by this :class:`_expression.ScalarSelect`. """ - self.element = self.element.where(crit) + self.element = cast(Select, self.element).where(crit) return self - def self_group(self, **kwargs): + @overload + def self_group( + self: ScalarSelect[Any], against: Optional[OperatorType] = None + ) -> ScalarSelect[Any]: + ... + + @overload + def self_group( + self: ColumnElement[Any], against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + ... + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + return self @_generative - def correlate(self: SelfScalarSelect, *fromclauses) -> SelfScalarSelect: + def correlate( + self: SelfScalarSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfScalarSelect: r"""Return a new :class:`_expression.ScalarSelect` which will correlate the given FROM clauses to that of an enclosing :class:`_expression.Select`. @@ -5631,12 +6278,13 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - self.element = self.element.correlate(*fromclauses) + self.element = cast(Select, self.element).correlate(*fromclauses) return self @_generative def correlate_except( - self: SelfScalarSelect, *fromclauses + self: SelfScalarSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], ) -> SelfScalarSelect: r"""Return a new :class:`_expression.ScalarSelect` which will omit the given FROM @@ -5668,11 +6316,16 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - self.element = self.element.correlate_except(*fromclauses) + self.element = cast(Select, self.element).correlate_except( + *fromclauses + ) return self -class Exists(UnaryExpression[_T]): +SelfExists = TypeVar("SelfExists", bound="Exists") + + +class Exists(UnaryExpression[bool]): """Represent an ``EXISTS`` clause. See :func:`_sql.exists` for a description of usage. @@ -5682,10 +6335,14 @@ class Exists(UnaryExpression[_T]): """ - _from_objects = () inherit_cache = True - def __init__(self, __argument=None): + def __init__( + self, + __argument: Optional[ + Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + ] = None, + ): if __argument is None: s = Select(literal_column("*")).scalar_subquery() elif isinstance(__argument, (SelectBase, ScalarSelect)): @@ -5701,12 +6358,16 @@ class Exists(UnaryExpression[_T]): wraps_column_expression=True, ) + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + def _regroup(self, fn): element = self.element._ungroup() element = fn(element) return element.self_group(against=operators.exists) - def select(self) -> "Select": + def select(self) -> Select: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: @@ -5726,7 +6387,10 @@ class Exists(UnaryExpression[_T]): return Select(self) - def correlate(self, *fromclause): + def correlate( + self: SelfExists, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfExists: """Apply correlation to the subquery noted by this :class:`_sql.Exists`. .. seealso:: @@ -5736,11 +6400,14 @@ class Exists(UnaryExpression[_T]): """ e = self._clone() e.element = self._regroup( - lambda element: element.correlate(*fromclause) + lambda element: element.correlate(*fromclauses) ) return e - def correlate_except(self, *fromclause): + def correlate_except( + self: SelfExists, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfExists: """Apply correlation to the subquery noted by this :class:`_sql.Exists`. .. seealso:: @@ -5751,11 +6418,11 @@ class Exists(UnaryExpression[_T]): e = self._clone() e.element = self._regroup( - lambda element: element.correlate_except(*fromclause) + lambda element: element.correlate_except(*fromclauses) ) return e - def select_from(self, *froms): + def select_from(self: SelfExists, *froms: FromClause) -> SelfExists: """Return a new :class:`_expression.Exists` construct, applying the given expression to the :meth:`_expression.Select.select_from` @@ -5772,7 +6439,9 @@ class Exists(UnaryExpression[_T]): e.element = self._regroup(lambda element: element.select_from(*froms)) return e - def where(self, *clause): + def where( + self: SelfExists, *clause: _ColumnExpressionArgument[bool] + ) -> SelfExists: """Return a new :func:`_expression.exists` construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any. @@ -5824,7 +6493,7 @@ class TextualSelect(SelectBase): _label_style = LABEL_STYLE_NONE - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("column_args", InternalTraversal.dp_clauseelement_list), ] + SupportsCloneAnnotations._clone_annotations_traverse_internals @@ -5842,8 +6511,8 @@ class TextualSelect(SelectBase): ] self.positional = positional - @HasMemoized.memoized_attribute - def selected_columns(self): + @HasMemoized_ro_memoized_attribute + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -5868,6 +6537,13 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_readonly() + # def _generate_columns_plus_names( + # self, anon_for_dupe_key: bool + # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: + # return Select._generate_columns_plus_names( + # self, anon_for_dupe_key=anon_for_dupe_key + # ) + @util.non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.column_args @@ -5880,7 +6556,9 @@ class TextualSelect(SelectBase): @_generative def bindparams( - self: SelfTextualSelect, *binds, **bind_as_values + self: SelfTextualSelect, + *binds: BindParameter[Any], + **bind_as_values: Any, ) -> SelfTextualSelect: self.element = self.element.bindparams(*binds, **bind_as_values) return self diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 1f3d50876..c3653c264 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -228,7 +228,7 @@ class HasCopyInternals(HasTraverseInternals): raise NotImplementedError() def _copy_internals( - self, omit_attrs: Iterable[str] = (), **kw: Any + self, *, omit_attrs: Iterable[str] = (), **kw: Any ) -> None: """Reassign internal elements to be clones of themselves. diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9a934a50b..82adf4a4f 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1008,10 +1008,7 @@ class TypeEngine(Visitable, Generic[_T]): @util.preload_module("sqlalchemy.engine.default") def _default_dialect(self) -> Dialect: - if TYPE_CHECKING: - from ..engine import default - else: - default = util.preloaded.engine_default + default = util.preloaded.engine_default # dmypy / mypy seems to sporadically keep thinking this line is # returning Any, which seems to be caused by the @deprecated_params diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index cdce49f7b..80711c4b5 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -13,10 +13,20 @@ from __future__ import annotations from collections import deque from itertools import chain import typing +from typing import AbstractSet from typing import Any +from typing import Callable from typing import cast +from typing import Dict from typing import Iterator +from typing import List from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from . import coercions from . import operators @@ -49,11 +59,22 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause +from .visitors import _ET from .. import exc from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _TypeEngineArgument from .roles import FromClauseRole + from .selectable import _JoinTargetElement + from .selectable import _OnClauseElement + from .selectable import Selectable + from .visitors import _TraverseCallableType + from .visitors import ExternallyTraversible + from .visitors import ExternalTraversal from ..engine.interfaces import _AnyExecuteParams from ..engine.interfaces import _AnyMultiExecuteParams from ..engine.interfaces import _AnySingleExecuteParams @@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from): return liberal_idx -def find_left_clause_to_join_from(clauses, join_to, onclause): +def find_left_clause_to_join_from( + clauses: Sequence[FromClause], + join_to: _JoinTargetElement, + onclause: Optional[ColumnElement[Any]], +) -> List[int]: """Given a list of FROM clauses, a selectable, and optional ON clause, return a list of integer indexes from the clauses list indicating the clauses that can be joined from. @@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): for i, f in enumerate(clauses): for s in selectables.difference([f]): if resolve_ambiguity: + assert cols_in_onclause is not None if set(f.c).union(s.c).issuperset(cols_in_onclause): idx.append(i) break @@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): # onclause was given and none of them resolved, so assume # all indexes can match if not idx and onclause is not None: - return range(len(clauses)) + return list(range(len(clauses))) else: return idx @@ -247,7 +273,7 @@ def visit_binary_product(fn, expr): a binary comparison is passed as pairs. """ - stack = [] + stack: List[ClauseElement] = [] def visit(element): if isinstance(element, ScalarSelect): @@ -272,21 +298,22 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) - visit = None # remove gc cycles + visit = None # type: ignore # remove gc cycles def find_tables( - clause, - check_columns=False, - include_aliases=False, - include_joins=False, - include_selects=False, - include_crud=False, -): + clause: ClauseElement, + *, + check_columns: bool = False, + include_aliases: bool = False, + include_joins: bool = False, + include_selects: bool = False, + include_crud: bool = False, +) -> List[TableClause]: """locate Table objects within the given expression.""" - tables = [] - _visitors = {} + tables: List[TableClause] = [] + _visitors: Dict[str, _TraverseCallableType[Any]] = {} if include_selects: _visitors["select"] = _visitors["compound_select"] = tables.append @@ -335,7 +362,7 @@ def unwrap_order_by(clause): t = stack.popleft() if isinstance(t, ColumnElement) and ( not isinstance(t, UnaryExpression) - or not operators.is_ordering_modifier(t.modifier) + or not operators.is_ordering_modifier(t.modifier) # type: ignore ): if isinstance(t, Label) and not isinstance( t.element, ScalarSelect @@ -365,9 +392,14 @@ def unwrap_order_by(clause): def unwrap_label_reference(element): - def replace(elem): - if isinstance(elem, (_label_reference, _textual_label_reference)): - return elem.element + def replace( + element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: + if isinstance(element, _label_reference): + return element.element + elif isinstance(element, _textual_label_reference): + assert False, "can't unwrap a textual label reference" + return None return visitors.replacement_traverse(element, {}, replace) @@ -407,7 +439,7 @@ def clause_is_present(clause, search): return False -def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]: +def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]: if isinstance(clause, Join): for t in tables_from_leftmost(clause.left): yield t @@ -509,6 +541,8 @@ class _repr_base: __slots__ = ("max_chars",) + max_chars: int + def trunc(self, value: Any) -> str: rep = repr(value) lenrep = len(rep) @@ -612,7 +646,7 @@ class _repr_params(_repr_base): def _repr_multi( self, multi_params: _AnyMultiExecuteParams, - typ, + typ: int, ) -> str: if multi_params: if isinstance(multi_params[0], list): @@ -639,7 +673,7 @@ class _repr_params(_repr_base): def _repr_params( self, - params: Optional[_AnySingleExecuteParams], + params: _AnySingleExecuteParams, typ: int, ) -> str: trunc = self.trunc @@ -653,9 +687,10 @@ class _repr_params(_repr_base): ) ) elif typ is self._TUPLE: + seq_params = cast("Sequence[Any]", params) return "(%s%s)" % ( - ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "", + ", ".join(trunc(value) for value in seq_params), + "," if len(seq_params) == 1 else "", ) else: return "[%s]" % (", ".join(trunc(value) for value in params)) @@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) -def splice_joins(left, right, stop_on=None): +def splice_joins( + left: Optional[FromClause], + right: Optional[FromClause], + stop_on: Optional[FromClause] = None, +) -> Optional[FromClause]: if left is None: return right - stack = [(right, None)] + stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)] adapter = ClauseAdapter(left) ret = None @@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None): else: right = adapter.traverse(right) if prevright is not None: + assert right is not None prevright.left = right if ret is None: ret = right @@ -845,11 +885,14 @@ def criterion_as_pairs( elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) - pairs = [] + pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = [] visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs +_CE = TypeVar("_CE", bound="ClauseElement") + + class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. @@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, - selectable, - equivalents=None, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): self.__traverse_options__ = { "stop_on": [selectable], @@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): self.adapt_on_names = adapt_on_names self.adapt_from_selectables = adapt_from_selectables + if TYPE_CHECKING: + + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + # note this specializes the ReplacingExternalTraversal.traverse() + # method to state + # that we will return the same kind of ExternalTraversal object as + # we were given. This is probably not 100% true, such as it's + # possible for us to swap out Alias for Table at the top level. + # Ideally there could be overloads specific to ColumnElement and + # FromClause but Mypy is not accepting those as compatible with + # the base ReplacingExternalTraversal + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: + ... + def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET ): @@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return newcol @util.preload_module("sqlalchemy.sql.functions") - def replace(self, col, _include_singleton_constants=False): + def replace( + self, col: _ET, _include_singleton_constants: bool = False + ) -> Optional[_ET]: functions = util.preloaded.sql_functions + # TODO: cython candidate + if isinstance(col, FromClause) and not isinstance( col, functions.FunctionElement ): @@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): break else: return None - return self.selectable + return self.selectable # type: ignore elif isinstance(col, Alias) and isinstance( col.element, TableClause ): @@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # we are an alias of a table and we are not derived from an # alias of a table (which nonetheless may be the same table # as ours) so, same thing - return col + return col # type: ignore else: # other cases where we are a selectable and the element # is another join or selectable that contains a table which our @@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): else: return None + if TYPE_CHECKING: + assert isinstance(col, ColumnElement) + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None else: - return self._corresponding_column(col, True) + return self._corresponding_column(col, True) # type: ignore + + +class _ColumnLookup(Protocol): + def __getitem__( + self, key: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: + ... class ColumnAdapter(ClauseAdapter): @@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter): """ + columns: _ColumnLookup + def __init__( self, - selectable, - equivalents=None, - adapt_required=False, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + adapt_required: bool = False, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): ClauseAdapter.__init__( self, @@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter): adapt_from_selectables=adapt_from_selectables, ) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore if self.include_fn or self.exclude_fn: self.columns = self._IncludeExcludeMapping(self, self.columns) self.adapt_required = adapt_required @@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter): ac = self.__class__.__new__(self.__class__) ac.__dict__.update(self.__dict__) ac._wrap = adapter - ac.columns = util.WeakPopulateDict(ac._locate_col) + ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore if ac.include_fn or ac.exclude_fn: ac.columns = self._IncludeExcludeMapping(ac, ac.columns) @@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter): def traverse(self, obj): return self.columns[obj] + def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: + assert isinstance(visitor, ColumnAdapter) + + return super().chain(visitor) + + if TYPE_CHECKING: + + @property + def visitor_iterator(self) -> Iterator[ColumnAdapter]: + ... + adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process @@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter): return newcol - def _locate_col(self, col): + def _locate_col( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: # both replace and traverse() are overly complicated for what # we are doing here and we would do better to have an inlined # version that doesn't build up as much overhead. the issue is that @@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore -def _offset_or_limit_clause(element, name=None, type_=None): +def _offset_or_limit_clause( + element: Union[int, _ColumnExpressionArgument[int]], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[int]] = None, +) -> ColumnElement[int]: """Convert the given value to an "offset or limit" clause. This handles incoming integers and converts to an expression; if @@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None): ) -def _offset_or_limit_clause_asint_if_possible(clause): +def _offset_or_limit_clause_asint_if_possible( + clause: Optional[Union[int, _ColumnExpressionArgument[int]]] +) -> Optional[Union[int, _ColumnExpressionArgument[int]]]: """Return the offset or limit clause as a simple integer if possible, else return the clause. @@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause): if clause is None: return None if hasattr(clause, "_limit_offset_value"): - value = clause._limit_offset_value + value = clause._limit_offset_value # type: ignore return util.asint(value) else: return clause -def _make_slice(limit_clause, offset_clause, start, stop): +def _make_slice( + limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + start: int, + stop: int, +) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]: """Compute LIMIT/OFFSET in terms of slice start/end""" # for calculated limit/offset, try to do the addition of # values to offset in Python, however if a SQL clause is present # then the addition has to be on the SQL side. + + # TODO: typing is finding a few gaps in here, see if they can be + # closed up + if start is not None and stop is not None: offset_clause = _offset_or_limit_clause_asint_if_possible( offset_clause @@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: + assert offset_clause is not None offset_clause = _offset_or_limit_clause(offset_clause) limit_clause = _offset_or_limit_clause(stop - start) @@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: - offset_clause = _offset_or_limit_clause(offset_clause) + offset_clause = _offset_or_limit_clause( + offset_clause # type: ignore + ) - return limit_clause, offset_clause + return limit_clause, offset_clause # type: ignore diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 903aae648..081faf1e9 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,6 +28,7 @@ from typing import Iterator from typing import List from typing import Mapping from typing import Optional +from typing import overload from typing import Tuple from typing import Type from typing import TypeVar @@ -37,6 +38,7 @@ from .. import exc from .. import util from ..util import langhelpers from ..util._has_cy import HAS_CYEXTENSION +from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self @@ -599,8 +601,8 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): raise NotImplementedError() def _copy_internals( - self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any - ) -> Self: + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> None: """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly @@ -615,10 +617,24 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): _ET = TypeVar("_ET", bound=ExternallyTraversible) + + _TraverseCallableType = Callable[[_ET], None] -_TraverseTransformCallableType = Callable[ - [ExternallyTraversible], Optional[ExternallyTraversible] -] + + +class _CloneCallableType(Protocol): + def __call__(self, element: _ET, **kw: Any) -> _ET: + ... + + +class _TraverseTransformCallableType(Protocol): + def __call__( + self, element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: + ... + + +_ExtT = TypeVar("_ExtT", bound="ExternalTraversal") class ExternalTraversal: @@ -640,7 +656,7 @@ class ExternalTraversal: return meth(obj, **kw) def iterate( - self, obj: ExternallyTraversible + self, obj: Optional[ExternallyTraversible] ) -> Iterator[ExternallyTraversible]: """Traverse the given expression structure, returning an iterator of all elements. @@ -648,7 +664,17 @@ class ExternalTraversal: """ return iterate(obj, self.__traverse_options__) + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure.""" return traverse(obj, self.__traverse_options__, self._visitor_dict) @@ -671,7 +697,7 @@ class ExternalTraversal: yield v v = getattr(v, "_next", None) - def chain(self, visitor: ExternalTraversal) -> ExternalTraversal: + def chain(self: _ExtT, visitor: ExternalTraversal) -> _ExtT: """'Chain' an additional ExternalTraversal onto this ExternalTraversal The chained visitor will receive all visit events after this one. @@ -701,7 +727,17 @@ class CloningExternalTraversal(ExternalTraversal): """ return [self.traverse(x) for x in list_] + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure.""" return cloned_traverse( @@ -729,14 +765,25 @@ class ReplacingExternalTraversal(CloningExternalTraversal): """ return None + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure.""" def replace( - elem: ExternallyTraversible, + element: ExternallyTraversible, + **kw: Any, ) -> Optional[ExternallyTraversible]: for v in self.visitor_iterator: - e = cast(ReplacingExternalTraversal, v).replace(elem) + e = cast(ReplacingExternalTraversal, v).replace(element) if e is not None: return e @@ -754,7 +801,8 @@ ReplacingCloningVisitor = ReplacingExternalTraversal def iterate( - obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any] = util.EMPTY_DICT, ) -> Iterator[ExternallyTraversible]: r"""Traverse the given expression structure, returning an iterator. @@ -776,6 +824,9 @@ def iterate( empty in modern usage. """ + if obj is None: + return + yield obj children = obj.get_children(**opts) @@ -790,11 +841,29 @@ def iterate( stack.append(t.get_children(**opts)) +@overload +def traverse_using( + iterator: Iterable[ExternallyTraversible], + obj: Literal[None], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> None: + ... + + +@overload def traverse_using( iterator: Iterable[ExternallyTraversible], obj: ExternallyTraversible, visitors: Mapping[str, _TraverseCallableType[Any]], ) -> ExternallyTraversible: + ... + + +def traverse_using( + iterator: Iterable[ExternallyTraversible], + obj: Optional[ExternallyTraversible], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> Optional[ExternallyTraversible]: """Visit the given expression structure using the given iterator of objects. @@ -826,11 +895,29 @@ def traverse_using( return obj +@overload +def traverse( + obj: Literal[None], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> None: + ... + + +@overload def traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], ) -> ExternallyTraversible: + ... + + +def traverse( + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure using the default iterator. @@ -863,11 +950,29 @@ def traverse( return traverse_using(iterate(obj, opts), obj, visitors) +@overload +def cloned_traverse( + obj: Literal[None], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> None: + ... + + +@overload def cloned_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], ) -> ExternallyTraversible: + ... + + +def cloned_traverse( + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> Optional[ExternallyTraversible]: """Clone the given expression structure, allowing modifications by visitors. @@ -931,11 +1036,29 @@ def cloned_traverse( return obj +@overload +def replacement_traverse( + obj: Literal[None], + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType, +) -> None: + ... + + +@overload def replacement_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], replace: _TraverseTransformCallableType, ) -> ExternallyTraversible: + ... + + +def replacement_traverse( + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType, +) -> Optional[ExternallyTraversible]: """Clone the given expression structure, allowing element replacement by a given replacement function. diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 4496b8ded..f1bf5c0c4 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -306,6 +306,7 @@ def testing_engine( options=None, asyncio=False, transfer_staticpool=False, + share_pool=False, _sqlite_savepoint=False, ): if asyncio: @@ -356,6 +357,8 @@ def testing_engine( if config.db is not None and isinstance(config.db.pool, StaticPool): use_reaper = False engine.pool._transfer_from(config.db.pool) + elif share_pool: + engine.pool = config.db.pool if scope == "global": if asyncio: diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 6c6b21fce..8c0120bcc 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -121,6 +121,7 @@ class TestBase: future=None, asyncio=False, transfer_staticpool=False, + share_pool=False, ): if options is None: options = {} @@ -130,6 +131,7 @@ class TestBase: options=options, asyncio=asyncio, transfer_staticpool=transfer_staticpool, + share_pool=share_pool, ) yield gen_testing_engine diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 406c8af24..c0c2e7dfb 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -9,7 +9,9 @@ from collections import defaultdict as defaultdict from functools import partial as partial from functools import update_wrapper as update_wrapper +from typing import TYPE_CHECKING +from . import preloaded as preloaded from ._collections import coerce_generator_arg as coerce_generator_arg from ._collections import coerce_to_immutabledict as coerce_to_immutabledict from ._collections import column_dict as column_dict @@ -44,8 +46,6 @@ from ._collections import UniqueAppender as UniqueAppender from ._collections import update_copy as update_copy from ._collections import WeakPopulateDict as WeakPopulateDict from ._collections import WeakSequence as WeakSequence -from ._preloaded import preload_module as preload_module -from ._preloaded import preloaded as preloaded from .compat import arm as arm from .compat import b as b from .compat import b64decode as b64decode @@ -148,3 +148,4 @@ from .langhelpers import warn as warn from .langhelpers import warn_exception as warn_exception from .langhelpers import warn_limited as warn_limited from .langhelpers import wrap_callable as wrap_callable +from .preloaded import preload_module as preload_module diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index bd73bf714..bcb2ad423 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -463,11 +463,12 @@ def update_copy(d, _new=None, **kw): return d -def flatten_iterator(x): +def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: """Given an iterator of which further sub-elements may also be iterators, flatten the sub-elements into a single iterator. """ + elem: _T for elem in x: if not isinstance(elem, str) and hasattr(elem, "__iter__"): for y in flatten_iterator(elem): diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/preloaded.py index 511b93351..c861c83b3 100644 --- a/lib/sqlalchemy/util/_preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -16,10 +16,16 @@ from types import ModuleType import typing from typing import Any from typing import Callable +from typing import TYPE_CHECKING from typing import TypeVar _FN = TypeVar("_FN", bound=Callable[..., Any]) +if TYPE_CHECKING: + from sqlalchemy.engine import default as engine_default + from sqlalchemy.sql import dml as sql_dml + from sqlalchemy.sql import util as sql_util + class _ModuleRegistry: """Registry of modules to load in a package init file. @@ -67,7 +73,7 @@ class _ModuleRegistry: not path or module.startswith(path) ) and key not in self.__dict__: __import__(module, globals(), locals()) - self.__dict__[key] = sys.modules[module] + self.__dict__[key] = globals()[key] = sys.modules[module] if typing.TYPE_CHECKING: @@ -75,5 +81,11 @@ class _ModuleRegistry: ... -preloaded = _ModuleRegistry() -preload_module = preloaded.preload_module +_reg = _ModuleRegistry() +preload_module = _reg.preload_module +import_prefix = _reg.import_prefix + +if TYPE_CHECKING: + + def __getattr__(key: str) -> ModuleType: + ... diff --git a/pyproject.toml b/pyproject.toml index cc79e8646..012f1bffa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,9 +68,6 @@ module = [ "sqlalchemy.ext.mutable", "sqlalchemy.ext.horizontal_shard", - "sqlalchemy.sql._selectable_constructors", - "sqlalchemy.sql._dml_constructors", - # TODO for non-strict: "sqlalchemy.ext.baked", "sqlalchemy.ext.instrumentation", @@ -78,11 +75,6 @@ module = [ "sqlalchemy.ext.orderinglist", "sqlalchemy.ext.serializer", - "sqlalchemy.sql.selectable", # would be nice as strict - "sqlalchemy.sql.functions", # would be nice as strict - "sqlalchemy.sql.lambdas", - "sqlalchemy.sql.util", - # not yet classified: "sqlalchemy.orm.*", "sqlalchemy.dialects.*", @@ -132,10 +124,14 @@ module = [ "sqlalchemy.sql.crud", "sqlalchemy.sql.ddl", # would be nice as strict "sqlalchemy.sql.elements", # would be nice as strict + "sqlalchemy.sql.functions", # would be nice as strict, requires sqltypes + "sqlalchemy.sql.lambdas", "sqlalchemy.sql.naming", + "sqlalchemy.sql.selectable", # would be nice as strict "sqlalchemy.sql.schema", # would be nice as strict "sqlalchemy.sql.sqltypes", # would be nice as strict "sqlalchemy.sql.traversals", + "sqlalchemy.sql.util", "sqlalchemy.util.*", ] diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 06be9fbd8..e03a8415d 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -835,27 +835,14 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest): ) s.commit() - def test_build_query(self): + def test_fetch_results_integrated(self, testing_engine): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") - sess = fixture_session() + # this test has been reworked to use the compiled cache again, + # as a real-world scenario. - @profiling.function_call_count() - def go(): - for i in range(100): - q = sess.query(A).options( - joinedload(A.bs).joinedload(B.cs).joinedload(C.ds), - joinedload(A.es).joinedload(E.fs), - defaultload(A.es).joinedload(E.gs), - ) - q._compile_context() - - go() - - def test_fetch_results(self): - A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") - - sess = Session(testing.db) + eng = testing_engine(share_pool=True) + sess = Session(eng) q = sess.query(A).options( joinedload(A.bs).joinedload(B.cs).joinedload(C.ds), @@ -863,47 +850,27 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest): defaultload(A.es).joinedload(E.gs), ) - compile_state = q._compile_state() + @profiling.function_call_count() + def initial_run(): + list(q.all()) - from sqlalchemy.orm.context import ORMCompileState + initial_run() + sess.close() - @profiling.function_call_count(warmup=1) - def go(): - for i in range(100): - # NOTE: this test was broken in - # 77f1b7d236dba6b1c859bb428ef32d118ec372e6 because we started - # clearing out the attributes after the first iteration. make - # sure the attributes are there every time. - assert compile_state.attributes - exec_opts = {} - bind_arguments = {} - ORMCompileState.orm_pre_session_exec( - sess, - compile_state.select_statement, - {}, - exec_opts, - bind_arguments, - is_reentrant_invoke=False, - ) + @profiling.function_call_count() + def subsequent_run(): + list(q.all()) - r = sess.connection().execute( - compile_state.statement, - execution_options=exec_opts, - ) + subsequent_run() + sess.close() - r.context.compiled.compile_state = compile_state - obj = ORMCompileState.orm_setup_cursor_result( - sess, - compile_state.statement, - {}, - exec_opts, - {}, - r, - ) - list(obj.unique()) - sess.close() + @profiling.function_call_count() + def more_runs(): + for i in range(100): + list(q.all()) - go() + more_runs() + sess.close() class JoinConditionTest(NoCache, fixtures.DeclarativeMappedTest): diff --git a/test/base/test_utils.py b/test/base/test_utils.py index fc61e39b6..e22340da6 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -25,11 +25,11 @@ from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ from sqlalchemy.testing.util import gc_collect from sqlalchemy.testing.util import picklers -from sqlalchemy.util import _preloaded from sqlalchemy.util import classproperty from sqlalchemy.util import compat from sqlalchemy.util import get_callable_argspec from sqlalchemy.util import langhelpers +from sqlalchemy.util import preloaded from sqlalchemy.util import WeakSequence from sqlalchemy.util._collections import merge_lists_w_ordering @@ -3187,7 +3187,7 @@ class TestModuleRegistry(fixtures.TestBase): for m in ("xml.dom", "wsgiref.simple_server"): to_restore.append((m, sys.modules.pop(m, None))) try: - mr = _preloaded._ModuleRegistry() + mr = preloaded._ModuleRegistry() ret = mr.preload_module( "xml.dom", "wsgiref.simple_server", "sqlalchemy.sql.util" diff --git a/test/profiles.txt b/test/profiles.txt index 31f72bd16..7b4f37734 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -196,16 +196,16 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96844 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102344 -# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query - -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520615 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522475 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 440705 test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 458805 +# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated + +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 30373,1014,96450 + # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 22984 diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index dd073d2a5..8d6dc7553 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -208,6 +208,8 @@ class CoreFixtures: column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), + (column("z") == column("x")).self_group(), + (column("q") == column("x")).self_group(), column("z") + column("x"), column("z") - column("x"), column("x") - column("z"), diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index d20037e92..6ca06dc0e 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -82,7 +82,6 @@ from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import CompilerColumnElement from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseList -from sqlalchemy.sql.expression import HasPrefixes from sqlalchemy.sql.selectable import LABEL_STYLE_NONE from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import assert_raises @@ -270,18 +269,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "columns", ) - def test_prefix_constructor(self): - class Pref(HasPrefixes): - def _generate(self): - return self - - assert_raises( - exc.ArgumentError, - Pref().prefix_with, - "some prefix", - not_a_dialect=True, - ) - def test_table_select(self): self.assert_compile( table1.select(), diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 686d4928d..d1d01a5c7 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1186,6 +1186,37 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect="postgresql", ) + def test_recursive_dml_syntax(self): + orders = table( + "orders", + column("region"), + column("amount"), + column("product"), + column("quantity"), + ) + + upsert = ( + orders.update() + .where(orders.c.region == "Region1") + .values(amount=1.0, product="Product1", quantity=1) + .returning(*(orders.c._all_columns)) + .cte("upsert", recursive=True) + ) + stmt = select(upsert) + + # This statement probably makes no sense, just want to see that the + # column generation aspect needed by RECURSIVE works (new in 2.0) + self.assert_compile( + stmt, + "WITH RECURSIVE upsert(region, amount, product, quantity) " + "AS (UPDATE orders SET amount=:param_1, product=:param_2, " + "quantity=:param_3 WHERE orders.region = :region_1 " + "RETURNING orders.region, orders.amount, orders.product, " + "orders.quantity) " + "SELECT upsert.region, upsert.amount, upsert.product, " + "upsert.quantity FROM upsert", + ) + def test_upsert_from_select(self): orders = table( "orders", diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index ca5f43bb6..9fdc51938 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -432,6 +432,24 @@ class SelectableTest( ): select(stmt.subquery()).compile() + def test_correlate_none_arg_error(self): + stmt = select(table1) + with expect_raises_message( + exc.ArgumentError, + "additional FROM objects not accepted when passing " + "None/False to correlate", + ): + stmt.correlate(None, table2) + + def test_correlate_except_none_arg_error(self): + stmt = select(table1) + with expect_raises_message( + exc.ArgumentError, + "additional FROM objects not accepted when passing " + "None/False to correlate_except", + ): + stmt.correlate_except(None, table2) + def test_select_label_grouped_still_corresponds(self): label = select(table1.c.col1).label("foo") label2 = label.self_group() diff --git a/test/sql/test_text.py b/test/sql/test_text.py index 81b20f86f..0f645a2d2 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -688,6 +688,19 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): mapping = self._mapping(s) assert x not in mapping + def test_subquery_accessors(self): + t = self._xy_table_fixture() + + s = text("SELECT x from t").columns(t.c.x) + + self.assert_compile( + select(s.scalar_subquery()), "SELECT (SELECT x from t) AS anon_1" + ) + self.assert_compile( + select(s.subquery()), + "SELECT anon_1.x FROM (SELECT x from t) AS anon_1", + ) + def test_select_label_alt_name_table_alias_column(self): t = self._xy_table_fixture() x = t.c.x @@ -716,6 +729,36 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): "FROM mytable, t WHERE mytable.myid = t.id", ) + def test_cte_recursive(self): + t = ( + text("select id, name from user") + .columns(id=Integer, name=String) + .cte("t", recursive=True) + ) + + s = select(table1).where(table1.c.myid == t.c.id) + self.assert_compile( + s, + "WITH RECURSIVE t(id, name) AS (select id, name from user) " + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable, t WHERE mytable.myid = t.id", + ) + + def test_unions(self): + s1 = text("select id, name from user where id > 5").columns( + id=Integer, name=String + ) + s2 = text("select id, name from user where id < 15").columns( + id=Integer, name=String + ) + stmt = union(s1, s2) + eq_(stmt.selected_columns.keys(), ["id", "name"]) + self.assert_compile( + stmt, + "select id, name from user where id > 5 UNION " + "select id, name from user where id < 15", + ) + def test_subquery(self): t = ( text("select id, name from user") diff --git a/test/sql/test_values.py b/test/sql/test_values.py index f5ae9ea53..d14de9aee 100644 --- a/test/sql/test_values.py +++ b/test/sql/test_values.py @@ -294,6 +294,31 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL): checkparams={}, ) + def test_anon_alias(self): + people = self.tables.people + values = ( + Values( + column("bookcase_id", Integer), + column("bookcase_owner_id", Integer), + ) + .data([(1, 1), (2, 1), (3, 2), (3, 3)]) + .alias() + ) + stmt = select(people, values).select_from( + people.join( + values, values.c.bookcase_owner_id == people.c.people_id + ) + ) + self.assert_compile( + stmt, + "SELECT people.people_id, people.age, people.name, " + "anon_1.bookcase_id, anon_1.bookcase_owner_id FROM people " + "JOIN (VALUES (:param_1, :param_2), (:param_3, :param_4), " + "(:param_5, :param_6), (:param_7, :param_8)) AS anon_1 " + "(bookcase_id, bookcase_owner_id) " + "ON people.people_id = anon_1.bookcase_owner_id", + ) + def test_with_join_unnamed(self): people = self.tables.people values = Values( |