summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/exc.py9
-rw-r--r--lib/sqlalchemy/orm/query.py7
-rw-r--r--lib/sqlalchemy/sql/_dml_constructors.py11
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py4
-rw-r--r--lib/sqlalchemy/sql/_selectable_constructors.py122
-rw-r--r--lib/sqlalchemy/sql/_typing.py135
-rw-r--r--lib/sqlalchemy/sql/annotation.py43
-rw-r--r--lib/sqlalchemy/sql/base.py23
-rw-r--r--lib/sqlalchemy/sql/cache_key.py9
-rw-r--r--lib/sqlalchemy/sql/coercions.py158
-rw-r--r--lib/sqlalchemy/sql/compiler.py29
-rw-r--r--lib/sqlalchemy/sql/dml.py49
-rw-r--r--lib/sqlalchemy/sql/elements.py85
-rw-r--r--lib/sqlalchemy/sql/functions.py246
-rw-r--r--lib/sqlalchemy/sql/lambdas.py194
-rw-r--r--lib/sqlalchemy/sql/roles.py8
-rw-r--r--lib/sqlalchemy/sql/schema.py44
-rw-r--r--lib/sqlalchemy/sql/selectable.py1868
-rw-r--r--lib/sqlalchemy/sql/traversals.py2
-rw-r--r--lib/sqlalchemy/sql/type_api.py5
-rw-r--r--lib/sqlalchemy/sql/util.py231
-rw-r--r--lib/sqlalchemy/sql/visitors.py143
-rw-r--r--lib/sqlalchemy/testing/engines.py3
-rw-r--r--lib/sqlalchemy/testing/fixtures.py2
-rw-r--r--lib/sqlalchemy/util/__init__.py5
-rw-r--r--lib/sqlalchemy/util/_collections.py3
-rw-r--r--lib/sqlalchemy/util/preloaded.py (renamed from lib/sqlalchemy/util/_preloaded.py)18
-rw-r--r--pyproject.toml12
-rw-r--r--test/aaa_profiling/test_orm.py75
-rw-r--r--test/base/test_utils.py4
-rw-r--r--test/profiles.txt8
-rw-r--r--test/sql/test_compare.py2
-rw-r--r--test/sql/test_compiler.py13
-rw-r--r--test/sql/test_cte.py31
-rw-r--r--test/sql/test_selectable.py18
-rw-r--r--test/sql/test_text.py43
-rw-r--r--test/sql/test_values.py25
37 files changed, 2568 insertions, 1119 deletions
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
index 8f4b963eb..9db6f3f52 100644
--- a/lib/sqlalchemy/exc.py
+++ b/lib/sqlalchemy/exc.py
@@ -23,8 +23,8 @@ from typing import Tuple
from typing import Type
from typing import Union
-from .util import _preloaded
from .util import compat
+from .util import preloaded as _preloaded
if typing.TYPE_CHECKING:
from .engine.interfaces import _AnyExecuteParams
@@ -345,6 +345,8 @@ class MultipleResultsFound(InvalidRequestError):
class NoReferenceError(InvalidRequestError):
"""Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
+ table_name: str
+
class AwaitRequired(InvalidRequestError):
"""Error raised by the async greenlet spawn if no async operation
@@ -501,10 +503,7 @@ class StatementError(SQLAlchemyError):
@_preloaded.preload_module("sqlalchemy.sql.util")
def _sql_message(self) -> str:
- if typing.TYPE_CHECKING:
- from .sql import util
- else:
- util = _preloaded.preloaded.sql_util
+ util = _preloaded.sql_util
details = [self._message()]
if self.statement:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 18a14012f..b9ced44d5 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -28,6 +28,8 @@ from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import exc as orm_exc
@@ -78,6 +80,8 @@ from ..sql.selectable import SelectBase
from ..sql.selectable import SelectStatementGrouping
from ..sql.visitors import InternalTraversal
+if TYPE_CHECKING:
+ from ..sql.selectable import _SetupJoinsElement
__all__ = ["Query", "QueryContext"]
@@ -134,7 +138,8 @@ class Query(
_correlate = ()
_auto_correlate = True
_from_obj = ()
- _setup_joins = ()
+ _setup_joins: Tuple[_SetupJoinsElement, ...] = ()
+
_label_style = LABEL_STYLE_LEGACY_ORM
_memoized_select_entities = ()
diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py
index 835819bac..926e5257b 100644
--- a/lib/sqlalchemy/sql/_dml_constructors.py
+++ b/lib/sqlalchemy/sql/_dml_constructors.py
@@ -7,12 +7,17 @@
from __future__ import annotations
+from typing import TYPE_CHECKING
+
from .dml import Delete
from .dml import Insert
from .dml import Update
+if TYPE_CHECKING:
+ from ._typing import _DMLTableArgument
+
-def insert(table):
+def insert(table: _DMLTableArgument) -> Insert:
"""Construct an :class:`_expression.Insert` object.
E.g.::
@@ -82,7 +87,7 @@ def insert(table):
return Insert(table)
-def update(table):
+def update(table: _DMLTableArgument) -> Update:
r"""Construct an :class:`_expression.Update` object.
E.g.::
@@ -122,7 +127,7 @@ def update(table):
return Update(table)
-def delete(table):
+def delete(table: _DMLTableArgument) -> Delete:
r"""Construct :class:`_expression.Delete` object.
E.g.::
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index fc925a8b3..ea21e01c6 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -345,8 +345,8 @@ def between(
:meth:`_expression.ColumnElement.between`
"""
- expr = coercions.expect(roles.ExpressionElementRole, expr)
- return expr.between(lower_bound, upper_bound, symmetric=symmetric)
+ col_expr = coercions.expect(roles.ExpressionElementRole, expr)
+ return col_expr.between(lower_bound, upper_bound, symmetric=symmetric)
def outparam(
diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py
index a17ee4ce8..7896c02c2 100644
--- a/lib/sqlalchemy/sql/_selectable_constructors.py
+++ b/lib/sqlalchemy/sql/_selectable_constructors.py
@@ -9,6 +9,8 @@ from __future__ import annotations
from typing import Any
from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from . import coercions
from . import roles
@@ -17,64 +19,65 @@ from .elements import ColumnClause
from .selectable import Alias
from .selectable import CompoundSelect
from .selectable import Exists
+from .selectable import FromClause
from .selectable import Join
from .selectable import Lateral
+from .selectable import LateralFromClause
+from .selectable import NamedFromClause
from .selectable import Select
from .selectable import TableClause
from .selectable import TableSample
from .selectable import Values
+if TYPE_CHECKING:
+ from ._typing import _ColumnsClauseArgument
+ from ._typing import _FromClauseArgument
+ from ._typing import _OnClauseArgument
+ from ._typing import _SelectStatementForCompoundArgument
+ from .functions import Function
+ from .selectable import CTE
+ from .selectable import HasCTE
+ from .selectable import ScalarSelect
+ from .selectable import SelectBase
-def alias(selectable, name=None, flat=False):
- """Return an :class:`_expression.Alias` object.
- An :class:`_expression.Alias` represents any
- :class:`_expression.FromClause`
- with an alternate name assigned within SQL, typically using the ``AS``
- clause when generated, e.g. ``SELECT * FROM table AS aliasname``.
+def alias(
+ selectable: FromClause, name: Optional[str] = None, flat: bool = False
+) -> NamedFromClause:
+ """Return a named alias of the given :class:`.FromClause`.
+
+ For :class:`.Table` and :class:`.Join` objects, the return type is the
+ :class:`_expression.Alias` object. Other kinds of :class:`.NamedFromClause`
+ objects may be returned for other kinds of :class:`.FromClause` objects.
+
+ The named alias represents any :class:`_expression.FromClause` with an
+ alternate name assigned within SQL, typically using the ``AS`` clause when
+ generated, e.g. ``SELECT * FROM table AS aliasname``.
- Similar functionality is available via the
+ Equivalent functionality is available via the
:meth:`_expression.FromClause.alias`
- method available on all :class:`_expression.FromClause` subclasses.
- In terms of
- a SELECT object as generated from the :func:`_expression.select`
- function, the :meth:`_expression.SelectBase.alias` method returns an
- :class:`_expression.Alias` or similar object which represents a named,
- parenthesized subquery.
-
- When an :class:`_expression.Alias` is created from a
- :class:`_schema.Table` object,
- this has the effect of the table being rendered
- as ``tablename AS aliasname`` in a SELECT statement.
-
- For :func:`_expression.select` objects, the effect is that of
- creating a named subquery, i.e. ``(select ...) AS aliasname``.
-
- The ``name`` parameter is optional, and provides the name
- to use in the rendered SQL. If blank, an "anonymous" name
- will be deterministically generated at compile time.
- Deterministic means the name is guaranteed to be unique against
- other constructs used in the same statement, and will also be the
- same name for each successive compilation of the same statement
- object.
+ method available on all :class:`_expression.FromClause` objects.
:param selectable: any :class:`_expression.FromClause` subclass,
such as a table, select statement, etc.
:param name: string name to be assigned as the alias.
- If ``None``, a name will be deterministically generated
- at compile time.
+ If ``None``, a name will be deterministically generated at compile
+ time. Deterministic means the name is guaranteed to be unique against
+ other constructs used in the same statement, and will also be the same
+ name for each successive compilation of the same statement object.
:param flat: Will be passed through to if the given selectable
is an instance of :class:`_expression.Join` - see
- :meth:`_expression.Join.alias`
- for details.
+ :meth:`_expression.Join.alias` for details.
"""
return Alias._factory(selectable, name=name, flat=flat)
-def cte(selectable, name=None, recursive=False):
+def cte(
+ selectable: HasCTE, name: Optional[str] = None, recursive: bool = False
+) -> CTE:
r"""Return a new :class:`_expression.CTE`,
or Common Table Expression instance.
@@ -86,7 +89,7 @@ def cte(selectable, name=None, recursive=False):
)
-def except_(*selects):
+def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return an ``EXCEPT`` of multiple selectables.
The returned object is an instance of
@@ -99,7 +102,9 @@ def except_(*selects):
return CompoundSelect._create_except(*selects)
-def except_all(*selects):
+def except_all(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``EXCEPT ALL`` of multiple selectables.
The returned object is an instance of
@@ -112,7 +117,11 @@ def except_all(*selects):
return CompoundSelect._create_except_all(*selects)
-def exists(__argument=None):
+def exists(
+ __argument: Optional[
+ Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ ] = None,
+) -> Exists:
"""Construct a new :class:`_expression.Exists` construct.
The :func:`_sql.exists` can be invoked by itself to produce an
@@ -153,7 +162,7 @@ def exists(__argument=None):
return Exists(__argument)
-def intersect(*selects):
+def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return an ``INTERSECT`` of multiple selectables.
The returned object is an instance of
@@ -166,7 +175,9 @@ def intersect(*selects):
return CompoundSelect._create_intersect(*selects)
-def intersect_all(*selects):
+def intersect_all(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``INTERSECT ALL`` of multiple selectables.
The returned object is an instance of
@@ -180,7 +191,13 @@ def intersect_all(*selects):
return CompoundSelect._create_intersect_all(*selects)
-def join(left, right, onclause=None, isouter=False, full=False):
+def join(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+) -> Join:
"""Produce a :class:`_expression.Join` object, given two
:class:`_expression.FromClause`
expressions.
@@ -232,7 +249,10 @@ def join(left, right, onclause=None, isouter=False, full=False):
return Join(left, right, onclause, isouter, full)
-def lateral(selectable, name=None):
+def lateral(
+ selectable: Union[SelectBase, _FromClauseArgument],
+ name: Optional[str] = None,
+) -> LateralFromClause:
"""Return a :class:`_expression.Lateral` object.
:class:`_expression.Lateral` is an :class:`_expression.Alias`
@@ -255,7 +275,12 @@ def lateral(selectable, name=None):
return Lateral._factory(selectable, name=name)
-def outerjoin(left, right, onclause=None, full=False):
+def outerjoin(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ full: bool = False,
+) -> Join:
"""Return an ``OUTER JOIN`` clause element.
The returned object is an instance of :class:`_expression.Join`.
@@ -349,7 +374,12 @@ def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause:
return TableClause(name, *columns, **kw)
-def tablesample(selectable, sampling, name=None, seed=None):
+def tablesample(
+ selectable: _FromClauseArgument,
+ sampling: Union[float, Function[Any]],
+ name: Optional[str] = None,
+ seed: Optional[roles.ExpressionElementRole[Any]] = None,
+) -> TableSample:
"""Return a :class:`_expression.TableSample` object.
:class:`_expression.TableSample` is an :class:`_expression.Alias`
@@ -395,7 +425,7 @@ def tablesample(selectable, sampling, name=None, seed=None):
return TableSample._factory(selectable, sampling, name=name, seed=seed)
-def union(*selects, **kwargs):
+def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return a ``UNION`` of multiple selectables.
The returned object is an instance of
@@ -412,10 +442,10 @@ def union(*selects, **kwargs):
:func:`select`.
"""
- return CompoundSelect._create_union(*selects, **kwargs)
+ return CompoundSelect._create_union(*selects)
-def union_all(*selects):
+def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return a ``UNION ALL`` of multiple selectables.
The returned object is an instance of
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index a5da87802..0a72a93c5 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -1,7 +1,7 @@
from __future__ import annotations
+import operator
from typing import Any
-from typing import Iterable
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -24,9 +24,16 @@ if TYPE_CHECKING:
from .roles import FromClauseRole
from .schema import DefaultGenerator
from .schema import Sequence
+ from .selectable import Alias
from .selectable import FromClause
+ from .selectable import Join
from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
+ from .selectable import Select
+ from .selectable import SelectBase
+ from .selectable import Subquery
from .selectable import TableClause
+ from .sqltypes import TableValueType
from .sqltypes import TupleType
from .type_api import TypeEngine
from ..util.typing import TypeGuard
@@ -47,6 +54,14 @@ class _HasClauseElement(Protocol):
# the coercions system is responsible for converting from XYZArgument to
# XYZElement.
+_TextCoercedExpressionArgument = Union[
+ str,
+ "TextClause",
+ "ColumnElement[_T]",
+ _HasClauseElement,
+ roles.ExpressionElementRole[_T],
+]
+
_ColumnsClauseArgument = Union[
Literal["*", 1],
roles.ColumnsClauseRole,
@@ -54,8 +69,31 @@ _ColumnsClauseArgument = Union[
Inspectable[_HasClauseElement],
_HasClauseElement,
]
+"""open-ended SELECT columns clause argument.
+
+Includes column expressions, tables, ORM mapped entities, a few literal values.
+
+This type is used for lists of columns / entities to be returned in result
+sets; select(...), insert().returning(...), etc.
+
+
+"""
+
+_ColumnExpressionArgument = Union[
+ "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
+]
+"""narrower "column expression" argument.
+
+This type is used for all the other "column" kinds of expressions that
+typically represent a single SQL column expression, not a set of columns the
+way a table or ORM entity does.
+
+This includes ColumnElement, or ORM-mapped attributes that will have a
+`__clause_element__()` method, it also has the ExpressionElementRole
+overall which brings in the TextClause object also.
+
+"""
-_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]]
_FromClauseArgument = Union[
roles.FromClauseRole,
@@ -63,28 +101,99 @@ _FromClauseArgument = Union[
Inspectable[_HasClauseElement],
_HasClauseElement,
]
+"""A FROM clause, like we would send to select().select_from().
-_ColumnExpressionArgument = Union[
- "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
+Also accommodates ORM entities and related constructs.
+
+"""
+
+_JoinTargetArgument = Union[_FromClauseArgument, roles.JoinTargetRole]
+"""target for join() builds on _FromClauseArgument to include additional
+join target roles such as those which come from the ORM.
+
+"""
+
+_OnClauseArgument = Union[_ColumnExpressionArgument[Any], roles.OnClauseRole]
+"""target for an ON clause, includes additional roles such as those which
+come from the ORM.
+
+"""
+
+_SelectStatementForCompoundArgument = Union[
+ "SelectBase", roles.CompoundElementRole
+]
+"""SELECT statement acceptable by ``union()`` and other SQL set operations"""
+
+_DMLColumnArgument = Union[
+ str, "ColumnClause[Any]", _HasClauseElement, roles.DMLColumnRole
]
+"""A DML column expression. This is a "key" inside of insert().values(),
+update().values(), and related.
+
+These are usually strings or SQL table columns.
+
+There's also edge cases like JSON expression assignment, which we would want
+the DMLColumnRole to be able to accommodate.
-_DMLColumnArgument = Union[str, "ColumnClause[Any]", _HasClauseElement]
+"""
+
+
+_DMLTableArgument = Union[
+ "TableClause",
+ "Join",
+ "Alias",
+ Type[Any],
+ Inspectable[_HasClauseElement],
+ _HasClauseElement,
+]
_PropagateAttrsType = util.immutabledict[str, Any]
_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
+if TYPE_CHECKING:
-def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]:
- return t.named_with_column
+ def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]:
+ ...
+ def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
+ ...
-def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
- return c._is_column_element
+ def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
+ ...
+ def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]:
+ ...
-def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
- return c._is_text_clause
+ def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
+ ...
+
+ def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
+ ...
+
+ def is_select_base(t: ReturnsRows) -> TypeGuard[SelectBase]:
+ ...
+
+ def is_select_statement(t: ReturnsRows) -> TypeGuard[Select]:
+ ...
+
+ def is_table(t: FromClause) -> TypeGuard[TableClause]:
+ ...
+
+ def is_subquery(t: FromClause) -> TypeGuard[Subquery]:
+ ...
+
+else:
+ is_named_from_clause = operator.attrgetter("named_with_column")
+ is_column_element = operator.attrgetter("_is_column_element")
+ is_text_clause = operator.attrgetter("_is_text_clause")
+ is_from_clause = operator.attrgetter("_is_from_clause")
+ is_tuple_type = operator.attrgetter("_is_tuple_type")
+ is_table_value_type = operator.attrgetter("_is_table_value")
+ is_select_base = operator.attrgetter("_is_select_base")
+ is_select_statement = operator.attrgetter("_is_select_statement")
+ is_table = operator.attrgetter("_is_table")
+ is_subquery = operator.attrgetter("_is_subquery")
def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]:
@@ -95,9 +204,5 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
return hasattr(s, "quote")
-def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
- return t._is_tuple_type
-
-
def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
return hasattr(s, "__clause_element__")
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index f1919d1d3..fa36c09fc 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -59,7 +59,9 @@ class SupportsAnnotations(ExternallyTraversible):
_is_immutable: bool
- def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations:
+ def _annotate(
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
raise NotImplementedError()
@overload
@@ -105,11 +107,6 @@ class SupportsAnnotations(ExternallyTraversible):
)
-SelfSupportsCloneAnnotations = TypeVar(
- "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations"
-)
-
-
class SupportsCloneAnnotations(SupportsAnnotations):
if not typing.TYPE_CHECKING:
__slots__ = ()
@@ -119,8 +116,8 @@ class SupportsCloneAnnotations(SupportsAnnotations):
]
def _annotate(
- self: SelfSupportsCloneAnnotations, values: _AnnotationDict
- ) -> SelfSupportsCloneAnnotations:
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -132,8 +129,8 @@ class SupportsCloneAnnotations(SupportsAnnotations):
return new
def _with_annotations(
- self: SelfSupportsCloneAnnotations, values: _AnnotationDict
- ) -> SelfSupportsCloneAnnotations:
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
@@ -184,11 +181,6 @@ class SupportsCloneAnnotations(SupportsAnnotations):
return self
-SelfSupportsWrappingAnnotations = TypeVar(
- "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations"
-)
-
-
class SupportsWrappingAnnotations(SupportsAnnotations):
__slots__ = ()
@@ -200,19 +192,23 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
def entity_namespace(self) -> _EntityNamespace:
...
- def _annotate(self, values: _AnnotationDict) -> Annotated:
+ def _annotate(
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
"""
- return Annotated._as_annotated_instance(self, values)
+ return Annotated._as_annotated_instance(self, values) # type: ignore
- def _with_annotations(self, values: _AnnotationDict) -> Annotated:
+ def _with_annotations(
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
"""
- return Annotated._as_annotated_instance(self, values)
+ return Annotated._as_annotated_instance(self, values) # type: ignore
@overload
def _deannotate(
@@ -306,16 +302,17 @@ class Annotated(SupportsAnnotations):
self: SelfAnnotated, values: _AnnotationDict
) -> SelfAnnotated:
_values = self._annotations.union(values)
- return self._with_annotations(_values)
+ new: SelfAnnotated = self._with_annotations(_values) # type: ignore
+ return new
def _with_annotations(
- self: SelfAnnotated, values: util.immutabledict[str, Any]
- ) -> SelfAnnotated:
+ self: SelfAnnotated, values: _AnnotationDict
+ ) -> SupportsAnnotations:
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotations_cache_key", None)
clone.__dict__.pop("_generate_cache_key", None)
- clone._annotations = values
+ clone._annotations = util.immutabledict(values)
return clone
@overload
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 19e4c13d2..6b25d8fcd 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -63,12 +63,14 @@ if TYPE_CHECKING:
from . import elements
from . import type_api
from ._typing import _ColumnsClauseArgument
- from ._typing import _SelectIterable
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
+ from .elements import TextClause
+ from .selectable import _JoinTargetElement
+ from .selectable import _SelectIterable
from .selectable import FromClause
from ..engine import Connection
from ..engine import Result
@@ -167,7 +169,11 @@ class SingletonConstant(Immutable):
cls._singleton = obj
-def _from_objects(*elements: ColumnElement[Any]) -> Iterator[FromClause]:
+def _from_objects(
+ *elements: Union[
+ ColumnElement[Any], FromClause, TextClause, _JoinTargetElement
+ ]
+) -> Iterator[FromClause]:
return itertools.chain.from_iterable(
[element._from_objects for element in elements]
)
@@ -255,6 +261,11 @@ def _expand_cloned(elements):
predecessors.
"""
+ # TODO: cython candidate
+ # and/or change approach: in
+ # https://gerrit.sqlalchemy.org/c/sqlalchemy/sqlalchemy/+/3712 we propose
+ # getting rid of _cloned_set.
+ # turning this into chain.from_iterable adds all kinds of callcount
return itertools.chain(*[x._cloned_set for x in elements])
@@ -1559,6 +1570,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
was moved onto the :class:`_expression.ColumnCollection` itself.
"""
+ # TODO: cython candidate
+
+ # don't dig around if the column is locally present
+ if column in self._colset:
+ return column
def embedded(expanded_proxy_set, target_set):
for t in target_set.difference(expanded_proxy_set):
@@ -1568,9 +1584,6 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
return False
return True
- # don't dig around if the column is locally present
- if column in self._colset:
- return column
col, intersect = None, None
target_set = column.proxy_set
cols = [c for (k, c) in self._collection]
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index 19a232c56..1f8b9c19e 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -54,6 +54,11 @@ class CacheConst(enum.Enum):
NO_CACHE = CacheConst.NO_CACHE
+_CacheKeyTraversalType = Union[
+ "_TraverseInternalsType", Literal[CacheConst.NO_CACHE], Literal[None]
+]
+
+
class CacheTraverseTarget(enum.Enum):
CACHE_IN_PLACE = 0
CALL_GEN_CACHE_KEY = 1
@@ -89,9 +94,7 @@ class HasCacheKey:
__slots__ = ()
- _cache_key_traversal: Union[
- _TraverseInternalsType, Literal[CacheConst.NO_CACHE], Literal[None]
- ] = NO_CACHE
+ _cache_key_traversal: _CacheKeyTraversalType = NO_CACHE
_is_has_cache_key = True
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index ccc8fba8d..4c71ca38b 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -12,7 +12,6 @@ import numbers
import re
import typing
from typing import Any
-from typing import Any as TODO_Any
from typing import Callable
from typing import Dict
from typing import List
@@ -20,7 +19,9 @@ from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import operators
from . import roles
@@ -32,6 +33,7 @@ from .visitors import Visitable
from .. import exc
from .. import inspection
from .. import util
+from ..util.typing import Literal
if not typing.TYPE_CHECKING:
elements = None
@@ -46,12 +48,26 @@ if typing.TYPE_CHECKING:
from . import schema
from . import selectable
from . import traversals
+ from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnsClauseArgument
+ from ._typing import _DMLTableArgument
+ from ._typing import _FromClauseArgument
+ from .dml import _DMLTableElement
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import DQLDMLClauseElement
from .elements import SQLCoreOperations
-
+ from .schema import Column
+ from .selectable import _ColumnsClauseElement
+ from .selectable import _JoinTargetElement
+ from .selectable import _JoinTargetProtocol
+ from .selectable import _OnClauseElement
+ from .selectable import FromClause
+ from .selectable import HasCTE
+ from .selectable import SelectBase
+ from .selectable import Subquery
+ from .visitors import _TraverseCallableType
_SR = TypeVar("_SR", bound=roles.SQLRole)
_F = TypeVar("_F", bound=Callable[..., Any])
@@ -143,10 +159,6 @@ def _expression_collection_was_a_list(attrname, fnname, args):
def expect(
role: Type[roles.TruncatedLabelRole],
element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
**kw: Any,
) -> str:
...
@@ -154,12 +166,30 @@ def expect(
@overload
def expect(
- role: Type[roles.ExpressionElementRole[_T]],
+ role: Type[roles.StatementOptionRole],
element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
+ **kw: Any,
+) -> DQLDMLClauseElement:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.DDLReferredColumnRole],
+ element: Any,
+ **kw: Any,
+) -> Column[Any]:
+ ...
+
+
+@overload
+def expect(
+ role: Union[
+ Type[roles.ExpressionElementRole[Any]],
+ Type[roles.LimitOffsetRole],
+ Type[roles.WhereHavingRole],
+ ],
+ element: _ColumnExpressionArgument[_T],
**kw: Any,
) -> ColumnElement[_T]:
...
@@ -167,40 +197,89 @@ def expect(
@overload
def expect(
- role: Type[roles.DMLTableRole],
+ role: Union[
+ Type[roles.ExpressionElementRole[Any]],
+ Type[roles.LimitOffsetRole],
+ Type[roles.WhereHavingRole],
+ ],
element: Any,
+ **kw: Any,
+) -> ColumnElement[Any]:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.DMLTableRole],
+ element: _DMLTableArgument,
+ **kw: Any,
+) -> _DMLTableElement:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.HasCTERole],
+ element: HasCTE,
+ **kw: Any,
+) -> HasCTE:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.SelectStatementRole],
+ element: SelectBase,
+ **kw: Any,
+) -> SelectBase:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.FromClauseRole],
+ element: _FromClauseArgument,
+ **kw: Any,
+) -> FromClause:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.FromClauseRole],
+ element: SelectBase,
*,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
+ explicit_subquery: Literal[True] = ...,
**kw: Any,
-) -> roles.DMLTableRole:
+) -> Subquery:
...
@overload
def expect(
role: Type[roles.ColumnsClauseRole],
- element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
+ element: _ColumnsClauseArgument,
+ **kw: Any,
+) -> _ColumnsClauseElement:
+ ...
+
+
+@overload
+def expect(
+ role: Union[Type[roles.JoinTargetRole], Type[roles.OnClauseRole]],
+ element: _JoinTargetProtocol,
**kw: Any,
-) -> roles.ColumnsClauseRole:
+) -> _JoinTargetProtocol:
...
+# catchall for not-yet-implemented overloads
@overload
def expect(
role: Type[_SR],
element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
**kw: Any,
-) -> TODO_Any:
+) -> Any:
...
@@ -212,7 +291,7 @@ def expect(
argname: Optional[str] = None,
post_inspect: bool = False,
**kw: Any,
-) -> TODO_Any:
+) -> Any:
if (
role.allows_lambda
# note callable() will not invoke a __getattr__() method, whereas
@@ -329,7 +408,8 @@ def expect_col_expression_collection(role, expressions):
strname = resolved = expr
else:
cols: List[ColumnClause[Any]] = []
- visitors.traverse(resolved, {}, {"column": cols.append})
+ col_append: _TraverseCallableType[ColumnClause[Any]] = cols.append
+ visitors.traverse(resolved, {}, {"column": col_append})
if cols:
column = cols[0]
add_element = column if column is not None else strname
@@ -432,7 +512,7 @@ class _ColumnCoercions(RoleImpl):
original_element = element
if not getattr(resolved, "is_clause_element", False):
self._raise_for_expected(original_element, argname, resolved)
- elif resolved._is_select_statement:
+ elif resolved._is_select_base:
self._warn_for_scalar_subquery_coercion()
return resolved.scalar_subquery()
elif resolved._is_from_clause and isinstance(
@@ -670,7 +750,7 @@ class InElementImpl(RoleImpl):
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
- and resolved.element._is_select_statement
+ and resolved.element._is_select_base
):
self._warn_for_implicit_coercion(resolved)
return self._post_coercion(resolved.element, **kw)
@@ -722,7 +802,7 @@ class InElementImpl(RoleImpl):
self._raise_for_expected(element, **kw)
def _post_coercion(self, element, expr, operator, **kw):
- if element._is_select_statement:
+ if element._is_select_base:
# for IN, we are doing scalar_subquery() coercion without
# a warning
return element.scalar_subquery()
@@ -1085,7 +1165,7 @@ class JoinTargetImpl(RoleImpl):
# #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
# were set to False.
return element
- elif legacy and resolved._is_select_statement:
+ elif legacy and resolved._is_select_base:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT "
"constructs into FROM clauses is deprecated; please call "
@@ -1114,7 +1194,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
allow_select: bool = True,
**kw: Any,
) -> Any:
- if resolved._is_select_statement:
+ if resolved._is_select_base:
if explicit_subquery:
return resolved.subquery()
elif allow_select:
@@ -1150,7 +1230,7 @@ class StrictFromClauseImpl(FromClauseImpl):
allow_select: bool = False,
**kw: Any,
) -> Any:
- if resolved._is_select_statement and allow_select:
+ if resolved._is_select_base and allow_select:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT constructs "
"into FROM clauses is deprecated; please call .subquery() "
@@ -1195,7 +1275,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
- and resolved.element._is_select_statement
+ and resolved.element._is_select_base
):
return resolved.element
else:
@@ -1235,3 +1315,9 @@ for name in dir(roles):
if name in globals():
impl = globals()[name](cls)
_impl_lookup[cls] = impl
+
+if not TYPE_CHECKING:
+ ee_impl = _impl_lookup[roles.ExpressionElementRole]
+
+ for py_type in (int, bool, str, float):
+ _impl_lookup[roles.ExpressionElementRole[py_type]] = ee_impl
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b7f6d11f6..6ecfbf986 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -313,12 +313,12 @@ EXTRACT_MAP = {
}
COMPOUND_KEYWORDS = {
- selectable.CompoundSelect.UNION: "UNION",
- selectable.CompoundSelect.UNION_ALL: "UNION ALL",
- selectable.CompoundSelect.EXCEPT: "EXCEPT",
- selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
- selectable.CompoundSelect.INTERSECT: "INTERSECT",
- selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
+ selectable._CompoundSelectKeyword.UNION: "UNION",
+ selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL",
+ selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT",
+ selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL",
+ selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT",
+ selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL",
}
@@ -1468,6 +1468,10 @@ class SQLCompiler(Compiled):
self.post_compile_params = frozenset()
for key in expanded_state.parameter_expansion:
bind = self.binds.pop(key)
+
+ if TYPE_CHECKING:
+ assert bind.value is not None
+
self.bind_names.pop(bind)
for value, expanded_key in zip(
bind.value, expanded_state.parameter_expansion[key]
@@ -3089,12 +3093,7 @@ class SQLCompiler(Compiled):
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
if cte.recursive:
- if isinstance(cte.element, selectable.Select):
- col_source = cte.element
- elif isinstance(cte.element, selectable.CompoundSelect):
- col_source = cte.element.selects[0]
- else:
- assert False, "cte should only be against SelectBase"
+ col_source = cte.element
# TODO: can we get at the .columns_plus_names collection
# that is already (or will be?) generated for the SELECT
@@ -3315,7 +3314,9 @@ class SQLCompiler(Compiled):
for elem in chunk
)
- if isinstance(element.name, elements._truncated_label):
+ if element._unnamed:
+ name = None
+ elif isinstance(element.name, elements._truncated_label):
name = self._truncated_identifier("values", element.name)
else:
name = element.name
@@ -3980,7 +3981,7 @@ class SQLCompiler(Compiled):
clause = " ".join(
prefix._compiler_dispatch(self, **kw)
for prefix, dialect_name in prefixes
- if dialect_name is None or dialect_name == self.dialect.name
+ if dialect_name in (None, "*") or dialect_name == self.dialect.name
)
if clause:
clause += " "
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 0c9056aee..8a3a1b38f 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -45,11 +45,14 @@ from .base import Executable
from .base import HasCompileState
from .elements import BooleanClauseList
from .elements import ClauseElement
+from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Null
+from .selectable import Alias
from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
+from .selectable import Join
from .selectable import ReturnsRows
from .selectable import TableClause
from .sqltypes import NullType
@@ -59,15 +62,15 @@ from .. import util
from ..util.typing import TypeGuard
if TYPE_CHECKING:
-
+ from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnsClauseArgument
from ._typing import _DMLColumnArgument
+ from ._typing import _DMLTableArgument
from ._typing import _FromClauseArgument
- from ._typing import _HasClauseElement
- from ._typing import _SelectIterable
from .base import ReadOnlyColumnCollection
from .compiler import SQLCompiler
- from .elements import ColumnClause
+ from .selectable import _ColumnsClauseElement
+ from .selectable import _SelectIterable
from .selectable import Select
def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]:
@@ -85,7 +88,8 @@ else:
isinsert = operator.attrgetter("isinsert")
-_DMLColumnElement = Union[str, "ColumnClause[Any]"]
+_DMLColumnElement = Union[str, ColumnClause[Any]]
+_DMLTableElement = Union[TableClause, Alias, Join]
class DMLState(CompileState):
@@ -132,7 +136,7 @@ class DMLState(CompileState):
]
@property
- def dml_table(self) -> roles.DMLTableRole:
+ def dml_table(self) -> _DMLTableElement:
return self.statement.table
if TYPE_CHECKING:
@@ -322,17 +326,17 @@ class UpdateBase(
__visit_name__ = "update_base"
_hints: util.immutabledict[
- Tuple[roles.DMLTableRole, str], str
+ Tuple[_DMLTableElement, str], str
] = util.EMPTY_DICT
named_with_column = False
- table: roles.DMLTableRole
+ table: _DMLTableElement
_return_defaults = False
_return_defaults_columns: Optional[
- Tuple[roles.ColumnsClauseRole, ...]
+ Tuple[_ColumnsClauseElement, ...]
] = None
- _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
+ _returning: Tuple[_ColumnsClauseElement, ...] = ()
is_dml = True
@@ -483,7 +487,7 @@ class UpdateBase(
def with_hint(
self: SelfUpdateBase,
text: str,
- selectable: Optional[roles.DMLTableRole] = None,
+ selectable: Optional[_DMLTableArgument] = None,
dialect_name: str = "*",
) -> SelfUpdateBase:
"""Add a table hint for a single table to this
@@ -517,7 +521,8 @@ class UpdateBase(
"""
if selectable is None:
selectable = self.table
-
+ else:
+ selectable = coercions.expect(roles.DMLTableRole, selectable)
self._hints = self._hints.union({(selectable, dialect_name): text})
return self
@@ -636,9 +641,9 @@ class ValuesBase(UpdateBase):
_select_names: Optional[List[str]] = None
_inline: bool = False
- _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
+ _returning: Tuple[_ColumnsClauseElement, ...] = ()
- def __init__(self, table: _FromClauseArgument):
+ def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
@@ -970,7 +975,7 @@ class Insert(ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table: roles.FromClauseRole):
+ def __init__(self, table: _DMLTableArgument):
super(Insert, self).__init__(table)
@_generative
@@ -1066,12 +1071,12 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
class DMLWhereBase:
- table: roles.DMLTableRole
+ table: _DMLTableElement
_where_criteria: Tuple[ColumnElement[Any], ...] = ()
@_generative
def where(
- self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any]
+ self: SelfDMLWhereBase, *whereclause: _ColumnExpressionArgument[bool]
) -> SelfDMLWhereBase:
"""Return a new construct with the given expression(s) added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -1104,7 +1109,9 @@ class DMLWhereBase:
"""
for criterion in whereclause:
- where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ where_criteria: ColumnElement[Any] = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
self._where_criteria += (where_criteria,)
return self
@@ -1119,7 +1126,7 @@ class DMLWhereBase:
return self.where(*criteria)
- def _filter_by_zero(self) -> roles.DMLTableRole:
+ def _filter_by_zero(self) -> _DMLTableElement:
return self.table
def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase:
@@ -1189,7 +1196,7 @@ class Update(DMLWhereBase, ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table: roles.FromClauseRole):
+ def __init__(self, table: _DMLTableArgument):
super(Update, self).__init__(table)
@_generative
@@ -1279,7 +1286,7 @@ class Delete(DMLWhereBase, UpdateBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table: roles.FromClauseRole):
+ def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index c735085f8..aec29d1b2 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -26,6 +26,7 @@ from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
@@ -77,8 +78,8 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _PropagateAttrsType
- from ._typing import _SelectIterable
from ._typing import _TypeEngineArgument
+ from .cache_key import _CacheKeyTraversalType
from .cache_key import CacheKey
from .compiler import Compiled
from .compiler import SQLCompiler
@@ -88,6 +89,7 @@ if typing.TYPE_CHECKING:
from .schema import DefaultGenerator
from .schema import FetchedValue
from .schema import ForeignKey
+ from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import NamedFromClause
from .selectable import ReturnsRows
@@ -96,6 +98,7 @@ if typing.TYPE_CHECKING:
from .sqltypes import Boolean
from .sqltypes import TupleType
from .type_api import TypeEngine
+ from .visitors import _CloneCallableType
from .visitors import _TraverseInternalsType
from ..engine import Connection
from ..engine import Dialect
@@ -310,6 +313,7 @@ class ClauseElement(
_is_text_clause = False
_is_from_container = False
_is_select_container = False
+ _is_select_base = False
_is_select_statement = False
_is_bind_parameter = False
_is_clause_list = False
@@ -321,7 +325,7 @@ class ClauseElement(
def _order_by_label_element(self) -> Optional[Label[Any]]:
return None
- _cache_key_traversal = None
+ _cache_key_traversal: _CacheKeyTraversalType = None
negation_clause: ColumnElement[bool]
@@ -528,7 +532,7 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def self_group(self, against=None):
+ def self_group(self, against: Optional[OperatorType] = None) -> Any:
"""Apply a 'grouping' to this :class:`_expression.ClauseElement`.
This method is overridden by subclasses to return a "grouping"
@@ -637,9 +641,9 @@ class ClauseElement(
return self._negate()
def _negate(self) -> ClauseElement:
- return UnaryExpression(
- self.self_group(against=operators.inv), operator=operators.inv
- )
+ grouped = self.self_group(against=operators.inv)
+ assert isinstance(grouped, ColumnElement)
+ return UnaryExpression(grouped, operator=operators.inv)
def __bool__(self):
raise TypeError("Boolean value of this clause is not defined")
@@ -1290,12 +1294,6 @@ class ColumnElement(
@overload
def self_group(
- self: ColumnElement[bool], against: Optional[OperatorType] = None
- ) -> ColumnElement[bool]:
- ...
-
- @overload
- def self_group(
self: ColumnElement[Any], against: Optional[OperatorType] = None
) -> ColumnElement[Any]:
...
@@ -1764,6 +1762,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
key: str
type: TypeEngine[_T]
+ value: Optional[_T]
_is_crud = False
_is_bind_parameter = True
@@ -1883,7 +1882,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
return cloned
@property
- def effective_value(self):
+ def effective_value(self) -> Optional[_T]:
"""Return the value of this bound parameter,
taking into account if the ``callable`` parameter
was set.
@@ -1893,11 +1892,12 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
"""
if self.callable:
- return self.callable()
+ # TODO: set up protocol for bind parameter callable
+ return self.callable() # type: ignore
else:
return self.value
- def render_literal_execute(self):
+ def render_literal_execute(self) -> BindParameter[_T]:
"""Produce a copy of this bound parameter that will enable the
:paramref:`_sql.BindParameter.literal_execute` flag.
@@ -2513,8 +2513,10 @@ class ClauseList(
self.operator = operator
self.group = group
self.group_contents = group_contents
+ clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses
if _flatten_sub_clauses:
- clauses = util.flatten_iterator(clauses)
+ clauses_iterator = util.flatten_iterator(clauses_iterator)
+
self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
@@ -2523,31 +2525,35 @@ class ClauseList(
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
).self_group(against=self.operator)
- for clause in clauses
+ for clause in clauses_iterator
]
else:
self.clauses = [
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
)
- for clause in clauses
+ for clause in clauses_iterator
]
self._is_implicitly_boolean = operators.is_boolean(self.operator)
@classmethod
- def _construct_raw(cls, operator, clauses=None):
+ def _construct_raw(
+ cls,
+ operator: OperatorType,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
+ ) -> ClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
self._is_implicitly_boolean = False
return self
- def __iter__(self):
+ def __iter__(self) -> Iterator[ColumnElement[Any]]:
return iter(self.clauses)
- def __len__(self):
+ def __len__(self) -> int:
return len(self.clauses)
@property
@@ -2708,10 +2714,10 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
def _construct_raw(
cls,
operator: OperatorType,
- clauses: Optional[List[ColumnElement[Any]]] = None,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
) -> BooleanClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
@@ -2781,7 +2787,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
sqltypes = util.preloaded.sql_sqltypes
if types is None:
- init_clauses = [
+ init_clauses: List[ColumnElement[Any]] = [
coercions.expect(roles.ExpressionElementRole, c)
for c in clauses
]
@@ -2908,7 +2914,7 @@ class Case(ColumnElement[_T]):
]
if whenlist:
- type_ = list(whenlist[-1])[-1].type
+ type_ = whenlist[-1][-1].type
else:
type_ = None
@@ -3098,6 +3104,8 @@ class _label_reference(ColumnElement[_T]):
("element", InternalTraversal.dp_clauseelement)
]
+ element: ColumnElement[_T]
+
def __init__(self, element: ColumnElement[_T]):
self.element = element
@@ -3212,7 +3220,9 @@ class UnaryExpression(ColumnElement[_T]):
cls,
expr: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
- col_expr = coercions.expect(roles.ExpressionElementRole, expr)
+ col_expr: ColumnElement[_T] = coercions.expect(
+ roles.ExpressionElementRole, expr
+ )
return UnaryExpression(
col_expr,
operator=operators.distinct_op,
@@ -3265,7 +3275,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_any(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3281,7 +3291,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_all(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3374,6 +3384,9 @@ class BinaryExpression(ColumnElement[_T]):
modifiers: Optional[Mapping[str, Any]]
+ left: ColumnElement[Any]
+ right: Union[ColumnElement[Any], ClauseList]
+
def __init__(
self,
left: ColumnElement[Any],
@@ -4147,7 +4160,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
def foreign_keys(self):
return self.element.foreign_keys
- def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ def _copy_internals(
+ self,
+ *,
+ clone: _CloneCallableType = _clone,
+ anonymize_labels: bool = False,
+ **kw: Any,
+ ) -> None:
self._reset_memoizations()
self._element = clone(self._element, **kw)
if anonymize_labels:
@@ -4447,7 +4466,9 @@ class TableValuedColumn(NamedColumn[_T]):
self.key = self.name = scalar_alias.name
self.type = type_
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
self.scalar_alias = clone(self.scalar_alias, **kw)
self.key = self.name = self.scalar_alias.name
@@ -4467,7 +4488,7 @@ class CollationClause(ColumnElement[str]):
def _create_collation_expression(
cls, expression: _ColumnExpressionArgument[str], collation: str
) -> BinaryExpression[str]:
- expr = coercions.expect(roles.ExpressionElementRole, expression)
+ expr = coercions.expect(roles.ExpressionElementRole[str], expression)
return BinaryExpression(
expr,
CollationClause(collation),
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 3bca8b502..db4bb5837 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -11,10 +11,15 @@
from __future__ import annotations
+import datetime
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Mapping
from typing import Optional
from typing import overload
-from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -24,7 +29,9 @@ from . import operators
from . import roles
from . import schema
from . import sqltypes
+from . import type_api
from . import util as sqlutil
+from ._typing import is_table_value_type
from .base import _entity_namespace
from .base import ColumnCollection
from .base import Executable
@@ -46,16 +53,21 @@ from .elements import WithinGroup
from .selectable import FromClause
from .selectable import Select
from .selectable import TableValuedAlias
+from .sqltypes import _N
+from .sqltypes import TableValueType
from .type_api import TypeEngine
from .visitors import InternalTraversal
from .. import util
+
if TYPE_CHECKING:
from ._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
-_registry = util.defaultdict(dict)
+_registry: util.defaultdict[
+ str, Dict[str, Type[Function[Any]]]
+] = util.defaultdict(dict)
def register_function(identifier, fn, package="_default"):
@@ -103,11 +115,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
("_table_value_type", InternalTraversal.dp_has_cache_key),
]
- packagenames = ()
+ packagenames: Tuple[str, ...] = ()
_has_args = False
_with_ordinality = False
- _table_value_type = None
+ _table_value_type: Optional[TableValueType] = None
+
+ # some attributes that are defined between both ColumnElement and
+ # FromClause are set to Any here to avoid typing errors
+ primary_key: Any
+ _is_clone_of: Any
+
+ clause_expr: Grouping[Any]
def __init__(self, *clauses: Any):
r"""Construct a :class:`.FunctionElement`.
@@ -135,9 +154,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
for c in clauses
]
self._has_args = self._has_args or bool(args)
- self.clause_expr = ClauseList(
- operator=operators.comma_op, group_contents=True, *args
- ).self_group()
+ self.clause_expr = Grouping(
+ ClauseList(operator=operators.comma_op, group_contents=True, *args)
+ )
_non_anon_label = None
@@ -263,9 +282,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
expr += (with_ordinality,)
new_func._with_ordinality = True
- new_func.type = new_func._table_value_type = sqltypes.TableValueType(
- *expr
- )
+ new_func.type = new_func._table_value_type = TableValueType(*expr)
return new_func.alias(name=name, joins_implicitly=joins_implicitly)
@@ -332,7 +349,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
@property
def _all_selected_columns(self):
- if self.type._is_table_value:
+ if is_table_value_type(self.type):
cols = self.type._elements
else:
cols = [self.label(None)]
@@ -344,12 +361,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
return self.columns
@HasMemoized.memoized_attribute
- def clauses(self):
+ def clauses(self) -> ClauseList:
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
"""
- return self.clause_expr.element
+ return cast(ClauseList, self.clause_expr.element)
def over(self, partition_by=None, order_by=None, rows=None, range_=None):
"""Produce an OVER clause against this function.
@@ -647,7 +664,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
return _entity_namespace(self.clause_expr)
-class FunctionAsBinary(BinaryExpression):
+class FunctionAsBinary(BinaryExpression[Any]):
_traverse_internals = [
("sql_function", InternalTraversal.dp_clauseelement),
("left_index", InternalTraversal.dp_plain_obj),
@@ -655,10 +672,16 @@ class FunctionAsBinary(BinaryExpression):
("modifiers", InternalTraversal.dp_plain_dict),
]
+ sql_function: FunctionElement[Any]
+ left_index: int
+ right_index: int
+
def _gen_cache_key(self, anon_map, bindparams):
return ColumnElement._gen_cache_key(self, anon_map, bindparams)
- def __init__(self, fn, left_index, right_index):
+ def __init__(
+ self, fn: FunctionElement[Any], left_index: int, right_index: int
+ ):
self.sql_function = fn
self.left_index = left_index
self.right_index = right_index
@@ -670,23 +693,30 @@ class FunctionAsBinary(BinaryExpression):
self.modifiers = {}
@property
- def left(self):
+ def left_expr(self) -> ColumnElement[Any]:
return self.sql_function.clauses.clauses[self.left_index - 1]
- @left.setter
- def left(self, value):
+ @left_expr.setter
+ def left_expr(self, value: ColumnElement[Any]) -> None:
self.sql_function.clauses.clauses[self.left_index - 1] = value
@property
- def right(self):
+ def right_expr(self) -> ColumnElement[Any]:
return self.sql_function.clauses.clauses[self.right_index - 1]
- @right.setter
- def right(self, value):
+ @right_expr.setter
+ def right_expr(self, value: ColumnElement[Any]) -> None:
self.sql_function.clauses.clauses[self.right_index - 1] = value
+ if not TYPE_CHECKING:
+ # mypy can't accommodate @property to replace an instance
+ # variable
+
+ left = left_expr
+ right = right_expr
+
-class ScalarFunctionColumn(NamedColumn):
+class ScalarFunctionColumn(NamedColumn[_T]):
__visit_name__ = "scalar_function_column"
_traverse_internals = [
@@ -698,10 +728,18 @@ class ScalarFunctionColumn(NamedColumn):
is_literal = False
table = None
- def __init__(self, fn, name, type_=None):
+ def __init__(
+ self,
+ fn: FunctionElement[_T],
+ name: str,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ ):
self.fn = fn
self.name = name
- self.type = sqltypes.to_instance(type_)
+
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
class _FunctionGenerator:
@@ -789,7 +827,7 @@ class _FunctionGenerator:
# passthru __ attributes; fixes pydoc
if name.startswith("__"):
try:
- return self.__dict__[name]
+ return self.__dict__[name] # type: ignore
except KeyError:
raise AttributeError(name)
@@ -883,8 +921,6 @@ class Function(FunctionElement[_T]):
identifier: str
- packagenames: Sequence[str]
-
type: TypeEngine[_T]
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
@@ -907,7 +943,7 @@ class Function(FunctionElement[_T]):
name: str,
*clauses: Any,
type_: Optional[_TypeEngineArgument[_T]] = None,
- packagenames: Optional[Sequence[str]] = None,
+ packagenames: Optional[Tuple[str, ...]] = None,
):
"""Construct a :class:`.Function`.
@@ -918,7 +954,9 @@ class Function(FunctionElement[_T]):
self.packagenames = packagenames or ()
self.name = name
- self.type = sqltypes.to_instance(type_)
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
FunctionElement.__init__(self, *clauses)
@@ -934,7 +972,7 @@ class Function(FunctionElement[_T]):
)
-class GenericFunction(Function):
+class GenericFunction(Function[_T]):
"""Define a 'generic' function.
A generic function is a pre-established :class:`.Function`
@@ -957,7 +995,7 @@ class GenericFunction(Function):
from sqlalchemy.types import DateTime
class as_utc(GenericFunction):
- type = DateTime
+ type = DateTime()
inherit_cache = True
print(select(func.as_utc()))
@@ -971,7 +1009,7 @@ class GenericFunction(Function):
"time"::
class as_utc(GenericFunction):
- type = DateTime
+ type = DateTime()
package = "time"
inherit_cache = True
@@ -987,7 +1025,7 @@ class GenericFunction(Function):
the usage of ``name`` as the rendered name::
class GeoBuffer(GenericFunction):
- type = Geometry
+ type = Geometry()
package = "geo"
name = "ST_Buffer"
identifier = "buffer"
@@ -1006,7 +1044,7 @@ class GenericFunction(Function):
from sqlalchemy.sql import quoted_name
class GeoBuffer(GenericFunction):
- type = Geometry
+ type = Geometry()
package = "geo"
name = quoted_name("ST_Buffer", True)
identifier = "buffer"
@@ -1028,6 +1066,8 @@ class GenericFunction(Function):
coerce_arguments = True
inherit_cache = True
+ _register: bool
+
name = "GenericFunction"
def __init_subclass__(cls) -> None:
@@ -1036,7 +1076,9 @@ class GenericFunction(Function):
super().__init_subclass__()
@classmethod
- def _register_generic_function(cls, clsname, clsdict):
+ def _register_generic_function(
+ cls, clsname: str, clsdict: Mapping[str, Any]
+ ) -> None:
cls.name = name = clsdict.get("name", clsname)
cls.identifier = identifier = clsdict.get("identifier", name)
package = clsdict.get("package", "_default")
@@ -1068,11 +1110,14 @@ class GenericFunction(Function):
]
self._has_args = self._has_args or bool(parsed_args)
self.packagenames = ()
- self.clause_expr = ClauseList(
- operator=operators.comma_op, group_contents=True, *parsed_args
- ).self_group()
- self.type = sqltypes.to_instance(
+ self.clause_expr = Grouping(
+ ClauseList(
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ )
+ )
+
+ self.type = type_api.to_instance( # type: ignore
kwargs.pop("type_", None) or getattr(self, "type", None)
)
@@ -1081,7 +1126,7 @@ register_function("cast", Cast)
register_function("extract", Extract)
-class next_value(GenericFunction):
+class next_value(GenericFunction[int]):
"""Represent the 'next value', given a :class:`.Sequence`
as its single argument.
@@ -1103,7 +1148,7 @@ class next_value(GenericFunction):
seq, schema.Sequence
), "next_value() accepts a Sequence object as input."
self.sequence = seq
- self.type = sqltypes.to_instance(
+ self.type = sqltypes.to_instance( # type: ignore
seq.data_type or getattr(self, "type", None)
)
@@ -1118,7 +1163,7 @@ class next_value(GenericFunction):
return []
-class AnsiFunction(GenericFunction):
+class AnsiFunction(GenericFunction[_T]):
"""Define a function in "ansi" format, which doesn't render parenthesis."""
inherit_cache = True
@@ -1127,13 +1172,13 @@ class AnsiFunction(GenericFunction):
GenericFunction.__init__(self, *args, **kwargs)
-class ReturnTypeFromArgs(GenericFunction):
+class ReturnTypeFromArgs(GenericFunction[_T]):
"""Define a function whose return type is the same as its arguments."""
inherit_cache = True
def __init__(self, *args, **kwargs):
- args = [
+ fn_args = [
coercions.expect(
roles.ExpressionElementRole,
c,
@@ -1142,35 +1187,35 @@ class ReturnTypeFromArgs(GenericFunction):
)
for c in args
]
- kwargs.setdefault("type_", _type_from_args(args))
- kwargs["_parsed_args"] = args
- super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
+ kwargs.setdefault("type_", _type_from_args(fn_args))
+ kwargs["_parsed_args"] = fn_args
+ super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs)
-class coalesce(ReturnTypeFromArgs):
+class coalesce(ReturnTypeFromArgs[_T]):
_has_args = True
inherit_cache = True
-class max(ReturnTypeFromArgs): # noqa A001
+class max(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL MAX() aggregate function."""
inherit_cache = True
-class min(ReturnTypeFromArgs): # noqa A001
+class min(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL MIN() aggregate function."""
inherit_cache = True
-class sum(ReturnTypeFromArgs): # noqa A001
+class sum(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL SUM() aggregate function."""
inherit_cache = True
-class now(GenericFunction):
+class now(GenericFunction[datetime.datetime]):
"""The SQL now() datetime function.
SQLAlchemy dialects will usually render this particular function
@@ -1178,11 +1223,11 @@ class now(GenericFunction):
"""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class concat(GenericFunction):
+class concat(GenericFunction[str]):
"""The SQL CONCAT() function, which concatenates strings.
E.g.::
@@ -1200,28 +1245,30 @@ class concat(GenericFunction):
"""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class char_length(GenericFunction):
+class char_length(GenericFunction[int]):
"""The CHAR_LENGTH() SQL function."""
- type = sqltypes.Integer
+ type = sqltypes.Integer()
inherit_cache = True
- def __init__(self, arg, **kwargs):
- GenericFunction.__init__(self, arg, **kwargs)
+ def __init__(self, arg, **kw):
+ # slight hack to limit to just one positional argument
+ # not sure why this one function has this special treatment
+ super().__init__(arg, **kw)
-class random(GenericFunction):
+class random(GenericFunction[float]):
"""The RANDOM() SQL function."""
_has_args = True
inherit_cache = True
-class count(GenericFunction):
+class count(GenericFunction[int]):
r"""The ANSI COUNT aggregate function. With no arguments,
emits COUNT \*.
@@ -1242,7 +1289,7 @@ class count(GenericFunction):
"""
- type = sqltypes.Integer
+ type = sqltypes.Integer()
inherit_cache = True
def __init__(self, expression=None, **kwargs):
@@ -1251,70 +1298,70 @@ class count(GenericFunction):
super(count, self).__init__(expression, **kwargs)
-class current_date(AnsiFunction):
+class current_date(AnsiFunction[datetime.date]):
"""The CURRENT_DATE() SQL function."""
- type = sqltypes.Date
+ type = sqltypes.Date()
inherit_cache = True
-class current_time(AnsiFunction):
+class current_time(AnsiFunction[datetime.time]):
"""The CURRENT_TIME() SQL function."""
- type = sqltypes.Time
+ type = sqltypes.Time()
inherit_cache = True
-class current_timestamp(AnsiFunction):
+class current_timestamp(AnsiFunction[datetime.datetime]):
"""The CURRENT_TIMESTAMP() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class current_user(AnsiFunction):
+class current_user(AnsiFunction[str]):
"""The CURRENT_USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class localtime(AnsiFunction):
+class localtime(AnsiFunction[datetime.datetime]):
"""The localtime() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class localtimestamp(AnsiFunction):
+class localtimestamp(AnsiFunction[datetime.datetime]):
"""The localtimestamp() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class session_user(AnsiFunction):
+class session_user(AnsiFunction[str]):
"""The SESSION_USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class sysdate(AnsiFunction):
+class sysdate(AnsiFunction[datetime.datetime]):
"""The SYSDATE() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class user(AnsiFunction):
+class user(AnsiFunction[str]):
"""The USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class array_agg(GenericFunction):
+class array_agg(GenericFunction[_T]):
"""Support for the ARRAY_AGG function.
The ``func.array_agg(expr)`` construct returns an expression of
@@ -1334,11 +1381,10 @@ class array_agg(GenericFunction):
"""
- type = sqltypes.ARRAY
inherit_cache = True
def __init__(self, *args, **kwargs):
- args = [
+ fn_args = [
coercions.expect(
roles.ExpressionElementRole, c, apply_propagate_attrs=self
)
@@ -1348,16 +1394,16 @@ class array_agg(GenericFunction):
default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
if "type_" not in kwargs:
- type_from_args = _type_from_args(args)
+ type_from_args = _type_from_args(fn_args)
if isinstance(type_from_args, sqltypes.ARRAY):
kwargs["type_"] = type_from_args
else:
kwargs["type_"] = default_array_type(type_from_args)
- kwargs["_parsed_args"] = args
- super(array_agg, self).__init__(*args, **kwargs)
+ kwargs["_parsed_args"] = fn_args
+ super(array_agg, self).__init__(*fn_args, **kwargs)
-class OrderedSetAgg(GenericFunction):
+class OrderedSetAgg(GenericFunction[_T]):
"""Define a function where the return type is based on the sort
expression type as defined by the expression passed to the
:meth:`.FunctionElement.within_group` method."""
@@ -1366,7 +1412,7 @@ class OrderedSetAgg(GenericFunction):
inherit_cache = True
def within_group_type(self, within_group):
- func_clauses = self.clause_expr.element
+ func_clauses = cast(ClauseList, self.clause_expr.element)
order_by = sqlutil.unwrap_order_by(within_group.order_by)
if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
return sqltypes.ARRAY(order_by[0].type)
@@ -1374,7 +1420,7 @@ class OrderedSetAgg(GenericFunction):
return order_by[0].type
-class mode(OrderedSetAgg):
+class mode(OrderedSetAgg[_T]):
"""Implement the ``mode`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1389,7 +1435,7 @@ class mode(OrderedSetAgg):
inherit_cache = True
-class percentile_cont(OrderedSetAgg):
+class percentile_cont(OrderedSetAgg[_T]):
"""Implement the ``percentile_cont`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1407,7 +1453,7 @@ class percentile_cont(OrderedSetAgg):
inherit_cache = True
-class percentile_disc(OrderedSetAgg):
+class percentile_disc(OrderedSetAgg[_T]):
"""Implement the ``percentile_disc`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1425,7 +1471,7 @@ class percentile_disc(OrderedSetAgg):
inherit_cache = True
-class rank(GenericFunction):
+class rank(GenericFunction[int]):
"""Implement the ``rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1441,7 +1487,7 @@ class rank(GenericFunction):
inherit_cache = True
-class dense_rank(GenericFunction):
+class dense_rank(GenericFunction[int]):
"""Implement the ``dense_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1457,7 +1503,7 @@ class dense_rank(GenericFunction):
inherit_cache = True
-class percent_rank(GenericFunction):
+class percent_rank(GenericFunction[_N]):
"""Implement the ``percent_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1469,11 +1515,11 @@ class percent_rank(GenericFunction):
"""
- type = sqltypes.Numeric()
+ type: sqltypes.Numeric[_N] = sqltypes.Numeric()
inherit_cache = True
-class cume_dist(GenericFunction):
+class cume_dist(GenericFunction[_N]):
"""Implement the ``cume_dist`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1485,11 +1531,11 @@ class cume_dist(GenericFunction):
"""
- type = sqltypes.Numeric()
+ type: sqltypes.Numeric[_N] = sqltypes.Numeric()
inherit_cache = True
-class cube(GenericFunction):
+class cube(GenericFunction[_T]):
r"""Implement the ``CUBE`` grouping operation.
This function is used as part of the GROUP BY of a statement,
@@ -1506,7 +1552,7 @@ class cube(GenericFunction):
inherit_cache = True
-class rollup(GenericFunction):
+class rollup(GenericFunction[_T]):
r"""Implement the ``ROLLUP`` grouping operation.
This function is used as part of the GROUP BY of a statement,
@@ -1523,7 +1569,7 @@ class rollup(GenericFunction):
inherit_cache = True
-class grouping_sets(GenericFunction):
+class grouping_sets(GenericFunction[_T]):
r"""Implement the ``GROUPING SETS`` grouping operation.
This function is used as part of the GROUP BY of a statement,
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index 9d011ef53..da15c305f 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -12,6 +12,18 @@ import inspect
import itertools
import operator
import types
+from types import CodeType
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Iterable
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
import weakref
from . import cache_key as _cache_key
@@ -19,37 +31,62 @@ from . import coercions
from . import elements
from . import roles
from . import schema
-from . import traversals
from . import type_api
from . import visitors
from .base import _clone
+from .base import Executable
from .base import Options
+from .cache_key import CacheConst
from .operators import ColumnOperators
from .. import exc
from .. import inspection
from .. import util
+from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import Self
-_closure_per_cache_key = util.LRUCache(1000)
+if TYPE_CHECKING:
+ from .cache_key import CacheConst
+ from .cache_key import NO_CACHE
+ from .elements import BindParameter
+ from .elements import ClauseElement
+ from .roles import SQLRole
+ from .visitors import _CloneCallableType
+
+_LambdaCacheType = MutableMapping[
+ Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"]
+]
+_BoundParameterGetter = Callable[..., Any]
+
+_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
+
+
+class _LambdaType(Protocol):
+ __code__: CodeType
+ __closure__: Iterable[Tuple[Any, Any]]
+
+ def __call__(self, *arg: Any, **kw: Any) -> ClauseElement:
+ ...
class LambdaOptions(Options):
enable_tracking = True
track_closure_variables = True
- track_on = None
+ track_on: Optional[object] = None
global_track_bound_values = True
track_bound_values = True
- lambda_cache = None
+ lambda_cache: Optional[_LambdaCacheType] = None
def lambda_stmt(
- lmb,
- enable_tracking=True,
- track_closure_variables=True,
- track_on=None,
- global_track_bound_values=True,
- track_bound_values=True,
- lambda_cache=None,
-):
+ lmb: _LambdaType,
+ enable_tracking: bool = True,
+ track_closure_variables: bool = True,
+ track_on: Optional[object] = None,
+ global_track_bound_values: bool = True,
+ track_bound_values: bool = True,
+ lambda_cache: Optional[_LambdaCacheType] = None,
+) -> StatementLambdaElement:
"""Produce a SQL statement that is cached as a lambda.
The Python code object within the lambda is scanned for both Python
@@ -142,15 +179,28 @@ class LambdaElement(elements.ClauseElement):
("_resolved", visitors.InternalTraversal.dp_clauseelement)
]
- _transforms = ()
+ _transforms: Tuple[_CloneCallableType, ...] = ()
- parent_lambda = None
+ _resolved_bindparams: List[BindParameter[Any]]
+ parent_lambda: Optional[StatementLambdaElement] = None
+ closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
+ role: Type[SQLRole]
+ _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
+ fn: _LambdaType
+ tracker_key: Tuple[CodeType, ...]
def __repr__(self):
- return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
+ return "%s(%r)" % (
+ self.__class__.__name__,
+ self.fn.__code__,
+ )
def __init__(
- self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
+ self,
+ fn: _LambdaType,
+ role: Type[SQLRole],
+ opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
):
self.fn = fn
self.role = role
@@ -182,6 +232,7 @@ class LambdaElement(elements.ClauseElement):
opts,
)
+ bindparams: List[BindParameter[Any]]
self._resolved_bindparams = bindparams = []
if self.parent_lambda is not None:
@@ -189,8 +240,10 @@ class LambdaElement(elements.ClauseElement):
else:
parent_closure_cache_key = ()
+ cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
+
if parent_closure_cache_key is not _cache_key.NO_CACHE:
- anon_map = traversals.anon_map()
+ anon_map = visitors.anon_map()
cache_key = tuple(
[
getter(closure, opts, anon_map, bindparams)
@@ -241,7 +294,7 @@ class LambdaElement(elements.ClauseElement):
if self.parent_lambda is not None:
bindparams[:0] = self.parent_lambda._resolved_bindparams
- lambda_element = self
+ lambda_element: Optional[LambdaElement] = self
while lambda_element is not None:
rec = lambda_element._rec
if rec.bindparam_trackers:
@@ -289,17 +342,21 @@ class LambdaElement(elements.ClauseElement):
def _setup_binds_for_tracked_expr(self, expr):
bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
- def replace(thing):
- if isinstance(thing, elements.BindParameter):
+ def replace(
+ element: Optional[visitors.ExternallyTraversible], **kw: Any
+ ) -> Optional[visitors.ExternallyTraversible]:
+ if isinstance(element, elements.BindParameter):
- if thing.key in bindparam_lookup:
- bind = bindparam_lookup[thing.key]
- if thing.expanding:
+ if element.key in bindparam_lookup:
+ bind = bindparam_lookup[element.key]
+ if element.expanding:
bind.expanding = True
- bind.expand_op = thing.expand_op
- bind.type = thing.type
+ bind.expand_op = element.expand_op
+ bind.type = element.type
return bind
+ return None
+
if self._rec.is_sequence:
expr = [
visitors.replacement_traverse(sub_expr, {}, replace)
@@ -311,8 +368,11 @@ class LambdaElement(elements.ClauseElement):
return expr
def _copy_internals(
- self, clone=_clone, deferred_copy_internals=None, **kw
- ):
+ self: Self,
+ clone: _CloneCallableType = _clone,
+ deferred_copy_internals: Optional[_CloneCallableType] = None,
+ **kw: Any,
+ ) -> None:
# TODO: this needs A LOT of tests
self._resolved = clone(
self._resolved,
@@ -340,9 +400,15 @@ class LambdaElement(elements.ClauseElement):
) + self.closure_cache_key
parent = self.parent_lambda
+
while parent is not None:
+ assert parent.closure_cache_key is not CacheConst.NO_CACHE
+ parent_closure_cache_key: Tuple[
+ Any, ...
+ ] = parent.closure_cache_key
+
cache_key = (
- (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+ (parent.fn.__code__,) + parent_closure_cache_key + cache_key
)
parent = parent.parent_lambda
@@ -351,7 +417,7 @@ class LambdaElement(elements.ClauseElement):
bindparams.extend(self._resolved_bindparams)
return cache_key
- def _invoke_user_fn(self, fn, *arg):
+ def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement:
return fn()
@@ -365,7 +431,13 @@ class DeferredLambdaElement(LambdaElement):
"""
- def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
+ def __init__(
+ self,
+ fn: _LambdaType,
+ role: Type[roles.SQLRole],
+ opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
+ lambda_args: Tuple[Any, ...] = (),
+ ):
self.lambda_args = lambda_args
super(DeferredLambdaElement, self).__init__(fn, role, opts)
@@ -373,6 +445,7 @@ class DeferredLambdaElement(LambdaElement):
return fn(*self.lambda_args)
def _resolve_with_args(self, *lambda_args):
+ assert isinstance(self._rec, AnalyzedFunction)
tracker_fn = self._rec.tracker_instrumented_fn
expr = tracker_fn(*lambda_args)
@@ -506,6 +579,8 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
def _execute_on_connection(
self, connection, distilled_params, execution_options
):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, ClauseElement)
if self._rec.expected_expr.supports_execution:
return connection._execute_clauseelement(
self, distilled_params, execution_options
@@ -515,14 +590,20 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
@property
def _with_options(self):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, Executable)
return self._rec.expected_expr._with_options
@property
def _effective_plugin_target(self):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, Executable)
return self._rec.expected_expr._effective_plugin_target
@property
def _execution_options(self):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, Executable)
return self._rec.expected_expr._execution_options
def spoil(self):
@@ -583,9 +664,14 @@ class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
class LinkedLambdaElement(StatementLambdaElement):
"""Represent subsequent links of a :class:`.StatementLambdaElement`."""
- role = None
+ parent_lambda: StatementLambdaElement
- def __init__(self, fn, parent_lambda, opts):
+ def __init__(
+ self,
+ fn: _LambdaType,
+ parent_lambda: StatementLambdaElement,
+ opts: Union[Type[LambdaOptions], LambdaOptions],
+ ):
self.opts = opts
self.fn = fn
self.parent_lambda = parent_lambda
@@ -606,7 +692,9 @@ class AnalyzedCode:
"closure_trackers",
"build_py_wrappers",
)
- _fns = weakref.WeakKeyDictionary()
+ _fns: weakref.WeakKeyDictionary[
+ CodeType, AnalyzedCode
+ ] = weakref.WeakKeyDictionary()
@classmethod
def get(cls, fn, lambda_element, lambda_kw, **kw):
@@ -615,6 +703,8 @@ class AnalyzedCode:
return cls._fns[fn.__code__]
except KeyError:
pass
+
+ analyzed: AnalyzedCode
cls._fns[fn.__code__] = analyzed = AnalyzedCode(
fn, lambda_element, lambda_kw, **kw
)
@@ -947,14 +1037,18 @@ class AnalyzedCode:
class NonAnalyzedFunction:
__slots__ = ("expr",)
- closure_bindparams = None
- bindparam_trackers = None
+ closure_bindparams: Optional[List[BindParameter[Any]]] = None
+ bindparam_trackers: Optional[List[_BoundParameterGetter]] = None
+
+ is_sequence = False
+
+ expr: ClauseElement
- def __init__(self, expr):
+ def __init__(self, expr: ClauseElement):
self.expr = expr
@property
- def expected_expr(self):
+ def expected_expr(self) -> ClauseElement:
return self.expr
@@ -972,6 +1066,10 @@ class AnalyzedFunction:
"closure_bindparams",
)
+ closure_bindparams: Optional[List[BindParameter[Any]]]
+ expected_expr: Union[ClauseElement, List[ClauseElement]]
+ bindparam_trackers: Optional[List[_BoundParameterGetter]]
+
def __init__(
self,
analyzed_code,
@@ -1071,19 +1169,25 @@ class AnalyzedFunction:
if parent_lambda is None:
if isinstance(expr, collections_abc.Sequence):
self.expected_expr = [
- coercions.expect(
- lambda_element.role,
- sub_expr,
- apply_propagate_attrs=apply_propagate_attrs,
+ cast(
+ "ClauseElement",
+ coercions.expect(
+ lambda_element.role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ ),
)
for sub_expr in expr
]
self.is_sequence = True
else:
- self.expected_expr = coercions.expect(
- lambda_element.role,
- expr,
- apply_propagate_attrs=apply_propagate_attrs,
+ self.expected_expr = cast(
+ "ClauseElement",
+ coercions.expect(
+ lambda_element.role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ ),
)
self.is_sequence = False
else:
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 86725f86f..577d868fd 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -19,7 +19,6 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _PropagateAttrsType
- from ._typing import _SelectIterable
from .base import _EntityNamespace
from .base import ColumnCollection
from .base import ReadOnlyColumnCollection
@@ -28,6 +27,7 @@ if TYPE_CHECKING:
from .elements import ColumnElement
from .elements import Label
from .elements import NamedColumn
+ from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import Subquery
@@ -164,6 +164,12 @@ class WhereHavingRole(OnClauseRole):
class ExpressionElementRole(Generic[_T], SQLRole):
+ # note when using generics for ExpressionElementRole,
+ # the generic type needs to be in
+ # sqlalchemy.sql.coercions._impl_lookup mapping also.
+ # these are set up for basic types like int, bool, str, float
+ # right now
+
__slots__ = ()
_role_name = "SQL expression element"
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index cbd0c77f4..883439ca5 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -36,6 +36,7 @@ import operator
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
from typing import Iterator
from typing import List
@@ -68,6 +69,7 @@ from .elements import SQLCoreOperations
from .elements import TextClause
from .selectable import TableClause
from .type_api import to_instance
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .. import event
from .. import exc
@@ -131,21 +133,33 @@ def _get_table_key(name: str, schema: Optional[str]) -> str:
# this should really be in sql/util.py but we'd have to
# break an import cycle
-def _copy_expression(expression, source_table, target_table):
+def _copy_expression(
+ expression: ColumnElement[Any],
+ source_table: Optional[Table],
+ target_table: Optional[Table],
+) -> ColumnElement[Any]:
if source_table is None or target_table is None:
return expression
- def replace(col):
+ fixed_source_table = source_table
+ fixed_target_table = target_table
+
+ def replace(
+ element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
if (
- isinstance(col, Column)
- and col.table is source_table
- and col.key in source_table.c
+ isinstance(element, Column)
+ and element.table is fixed_source_table
+ and element.key in fixed_source_table.c
):
- return target_table.c[col.key]
+ return fixed_target_table.c[element.key]
else:
return None
- return visitors.replacement_traverse(expression, {}, replace)
+ return cast(
+ ColumnElement[Any],
+ visitors.replacement_traverse(expression, {}, replace),
+ )
@inspection._self_inspects
@@ -911,8 +925,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
def _reset_exported(self):
pass
- @property
- def _autoincrement_column(self):
+ @util.ro_non_memoized_property
+ def _autoincrement_column(self) -> Optional[Column[Any]]:
return self.primary_key._autoincrement_column
@property
@@ -2308,6 +2322,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
parent: Column[Any]
+ _table_column: Optional[Column[Any]]
+
def __init__(
self,
column: Union[str, Column[Any], SQLCoreOperations[Any]],
@@ -4290,11 +4306,11 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
self._columns.extend(columns)
- PrimaryKeyConstraint._autoincrement_column._reset(self)
+ PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore
self._set_parent_with_dispatch(self.table)
def _replace(self, col):
- PrimaryKeyConstraint._autoincrement_column._reset(self)
+ PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore
self._columns.replace(col)
self.dispatch._sa_event_column_added_to_pk_constraint(self, col)
@@ -4308,8 +4324,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
else:
return list(self._columns)
- @util.memoized_property
- def _autoincrement_column(self):
+ @util.ro_memoized_property
+ def _autoincrement_column(self) -> Optional[Column[Any]]:
def _validate_autoinc(col, autoinc_true):
if col.type._type_affinity is None or not issubclass(
col.type._type_affinity, type_api.INTEGERTYPE._type_affinity
@@ -4350,6 +4366,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
"ignore_fk",
) and _validate_autoinc(col, False):
return col
+ else:
+ return None
else:
autoinc = None
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 4f6e3795e..6504449f1 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -17,16 +17,26 @@ import collections
from enum import Enum
import itertools
import typing
+from typing import AbstractSet
from typing import Any as TODO_Any
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
from typing import Iterable
+from typing import Iterator
from typing import List
from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import cache_key
from . import coercions
@@ -37,6 +47,9 @@ from . import type_api
from . import visitors
from ._typing import _ColumnsClauseArgument
from ._typing import is_column_element
+from ._typing import is_select_statement
+from ._typing import is_subquery
+from ._typing import is_table
from .annotation import Annotated
from .annotation import SupportsCloneAnnotations
from .base import _clone
@@ -68,32 +81,80 @@ from .elements import ColumnClause
from .elements import ColumnElement
from .elements import DQLDMLClauseElement
from .elements import GroupedElement
-from .elements import Grouping
from .elements import literal_column
from .elements import TableValuedColumn
from .elements import UnaryExpression
+from .operators import OperatorType
+from .visitors import _TraverseInternalsType
from .visitors import InternalTraversal
from .visitors import prefix_anon_map
from .. import exc
from .. import util
+from ..util import HasMemoized_ro_memoized_attribute
+from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import Self
and_ = BooleanClauseList.and_
_T = TypeVar("_T", bound=Any)
if TYPE_CHECKING:
- from ._typing import _SelectIterable
+ from ._typing import _ColumnExpressionArgument
+ from ._typing import _FromClauseArgument
+ from ._typing import _JoinTargetArgument
+ from ._typing import _OnClauseArgument
+ from ._typing import _SelectStatementForCompoundArgument
+ from ._typing import _TextCoercedExpressionArgument
+ from ._typing import _TypeEngineArgument
+ from .base import _AmbiguousTableNameMap
+ from .base import ExecutableOption
from .base import ReadOnlyColumnCollection
+ from .cache_key import _CacheKeyTraversalType
+ from .compiler import SQLCompiler
+ from .dml import Delete
+ from .dml import Insert
+ from .dml import Update
from .elements import NamedColumn
+ from .elements import TextClause
+ from .functions import Function
+ from .schema import Column
from .schema import ForeignKey
- from .schema import PrimaryKeyConstraint
+ from .schema import ForeignKeyConstraint
+ from .type_api import TypeEngine
+ from .util import ClauseAdapter
+ from .visitors import _CloneCallableType
-class _OffsetLimitParam(BindParameter):
+_ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"]
+
+
+class _JoinTargetProtocol(Protocol):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ ...
+
+
+_JoinTargetElement = Union["FromClause", _JoinTargetProtocol]
+_OnClauseElement = Union["ColumnElement[bool]", _JoinTargetProtocol]
+
+
+_SetupJoinsElement = Tuple[
+ _JoinTargetElement,
+ Optional[_OnClauseElement],
+ Optional["FromClause"],
+ Dict[str, Any],
+]
+
+
+_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]]
+
+
+class _OffsetLimitParam(BindParameter[int]):
inherit_cache = True
@property
- def _limit_offset_value(self):
+ def _limit_offset_value(self) -> Optional[int]:
return self.effective_value
@@ -114,11 +175,12 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
# sub-elements of returns_rows
_is_from_clause = False
+ _is_select_base = False
_is_select_statement = False
_is_lateral = False
@property
- def selectable(self):
+ def selectable(self) -> ReturnsRows:
return self
@util.non_memoized_property
@@ -133,8 +195,28 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
"""
raise NotImplementedError()
+ def is_derived_from(self, fromclause: FromClause) -> bool:
+ """Return ``True`` if this :class:`.ReturnsRows` is
+ 'derived' from the given :class:`.FromClause`.
+
+ An example would be an Alias of a Table is derived from that Table.
+
+ """
+ raise NotImplementedError()
+
+ def _generate_fromclause_column_proxies(
+ self, fromclause: FromClause
+ ) -> None:
+ """Populate columns into an :class:`.AliasedReturnsRows` object."""
+
+ raise NotImplementedError()
+
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
+ """reset internal collections for an incoming column being added."""
+ raise NotImplementedError()
+
@property
- def exported_columns(self):
+ def exported_columns(self) -> ReadOnlyColumnCollection[Any, Any]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.ReturnsRows`.
@@ -160,6 +242,9 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
raise NotImplementedError()
+SelfSelectable = TypeVar("SelfSelectable", bound="Selectable")
+
+
class Selectable(ReturnsRows):
"""Mark a class as being selectable."""
@@ -167,10 +252,10 @@ class Selectable(ReturnsRows):
is_selectable = True
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
raise NotImplementedError()
- def lateral(self, name=None):
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a LATERAL alias of this :class:`_expression.Selectable`.
The return value is the :class:`_expression.Lateral` construct also
@@ -192,15 +277,21 @@ class Selectable(ReturnsRows):
"functionality is available via the sqlalchemy.sql.visitors module.",
)
@util.preload_module("sqlalchemy.sql.util")
- def replace_selectable(self, old, alias):
+ def replace_selectable(
+ self: SelfSelectable, old: FromClause, alias: Alias
+ ) -> SelfSelectable:
"""Replace all occurrences of :class:`_expression.FromClause`
'old' with the given :class:`_expression.Alias`
object, returning a copy of this :class:`_expression.FromClause`.
"""
- return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self)
+ return util.preloaded.sql_util.ClauseAdapter(alias).traverse( # type: ignore # noqa E501
+ self
+ )
- def corresponding_column(self, column, require_embedded=False):
+ def corresponding_column(
+ self, column: ColumnElement[Any], require_embedded: bool = False
+ ) -> Optional[ColumnElement[Any]]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from the
:attr:`_expression.Selectable.exported_columns`
@@ -242,19 +333,23 @@ SelfHasPrefixes = typing.TypeVar("SelfHasPrefixes", bound="HasPrefixes")
class HasPrefixes:
- _prefixes = ()
+ _prefixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = ()
- _has_prefixes_traverse_internals = [
+ _has_prefixes_traverse_internals: _TraverseInternalsType = [
("_prefixes", InternalTraversal.dp_prefix_sequence)
]
@_generative
@_document_text_coercion(
- "expr",
+ "prefixes",
":meth:`_expression.HasPrefixes.prefix_with`",
- ":paramref:`.HasPrefixes.prefix_with.*expr`",
+ ":paramref:`.HasPrefixes.prefix_with.*prefixes`",
)
- def prefix_with(self: SelfHasPrefixes, *expr, **kw) -> SelfHasPrefixes:
+ def prefix_with(
+ self: SelfHasPrefixes,
+ *prefixes: _TextCoercedExpressionArgument[Any],
+ dialect: str = "*",
+ ) -> SelfHasPrefixes:
r"""Add one or more expressions following the statement keyword, i.e.
SELECT, INSERT, UPDATE, or DELETE. Generative.
@@ -272,49 +367,44 @@ class HasPrefixes:
Multiple prefixes can be specified by multiple calls
to :meth:`_expression.HasPrefixes.prefix_with`.
- :param \*expr: textual or :class:`_expression.ClauseElement`
+ :param \*prefixes: textual or :class:`_expression.ClauseElement`
construct which
will be rendered following the INSERT, UPDATE, or DELETE
keyword.
- :param \**kw: A single keyword 'dialect' is accepted. This is an
- optional string dialect name which will
+ :param dialect: optional string dialect name which will
limit rendering of this prefix to only that dialect.
"""
- dialect = kw.pop("dialect", None)
- if kw:
- raise exc.ArgumentError(
- "Unsupported argument(s): %s" % ",".join(kw)
- )
- self._setup_prefixes(expr, dialect)
- return self
-
- def _setup_prefixes(self, prefixes, dialect=None):
self._prefixes = self._prefixes + tuple(
[
(coercions.expect(roles.StatementOptionRole, p), dialect)
for p in prefixes
]
)
+ return self
SelfHasSuffixes = typing.TypeVar("SelfHasSuffixes", bound="HasSuffixes")
class HasSuffixes:
- _suffixes = ()
+ _suffixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = ()
- _has_suffixes_traverse_internals = [
+ _has_suffixes_traverse_internals: _TraverseInternalsType = [
("_suffixes", InternalTraversal.dp_prefix_sequence)
]
@_generative
@_document_text_coercion(
- "expr",
+ "suffixes",
":meth:`_expression.HasSuffixes.suffix_with`",
- ":paramref:`.HasSuffixes.suffix_with.*expr`",
+ ":paramref:`.HasSuffixes.suffix_with.*suffixes`",
)
- def suffix_with(self: SelfHasSuffixes, *expr, **kw) -> SelfHasSuffixes:
+ def suffix_with(
+ self: SelfHasSuffixes,
+ *suffixes: _TextCoercedExpressionArgument[Any],
+ dialect: str = "*",
+ ) -> SelfHasSuffixes:
r"""Add one or more expressions following the statement as a whole.
This is used to support backend-specific suffix keywords on
@@ -328,44 +418,39 @@ class HasSuffixes:
Multiple suffixes can be specified by multiple calls
to :meth:`_expression.HasSuffixes.suffix_with`.
- :param \*expr: textual or :class:`_expression.ClauseElement`
+ :param \*suffixes: textual or :class:`_expression.ClauseElement`
construct which
will be rendered following the target clause.
- :param \**kw: A single keyword 'dialect' is accepted. This is an
- optional string dialect name which will
+ :param dialect: Optional string dialect name which will
limit rendering of this suffix to only that dialect.
"""
- dialect = kw.pop("dialect", None)
- if kw:
- raise exc.ArgumentError(
- "Unsupported argument(s): %s" % ",".join(kw)
- )
- self._setup_suffixes(expr, dialect)
- return self
-
- def _setup_suffixes(self, suffixes, dialect=None):
self._suffixes = self._suffixes + tuple(
[
(coercions.expect(roles.StatementOptionRole, p), dialect)
for p in suffixes
]
)
+ return self
SelfHasHints = typing.TypeVar("SelfHasHints", bound="HasHints")
class HasHints:
- _hints = util.immutabledict()
- _statement_hints = ()
+ _hints: util.immutabledict[
+ Tuple[FromClause, str], str
+ ] = util.immutabledict()
+ _statement_hints: Tuple[Tuple[str, str], ...] = ()
- _has_hints_traverse_internals = [
+ _has_hints_traverse_internals: _TraverseInternalsType = [
("_statement_hints", InternalTraversal.dp_statement_hint_list),
("_hints", InternalTraversal.dp_table_hint_list),
]
- def with_statement_hint(self, text, dialect_name="*"):
+ def with_statement_hint(
+ self: SelfHasHints, text: str, dialect_name: str = "*"
+ ) -> SelfHasHints:
"""Add a statement hint to this :class:`_expression.Select` or
other selectable object.
@@ -389,11 +474,14 @@ class HasHints:
MySQL optimizer hints
"""
- return self.with_hint(None, text, dialect_name)
+ return self._with_hint(None, text, dialect_name)
@_generative
def with_hint(
- self: SelfHasHints, selectable, text, dialect_name="*"
+ self: SelfHasHints,
+ selectable: _FromClauseArgument,
+ text: str,
+ dialect_name: str = "*",
) -> SelfHasHints:
r"""Add an indexing or other executional context hint for the given
selectable to this :class:`_expression.Select` or other selectable
@@ -429,6 +517,15 @@ class HasHints:
:meth:`_expression.Select.with_statement_hint`
"""
+
+ return self._with_hint(selectable, text, dialect_name)
+
+ def _with_hint(
+ self: SelfHasHints,
+ selectable: Optional[_FromClauseArgument],
+ text: str,
+ dialect_name: str,
+ ) -> SelfHasHints:
if selectable is None:
self._statement_hints += ((dialect_name, text),)
else:
@@ -443,6 +540,9 @@ class HasHints:
return self
+SelfFromClause = TypeVar("SelfFromClause", bound="FromClause")
+
+
class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
@@ -473,6 +573,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_is_clone_of: Optional[FromClause]
+ _columns: ColumnCollection[Any, Any]
+
schema: Optional[str] = None
"""Define the 'schema' attribute for this :class:`_expression.FromClause`.
@@ -488,7 +590,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_use_schema_map = False
- def select(self) -> "Select":
+ def select(self) -> Select:
r"""Return a SELECT of this :class:`_expression.FromClause`.
@@ -504,7 +606,13 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return Select(self)
- def join(self, right, onclause=None, isouter=False, full=False):
+ def join(
+ self,
+ right: _FromClauseArgument,
+ onclause: Optional[_ColumnExpressionArgument[bool]] = None,
+ isouter: bool = False,
+ full: bool = False,
+ ) -> Join:
"""Return a :class:`_expression.Join` from this
:class:`_expression.FromClause`
to another :class:`FromClause`.
@@ -550,7 +658,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return Join(self, right, onclause, isouter, full)
- def outerjoin(self, right, onclause=None, full=False):
+ def outerjoin(
+ self,
+ right: _FromClauseArgument,
+ onclause: Optional[_ColumnExpressionArgument[bool]] = None,
+ full: bool = False,
+ ) -> Join:
"""Return a :class:`_expression.Join` from this
:class:`_expression.FromClause`
to another :class:`FromClause`, with the "isouter" flag set to
@@ -596,7 +709,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return Join(self, right, onclause, True, full)
- def alias(self, name=None, flat=False):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromClause:
"""Return an alias of this :class:`_expression.FromClause`.
E.g.::
@@ -617,35 +732,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return Alias._construct(self, name)
- @util.preload_module("sqlalchemy.sql.sqltypes")
- def table_valued(self):
- """Return a :class:`_sql.TableValuedColumn` object for this
- :class:`_expression.FromClause`.
-
- A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that
- represents a complete row in a table. Support for this construct is
- backend dependent, and is supported in various forms by backends
- such as PostgreSQL, Oracle and SQL Server.
-
- E.g.::
-
- >>> from sqlalchemy import select, column, func, table
- >>> a = table("a", column("id"), column("x"), column("y"))
- >>> stmt = select(func.row_to_json(a.table_valued()))
- >>> print(stmt)
- SELECT row_to_json(a) AS row_to_json_1
- FROM a
-
- .. versionadded:: 1.4.0b2
-
- .. seealso::
-
- :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
-
- """
- return TableValuedColumn(self, type_api.TABLEVALUE)
-
- def tablesample(self, sampling, name=None, seed=None):
+ def tablesample(
+ self,
+ sampling: Union[float, Function[Any]],
+ name: Optional[str] = None,
+ seed: Optional[roles.ExpressionElementRole[Any]] = None,
+ ) -> TableSample:
"""Return a TABLESAMPLE alias of this :class:`_expression.FromClause`.
The return value is the :class:`_expression.TableSample`
@@ -661,7 +753,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return TableSample._construct(self, sampling, name, seed)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
"""Return ``True`` if this :class:`_expression.FromClause` is
'derived' from the given ``FromClause``.
@@ -673,7 +765,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
# contained elements.
return fromclause in self._cloned_set
- def _is_lexical_equivalent(self, other):
+ def _is_lexical_equivalent(self, other: FromClause) -> bool:
"""Return ``True`` if this :class:`_expression.FromClause` and
the other represent the same lexical identity.
@@ -681,9 +773,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
if they are the same via annotation identity.
"""
- return self._cloned_set.intersection(other._cloned_set)
+ return bool(self._cloned_set.intersection(other._cloned_set))
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def description(self) -> str:
"""A brief description of this :class:`_expression.FromClause`.
@@ -692,13 +784,15 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return getattr(self, "name", self.__class__.__name__ + " object")
- def _generate_fromclause_column_proxies(self, fromclause):
+ def _generate_fromclause_column_proxies(
+ self, fromclause: FromClause
+ ) -> None:
fromclause._columns._populate_separate_keys(
col._make_proxy(fromclause) for col in self.c
)
@property
- def exported_columns(self):
+ def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`.
@@ -796,7 +890,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._populate_column_collection()
return self.foreign_keys
- def _reset_column_collection(self):
+ def _reset_column_collection(self) -> None:
"""Reset the attributes linked to the ``FromClause.c`` attribute.
This collection is separate from all the other memoized things
@@ -817,7 +911,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
def _select_iterable(self) -> _SelectIterable:
return self.c
- def _init_collections(self):
+ def _init_collections(self) -> None:
assert "_columns" not in self.__dict__
assert "primary_key" not in self.__dict__
assert "foreign_keys" not in self.__dict__
@@ -827,10 +921,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self.foreign_keys = set() # type: ignore
@property
- def _cols_populated(self):
+ def _cols_populated(self) -> bool:
return "_columns" in self.__dict__
- def _populate_column_collection(self):
+ def _populate_column_collection(self) -> None:
"""Called on subclasses to establish the .c collection.
Each implementation has a different way of establishing
@@ -838,7 +932,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
"""Given a column added to the .c collection of an underlying
selectable, produce the local version of that column, assuming this
selectable ultimately should proxy this column.
@@ -865,15 +959,60 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
self._reset_column_collection()
- def _anonymous_fromclause(self, name=None, flat=False):
+ def _anonymous_fromclause(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromClause:
return self.alias(name=name)
+ if TYPE_CHECKING:
+
+ def self_group(
+ self: Self, against: Optional[OperatorType] = None
+ ) -> Union[FromGrouping, Self]:
+ ...
+
class NamedFromClause(FromClause):
+ """A :class:`.FromClause` that has a name.
+
+ Examples include tables, subqueries, CTEs, aliased tables.
+
+ .. versionadded:: 2.0
+
+ """
+
named_with_column = True
name: str
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def table_valued(self) -> TableValuedColumn[Any]:
+ """Return a :class:`_sql.TableValuedColumn` object for this
+ :class:`_expression.FromClause`.
+
+ A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that
+ represents a complete row in a table. Support for this construct is
+ backend dependent, and is supported in various forms by backends
+ such as PostgreSQL, Oracle and SQL Server.
+
+ E.g.::
+
+ >>> from sqlalchemy import select, column, func, table
+ >>> a = table("a", column("id"), column("x"), column("y"))
+ >>> stmt = select(func.row_to_json(a.table_valued()))
+ >>> print(stmt)
+ SELECT row_to_json(a) AS row_to_json_1
+ FROM a
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ """
+ return TableValuedColumn(self, type_api.TABLEVALUE)
+
class SelectLabelStyle(Enum):
"""Label style constants that may be passed to
@@ -992,7 +1131,7 @@ class Join(roles.DMLTableRole, FromClause):
__visit_name__ = "join"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("left", InternalTraversal.dp_clauseelement),
("right", InternalTraversal.dp_clauseelement),
("onclause", InternalTraversal.dp_clauseelement),
@@ -1002,7 +1141,20 @@ class Join(roles.DMLTableRole, FromClause):
_is_join = True
- def __init__(self, left, right, onclause=None, isouter=False, full=False):
+ left: FromClause
+ right: FromClause
+ onclause: Optional[ColumnElement[bool]]
+ isouter: bool
+ full: bool
+
+ def __init__(
+ self,
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+ ):
"""Construct a new :class:`_expression.Join`.
The usual entrypoint here is the :func:`_expression.join`
@@ -1010,11 +1162,23 @@ class Join(roles.DMLTableRole, FromClause):
:class:`_expression.FromClause` object.
"""
+
+ # when deannotate was removed here, callcounts went up for ORM
+ # compilation of eager joins, since there were more comparisons of
+ # annotated objects. test_orm.py -> test_fetch_results
+ # was therefore changed to show a more real-world use case, where the
+ # compilation is cached; there's no change in post-cache callcounts.
+ # callcounts for a single compilation in that particular test
+ # that includes about eight joins about 1100 extra fn calls, from
+ # 29200 -> 30373
+
self.left = coercions.expect(
- roles.FromClauseRole, left, deannotate=True
+ roles.FromClauseRole,
+ left,
)
self.right = coercions.expect(
- roles.FromClauseRole, right, deannotate=True
+ roles.FromClauseRole,
+ right,
).self_group()
if onclause is None:
@@ -1029,7 +1193,7 @@ class Join(roles.DMLTableRole, FromClause):
self.isouter = isouter
self.full = full
- @property
+ @util.ro_non_memoized_property
def description(self) -> str:
return "Join object on %s(%d) and %s(%d)" % (
self.left.description,
@@ -1038,7 +1202,7 @@ class Join(roles.DMLTableRole, FromClause):
id(self.right),
)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
return (
# use hash() to ensure direct comparison to annotated works
# as well
@@ -1047,7 +1211,10 @@ class Join(roles.DMLTableRole, FromClause):
or self.right.is_derived_from(fromclause)
)
- def self_group(self, against=None):
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> FromGrouping:
+ ...
return FromGrouping(self)
@util.preload_module("sqlalchemy.sql.util")
@@ -1055,7 +1222,7 @@ class Join(roles.DMLTableRole, FromClause):
sqlutil = util.preloaded.sql_util
columns = [c for c in self.left.c] + [c for c in self.right.c]
- self.primary_key.extend(
+ self.primary_key.extend( # type: ignore
sqlutil.reduce_columns(
(c for c in columns if c.primary_key), self.onclause
)
@@ -1063,11 +1230,13 @@ class Join(roles.DMLTableRole, FromClause):
self._columns._populate_separate_keys(
(col._tq_key_label, col) for col in columns
)
- self.foreign_keys.update(
+ self.foreign_keys.update( # type: ignore
itertools.chain(*[col.foreign_keys for col in columns])
)
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
# see Select._copy_internals() for similar concept
# here we pre-clone "left" and "right" so that we can
@@ -1100,12 +1269,14 @@ class Join(roles.DMLTableRole, FromClause):
self._reset_memoizations()
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
super(Join, self)._refresh_for_new_column(column)
self.left._refresh_for_new_column(column)
self.right._refresh_for_new_column(column)
- def _match_primaries(self, left, right):
+ def _match_primaries(
+ self, left: FromClause, right: FromClause
+ ) -> ColumnElement[bool]:
if isinstance(left, Join):
left_right = left.right
else:
@@ -1114,8 +1285,15 @@ class Join(roles.DMLTableRole, FromClause):
@classmethod
def _join_condition(
- cls, a, b, a_subset=None, consider_as_foreign_keys=None
- ):
+ cls,
+ a: FromClause,
+ b: FromClause,
+ *,
+ a_subset: Optional[FromClause] = None,
+ consider_as_foreign_keys: Optional[
+ AbstractSet[ColumnClause[Any]]
+ ] = None,
+ ) -> ColumnElement[bool]:
"""Create a join condition between two tables or selectables.
See sqlalchemy.sql.util.join_condition() for full docs.
@@ -1151,7 +1329,15 @@ class Join(roles.DMLTableRole, FromClause):
return and_(*crit)
@classmethod
- def _can_join(cls, left, right, consider_as_foreign_keys=None):
+ def _can_join(
+ cls,
+ left: FromClause,
+ right: FromClause,
+ *,
+ consider_as_foreign_keys: Optional[
+ AbstractSet[ColumnClause[Any]]
+ ] = None,
+ ) -> bool:
if isinstance(left, Join):
left_right = left.right
else:
@@ -1169,20 +1355,31 @@ class Join(roles.DMLTableRole, FromClause):
@classmethod
@util.preload_module("sqlalchemy.sql.util")
def _joincond_scan_left_right(
- cls, a, a_subset, b, consider_as_foreign_keys
- ):
+ cls,
+ a: FromClause,
+ a_subset: Optional[FromClause],
+ b: FromClause,
+ consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]],
+ ) -> collections.defaultdict[
+ Optional[ForeignKeyConstraint],
+ List[Tuple[ColumnClause[Any], ColumnClause[Any]]],
+ ]:
sql_util = util.preloaded.sql_util
a = coercions.expect(roles.FromClauseRole, a)
b = coercions.expect(roles.FromClauseRole, b)
- constraints = collections.defaultdict(list)
+ constraints: collections.defaultdict[
+ Optional[ForeignKeyConstraint],
+ List[Tuple[ColumnClause[Any], ColumnClause[Any]]],
+ ] = collections.defaultdict(list)
for left in (a_subset, a):
if left is None:
continue
for fk in sorted(
- b.foreign_keys, key=lambda fk: fk.parent._creation_order
+ b.foreign_keys,
+ key=lambda fk: fk.parent._creation_order, # type: ignore
):
if (
consider_as_foreign_keys is not None
@@ -1202,7 +1399,8 @@ class Join(roles.DMLTableRole, FromClause):
constraints[fk.constraint].append((col, fk.parent))
if left is not b:
for fk in sorted(
- left.foreign_keys, key=lambda fk: fk.parent._creation_order
+ left.foreign_keys,
+ key=lambda fk: fk.parent._creation_order, # type: ignore
):
if (
consider_as_foreign_keys is not None
@@ -1309,7 +1507,8 @@ class Join(roles.DMLTableRole, FromClause):
@util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
- return [self] + self.left._from_objects + self.right._from_objects
+ self_list: List[FromClause] = [self]
+ return self_list + self.left._from_objects + self.right._from_objects
class NoInit:
@@ -1327,6 +1526,14 @@ class NoInit:
)
+class LateralFromClause(NamedFromClause):
+ """mark a FROM clause as being able to render directly as LATERAL"""
+
+
+_SelfAliasedReturnsRows = TypeVar(
+ "_SelfAliasedReturnsRows", bound="AliasedReturnsRows"
+)
+
# FromClause ->
# AliasedReturnsRows
# -> Alias only for FromClause
@@ -1335,6 +1542,8 @@ class NoInit:
# -> Lateral -> FromClause, but we accept SelectBase
# w/ non-deprecated coercion
# -> TableSample -> only for FromClause
+
+
class AliasedReturnsRows(NoInit, NamedFromClause):
"""Base class of aliases against tables, subqueries, and other
selectables."""
@@ -1343,24 +1552,21 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
_supports_derived_columns = False
- element: ClauseElement
+ element: ReturnsRows
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
]
@classmethod
- def _construct(cls, *arg, **kw):
+ def _construct(
+ cls: Type[_SelfAliasedReturnsRows], *arg: Any, **kw: Any
+ ) -> _SelfAliasedReturnsRows:
obj = cls.__new__(cls)
obj._init(*arg, **kw)
return obj
- @classmethod
- def _factory(cls, returnsrows, name=None):
- """Base factory method. Subclasses need to provide this."""
- raise NotImplementedError()
-
def _init(self, selectable, name=None):
self.element = coercions.expect(
roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self
@@ -1378,11 +1584,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
name = _anonymous_label.safe_construct(id(self), name or "anon")
self.name = name
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
super(AliasedReturnsRows, self)._refresh_for_new_column(column)
self.element._refresh_for_new_column(column)
- @property
+ def _populate_column_collection(self):
+ self.element._generate_fromclause_column_proxies(self)
+
+ @util.ro_non_memoized_property
def description(self) -> str:
name = self.name
if isinstance(name, _anonymous_label):
@@ -1395,15 +1604,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
"""Legacy for dialects that are referring to Alias.original."""
return self.element
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
if fromclause in self._cloned_set:
return True
return self.element.is_derived_from(fromclause)
- def _populate_column_collection(self):
- self.element._generate_fromclause_column_proxies(self)
-
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
existing_element = self.element
super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw)
@@ -1420,7 +1628,11 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
return [self]
-class Alias(roles.DMLTableRole, AliasedReturnsRows):
+class FromClauseAlias(AliasedReturnsRows):
+ element: FromClause
+
+
+class Alias(roles.DMLTableRole, FromClauseAlias):
"""Represents an table or selectable alias (AS).
Represents an alias, as typically applied to any table or
@@ -1445,13 +1657,18 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
element: FromClause
@classmethod
- def _factory(cls, selectable, name=None, flat=False):
+ def _factory(
+ cls,
+ selectable: FromClause,
+ name: Optional[str] = None,
+ flat: bool = False,
+ ) -> NamedFromClause:
return coercions.expect(
roles.FromClauseRole, selectable, allow_select=True
).alias(name=name, flat=flat)
-class TableValuedAlias(Alias):
+class TableValuedAlias(LateralFromClause, Alias):
"""An alias against a "table valued" SQL function.
This construct provides for a SQL function that returns columns
@@ -1480,7 +1697,7 @@ class TableValuedAlias(Alias):
_render_derived_w_types = False
joins_implicitly = False
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
("_tableval_type", InternalTraversal.dp_type),
@@ -1526,7 +1743,9 @@ class TableValuedAlias(Alias):
return TableValuedColumn(self, self._tableval_type)
- def alias(self, name=None):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> TableValuedAlias:
"""Return a new alias of this :class:`_sql.TableValuedAlias`.
This creates a distinct FROM object that will be distinguished
@@ -1547,7 +1766,7 @@ class TableValuedAlias(Alias):
return tva
- def lateral(self, name=None):
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a new :class:`_sql.TableValuedAlias` with the lateral flag set,
so that it renders as LATERAL.
@@ -1619,7 +1838,7 @@ class TableValuedAlias(Alias):
return new_alias
-class Lateral(AliasedReturnsRows):
+class Lateral(FromClauseAlias, LateralFromClause):
"""Represent a LATERAL subquery.
This object is constructed from the :func:`_expression.lateral` module
@@ -1644,13 +1863,17 @@ class Lateral(AliasedReturnsRows):
inherit_cache = True
@classmethod
- def _factory(cls, selectable, name=None):
+ def _factory(
+ cls,
+ selectable: Union[SelectBase, _FromClauseArgument],
+ name: Optional[str] = None,
+ ) -> LateralFromClause:
return coercions.expect(
roles.FromClauseRole, selectable, explicit_subquery=True
).lateral(name=name)
-class TableSample(AliasedReturnsRows):
+class TableSample(FromClauseAlias):
"""Represent a TABLESAMPLE clause.
This object is constructed from the :func:`_expression.tablesample` module
@@ -1668,13 +1891,22 @@ class TableSample(AliasedReturnsRows):
__visit_name__ = "tablesample"
- _traverse_internals = AliasedReturnsRows._traverse_internals + [
- ("sampling", InternalTraversal.dp_clauseelement),
- ("seed", InternalTraversal.dp_clauseelement),
- ]
+ _traverse_internals: _TraverseInternalsType = (
+ AliasedReturnsRows._traverse_internals
+ + [
+ ("sampling", InternalTraversal.dp_clauseelement),
+ ("seed", InternalTraversal.dp_clauseelement),
+ ]
+ )
@classmethod
- def _factory(cls, selectable, sampling, name=None, seed=None):
+ def _factory(
+ cls,
+ selectable: _FromClauseArgument,
+ sampling: Union[float, Function[Any]],
+ name: Optional[str] = None,
+ seed: Optional[roles.ExpressionElementRole[Any]] = None,
+ ) -> TableSample:
return coercions.expect(roles.FromClauseRole, selectable).tablesample(
sampling, name=name, seed=seed
)
@@ -1721,7 +1953,7 @@ class CTE(
__visit_name__ = "cte"
- _traverse_internals = (
+ _traverse_internals: _TraverseInternalsType = (
AliasedReturnsRows._traverse_internals
+ [
("_cte_alias", InternalTraversal.dp_clauseelement),
@@ -1736,7 +1968,12 @@ class CTE(
element: HasCTE
@classmethod
- def _factory(cls, selectable, name=None, recursive=False):
+ def _factory(
+ cls,
+ selectable: HasCTE,
+ name: Optional[str] = None,
+ recursive: bool = False,
+ ) -> CTE:
r"""Return a new :class:`_expression.CTE`,
or Common Table Expression instance.
@@ -1775,7 +2012,9 @@ class CTE(
else:
self.element._generate_fromclause_column_proxies(self)
- def alias(self, name=None, flat=False):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromClause:
"""Return an :class:`_expression.Alias` of this
:class:`_expression.CTE`.
@@ -1814,6 +2053,10 @@ class CTE(
:meth:`_sql.HasCTE.cte` - examples of calling styles
"""
+ assert is_select_statement(
+ self.element
+ ), f"CTE element f{self.element} does not support union()"
+
return CTE._construct(
self.element.union(*other),
name=self.name,
@@ -1839,6 +2082,11 @@ class CTE(
:meth:`_sql.HasCTE.cte` - examples of calling styles
"""
+
+ assert is_select_statement(
+ self.element
+ ), f"CTE element f{self.element} does not support union_all()"
+
return CTE._construct(
self.element.union_all(*other),
name=self.name,
@@ -1865,23 +2113,229 @@ class _CTEOpts(NamedTuple):
nesting: bool
-class HasCTE(roles.HasCTERole, ClauseElement):
+class _ColumnsPlusNames(NamedTuple):
+ required_label_name: Optional[str]
+ """
+ string label name, if non-None, must be rendered as a
+ label, i.e. "AS <name>"
+ """
+
+ proxy_key: Optional[str]
+ """
+ proxy_key that is to be part of the result map for this
+ col. this is also the key in a fromclause.c or
+ select.selected_columns collection
+ """
+
+ fallback_label_name: Optional[str]
+ """
+ name that can be used to render an "AS <name>" when
+ we have to render a label even though
+ required_label_name was not given
+ """
+
+ column: Union[ColumnElement[Any], TextClause]
+ """
+ the ColumnElement itself
+ """
+
+ repeated: bool
+ """
+ True if this is a duplicate of a previous column
+ in the list of columns
+ """
+
+
+class SelectsRows(ReturnsRows):
+ """Sub-base of ReturnsRows for elements that deliver rows
+ directly, namely SELECT and INSERT/UPDATE/DELETE..RETURNING"""
+
+ _label_style: SelectLabelStyle = LABEL_STYLE_NONE
+
+ def _generate_columns_plus_names(
+ self, anon_for_dupe_key: bool
+ ) -> List[_ColumnsPlusNames]:
+ """Generate column names as rendered in a SELECT statement by
+ the compiler.
+
+ This is distinct from the _column_naming_convention generator that's
+ intended for population of .c collections and similar, which has
+ different rules. the collection returned here calls upon the
+ _column_naming_convention as well.
+
+ """
+ cols = self._all_selected_columns
+
+ key_naming_convention = SelectState._column_naming_convention(
+ self._label_style
+ )
+
+ names = {}
+
+ result: List[_ColumnsPlusNames] = []
+ result_append = result.append
+
+ table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ label_style_none = self._label_style is LABEL_STYLE_NONE
+
+ # a counter used for "dedupe" labels, which have double underscores
+ # in them and are never referred by name; they only act
+ # as positional placeholders. they need only be unique within
+ # the single columns clause they're rendered within (required by
+ # some dbs such as mysql). So their anon identity is tracked against
+ # a fixed counter rather than hash() identity.
+ dedupe_hash = 1
+
+ for c in cols:
+ repeated = False
+
+ if not c._render_label_in_columns_clause:
+ effective_name = (
+ required_label_name
+ ) = fallback_label_name = None
+ elif label_style_none:
+ if TYPE_CHECKING:
+ assert is_column_element(c)
+
+ effective_name = required_label_name = None
+ fallback_label_name = c._non_anon_label or c._anon_name_label
+ else:
+ if TYPE_CHECKING:
+ assert is_column_element(c)
+
+ if table_qualified:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = c._tq_label
+ else:
+ effective_name = fallback_label_name = c._non_anon_label
+ required_label_name = None
+
+ if effective_name is None:
+ # it seems like this could be _proxy_key and we would
+ # not need _expression_label but it isn't
+ # giving us a clue when to use anon_label instead
+ expr_label = c._expression_label
+ if expr_label is None:
+ repeated = c._anon_name_label in names
+ names[c._anon_name_label] = c
+ effective_name = required_label_name = None
+
+ if repeated:
+ # here, "required_label_name" is sent as
+ # "None" and "fallback_label_name" is sent.
+ if table_qualified:
+ fallback_label_name = (
+ c._dedupe_anon_tq_label_idx(dedupe_hash)
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._dedupe_anon_label_idx(
+ dedupe_hash
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._anon_name_label
+ else:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = expr_label
+
+ if effective_name is not None:
+ if TYPE_CHECKING:
+ assert is_column_element(c)
+
+ if effective_name in names:
+ # when looking to see if names[name] is the same column as
+ # c, use hash(), so that an annotated version of the column
+ # is seen as the same as the non-annotated
+ if hash(names[effective_name]) != hash(c):
+
+ # different column under the same name. apply
+ # disambiguating label
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_tq_label
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_name_label
+
+ if anon_for_dupe_key and required_label_name in names:
+ # here, c._anon_tq_label is definitely unique to
+ # that column identity (or annotated version), so
+ # this should always be true.
+ # this is also an infrequent codepath because
+ # you need two levels of duplication to be here
+ assert hash(names[required_label_name]) == hash(c)
+
+ # the column under the disambiguating label is
+ # already present. apply the "dedupe" label to
+ # subsequent occurrences of the column so that the
+ # original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[required_label_name] = c
+ elif anon_for_dupe_key:
+ # same column under the same name. apply the "dedupe"
+ # label so that the original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[effective_name] = c
+
+ result_append(
+ _ColumnsPlusNames(
+ required_label_name,
+ key_naming_convention(c),
+ fallback_label_name,
+ c,
+ repeated,
+ )
+ )
+
+ return result
+
+
+class HasCTE(roles.HasCTERole, SelectsRows):
"""Mixin that declares a class to include CTE support.
.. versionadded:: 1.1
"""
- _has_ctes_traverse_internals = [
+ _has_ctes_traverse_internals: _TraverseInternalsType = [
("_independent_ctes", InternalTraversal.dp_clauseelement_list),
("_independent_ctes_opts", InternalTraversal.dp_plain_obj),
]
- _independent_ctes = ()
- _independent_ctes_opts = ()
+ _independent_ctes: Tuple[CTE, ...] = ()
+ _independent_ctes_opts: Tuple[_CTEOpts, ...] = ()
@_generative
- def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE:
+ def add_cte(
+ self: SelfHasCTE, *ctes: CTE, nest_here: bool = False
+ ) -> SelfHasCTE:
r"""Add one or more :class:`_sql.CTE` constructs to this statement.
This method will associate the given :class:`_sql.CTE` constructs with
@@ -1985,7 +2439,12 @@ class HasCTE(roles.HasCTERole, ClauseElement):
self._independent_ctes_opts += (opt,)
return self
- def cte(self, name=None, recursive=False, nesting=False):
+ def cte(
+ self,
+ name: Optional[str] = None,
+ recursive: bool = False,
+ nesting: bool = False,
+ ) -> CTE:
r"""Return a new :class:`_expression.CTE`,
or Common Table Expression instance.
@@ -2293,10 +2752,12 @@ class Subquery(AliasedReturnsRows):
inherit_cache = True
- element: Select
+ element: SelectBase
@classmethod
- def _factory(cls, selectable, name=None):
+ def _factory(
+ cls, selectable: SelectBase, name: Optional[str] = None
+ ) -> Subquery:
"""Return a :class:`.Subquery` object."""
return coercions.expect(
roles.SelectStatementRole, selectable
@@ -2335,11 +2796,13 @@ class Subquery(AliasedReturnsRows):
class FromGrouping(GroupedElement, FromClause):
"""Represent a grouping of a FROM clause"""
- _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_clauseelement)
+ ]
element: FromClause
- def __init__(self, element):
+ def __init__(self, element: FromClause):
self.element = coercions.expect(roles.FromClauseRole, element)
def _init_collections(self):
@@ -2361,11 +2824,13 @@ class FromGrouping(GroupedElement, FromClause):
def foreign_keys(self):
return self.element.foreign_keys
- def is_derived_from(self, element):
- return self.element.is_derived_from(element)
+ def is_derived_from(self, fromclause: FromClause) -> bool:
+ return self.element.is_derived_from(fromclause)
- def alias(self, **kw):
- return FromGrouping(self.element.alias(**kw))
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromGrouping:
+ return NamedFromGrouping(self.element.alias(name=name, flat=flat))
def _anonymous_fromclause(self, **kw):
return FromGrouping(self.element._anonymous_fromclause(**kw))
@@ -2385,6 +2850,16 @@ class FromGrouping(GroupedElement, FromClause):
self.element = state["element"]
+class NamedFromGrouping(FromGrouping, NamedFromClause):
+ """represent a grouping of a named FROM clause
+
+ .. versionadded:: 2.0
+
+ """
+
+ inherit_cache = True
+
+
class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
"""Represents a minimal "table" construct.
@@ -2417,7 +2892,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
__visit_name__ = "table"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
(
"columns",
InternalTraversal.dp_fromclause_canonical_column_collection,
@@ -2434,15 +2909,17 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
doesn't support having a primary key or column
-level defaults, so implicit returning doesn't apply."""
- _autoincrement_column = None
- """No PK or default support so no autoincrement column."""
+ @util.ro_memoized_property
+ def _autoincrement_column(self) -> Optional[ColumnClause[Any]]:
+ """No PK or default support so no autoincrement column."""
+ return None
- def __init__(self, name, *columns, **kw):
+ def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any):
super(TableClause, self).__init__()
self.name = name
self._columns = DedupeColumnCollection()
- self.primary_key = ColumnSet()
- self.foreign_keys = set()
+ self.primary_key = ColumnSet() # type: ignore
+ self.foreign_keys = set() # type: ignore
for c in columns:
self.append_column(c)
@@ -2466,23 +2943,23 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
...
- def __str__(self):
+ def __str__(self) -> str:
if self.schema is not None:
return self.schema + "." + self.name
else:
return self.name
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
pass
- def _init_collections(self):
+ def _init_collections(self) -> None:
pass
- @util.memoized_property
+ @util.ro_memoized_property
def description(self) -> str:
return self.name
- def append_column(self, c, **kw):
+ def append_column(self, c: ColumnClause[Any]) -> None:
existing = c.table
if existing is not None and existing is not self:
raise exc.ArgumentError(
@@ -2494,7 +2971,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
c.table = self
@util.preload_module("sqlalchemy.sql.dml")
- def insert(self):
+ def insert(self) -> Insert:
"""Generate an :func:`_expression.insert` construct against this
:class:`_expression.TableClause`.
@@ -2505,10 +2982,11 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
See :func:`_expression.insert` for argument and usage information.
"""
+
return util.preloaded.sql_dml.Insert(self)
@util.preload_module("sqlalchemy.sql.dml")
- def update(self):
+ def update(self) -> Update:
"""Generate an :func:`_expression.update` construct against this
:class:`_expression.TableClause`.
@@ -2524,7 +3002,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
)
@util.preload_module("sqlalchemy.sql.dml")
- def delete(self):
+ def delete(self) -> Delete:
"""Generate a :func:`_expression.delete` construct against this
:class:`_expression.TableClause`.
@@ -2543,13 +3021,18 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
class ForUpdateArg(ClauseElement):
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("of", InternalTraversal.dp_clauseelement_list),
("nowait", InternalTraversal.dp_boolean),
("read", InternalTraversal.dp_boolean),
("skip_locked", InternalTraversal.dp_boolean),
]
+ of: Optional[Sequence[ClauseElement]]
+ nowait: bool
+ read: bool
+ skip_locked: bool
+
@classmethod
def _from_argument(cls, with_for_update):
if isinstance(with_for_update, ForUpdateArg):
@@ -2606,7 +3089,7 @@ class ForUpdateArg(ClauseElement):
SelfValues = typing.TypeVar("SelfValues", bound="Values")
-class Values(Generative, NamedFromClause):
+class Values(Generative, LateralFromClause):
"""Represent a ``VALUES`` construct that can be used as a FROM element
in a statement.
@@ -2619,28 +3102,42 @@ class Values(Generative, NamedFromClause):
__visit_name__ = "values"
- _data = ()
+ _data: Tuple[List[Tuple[Any, ...]], ...] = ()
- _traverse_internals = [
+ _unnamed: bool
+ _traverse_internals: _TraverseInternalsType = [
("_column_args", InternalTraversal.dp_clauseelement_list),
("_data", InternalTraversal.dp_dml_multi_values),
("name", InternalTraversal.dp_string),
("literal_binds", InternalTraversal.dp_boolean),
]
- def __init__(self, *columns, name=None, literal_binds=False):
+ def __init__(
+ self,
+ *columns: ColumnClause[Any],
+ name: Optional[str] = None,
+ literal_binds: bool = False,
+ ):
super(Values, self).__init__()
self._column_args = columns
- self.name = name
+ if name is None:
+ self._unnamed = True
+ self.name = _anonymous_label.safe_construct(id(self), "anon")
+ else:
+ self._unnamed = False
+ self.name = name
self.literal_binds = literal_binds
- self.named_with_column = self.name is not None
+ self.named_with_column = not self._unnamed
@property
def _column_types(self):
return [col.type for col in self._column_args]
@_generative
- def alias(self: SelfValues, name, **kw) -> SelfValues:
+ def alias(
+ self: SelfValues, name: Optional[str] = None, flat: bool = False
+ ) -> SelfValues:
+
"""Return a new :class:`_expression.Values`
construct that is a copy of this
one with the given name.
@@ -2655,12 +3152,20 @@ class Values(Generative, NamedFromClause):
:func:`_expression.alias`
"""
- self.name = name
- self.named_with_column = self.name is not None
+ non_none_name: str
+
+ if name is None:
+ non_none_name = _anonymous_label.safe_construct(id(self), "anon")
+ else:
+ non_none_name = name
+
+ self.name = non_none_name
+ self.named_with_column = True
+ self._unnamed = False
return self
@_generative
- def lateral(self: SelfValues, name=None) -> SelfValues:
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a new :class:`_expression.Values` with the lateral flag set,
so that
it renders as LATERAL.
@@ -2670,13 +3175,20 @@ class Values(Generative, NamedFromClause):
:func:`_expression.lateral`
"""
+ non_none_name: str
+
+ if name is None:
+ non_none_name = self.name
+ else:
+ non_none_name = name
+
self._is_lateral = True
- if name is not None:
- self.name = name
+ self.name = non_none_name
+ self._unnamed = False
return self
@_generative
- def data(self: SelfValues, values) -> SelfValues:
+ def data(self: SelfValues, values: List[Tuple[Any, ...]]) -> SelfValues:
"""Return a new :class:`_expression.Values` construct,
adding the given data
to the data list.
@@ -2694,7 +3206,7 @@ class Values(Generative, NamedFromClause):
self._data += (values,)
return self
- def _populate_column_collection(self):
+ def _populate_column_collection(self) -> None:
for c in self._column_args:
self._columns.add(c)
c.table = self
@@ -2727,32 +3239,16 @@ class SelectBase(
"""
- _is_select_statement = True
+ _is_select_base = True
is_select = True
- def _generate_fromclause_column_proxies(
- self, fromclause: FromClause
- ) -> None:
- raise NotImplementedError()
+ _label_style: SelectLabelStyle = LABEL_STYLE_NONE
def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
self._reset_memoizations()
- def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
- ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
- raise NotImplementedError()
-
- def set_label_style(
- self: SelfSelectBase, label_style: SelectLabelStyle
- ) -> SelfSelectBase:
- raise NotImplementedError()
-
- def get_label_style(self) -> SelectLabelStyle:
- raise NotImplementedError()
-
- @property
- def selected_columns(self):
+ @util.ro_non_memoized_property
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set.
@@ -2797,7 +3293,7 @@ class SelectBase(
raise NotImplementedError()
@property
- def exported_columns(self):
+ def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`, not including
@@ -2819,7 +3315,7 @@ class SelectBase(
"""
- return self.selected_columns
+ return self.selected_columns.as_readonly()
@util.deprecated_property(
"1.4",
@@ -2841,6 +3337,26 @@ class SelectBase(
def columns(self):
return self.c
+ def get_label_style(self) -> SelectLabelStyle:
+ """
+ Retrieve the current label style.
+
+ Implemented by subclasses.
+
+ """
+ raise NotImplementedError()
+
+ def set_label_style(
+ self: SelfSelectBase, style: SelectLabelStyle
+ ) -> SelfSelectBase:
+ """Return a new selectable with the specified label style.
+
+ Implemented by subclasses.
+
+ """
+
+ raise NotImplementedError()
+
@util.deprecated(
"1.4",
"The :meth:`_expression.SelectBase.select` method is deprecated "
@@ -2857,6 +3373,9 @@ class SelectBase(
def _implicit_subquery(self):
return self.subquery()
+ def _scalar_type(self) -> TypeEngine[Any]:
+ raise NotImplementedError()
+
@util.deprecated(
"1.4",
"The :meth:`_expression.SelectBase.as_scalar` "
@@ -2926,7 +3445,7 @@ class SelectBase(
"""
return self.scalar_subquery().label(name)
- def lateral(self, name=None):
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a LATERAL alias of this :class:`_expression.Selectable`.
The return value is the :class:`_expression.Lateral` construct also
@@ -2941,11 +3460,7 @@ class SelectBase(
"""
return Lateral._factory(self, name)
- @util.ro_non_memoized_property
- def _from_objects(self) -> List[FromClause]:
- return [self]
-
- def subquery(self, name=None):
+ def subquery(self, name: Optional[str] = None) -> Subquery:
"""Return a subquery of this :class:`_expression.SelectBase`.
A subquery is from a SQL perspective a parenthesized, named
@@ -2995,7 +3510,9 @@ class SelectBase(
raise NotImplementedError()
- def alias(self, name=None, flat=False):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> Subquery:
"""Return a named subquery against this
:class:`_expression.SelectBase`.
@@ -3023,7 +3540,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
__visit_name__ = "select_statement_grouping"
- _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_clauseelement)
+ ]
_is_select_container = True
@@ -3053,13 +3572,14 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
def select_statement(self):
return self.element
- def self_group(self, against=None):
+ def self_group(self: Self, against: Optional[OperatorType] = None) -> Self:
+ ...
return self
- def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
- ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
- return self.element._generate_columns_plus_names(anon_for_dupe_key)
+ # def _generate_columns_plus_names(
+ # self, anon_for_dupe_key: bool
+ # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
+ # return self.element._generate_columns_plus_names(anon_for_dupe_key)
def _generate_fromclause_column_proxies(
self, subquery: FromClause
@@ -3070,8 +3590,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
def _all_selected_columns(self) -> _SelectIterable:
return self.element._all_selected_columns
- @property
- def selected_columns(self):
+ @util.ro_non_memoized_property
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
the embedded SELECT statement returns in its result set, not including
@@ -3112,25 +3632,30 @@ class GenerativeSelect(SelectBase):
"""
- _order_by_clauses = ()
- _group_by_clauses = ()
- _limit_clause = None
- _offset_clause = None
- _fetch_clause = None
- _fetch_clause_options = None
- _for_update_arg = None
+ _order_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _group_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _limit_clause: Optional[ColumnElement[Any]] = None
+ _offset_clause: Optional[ColumnElement[Any]] = None
+ _fetch_clause: Optional[ColumnElement[Any]] = None
+ _fetch_clause_options: Optional[Dict[str, bool]] = None
+ _for_update_arg: Optional[ForUpdateArg] = None
- def __init__(self, _label_style=LABEL_STYLE_DEFAULT):
+ def __init__(self, _label_style: SelectLabelStyle = LABEL_STYLE_DEFAULT):
self._label_style = _label_style
@_generative
def with_for_update(
self: SelfGenerativeSelect,
- nowait=False,
- read=False,
- of=None,
- skip_locked=False,
- key_share=False,
+ nowait: bool = False,
+ read: bool = False,
+ of: Optional[
+ Union[
+ _ColumnExpressionArgument[Any],
+ Sequence[_ColumnExpressionArgument[Any]],
+ ]
+ ] = None,
+ skip_locked: bool = False,
+ key_share: bool = False,
) -> SelfGenerativeSelect:
"""Specify a ``FOR UPDATE`` clause for this
:class:`_expression.GenerativeSelect`.
@@ -3241,20 +3766,25 @@ class GenerativeSelect(SelectBase):
return self
@property
- def _group_by_clause(self):
+ def _group_by_clause(self) -> ClauseList:
"""ClauseList access to group_by_clauses for legacy dialects"""
return ClauseList._construct_raw(
operators.comma_op, self._group_by_clauses
)
@property
- def _order_by_clause(self):
+ def _order_by_clause(self) -> ClauseList:
"""ClauseList access to order_by_clauses for legacy dialects"""
return ClauseList._construct_raw(
operators.comma_op, self._order_by_clauses
)
- def _offset_or_limit_clause(self, element, name=None, type_=None):
+ def _offset_or_limit_clause(
+ self,
+ element: Union[int, _ColumnExpressionArgument[Any]],
+ name: Optional[str] = None,
+ type_: Optional[_TypeEngineArgument[int]] = None,
+ ) -> ColumnElement[Any]:
"""Convert the given value to an "offset or limit" clause.
This handles incoming integers and converts to an expression; if
@@ -3265,7 +3795,21 @@ class GenerativeSelect(SelectBase):
roles.LimitOffsetRole, element, name=name, type_=type_
)
- def _offset_or_limit_clause_asint(self, clause, attrname):
+ @overload
+ def _offset_or_limit_clause_asint(
+ self, clause: ColumnElement[Any], attrname: str
+ ) -> NoReturn:
+ ...
+
+ @overload
+ def _offset_or_limit_clause_asint(
+ self, clause: Optional[_OffsetLimitParam], attrname: str
+ ) -> Optional[int]:
+ ...
+
+ def _offset_or_limit_clause_asint(
+ self, clause: Optional[ColumnElement[Any]], attrname: str
+ ) -> Union[NoReturn, Optional[int]]:
"""Convert the "offset or limit" clause of a select construct to an
integer.
@@ -3286,7 +3830,7 @@ class GenerativeSelect(SelectBase):
return util.asint(value)
@property
- def _limit(self):
+ def _limit(self) -> Optional[int]:
"""Get an integer value for the limit. This should only be used
by code that cannot support a limit as a BindParameter or
other custom clause as it will throw an exception if the limit
@@ -3295,14 +3839,14 @@ class GenerativeSelect(SelectBase):
"""
return self._offset_or_limit_clause_asint(self._limit_clause, "limit")
- def _simple_int_clause(self, clause):
+ def _simple_int_clause(self, clause: ClauseElement) -> bool:
"""True if the clause is a simple integer, False
if it is not present or is a SQL expression.
"""
return isinstance(clause, _OffsetLimitParam)
@property
- def _offset(self):
+ def _offset(self) -> Optional[int]:
"""Get an integer value for the offset. This should only be used
by code that cannot support an offset as a BindParameter or
other custom clause as it will throw an exception if the
@@ -3314,7 +3858,7 @@ class GenerativeSelect(SelectBase):
)
@property
- def _has_row_limiting_clause(self):
+ def _has_row_limiting_clause(self) -> bool:
return (
self._limit_clause is not None
or self._offset_clause is not None
@@ -3322,7 +3866,10 @@ class GenerativeSelect(SelectBase):
)
@_generative
- def limit(self: SelfGenerativeSelect, limit) -> SelfGenerativeSelect:
+ def limit(
+ self: SelfGenerativeSelect,
+ limit: Union[int, _ColumnExpressionArgument[int]],
+ ) -> SelfGenerativeSelect:
"""Return a new selectable with the given LIMIT criterion
applied.
@@ -3356,7 +3903,10 @@ class GenerativeSelect(SelectBase):
@_generative
def fetch(
- self: SelfGenerativeSelect, count, with_ties=False, percent=False
+ self: SelfGenerativeSelect,
+ count: Union[int, _ColumnExpressionArgument[int]],
+ with_ties: bool = False,
+ percent: bool = False,
) -> SelfGenerativeSelect:
"""Return a new selectable with the given FETCH FIRST criterion
applied.
@@ -3408,7 +3958,10 @@ class GenerativeSelect(SelectBase):
return self
@_generative
- def offset(self: SelfGenerativeSelect, offset) -> SelfGenerativeSelect:
+ def offset(
+ self: SelfGenerativeSelect,
+ offset: Union[int, _ColumnExpressionArgument[int]],
+ ) -> SelfGenerativeSelect:
"""Return a new selectable with the given OFFSET criterion
applied.
@@ -3438,7 +3991,11 @@ class GenerativeSelect(SelectBase):
@_generative
@util.preload_module("sqlalchemy.sql.util")
- def slice(self: SelfGenerativeSelect, start, stop) -> SelfGenerativeSelect:
+ def slice(
+ self: SelfGenerativeSelect,
+ start: int,
+ stop: int,
+ ) -> SelfGenerativeSelect:
"""Apply LIMIT / OFFSET to this statement based on a slice.
The start and stop indices behave like the argument to Python's
@@ -3485,7 +4042,9 @@ class GenerativeSelect(SelectBase):
return self
@_generative
- def order_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect:
+ def order_by(
+ self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of ORDER BY
criteria applied.
@@ -3522,7 +4081,9 @@ class GenerativeSelect(SelectBase):
return self
@_generative
- def group_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect:
+ def group_by(
+ self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of GROUP BY
criterion applied.
@@ -3567,6 +4128,15 @@ class CompoundSelectState(CompileState):
return d, d, d
+class _CompoundSelectKeyword(Enum):
+ UNION = "UNION"
+ UNION_ALL = "UNION ALL"
+ EXCEPT = "EXCEPT"
+ EXCEPT_ALL = "EXCEPT ALL"
+ INTERSECT = "INTERSECT"
+ INTERSECT_ALL = "INTERSECT ALL"
+
+
class CompoundSelect(HasCompileState, GenerativeSelect):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
SELECT-based set operations.
@@ -3590,7 +4160,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
__visit_name__ = "compound_select"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("selects", InternalTraversal.dp_clauseelement_list),
("_limit_clause", InternalTraversal.dp_clauseelement),
("_offset_clause", InternalTraversal.dp_clauseelement),
@@ -3602,17 +4172,16 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
("keyword", InternalTraversal.dp_string),
] + SupportsCloneAnnotations._clone_annotations_traverse_internals
- UNION = util.symbol("UNION")
- UNION_ALL = util.symbol("UNION ALL")
- EXCEPT = util.symbol("EXCEPT")
- EXCEPT_ALL = util.symbol("EXCEPT ALL")
- INTERSECT = util.symbol("INTERSECT")
- INTERSECT_ALL = util.symbol("INTERSECT ALL")
+ selects: List[SelectBase]
_is_from_container = True
_auto_correlate = False
- def __init__(self, keyword, *selects):
+ def __init__(
+ self,
+ keyword: _CompoundSelectKeyword,
+ *selects: _SelectStatementForCompoundArgument,
+ ):
self.keyword = keyword
self.selects = [
coercions.expect(roles.CompoundElementRole, s).self_group(
@@ -3624,36 +4193,50 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
GenerativeSelect.__init__(self)
@classmethod
- def _create_union(cls, *selects, **kwargs):
- return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs)
+ def _create_union(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.UNION, *selects)
@classmethod
- def _create_union_all(cls, *selects):
- return CompoundSelect(CompoundSelect.UNION_ALL, *selects)
+ def _create_union_all(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects)
@classmethod
- def _create_except(cls, *selects):
- return CompoundSelect(CompoundSelect.EXCEPT, *selects)
+ def _create_except(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects)
@classmethod
- def _create_except_all(cls, *selects):
- return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects)
+ def _create_except_all(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects)
@classmethod
- def _create_intersect(cls, *selects):
- return CompoundSelect(CompoundSelect.INTERSECT, *selects)
+ def _create_intersect(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects)
@classmethod
- def _create_intersect_all(cls, *selects):
- return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects)
+ def _create_intersect_all(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects)
- def _scalar_type(self):
+ def _scalar_type(self) -> TypeEngine[Any]:
return self.selects[0]._scalar_type()
- def self_group(self, against=None):
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> GroupedElement:
return SelectStatementGrouping(self)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
for s in self.selects:
if s.is_derived_from(fromclause):
return True
@@ -3675,7 +4258,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return self
- def _generate_fromclause_column_proxies(self, subquery):
+ def _generate_fromclause_column_proxies(
+ self, subquery: FromClause
+ ) -> None:
# this is a slightly hacky thing - the union exports a
# column that resembles just that of the *first* selectable.
@@ -3716,8 +4301,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
def _all_selected_columns(self) -> _SelectIterable:
return self.selects[0]._all_selected_columns
- @property
- def selected_columns(self):
+ @util.ro_non_memoized_property
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -3739,6 +4324,11 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return self.selects[0].selected_columns
+# backwards compat
+for elem in _CompoundSelectKeyword:
+ setattr(CompoundSelect, elem.name, elem)
+
+
@CompileState.plugin_for("default", "select")
class SelectState(util.MemoizedSlots, CompileState):
__slots__ = (
@@ -3758,10 +4348,12 @@ class SelectState(util.MemoizedSlots, CompileState):
if TYPE_CHECKING:
@classmethod
- def get_plugin_class(cls, statement: Select) -> SelectState:
+ def get_plugin_class(cls, statement: Executable) -> Type[SelectState]:
...
- def __init__(self, statement, compiler, **kw):
+ def __init__(
+ self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any
+ ):
self.statement = statement
self.from_clauses = statement._from_obj
@@ -3778,14 +4370,16 @@ class SelectState(util.MemoizedSlots, CompileState):
self.columns_plus_names = statement._generate_columns_plus_names(True)
@classmethod
- def _plugin_not_implemented(cls):
+ def _plugin_not_implemented(cls) -> NoReturn:
raise NotImplementedError(
"The default SELECT construct without plugins does not "
"implement this method."
)
@classmethod
- def get_column_descriptions(cls, statement):
+ def get_column_descriptions(
+ cls, statement: Select
+ ) -> List[Dict[str, Any]]:
return [
{
"name": name,
@@ -3798,11 +4392,13 @@ class SelectState(util.MemoizedSlots, CompileState):
]
@classmethod
- def from_statement(cls, statement, from_statement):
+ def from_statement(
+ cls, statement: Select, from_statement: ReturnsRows
+ ) -> Any:
cls._plugin_not_implemented()
@classmethod
- def get_columns_clause_froms(cls, statement):
+ def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]:
return cls._normalize_froms(
itertools.chain.from_iterable(
element._from_objects for element in statement._raw_columns
@@ -3810,7 +4406,9 @@ class SelectState(util.MemoizedSlots, CompileState):
)
@classmethod
- def _column_naming_convention(cls, label_style):
+ def _column_naming_convention(
+ cls, label_style: SelectLabelStyle
+ ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]:
table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL
dedupe = label_style is not LABEL_STYLE_NONE
@@ -3850,7 +4448,8 @@ class SelectState(util.MemoizedSlots, CompileState):
return go
- def _get_froms(self, statement):
+ def _get_froms(self, statement: Select) -> List[FromClause]:
+ ambiguous_table_name_map: _AmbiguousTableNameMap
self._ambiguous_table_name_map = ambiguous_table_name_map = {}
return self._normalize_froms(
@@ -3876,10 +4475,10 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def _normalize_froms(
cls,
- iterable_of_froms,
- check_statement=None,
- ambiguous_table_name_map=None,
- ):
+ iterable_of_froms: Iterable[FromClause],
+ check_statement: Optional[Select] = None,
+ ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+ ) -> List[FromClause]:
"""given an iterable of things to select FROM, reduce them to what
would actually render in the FROM clause of a SELECT.
@@ -3888,12 +4487,12 @@ class SelectState(util.MemoizedSlots, CompileState):
etc.
"""
- seen = set()
- froms = []
+ seen: Set[FromClause] = set()
+ froms: List[FromClause] = []
for item in iterable_of_froms:
- if item._is_subquery and item.element is check_statement:
+ if is_subquery(item) and item.element is check_statement:
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
)
@@ -3923,7 +4522,7 @@ class SelectState(util.MemoizedSlots, CompileState):
)
for item in froms
for fr in item._from_objects
- if fr._is_table
+ if is_table(fr)
and fr.schema
and fr.name not in ambiguous_table_name_map
)
@@ -3931,8 +4530,10 @@ class SelectState(util.MemoizedSlots, CompileState):
return froms
def _get_display_froms(
- self, explicit_correlate_froms=None, implicit_correlate_froms=None
- ):
+ self,
+ explicit_correlate_froms: Optional[Sequence[FromClause]] = None,
+ implicit_correlate_froms: Optional[Sequence[FromClause]] = None,
+ ) -> List[FromClause]:
"""Return the full list of 'from' clauses to be displayed.
Takes into account a set of existing froms which may be
@@ -3998,25 +4599,33 @@ class SelectState(util.MemoizedSlots, CompileState):
return froms
- def _memoized_attr__label_resolve_dict(self):
- with_cols = dict(
- (c._tq_label or c.key, c)
+ def _memoized_attr__label_resolve_dict(
+ self,
+ ) -> Tuple[
+ Dict[str, ColumnElement[Any]],
+ Dict[str, ColumnElement[Any]],
+ Dict[str, ColumnElement[Any]],
+ ]:
+ with_cols: Dict[str, ColumnElement[Any]] = dict(
+ (c._tq_label or c.key, c) # type: ignore
for c in self.statement._all_selected_columns
if c._allow_label_resolve
)
- only_froms = dict(
- (c.key, c)
+ only_froms: Dict[str, ColumnElement[Any]] = dict(
+ (c.key, c) # type: ignore
for c in _select_iterables(self.froms)
if c._allow_label_resolve
)
- only_cols = with_cols.copy()
+ only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy()
for key, value in only_froms.items():
with_cols.setdefault(key, value)
return with_cols, only_froms, only_cols
@classmethod
- def determine_last_joined_entity(cls, stmt):
+ def determine_last_joined_entity(
+ cls, stmt: Select
+ ) -> Optional[_JoinTargetElement]:
if stmt._setup_joins:
return stmt._setup_joins[-1][0]
else:
@@ -4026,8 +4635,16 @@ class SelectState(util.MemoizedSlots, CompileState):
def all_selected_columns(cls, statement: Select) -> _SelectIterable:
return [c for c in _select_iterables(statement._raw_columns)]
- def _setup_joins(self, args, raw_columns):
+ def _setup_joins(
+ self,
+ args: Tuple[_SetupJoinsElement, ...],
+ raw_columns: List[_ColumnsClauseElement],
+ ) -> None:
for (right, onclause, left, flags) in args:
+ if TYPE_CHECKING:
+ if onclause is not None:
+ assert isinstance(onclause, ColumnElement)
+
isouter = flags["isouter"]
full = flags["full"]
@@ -4043,6 +4660,16 @@ class SelectState(util.MemoizedSlots, CompileState):
left
)
+ # these assertions can be made here, as if the right/onclause
+ # contained ORM elements, the select() statement would have been
+ # upgraded to an ORM select, and this method would not be called;
+ # orm.context.ORMSelectCompileState._join() would be
+ # used instead.
+ if TYPE_CHECKING:
+ assert isinstance(right, FromClause)
+ if onclause is not None:
+ assert isinstance(onclause, ColumnElement)
+
if replace_from_obj_index is not None:
# splice into an existing element in the
# self._from_obj list
@@ -4062,15 +4689,19 @@ class SelectState(util.MemoizedSlots, CompileState):
+ self.from_clauses[replace_from_obj_index + 1 :]
)
else:
-
+ assert left is not None
self.from_clauses = self.from_clauses + (
Join(left, right, onclause, isouter=isouter, full=full),
)
@util.preload_module("sqlalchemy.sql.util")
def _join_determine_implicit_left_side(
- self, raw_columns, left, right, onclause
- ):
+ self,
+ raw_columns: List[_ColumnsClauseElement],
+ left: Optional[FromClause],
+ right: _JoinTargetElement,
+ onclause: Optional[ColumnElement[Any]],
+ ) -> Tuple[Optional[FromClause], Optional[int]]:
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -4079,13 +4710,13 @@ class SelectState(util.MemoizedSlots, CompileState):
sql_util = util.preloaded.sql_util
- replace_from_obj_index = None
+ replace_from_obj_index: Optional[int] = None
from_clauses = self.from_clauses
if from_clauses:
- indexes = sql_util.find_left_clause_to_join_from(
+ indexes: List[int] = sql_util.find_left_clause_to_join_from(
from_clauses, right, onclause
)
@@ -4138,15 +4769,17 @@ class SelectState(util.MemoizedSlots, CompileState):
return left, replace_from_obj_index
@util.preload_module("sqlalchemy.sql.util")
- def _join_place_explicit_left_side(self, left):
- replace_from_obj_index = None
+ def _join_place_explicit_left_side(
+ self, left: FromClause
+ ) -> Optional[int]:
+ replace_from_obj_index: Optional[int] = None
sql_util = util.preloaded.sql_util
from_clauses = list(self.statement._iterate_from_elements())
if from_clauses:
- indexes = sql_util.find_left_clause_that_matches_given(
+ indexes: List[int] = sql_util.find_left_clause_that_matches_given(
self.from_clauses, left
)
else:
@@ -4171,7 +4804,13 @@ class SelectState(util.MemoizedSlots, CompileState):
class _SelectFromElements:
- def _iterate_from_elements(self):
+ __slots__ = ()
+
+ _raw_columns: List[_ColumnsClauseElement]
+ _where_criteria: Tuple[ColumnElement[Any], ...]
+ _from_obj: Tuple[FromClause, ...]
+
+ def _iterate_from_elements(self) -> Iterator[FromClause]:
# note this does not include elements
# in _setup_joins
@@ -4195,28 +4834,58 @@ class _SelectFromElements:
yield element
+Self_MemoizedSelectEntities = TypeVar("Self_MemoizedSelectEntities", bound=Any)
+
+
class _MemoizedSelectEntities(
cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
):
+ """represents partial state from a Select object, for the case
+ where Select.columns() has redefined the set of columns/entities the
+ statement will be SELECTing from. This object represents
+ the entities from the SELECT before that transformation was applied,
+ so that transformations that were made in terms of the SELECT at that
+ time, such as join() as well as options(), can access the correct context.
+
+ In previous SQLAlchemy versions, this wasn't needed because these
+ constructs calculated everything up front, like when you called join()
+ or options(), it did everything to figure out how that would translate
+ into specific SQL constructs that would be ready to send directly to the
+ SQL compiler when needed. But as of
+ 1.4, all of that stuff is done in the compilation phase, during the
+ "compile state" portion of the process, so that the work can all be
+ cached. So it needs to be able to resolve joins/options2 based on what
+ the list of entities was when those methods were called.
+
+
+ """
+
__visit_name__ = "memoized_select_entities"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("_raw_columns", InternalTraversal.dp_clauseelement_list),
("_setup_joins", InternalTraversal.dp_setup_join_tuple),
("_with_options", InternalTraversal.dp_executable_options),
]
+ _is_clone_of: Optional[ClauseElement]
+ _raw_columns: List[_ColumnsClauseElement]
+ _setup_joins: Tuple[_SetupJoinsElement, ...]
+ _with_options: Tuple[ExecutableOption, ...]
+
_annotations = util.EMPTY_DICT
- def _clone(self, **kw):
+ def _clone(
+ self: Self_MemoizedSelectEntities, **kw: Any
+ ) -> Self_MemoizedSelectEntities:
c = self.__class__.__new__(self.__class__)
c.__dict__ = {k: v for k, v in self.__dict__.items()}
c._is_clone_of = self.__dict__.get("_is_clone_of", self)
- return c
+ return c # type: ignore
@classmethod
- def _generate_for_statement(cls, select_stmt):
+ def _generate_for_statement(cls, select_stmt: Select) -> None:
if select_stmt._setup_joins or select_stmt._with_options:
self = _MemoizedSelectEntities()
self._raw_columns = select_stmt._raw_columns
@@ -4224,12 +4893,10 @@ class _MemoizedSelectEntities(
self._with_options = select_stmt._with_options
select_stmt._memoized_select_entities += (self,)
- select_stmt._raw_columns = (
- select_stmt._setup_joins
- ) = select_stmt._with_options = ()
+ select_stmt._raw_columns = []
+ select_stmt._setup_joins = select_stmt._with_options = ()
-# TODO: use pep-673 when feasible
SelfSelect = typing.TypeVar("SelfSelect", bound="Select")
@@ -4258,9 +4925,11 @@ class Select(
__visit_name__ = "select"
- _setup_joins: Tuple[TODO_Any, ...] = ()
+ _setup_joins: Tuple[_SetupJoinsElement, ...] = ()
_memoized_select_entities: Tuple[TODO_Any, ...] = ()
+ _raw_columns: List[_ColumnsClauseElement]
+
_distinct = False
_distinct_on: Tuple[ColumnElement[Any], ...] = ()
_correlate: Tuple[FromClause, ...] = ()
@@ -4269,12 +4938,12 @@ class Select(
_having_criteria: Tuple[ColumnElement[Any], ...] = ()
_from_obj: Tuple[FromClause, ...] = ()
_auto_correlate = True
-
+ _is_select_statement = True
_compile_options: CacheableOptions = (
SelectState.default_select_compile_options
)
- _traverse_internals = (
+ _traverse_internals: _TraverseInternalsType = (
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
(
@@ -4306,12 +4975,14 @@ class Select(
+ Executable._executable_traverse_internals
)
- _cache_key_traversal = _traverse_internals + [
+ _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [
("_compile_options", InternalTraversal.dp_has_cache_key)
]
+ _compile_state_factory: Type[SelectState]
+
@classmethod
- def _create_raw_select(cls, **kw) -> "Select":
+ def _create_raw_select(cls, **kw: Any) -> Select:
"""Create a :class:`.Select` using raw ``__new__`` with no coercions.
Used internally to build up :class:`.Select` constructs with
@@ -4330,6 +5001,12 @@ class Select(
:func:`_sql.select` function.
"""
+ things = [
+ coercions.expect(
+ roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
+ )
+ for ent in entities
+ ]
self._raw_columns = [
coercions.expect(
@@ -4340,7 +5017,7 @@ class Select(
GenerativeSelect.__init__(self)
- def _scalar_type(self):
+ def _scalar_type(self) -> TypeEngine[Any]:
elem = self._raw_columns[0]
cols = list(elem._select_iterable)
return cols[0].type
@@ -4446,7 +5123,12 @@ class Select(
@_generative
def join(
- self: SelfSelect, target, onclause=None, *, isouter=False, full=False
+ self: SelfSelect,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ isouter: bool = False,
+ full: bool = False,
) -> SelfSelect:
r"""Create a SQL JOIN against this :class:`_expression.Select`
object's criterion
@@ -4505,17 +5187,32 @@ class Select(
:meth:`_expression.Select.outerjoin`
""" # noqa: E501
- target = coercions.expect(
+ join_target = coercions.expect(
roles.JoinTargetRole, target, apply_propagate_attrs=self
)
if onclause is not None:
- onclause = coercions.expect(roles.OnClauseRole, onclause)
+ onclause_element = coercions.expect(roles.OnClauseRole, onclause)
+ else:
+ onclause_element = None
+
self._setup_joins += (
- (target, onclause, None, {"isouter": isouter, "full": full}),
+ (
+ join_target,
+ onclause_element,
+ None,
+ {"isouter": isouter, "full": full},
+ ),
)
return self
- def outerjoin_from(self, from_, target, onclause=None, *, full=False):
+ def outerjoin_from(
+ self: SelfSelect,
+ from_: _FromClauseArgument,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ full: bool = False,
+ ) -> SelfSelect:
r"""Create a SQL LEFT OUTER JOIN against this :class:`_expression.Select`
object's criterion
and apply generatively, returning the newly resulting
@@ -4531,12 +5228,12 @@ class Select(
@_generative
def join_from(
self: SelfSelect,
- from_,
- target,
- onclause=None,
+ from_: _FromClauseArgument,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
*,
- isouter=False,
- full=False,
+ isouter: bool = False,
+ full: bool = False,
) -> SelfSelect:
r"""Create a SQL JOIN against this :class:`_expression.Select`
object's criterion
@@ -4586,18 +5283,31 @@ class Select(
from_ = coercions.expect(
roles.FromClauseRole, from_, apply_propagate_attrs=self
)
- target = coercions.expect(
+ join_target = coercions.expect(
roles.JoinTargetRole, target, apply_propagate_attrs=self
)
if onclause is not None:
- onclause = coercions.expect(roles.OnClauseRole, onclause)
+ onclause_element = coercions.expect(roles.OnClauseRole, onclause)
+ else:
+ onclause_element = None
self._setup_joins += (
- (target, onclause, from_, {"isouter": isouter, "full": full}),
+ (
+ join_target,
+ onclause_element,
+ from_,
+ {"isouter": isouter, "full": full},
+ ),
)
return self
- def outerjoin(self, target, onclause=None, *, full=False):
+ def outerjoin(
+ self: SelfSelect,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ full: bool = False,
+ ) -> SelfSelect:
"""Create a left outer join.
Parameters are the same as that of :meth:`_expression.Select.join`.
@@ -4634,7 +5344,7 @@ class Select(
"""
return self.join(target, onclause=onclause, isouter=True, full=full)
- def get_final_froms(self):
+ def get_final_froms(self) -> Sequence[FromClause]:
"""Compute the final displayed list of :class:`_expression.FromClause`
elements.
@@ -4671,6 +5381,7 @@ class Select(
:attr:`_sql.Select.columns_clause_froms`
"""
+
return self._compile_state_factory(self, None)._get_display_froms()
@util.deprecated_property(
@@ -4678,7 +5389,7 @@ class Select(
"The :attr:`_expression.Select.froms` attribute is moved to "
"the :meth:`_expression.Select.get_final_froms` method.",
)
- def froms(self):
+ def froms(self) -> Sequence[FromClause]:
"""Return the displayed list of :class:`_expression.FromClause`
elements.
@@ -4687,7 +5398,7 @@ class Select(
return self.get_final_froms()
@property
- def columns_clause_froms(self):
+ def columns_clause_froms(self) -> List[FromClause]:
"""Return the set of :class:`_expression.FromClause` objects implied
by the columns clause of this SELECT statement.
@@ -4720,7 +5431,7 @@ class Select(
return iter(self._all_selected_columns)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
if self in fromclause._cloned_set:
return True
@@ -4729,7 +5440,9 @@ class Select(
return True
return False
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
# Select() object has been cloned and probably adapted by the
# given clone function. Apply the cloning function to internal
# objects
@@ -4786,13 +5499,15 @@ class Select(
def get_children(self, **kwargs):
return itertools.chain(
super(Select, self).get_children(
- omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
+ omit_attrs=("_from_obj", "_correlate", "_correlate_except")
),
self._iterate_from_elements(),
)
@_generative
- def add_columns(self: SelfSelect, *columns) -> SelfSelect:
+ def add_columns(
+ self: SelfSelect, *columns: _ColumnsClauseArgument
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given column expressions added to its columns clause.
@@ -4816,7 +5531,9 @@ class Select(
]
return self
- def _set_entities(self, entities):
+ def _set_entities(
+ self, entities: Iterable[_ColumnsClauseArgument]
+ ) -> None:
self._raw_columns = [
coercions.expect(
roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
@@ -4830,7 +5547,7 @@ class Select(
"be removed in a future release. Please use "
":meth:`_expression.Select.add_columns`",
)
- def column(self, column):
+ def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given column expression added to its columns clause.
@@ -4847,7 +5564,9 @@ class Select(
return self.add_columns(column)
@util.preload_module("sqlalchemy.sql.util")
- def reduce_columns(self, only_synonyms=True):
+ def reduce_columns(
+ self: SelfSelect, only_synonyms: bool = True
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with redundantly
named, equivalently-valued columns removed from the columns clause.
@@ -4880,7 +5599,9 @@ class Select(
@_generative
def with_only_columns(
- self: SelfSelect, *columns, maintain_column_froms=False
+ self: SelfSelect,
+ *columns: _ColumnsClauseArgument,
+ maintain_column_froms: bool = False,
) -> SelfSelect:
r"""Return a new :func:`_expression.select` construct with its columns
clause replaced with the given columns.
@@ -4941,7 +5662,9 @@ class Select(
self._assert_no_memoizations()
if maintain_column_froms:
- self.select_from.non_generative(self, *self.columns_clause_froms)
+ self.select_from.non_generative( # type: ignore
+ self, *self.columns_clause_froms
+ )
# then memoize the FROMs etc.
_MemoizedSelectEntities._generate_for_statement(self)
@@ -4974,7 +5697,9 @@ class Select(
_whereclause = whereclause
@_generative
- def where(self: SelfSelect, *whereclause) -> SelfSelect:
+ def where(
+ self: SelfSelect, *whereclause: _ColumnExpressionArgument[bool]
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given expression added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -4984,24 +5709,33 @@ class Select(
assert isinstance(self._where_criteria, tuple)
for criterion in whereclause:
- where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ where_criteria: ColumnElement[Any] = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
self._where_criteria += (where_criteria,)
return self
@_generative
- def having(self: SelfSelect, having) -> SelfSelect:
+ def having(
+ self: SelfSelect, *having: _ColumnExpressionArgument[bool]
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given expression added to
its HAVING clause, joined to the existing clause via AND, if any.
"""
- self._having_criteria += (
- coercions.expect(roles.WhereHavingRole, having),
- )
+
+ for criterion in having:
+ having_criteria = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
+ self._having_criteria += (having_criteria,)
return self
@_generative
- def distinct(self: SelfSelect, *expr) -> SelfSelect:
+ def distinct(
+ self: SelfSelect, *expr: _ColumnExpressionArgument[Any]
+ ) -> SelfSelect:
r"""Return a new :func:`_expression.select` construct which
will apply DISTINCT to its columns clause.
@@ -5023,7 +5757,9 @@ class Select(
return self
@_generative
- def select_from(self: SelfSelect, *froms) -> SelfSelect:
+ def select_from(
+ self: SelfSelect, *froms: _FromClauseArgument
+ ) -> SelfSelect:
r"""Return a new :func:`_expression.select` construct with the
given FROM expression(s)
merged into its list of FROM objects.
@@ -5067,7 +5803,10 @@ class Select(
return self
@_generative
- def correlate(self: SelfSelect, *fromclauses) -> SelfSelect:
+ def correlate(
+ self: SelfSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfSelect:
r"""Return a new :class:`_expression.Select`
which will correlate the given FROM
clauses to that of an enclosing :class:`_expression.Select`.
@@ -5106,10 +5845,10 @@ class Select(
none of its FROM entries, and all will render unconditionally
in the local FROM clause.
- :param \*fromclauses: a list of one or more
- :class:`_expression.FromClause`
- constructs, or other compatible constructs (i.e. ORM-mapped
- classes) to become part of the correlate collection.
+ :param \*fromclauses: one or more :class:`.FromClause` or other
+ FROM-compatible construct such as an ORM mapped entity to become part
+ of the correlate collection; alternatively pass a single value
+ ``None`` to remove all existing correlations.
.. seealso::
@@ -5119,8 +5858,16 @@ class Select(
"""
+ # tests failing when we try to change how these
+ # arguments are passed
+
self._auto_correlate = False
- if fromclauses and fromclauses[0] in {None, False}:
+ if not fromclauses or fromclauses[0] in {None, False}:
+ if len(fromclauses) > 1:
+ raise exc.ArgumentError(
+ "additional FROM objects not accepted when "
+ "passing None/False to correlate()"
+ )
self._correlate = ()
else:
self._correlate = self._correlate + tuple(
@@ -5129,7 +5876,10 @@ class Select(
return self
@_generative
- def correlate_except(self: SelfSelect, *fromclauses) -> SelfSelect:
+ def correlate_except(
+ self: SelfSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfSelect:
r"""Return a new :class:`_expression.Select`
which will omit the given FROM
clauses from the auto-correlation process.
@@ -5141,9 +5891,9 @@ class Select(
all other FROM elements remain subject to normal auto-correlation
behaviors.
- If ``None`` is passed, the :class:`_expression.Select`
- object will correlate
- all of its FROM entries.
+ If ``None`` is passed, or no arguments are passed,
+ the :class:`_expression.Select` object will correlate all of its
+ FROM entries.
:param \*fromclauses: a list of one or more
:class:`_expression.FromClause`
@@ -5159,16 +5909,22 @@ class Select(
"""
self._auto_correlate = False
- if fromclauses and fromclauses[0] in {None, False}:
+ if not fromclauses or fromclauses[0] in {None, False}:
+ if len(fromclauses) > 1:
+ raise exc.ArgumentError(
+ "additional FROM objects not accepted when "
+ "passing None/False to correlate_except()"
+ )
self._correlate_except = ()
else:
self._correlate_except = (self._correlate_except or ()) + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
+
return self
- @HasMemoized.memoized_attribute
- def selected_columns(self):
+ @HasMemoized_ro_memoized_attribute
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -5214,18 +5970,22 @@ class Select(
# generates the actual names used in the SELECT string. that
# method is more complex because it also renders columns that are
# fully ambiguous, e.g. same column more than once.
- conv = SelectState._column_naming_convention(self._label_style)
+ conv = cast(
+ "Callable[[Any], str]",
+ SelectState._column_naming_convention(self._label_style),
+ )
- return ColumnCollection(
+ cc: ColumnCollection[str, ColumnElement[Any]] = ColumnCollection(
[
(conv(c), c)
for c in self._all_selected_columns
if is_column_element(c)
]
- ).as_readonly()
+ )
+ return cc.as_readonly()
@HasMemoized.memoized_attribute
- def _all_selected_columns(self) -> Sequence[ColumnElement[Any]]:
+ def _all_selected_columns(self) -> _SelectIterable:
meth = SelectState.get_plugin_class(self).all_selected_columns
return list(meth(self))
@@ -5234,173 +5994,9 @@ class Select(
self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
return self
- def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
- ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
- """Generate column names as rendered in a SELECT statement by
- the compiler.
-
- This is distinct from the _column_naming_convention generator that's
- intended for population of .c collections and similar, which has
- different rules. the collection returned here calls upon the
- _column_naming_convention as well.
-
- """
- cols = self._all_selected_columns
-
- key_naming_convention = SelectState._column_naming_convention(
- self._label_style
- )
-
- names = {}
-
- result = []
- result_append = result.append
-
- table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
- label_style_none = self._label_style is LABEL_STYLE_NONE
-
- # a counter used for "dedupe" labels, which have double underscores
- # in them and are never referred by name; they only act
- # as positional placeholders. they need only be unique within
- # the single columns clause they're rendered within (required by
- # some dbs such as mysql). So their anon identity is tracked against
- # a fixed counter rather than hash() identity.
- dedupe_hash = 1
-
- for c in cols:
- repeated = False
-
- if not c._render_label_in_columns_clause:
- effective_name = (
- required_label_name
- ) = fallback_label_name = None
- elif label_style_none:
- effective_name = required_label_name = None
- fallback_label_name = c._non_anon_label or c._anon_name_label
- else:
- if table_qualified:
- required_label_name = (
- effective_name
- ) = fallback_label_name = c._tq_label
- else:
- effective_name = fallback_label_name = c._non_anon_label
- required_label_name = None
-
- if effective_name is None:
- # it seems like this could be _proxy_key and we would
- # not need _expression_label but it isn't
- # giving us a clue when to use anon_label instead
- expr_label = c._expression_label
- if expr_label is None:
- repeated = c._anon_name_label in names
- names[c._anon_name_label] = c
- effective_name = required_label_name = None
-
- if repeated:
- # here, "required_label_name" is sent as
- # "None" and "fallback_label_name" is sent.
- if table_qualified:
- fallback_label_name = (
- c._dedupe_anon_tq_label_idx(dedupe_hash)
- )
- dedupe_hash += 1
- else:
- fallback_label_name = c._dedupe_anon_label_idx(
- dedupe_hash
- )
- dedupe_hash += 1
- else:
- fallback_label_name = c._anon_name_label
- else:
- required_label_name = (
- effective_name
- ) = fallback_label_name = expr_label
-
- if effective_name is not None:
- if effective_name in names:
- # when looking to see if names[name] is the same column as
- # c, use hash(), so that an annotated version of the column
- # is seen as the same as the non-annotated
- if hash(names[effective_name]) != hash(c):
-
- # different column under the same name. apply
- # disambiguating label
- if table_qualified:
- required_label_name = (
- fallback_label_name
- ) = c._anon_tq_label
- else:
- required_label_name = (
- fallback_label_name
- ) = c._anon_name_label
-
- if anon_for_dupe_key and required_label_name in names:
- # here, c._anon_tq_label is definitely unique to
- # that column identity (or annotated version), so
- # this should always be true.
- # this is also an infrequent codepath because
- # you need two levels of duplication to be here
- assert hash(names[required_label_name]) == hash(c)
-
- # the column under the disambiguating label is
- # already present. apply the "dedupe" label to
- # subsequent occurrences of the column so that the
- # original stays non-ambiguous
- if table_qualified:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
- dedupe_hash += 1
- else:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_label_idx(dedupe_hash)
- dedupe_hash += 1
- repeated = True
- else:
- names[required_label_name] = c
- elif anon_for_dupe_key:
- # same column under the same name. apply the "dedupe"
- # label so that the original stays non-ambiguous
- if table_qualified:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
- dedupe_hash += 1
- else:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_label_idx(dedupe_hash)
- dedupe_hash += 1
- repeated = True
- else:
- names[effective_name] = c
-
- result_append(
- (
- # string label name, if non-None, must be rendered as a
- # label, i.e. "AS <name>"
- required_label_name,
- # proxy_key that is to be part of the result map for this
- # col. this is also the key in a fromclause.c or
- # select.selected_columns collection
- key_naming_convention(c),
- # name that can be used to render an "AS <name>" when
- # we have to render a label even though
- # required_label_name was not given
- fallback_label_name,
- # the ColumnElement itself
- c,
- # True if this is a duplicate of a previous column
- # in the list of columns
- repeated,
- )
- )
-
- return result
-
- def _generate_fromclause_column_proxies(self, subquery):
+ def _generate_fromclause_column_proxies(
+ self, subquery: FromClause
+ ) -> None:
"""Generate column proxies to place in the exported ``.c``
collection of a subquery."""
@@ -5418,7 +6014,7 @@ class Select(
c,
repeated,
) in (self._generate_columns_plus_names(False))
- if not c._is_text_clause
+ if is_column_element(c)
]
subquery._columns._populate_separate_keys(prox)
@@ -5428,7 +6024,10 @@ class Select(
self._order_by_clause.clauses
)
- def self_group(self, against=None):
+ def self_group(
+ self: Self, against: Optional[OperatorType] = None
+ ) -> Union[SelectStatementGrouping, Self]:
+ ...
"""Return a 'grouping' construct as per the
:class:`_expression.ClauseElement` specification.
@@ -5445,7 +6044,9 @@ class Select(
else:
return SelectStatementGrouping(self)
- def union(self, *other, **kwargs):
+ def union(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``UNION`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5460,9 +6061,11 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_union(self, *other, **kwargs)
+ return CompoundSelect._create_union(self, *other)
- def union_all(self, *other, **kwargs):
+ def union_all(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``UNION ALL`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5477,9 +6080,11 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_union_all(self, *other, **kwargs)
+ return CompoundSelect._create_union_all(self, *other)
- def except_(self, *other, **kwargs):
+ def except_(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``EXCEPT`` of this select() construct against
the given selectable provided as positional arguments.
@@ -5490,13 +6095,12 @@ class Select(
multiple elements are now accepted.
- :param \**kwargs: keyword arguments are forwarded to the constructor
- for the newly created :class:`_sql.CompoundSelect` object.
-
"""
- return CompoundSelect._create_except(self, *other, **kwargs)
+ return CompoundSelect._create_except(self, *other)
- def except_all(self, *other, **kwargs):
+ def except_all(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``EXCEPT ALL`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5507,13 +6111,12 @@ class Select(
multiple elements are now accepted.
- :param \**kwargs: keyword arguments are forwarded to the constructor
- for the newly created :class:`_sql.CompoundSelect` object.
-
"""
- return CompoundSelect._create_except_all(self, *other, **kwargs)
+ return CompoundSelect._create_except_all(self, *other)
- def intersect(self, *other, **kwargs):
+ def intersect(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``INTERSECT`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5528,9 +6131,11 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_intersect(self, *other, **kwargs)
+ return CompoundSelect._create_intersect(self, *other)
- def intersect_all(self, *other, **kwargs):
+ def intersect_all(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``INTERSECT ALL`` of this select() construct
against the given selectables provided as positional arguments.
@@ -5545,13 +6150,17 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_intersect_all(self, *other, **kwargs)
+ return CompoundSelect._create_intersect_all(self, *other)
-SelfScalarSelect = typing.TypeVar("SelfScalarSelect", bound="ScalarSelect")
+SelfScalarSelect = typing.TypeVar(
+ "SelfScalarSelect", bound="ScalarSelect[Any]"
+)
-class ScalarSelect(roles.InElementRole, Generative, Grouping):
+class ScalarSelect(
+ roles.InElementRole, Generative, GroupedElement, ColumnElement[_T]
+):
"""Represent a scalar subquery.
@@ -5570,15 +6179,33 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
"""
- _from_objects = []
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ _from_objects: List[FromClause] = []
_is_from_container = True
- _is_implicitly_boolean = False
+ if not TYPE_CHECKING:
+ _is_implicitly_boolean = False
inherit_cache = True
+ element: SelectBase
+
def __init__(self, element):
self.element = element
self.type = element._scalar_type()
+ def __getattr__(self, attr):
+ return getattr(self.element, attr)
+
+ def __getstate__(self):
+ return {"element": self.element, "type": self.type}
+
+ def __setstate__(self, state):
+ self.element = state["element"]
+ self.type = state["type"]
+
@property
def columns(self):
raise exc.InvalidRequestError(
@@ -5590,19 +6217,39 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
c = columns
@_generative
- def where(self: SelfScalarSelect, crit) -> SelfScalarSelect:
+ def where(
+ self: SelfScalarSelect, crit: _ColumnExpressionArgument[bool]
+ ) -> SelfScalarSelect:
"""Apply a WHERE clause to the SELECT statement referred to
by this :class:`_expression.ScalarSelect`.
"""
- self.element = self.element.where(crit)
+ self.element = cast(Select, self.element).where(crit)
return self
- def self_group(self, **kwargs):
+ @overload
+ def self_group(
+ self: ScalarSelect[Any], against: Optional[OperatorType] = None
+ ) -> ScalarSelect[Any]:
+ ...
+
+ @overload
+ def self_group(
+ self: ColumnElement[Any], against: Optional[OperatorType] = None
+ ) -> ColumnElement[Any]:
+ ...
+
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> ColumnElement[Any]:
+
return self
@_generative
- def correlate(self: SelfScalarSelect, *fromclauses) -> SelfScalarSelect:
+ def correlate(
+ self: SelfScalarSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfScalarSelect:
r"""Return a new :class:`_expression.ScalarSelect`
which will correlate the given FROM
clauses to that of an enclosing :class:`_expression.Select`.
@@ -5631,12 +6278,13 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
"""
- self.element = self.element.correlate(*fromclauses)
+ self.element = cast(Select, self.element).correlate(*fromclauses)
return self
@_generative
def correlate_except(
- self: SelfScalarSelect, *fromclauses
+ self: SelfScalarSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
) -> SelfScalarSelect:
r"""Return a new :class:`_expression.ScalarSelect`
which will omit the given FROM
@@ -5668,11 +6316,16 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
"""
- self.element = self.element.correlate_except(*fromclauses)
+ self.element = cast(Select, self.element).correlate_except(
+ *fromclauses
+ )
return self
-class Exists(UnaryExpression[_T]):
+SelfExists = TypeVar("SelfExists", bound="Exists")
+
+
+class Exists(UnaryExpression[bool]):
"""Represent an ``EXISTS`` clause.
See :func:`_sql.exists` for a description of usage.
@@ -5682,10 +6335,14 @@ class Exists(UnaryExpression[_T]):
"""
- _from_objects = ()
inherit_cache = True
- def __init__(self, __argument=None):
+ def __init__(
+ self,
+ __argument: Optional[
+ Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ ] = None,
+ ):
if __argument is None:
s = Select(literal_column("*")).scalar_subquery()
elif isinstance(__argument, (SelectBase, ScalarSelect)):
@@ -5701,12 +6358,16 @@ class Exists(UnaryExpression[_T]):
wraps_column_expression=True,
)
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ return []
+
def _regroup(self, fn):
element = self.element._ungroup()
element = fn(element)
return element.self_group(against=operators.exists)
- def select(self) -> "Select":
+ def select(self) -> Select:
r"""Return a SELECT of this :class:`_expression.Exists`.
e.g.::
@@ -5726,7 +6387,10 @@ class Exists(UnaryExpression[_T]):
return Select(self)
- def correlate(self, *fromclause):
+ def correlate(
+ self: SelfExists,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfExists:
"""Apply correlation to the subquery noted by this :class:`_sql.Exists`.
.. seealso::
@@ -5736,11 +6400,14 @@ class Exists(UnaryExpression[_T]):
"""
e = self._clone()
e.element = self._regroup(
- lambda element: element.correlate(*fromclause)
+ lambda element: element.correlate(*fromclauses)
)
return e
- def correlate_except(self, *fromclause):
+ def correlate_except(
+ self: SelfExists,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfExists:
"""Apply correlation to the subquery noted by this :class:`_sql.Exists`.
.. seealso::
@@ -5751,11 +6418,11 @@ class Exists(UnaryExpression[_T]):
e = self._clone()
e.element = self._regroup(
- lambda element: element.correlate_except(*fromclause)
+ lambda element: element.correlate_except(*fromclauses)
)
return e
- def select_from(self, *froms):
+ def select_from(self: SelfExists, *froms: FromClause) -> SelfExists:
"""Return a new :class:`_expression.Exists` construct,
applying the given
expression to the :meth:`_expression.Select.select_from`
@@ -5772,7 +6439,9 @@ class Exists(UnaryExpression[_T]):
e.element = self._regroup(lambda element: element.select_from(*froms))
return e
- def where(self, *clause):
+ def where(
+ self: SelfExists, *clause: _ColumnExpressionArgument[bool]
+ ) -> SelfExists:
"""Return a new :func:`_expression.exists` construct with the
given expression added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -5824,7 +6493,7 @@ class TextualSelect(SelectBase):
_label_style = LABEL_STYLE_NONE
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("column_args", InternalTraversal.dp_clauseelement_list),
] + SupportsCloneAnnotations._clone_annotations_traverse_internals
@@ -5842,8 +6511,8 @@ class TextualSelect(SelectBase):
]
self.positional = positional
- @HasMemoized.memoized_attribute
- def selected_columns(self):
+ @HasMemoized_ro_memoized_attribute
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -5868,6 +6537,13 @@ class TextualSelect(SelectBase):
(c.key, c) for c in self.column_args
).as_readonly()
+ # def _generate_columns_plus_names(
+ # self, anon_for_dupe_key: bool
+ # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
+ # return Select._generate_columns_plus_names(
+ # self, anon_for_dupe_key=anon_for_dupe_key
+ # )
+
@util.non_memoized_property
def _all_selected_columns(self) -> _SelectIterable:
return self.column_args
@@ -5880,7 +6556,9 @@ class TextualSelect(SelectBase):
@_generative
def bindparams(
- self: SelfTextualSelect, *binds, **bind_as_values
+ self: SelfTextualSelect,
+ *binds: BindParameter[Any],
+ **bind_as_values: Any,
) -> SelfTextualSelect:
self.element = self.element.bindparams(*binds, **bind_as_values)
return self
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 1f3d50876..c3653c264 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -228,7 +228,7 @@ class HasCopyInternals(HasTraverseInternals):
raise NotImplementedError()
def _copy_internals(
- self, omit_attrs: Iterable[str] = (), **kw: Any
+ self, *, omit_attrs: Iterable[str] = (), **kw: Any
) -> None:
"""Reassign internal elements to be clones of themselves.
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 9a934a50b..82adf4a4f 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -1008,10 +1008,7 @@ class TypeEngine(Visitable, Generic[_T]):
@util.preload_module("sqlalchemy.engine.default")
def _default_dialect(self) -> Dialect:
- if TYPE_CHECKING:
- from ..engine import default
- else:
- default = util.preloaded.engine_default
+ default = util.preloaded.engine_default
# dmypy / mypy seems to sporadically keep thinking this line is
# returning Any, which seems to be caused by the @deprecated_params
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index cdce49f7b..80711c4b5 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -13,10 +13,20 @@ from __future__ import annotations
from collections import deque
from itertools import chain
import typing
+from typing import AbstractSet
from typing import Any
+from typing import Callable
from typing import cast
+from typing import Dict
from typing import Iterator
+from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from . import coercions
from . import operators
@@ -49,11 +59,22 @@ from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
+from .visitors import _ET
from .. import exc
from .. import util
+from ..util.typing import Literal
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _ColumnExpressionArgument
+ from ._typing import _TypeEngineArgument
from .roles import FromClauseRole
+ from .selectable import _JoinTargetElement
+ from .selectable import _OnClauseElement
+ from .selectable import Selectable
+ from .visitors import _TraverseCallableType
+ from .visitors import ExternallyTraversible
+ from .visitors import ExternalTraversal
from ..engine.interfaces import _AnyExecuteParams
from ..engine.interfaces import _AnyMultiExecuteParams
from ..engine.interfaces import _AnySingleExecuteParams
@@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from):
return liberal_idx
-def find_left_clause_to_join_from(clauses, join_to, onclause):
+def find_left_clause_to_join_from(
+ clauses: Sequence[FromClause],
+ join_to: _JoinTargetElement,
+ onclause: Optional[ColumnElement[Any]],
+) -> List[int]:
"""Given a list of FROM clauses, a selectable,
and optional ON clause, return a list of integer indexes from the
clauses list indicating the clauses that can be joined from.
@@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
for i, f in enumerate(clauses):
for s in selectables.difference([f]):
if resolve_ambiguity:
+ assert cols_in_onclause is not None
if set(f.c).union(s.c).issuperset(cols_in_onclause):
idx.append(i)
break
@@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
# onclause was given and none of them resolved, so assume
# all indexes can match
if not idx and onclause is not None:
- return range(len(clauses))
+ return list(range(len(clauses)))
else:
return idx
@@ -247,7 +273,7 @@ def visit_binary_product(fn, expr):
a binary comparison is passed as pairs.
"""
- stack = []
+ stack: List[ClauseElement] = []
def visit(element):
if isinstance(element, ScalarSelect):
@@ -272,21 +298,22 @@ def visit_binary_product(fn, expr):
yield e
list(visit(expr))
- visit = None # remove gc cycles
+ visit = None # type: ignore # remove gc cycles
def find_tables(
- clause,
- check_columns=False,
- include_aliases=False,
- include_joins=False,
- include_selects=False,
- include_crud=False,
-):
+ clause: ClauseElement,
+ *,
+ check_columns: bool = False,
+ include_aliases: bool = False,
+ include_joins: bool = False,
+ include_selects: bool = False,
+ include_crud: bool = False,
+) -> List[TableClause]:
"""locate Table objects within the given expression."""
- tables = []
- _visitors = {}
+ tables: List[TableClause] = []
+ _visitors: Dict[str, _TraverseCallableType[Any]] = {}
if include_selects:
_visitors["select"] = _visitors["compound_select"] = tables.append
@@ -335,7 +362,7 @@ def unwrap_order_by(clause):
t = stack.popleft()
if isinstance(t, ColumnElement) and (
not isinstance(t, UnaryExpression)
- or not operators.is_ordering_modifier(t.modifier)
+ or not operators.is_ordering_modifier(t.modifier) # type: ignore
):
if isinstance(t, Label) and not isinstance(
t.element, ScalarSelect
@@ -365,9 +392,14 @@ def unwrap_order_by(clause):
def unwrap_label_reference(element):
- def replace(elem):
- if isinstance(elem, (_label_reference, _textual_label_reference)):
- return elem.element
+ def replace(
+ element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
+ if isinstance(element, _label_reference):
+ return element.element
+ elif isinstance(element, _textual_label_reference):
+ assert False, "can't unwrap a textual label reference"
+ return None
return visitors.replacement_traverse(element, {}, replace)
@@ -407,7 +439,7 @@ def clause_is_present(clause, search):
return False
-def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]:
+def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
if isinstance(clause, Join):
for t in tables_from_leftmost(clause.left):
yield t
@@ -509,6 +541,8 @@ class _repr_base:
__slots__ = ("max_chars",)
+ max_chars: int
+
def trunc(self, value: Any) -> str:
rep = repr(value)
lenrep = len(rep)
@@ -612,7 +646,7 @@ class _repr_params(_repr_base):
def _repr_multi(
self,
multi_params: _AnyMultiExecuteParams,
- typ,
+ typ: int,
) -> str:
if multi_params:
if isinstance(multi_params[0], list):
@@ -639,7 +673,7 @@ class _repr_params(_repr_base):
def _repr_params(
self,
- params: Optional[_AnySingleExecuteParams],
+ params: _AnySingleExecuteParams,
typ: int,
) -> str:
trunc = self.trunc
@@ -653,9 +687,10 @@ class _repr_params(_repr_base):
)
)
elif typ is self._TUPLE:
+ seq_params = cast("Sequence[Any]", params)
return "(%s%s)" % (
- ", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else "",
+ ", ".join(trunc(value) for value in seq_params),
+ "," if len(seq_params) == 1 else "",
)
else:
return "[%s]" % (", ".join(trunc(value) for value in params))
@@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
-def splice_joins(left, right, stop_on=None):
+def splice_joins(
+ left: Optional[FromClause],
+ right: Optional[FromClause],
+ stop_on: Optional[FromClause] = None,
+) -> Optional[FromClause]:
if left is None:
return right
- stack = [(right, None)]
+ stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
adapter = ClauseAdapter(left)
ret = None
@@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None):
else:
right = adapter.traverse(right)
if prevright is not None:
+ assert right is not None
prevright.left = right
if ret is None:
ret = right
@@ -845,11 +885,14 @@ def criterion_as_pairs(
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
- pairs = []
+ pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
+_CE = TypeVar("_CE", bound="ClauseElement")
+
+
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
@@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
- selectable,
- equivalents=None,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
self.__traverse_options__ = {
"stop_on": [selectable],
@@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
self.adapt_on_names = adapt_on_names
self.adapt_from_selectables = adapt_from_selectables
+ if TYPE_CHECKING:
+
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ # note this specializes the ReplacingExternalTraversal.traverse()
+ # method to state
+ # that we will return the same kind of ExternalTraversal object as
+ # we were given. This is probably not 100% true, such as it's
+ # possible for us to swap out Alias for Table at the top level.
+ # Ideally there could be overloads specific to ColumnElement and
+ # FromClause but Mypy is not accepting those as compatible with
+ # the base ReplacingExternalTraversal
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
+ ...
+
def _corresponding_column(
self, col, require_embedded, _seen=util.EMPTY_SET
):
@@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
return newcol
@util.preload_module("sqlalchemy.sql.functions")
- def replace(self, col, _include_singleton_constants=False):
+ def replace(
+ self, col: _ET, _include_singleton_constants: bool = False
+ ) -> Optional[_ET]:
functions = util.preloaded.sql_functions
+ # TODO: cython candidate
+
if isinstance(col, FromClause) and not isinstance(
col, functions.FunctionElement
):
@@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
break
else:
return None
- return self.selectable
+ return self.selectable # type: ignore
elif isinstance(col, Alias) and isinstance(
col.element, TableClause
):
@@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# we are an alias of a table and we are not derived from an
# alias of a table (which nonetheless may be the same table
# as ours) so, same thing
- return col
+ return col # type: ignore
else:
# other cases where we are a selectable and the element
# is another join or selectable that contains a table which our
@@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
else:
return None
+ if TYPE_CHECKING:
+ assert isinstance(col, ColumnElement)
+
if self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
else:
- return self._corresponding_column(col, True)
+ return self._corresponding_column(col, True) # type: ignore
+
+
+class _ColumnLookup(Protocol):
+ def __getitem__(
+ self, key: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
+ ...
class ColumnAdapter(ClauseAdapter):
@@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter):
"""
+ columns: _ColumnLookup
+
def __init__(
self,
- selectable,
- equivalents=None,
- adapt_required=False,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ adapt_required: bool = False,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ allow_label_resolve: bool = True,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
ClauseAdapter.__init__(
self,
@@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter):
adapt_from_selectables=adapt_from_selectables,
)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
if self.include_fn or self.exclude_fn:
self.columns = self._IncludeExcludeMapping(self, self.columns)
self.adapt_required = adapt_required
@@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter):
ac = self.__class__.__new__(self.__class__)
ac.__dict__.update(self.__dict__)
ac._wrap = adapter
- ac.columns = util.WeakPopulateDict(ac._locate_col)
+ ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
if ac.include_fn or ac.exclude_fn:
ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
@@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter):
def traverse(self, obj):
return self.columns[obj]
+ def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
+ assert isinstance(visitor, ColumnAdapter)
+
+ return super().chain(visitor)
+
+ if TYPE_CHECKING:
+
+ @property
+ def visitor_iterator(self) -> Iterator[ColumnAdapter]:
+ ...
+
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
@@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter):
return newcol
- def _locate_col(self, col):
+ def _locate_col(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
# both replace and traverse() are overly complicated for what
# we are doing here and we would do better to have an inlined
# version that doesn't build up as much overhead. the issue is that
@@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter):
def __setstate__(self, state):
self.__dict__.update(state)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
-def _offset_or_limit_clause(element, name=None, type_=None):
+def _offset_or_limit_clause(
+ element: Union[int, _ColumnExpressionArgument[int]],
+ name: Optional[str] = None,
+ type_: Optional[_TypeEngineArgument[int]] = None,
+) -> ColumnElement[int]:
"""Convert the given value to an "offset or limit" clause.
This handles incoming integers and converts to an expression; if
@@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None):
)
-def _offset_or_limit_clause_asint_if_possible(clause):
+def _offset_or_limit_clause_asint_if_possible(
+ clause: Optional[Union[int, _ColumnExpressionArgument[int]]]
+) -> Optional[Union[int, _ColumnExpressionArgument[int]]]:
"""Return the offset or limit clause as a simple integer if possible,
else return the clause.
@@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause):
if clause is None:
return None
if hasattr(clause, "_limit_offset_value"):
- value = clause._limit_offset_value
+ value = clause._limit_offset_value # type: ignore
return util.asint(value)
else:
return clause
-def _make_slice(limit_clause, offset_clause, start, stop):
+def _make_slice(
+ limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ start: int,
+ stop: int,
+) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
"""Compute LIMIT/OFFSET in terms of slice start/end"""
# for calculated limit/offset, try to do the addition of
# values to offset in Python, however if a SQL clause is present
# then the addition has to be on the SQL side.
+
+ # TODO: typing is finding a few gaps in here, see if they can be
+ # closed up
+
if start is not None and stop is not None:
offset_clause = _offset_or_limit_clause_asint_if_possible(
offset_clause
@@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
+ assert offset_clause is not None
offset_clause = _offset_or_limit_clause(offset_clause)
limit_clause = _offset_or_limit_clause(stop - start)
@@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
- offset_clause = _offset_or_limit_clause(offset_clause)
+ offset_clause = _offset_or_limit_clause(
+ offset_clause # type: ignore
+ )
- return limit_clause, offset_clause
+ return limit_clause, offset_clause # type: ignore
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 903aae648..081faf1e9 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -28,6 +28,7 @@ from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -37,6 +38,7 @@ from .. import exc
from .. import util
from ..util import langhelpers
from ..util._has_cy import HAS_CYEXTENSION
+from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
@@ -599,8 +601,8 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
raise NotImplementedError()
def _copy_internals(
- self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any
- ) -> Self:
+ self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> None:
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
@@ -615,10 +617,24 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
_ET = TypeVar("_ET", bound=ExternallyTraversible)
+
+
_TraverseCallableType = Callable[[_ET], None]
-_TraverseTransformCallableType = Callable[
- [ExternallyTraversible], Optional[ExternallyTraversible]
-]
+
+
+class _CloneCallableType(Protocol):
+ def __call__(self, element: _ET, **kw: Any) -> _ET:
+ ...
+
+
+class _TraverseTransformCallableType(Protocol):
+ def __call__(
+ self, element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
+ ...
+
+
+_ExtT = TypeVar("_ExtT", bound="ExternalTraversal")
class ExternalTraversal:
@@ -640,7 +656,7 @@ class ExternalTraversal:
return meth(obj, **kw)
def iterate(
- self, obj: ExternallyTraversible
+ self, obj: Optional[ExternallyTraversible]
) -> Iterator[ExternallyTraversible]:
"""Traverse the given expression structure, returning an iterator
of all elements.
@@ -648,7 +664,17 @@ class ExternalTraversal:
"""
return iterate(obj, self.__traverse_options__)
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@@ -671,7 +697,7 @@ class ExternalTraversal:
yield v
v = getattr(v, "_next", None)
- def chain(self, visitor: ExternalTraversal) -> ExternalTraversal:
+ def chain(self: _ExtT, visitor: ExternalTraversal) -> _ExtT:
"""'Chain' an additional ExternalTraversal onto this ExternalTraversal
The chained visitor will receive all visit events after this one.
@@ -701,7 +727,17 @@ class CloningExternalTraversal(ExternalTraversal):
"""
return [self.traverse(x) for x in list_]
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
return cloned_traverse(
@@ -729,14 +765,25 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
"""
return None
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
def replace(
- elem: ExternallyTraversible,
+ element: ExternallyTraversible,
+ **kw: Any,
) -> Optional[ExternallyTraversible]:
for v in self.visitor_iterator:
- e = cast(ReplacingExternalTraversal, v).replace(elem)
+ e = cast(ReplacingExternalTraversal, v).replace(element)
if e is not None:
return e
@@ -754,7 +801,8 @@ ReplacingCloningVisitor = ReplacingExternalTraversal
def iterate(
- obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any] = util.EMPTY_DICT,
) -> Iterator[ExternallyTraversible]:
r"""Traverse the given expression structure, returning an iterator.
@@ -776,6 +824,9 @@ def iterate(
empty in modern usage.
"""
+ if obj is None:
+ return
+
yield obj
children = obj.get_children(**opts)
@@ -790,11 +841,29 @@ def iterate(
stack.append(t.get_children(**opts))
+@overload
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: Literal[None],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def traverse_using(
iterator: Iterable[ExternallyTraversible],
obj: ExternallyTraversible,
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: Optional[ExternallyTraversible],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Visit the given expression structure using the given iterator of
objects.
@@ -826,11 +895,29 @@ def traverse_using(
return obj
+@overload
+def traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure using the default
iterator.
@@ -863,11 +950,29 @@ def traverse(
return traverse_using(iterate(obj, opts), obj, visitors)
+@overload
+def cloned_traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def cloned_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def cloned_traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing modifications by
visitors.
@@ -931,11 +1036,29 @@ def cloned_traverse(
return obj
+@overload
+def replacement_traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> None:
+ ...
+
+
+@overload
def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
replace: _TraverseTransformCallableType,
) -> ExternallyTraversible:
+ ...
+
+
+def replacement_traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index 4496b8ded..f1bf5c0c4 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -306,6 +306,7 @@ def testing_engine(
options=None,
asyncio=False,
transfer_staticpool=False,
+ share_pool=False,
_sqlite_savepoint=False,
):
if asyncio:
@@ -356,6 +357,8 @@ def testing_engine(
if config.db is not None and isinstance(config.db.pool, StaticPool):
use_reaper = False
engine.pool._transfer_from(config.db.pool)
+ elif share_pool:
+ engine.pool = config.db.pool
if scope == "global":
if asyncio:
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 6c6b21fce..8c0120bcc 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -121,6 +121,7 @@ class TestBase:
future=None,
asyncio=False,
transfer_staticpool=False,
+ share_pool=False,
):
if options is None:
options = {}
@@ -130,6 +131,7 @@ class TestBase:
options=options,
asyncio=asyncio,
transfer_staticpool=transfer_staticpool,
+ share_pool=share_pool,
)
yield gen_testing_engine
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 406c8af24..c0c2e7dfb 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -9,7 +9,9 @@
from collections import defaultdict as defaultdict
from functools import partial as partial
from functools import update_wrapper as update_wrapper
+from typing import TYPE_CHECKING
+from . import preloaded as preloaded
from ._collections import coerce_generator_arg as coerce_generator_arg
from ._collections import coerce_to_immutabledict as coerce_to_immutabledict
from ._collections import column_dict as column_dict
@@ -44,8 +46,6 @@ from ._collections import UniqueAppender as UniqueAppender
from ._collections import update_copy as update_copy
from ._collections import WeakPopulateDict as WeakPopulateDict
from ._collections import WeakSequence as WeakSequence
-from ._preloaded import preload_module as preload_module
-from ._preloaded import preloaded as preloaded
from .compat import arm as arm
from .compat import b as b
from .compat import b64decode as b64decode
@@ -148,3 +148,4 @@ from .langhelpers import warn as warn
from .langhelpers import warn_exception as warn_exception
from .langhelpers import warn_limited as warn_limited
from .langhelpers import wrap_callable as wrap_callable
+from .preloaded import preload_module as preload_module
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index bd73bf714..bcb2ad423 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -463,11 +463,12 @@ def update_copy(d, _new=None, **kw):
return d
-def flatten_iterator(x):
+def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
"""Given an iterator of which further sub-elements may also be
iterators, flatten the sub-elements into a single iterator.
"""
+ elem: _T
for elem in x:
if not isinstance(elem, str) and hasattr(elem, "__iter__"):
for y in flatten_iterator(elem):
diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/preloaded.py
index 511b93351..c861c83b3 100644
--- a/lib/sqlalchemy/util/_preloaded.py
+++ b/lib/sqlalchemy/util/preloaded.py
@@ -16,10 +16,16 @@ from types import ModuleType
import typing
from typing import Any
from typing import Callable
+from typing import TYPE_CHECKING
from typing import TypeVar
_FN = TypeVar("_FN", bound=Callable[..., Any])
+if TYPE_CHECKING:
+ from sqlalchemy.engine import default as engine_default
+ from sqlalchemy.sql import dml as sql_dml
+ from sqlalchemy.sql import util as sql_util
+
class _ModuleRegistry:
"""Registry of modules to load in a package init file.
@@ -67,7 +73,7 @@ class _ModuleRegistry:
not path or module.startswith(path)
) and key not in self.__dict__:
__import__(module, globals(), locals())
- self.__dict__[key] = sys.modules[module]
+ self.__dict__[key] = globals()[key] = sys.modules[module]
if typing.TYPE_CHECKING:
@@ -75,5 +81,11 @@ class _ModuleRegistry:
...
-preloaded = _ModuleRegistry()
-preload_module = preloaded.preload_module
+_reg = _ModuleRegistry()
+preload_module = _reg.preload_module
+import_prefix = _reg.import_prefix
+
+if TYPE_CHECKING:
+
+ def __getattr__(key: str) -> ModuleType:
+ ...
diff --git a/pyproject.toml b/pyproject.toml
index cc79e8646..012f1bffa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -68,9 +68,6 @@ module = [
"sqlalchemy.ext.mutable",
"sqlalchemy.ext.horizontal_shard",
- "sqlalchemy.sql._selectable_constructors",
- "sqlalchemy.sql._dml_constructors",
-
# TODO for non-strict:
"sqlalchemy.ext.baked",
"sqlalchemy.ext.instrumentation",
@@ -78,11 +75,6 @@ module = [
"sqlalchemy.ext.orderinglist",
"sqlalchemy.ext.serializer",
- "sqlalchemy.sql.selectable", # would be nice as strict
- "sqlalchemy.sql.functions", # would be nice as strict
- "sqlalchemy.sql.lambdas",
- "sqlalchemy.sql.util",
-
# not yet classified:
"sqlalchemy.orm.*",
"sqlalchemy.dialects.*",
@@ -132,10 +124,14 @@ module = [
"sqlalchemy.sql.crud",
"sqlalchemy.sql.ddl", # would be nice as strict
"sqlalchemy.sql.elements", # would be nice as strict
+ "sqlalchemy.sql.functions", # would be nice as strict, requires sqltypes
+ "sqlalchemy.sql.lambdas",
"sqlalchemy.sql.naming",
+ "sqlalchemy.sql.selectable", # would be nice as strict
"sqlalchemy.sql.schema", # would be nice as strict
"sqlalchemy.sql.sqltypes", # would be nice as strict
"sqlalchemy.sql.traversals",
+ "sqlalchemy.sql.util",
"sqlalchemy.util.*",
]
diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py
index 06be9fbd8..e03a8415d 100644
--- a/test/aaa_profiling/test_orm.py
+++ b/test/aaa_profiling/test_orm.py
@@ -835,27 +835,14 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest):
)
s.commit()
- def test_build_query(self):
+ def test_fetch_results_integrated(self, testing_engine):
A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
- sess = fixture_session()
+ # this test has been reworked to use the compiled cache again,
+ # as a real-world scenario.
- @profiling.function_call_count()
- def go():
- for i in range(100):
- q = sess.query(A).options(
- joinedload(A.bs).joinedload(B.cs).joinedload(C.ds),
- joinedload(A.es).joinedload(E.fs),
- defaultload(A.es).joinedload(E.gs),
- )
- q._compile_context()
-
- go()
-
- def test_fetch_results(self):
- A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
-
- sess = Session(testing.db)
+ eng = testing_engine(share_pool=True)
+ sess = Session(eng)
q = sess.query(A).options(
joinedload(A.bs).joinedload(B.cs).joinedload(C.ds),
@@ -863,47 +850,27 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest):
defaultload(A.es).joinedload(E.gs),
)
- compile_state = q._compile_state()
+ @profiling.function_call_count()
+ def initial_run():
+ list(q.all())
- from sqlalchemy.orm.context import ORMCompileState
+ initial_run()
+ sess.close()
- @profiling.function_call_count(warmup=1)
- def go():
- for i in range(100):
- # NOTE: this test was broken in
- # 77f1b7d236dba6b1c859bb428ef32d118ec372e6 because we started
- # clearing out the attributes after the first iteration. make
- # sure the attributes are there every time.
- assert compile_state.attributes
- exec_opts = {}
- bind_arguments = {}
- ORMCompileState.orm_pre_session_exec(
- sess,
- compile_state.select_statement,
- {},
- exec_opts,
- bind_arguments,
- is_reentrant_invoke=False,
- )
+ @profiling.function_call_count()
+ def subsequent_run():
+ list(q.all())
- r = sess.connection().execute(
- compile_state.statement,
- execution_options=exec_opts,
- )
+ subsequent_run()
+ sess.close()
- r.context.compiled.compile_state = compile_state
- obj = ORMCompileState.orm_setup_cursor_result(
- sess,
- compile_state.statement,
- {},
- exec_opts,
- {},
- r,
- )
- list(obj.unique())
- sess.close()
+ @profiling.function_call_count()
+ def more_runs():
+ for i in range(100):
+ list(q.all())
- go()
+ more_runs()
+ sess.close()
class JoinConditionTest(NoCache, fixtures.DeclarativeMappedTest):
diff --git a/test/base/test_utils.py b/test/base/test_utils.py
index fc61e39b6..e22340da6 100644
--- a/test/base/test_utils.py
+++ b/test/base/test_utils.py
@@ -25,11 +25,11 @@ from sqlalchemy.testing import mock
from sqlalchemy.testing import ne_
from sqlalchemy.testing.util import gc_collect
from sqlalchemy.testing.util import picklers
-from sqlalchemy.util import _preloaded
from sqlalchemy.util import classproperty
from sqlalchemy.util import compat
from sqlalchemy.util import get_callable_argspec
from sqlalchemy.util import langhelpers
+from sqlalchemy.util import preloaded
from sqlalchemy.util import WeakSequence
from sqlalchemy.util._collections import merge_lists_w_ordering
@@ -3187,7 +3187,7 @@ class TestModuleRegistry(fixtures.TestBase):
for m in ("xml.dom", "wsgiref.simple_server"):
to_restore.append((m, sys.modules.pop(m, None)))
try:
- mr = _preloaded._ModuleRegistry()
+ mr = preloaded._ModuleRegistry()
ret = mr.preload_module(
"xml.dom", "wsgiref.simple_server", "sqlalchemy.sql.util"
diff --git a/test/profiles.txt b/test/profiles.txt
index 31f72bd16..7b4f37734 100644
--- a/test/profiles.txt
+++ b/test/profiles.txt
@@ -196,16 +196,16 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96844
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102344
-# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query
-
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520615
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522475
# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results
test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 440705
test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 458805
+# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated
+
+test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 30373,1014,96450
+
# TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity
test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 22984
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index dd073d2a5..8d6dc7553 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -208,6 +208,8 @@ class CoreFixtures:
column("q") == column("x"),
column("q") == column("y"),
column("z") == column("x"),
+ (column("z") == column("x")).self_group(),
+ (column("q") == column("x")).self_group(),
column("z") + column("x"),
column("z") - column("x"),
column("x") - column("z"),
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index d20037e92..6ca06dc0e 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -82,7 +82,6 @@ from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import CompilerColumnElement
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.sql.expression import ClauseList
-from sqlalchemy.sql.expression import HasPrefixes
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlalchemy.testing import assert_raises
@@ -270,18 +269,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
"columns",
)
- def test_prefix_constructor(self):
- class Pref(HasPrefixes):
- def _generate(self):
- return self
-
- assert_raises(
- exc.ArgumentError,
- Pref().prefix_with,
- "some prefix",
- not_a_dialect=True,
- )
-
def test_table_select(self):
self.assert_compile(
table1.select(),
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index 686d4928d..d1d01a5c7 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -1186,6 +1186,37 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
dialect="postgresql",
)
+ def test_recursive_dml_syntax(self):
+ orders = table(
+ "orders",
+ column("region"),
+ column("amount"),
+ column("product"),
+ column("quantity"),
+ )
+
+ upsert = (
+ orders.update()
+ .where(orders.c.region == "Region1")
+ .values(amount=1.0, product="Product1", quantity=1)
+ .returning(*(orders.c._all_columns))
+ .cte("upsert", recursive=True)
+ )
+ stmt = select(upsert)
+
+ # This statement probably makes no sense, just want to see that the
+ # column generation aspect needed by RECURSIVE works (new in 2.0)
+ self.assert_compile(
+ stmt,
+ "WITH RECURSIVE upsert(region, amount, product, quantity) "
+ "AS (UPDATE orders SET amount=:param_1, product=:param_2, "
+ "quantity=:param_3 WHERE orders.region = :region_1 "
+ "RETURNING orders.region, orders.amount, orders.product, "
+ "orders.quantity) "
+ "SELECT upsert.region, upsert.amount, upsert.product, "
+ "upsert.quantity FROM upsert",
+ )
+
def test_upsert_from_select(self):
orders = table(
"orders",
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index ca5f43bb6..9fdc51938 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -432,6 +432,24 @@ class SelectableTest(
):
select(stmt.subquery()).compile()
+ def test_correlate_none_arg_error(self):
+ stmt = select(table1)
+ with expect_raises_message(
+ exc.ArgumentError,
+ "additional FROM objects not accepted when passing "
+ "None/False to correlate",
+ ):
+ stmt.correlate(None, table2)
+
+ def test_correlate_except_none_arg_error(self):
+ stmt = select(table1)
+ with expect_raises_message(
+ exc.ArgumentError,
+ "additional FROM objects not accepted when passing "
+ "None/False to correlate_except",
+ ):
+ stmt.correlate_except(None, table2)
+
def test_select_label_grouped_still_corresponds(self):
label = select(table1.c.col1).label("foo")
label2 = label.self_group()
diff --git a/test/sql/test_text.py b/test/sql/test_text.py
index 81b20f86f..0f645a2d2 100644
--- a/test/sql/test_text.py
+++ b/test/sql/test_text.py
@@ -688,6 +688,19 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
mapping = self._mapping(s)
assert x not in mapping
+ def test_subquery_accessors(self):
+ t = self._xy_table_fixture()
+
+ s = text("SELECT x from t").columns(t.c.x)
+
+ self.assert_compile(
+ select(s.scalar_subquery()), "SELECT (SELECT x from t) AS anon_1"
+ )
+ self.assert_compile(
+ select(s.subquery()),
+ "SELECT anon_1.x FROM (SELECT x from t) AS anon_1",
+ )
+
def test_select_label_alt_name_table_alias_column(self):
t = self._xy_table_fixture()
x = t.c.x
@@ -716,6 +729,36 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
"FROM mytable, t WHERE mytable.myid = t.id",
)
+ def test_cte_recursive(self):
+ t = (
+ text("select id, name from user")
+ .columns(id=Integer, name=String)
+ .cte("t", recursive=True)
+ )
+
+ s = select(table1).where(table1.c.myid == t.c.id)
+ self.assert_compile(
+ s,
+ "WITH RECURSIVE t(id, name) AS (select id, name from user) "
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable, t WHERE mytable.myid = t.id",
+ )
+
+ def test_unions(self):
+ s1 = text("select id, name from user where id > 5").columns(
+ id=Integer, name=String
+ )
+ s2 = text("select id, name from user where id < 15").columns(
+ id=Integer, name=String
+ )
+ stmt = union(s1, s2)
+ eq_(stmt.selected_columns.keys(), ["id", "name"])
+ self.assert_compile(
+ stmt,
+ "select id, name from user where id > 5 UNION "
+ "select id, name from user where id < 15",
+ )
+
def test_subquery(self):
t = (
text("select id, name from user")
diff --git a/test/sql/test_values.py b/test/sql/test_values.py
index f5ae9ea53..d14de9aee 100644
--- a/test/sql/test_values.py
+++ b/test/sql/test_values.py
@@ -294,6 +294,31 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
checkparams={},
)
+ def test_anon_alias(self):
+ people = self.tables.people
+ values = (
+ Values(
+ column("bookcase_id", Integer),
+ column("bookcase_owner_id", Integer),
+ )
+ .data([(1, 1), (2, 1), (3, 2), (3, 3)])
+ .alias()
+ )
+ stmt = select(people, values).select_from(
+ people.join(
+ values, values.c.bookcase_owner_id == people.c.people_id
+ )
+ )
+ self.assert_compile(
+ stmt,
+ "SELECT people.people_id, people.age, people.name, "
+ "anon_1.bookcase_id, anon_1.bookcase_owner_id FROM people "
+ "JOIN (VALUES (:param_1, :param_2), (:param_3, :param_4), "
+ "(:param_5, :param_6), (:param_7, :param_8)) AS anon_1 "
+ "(bookcase_id, bookcase_owner_id) "
+ "ON people.people_id = anon_1.bookcase_owner_id",
+ )
+
def test_with_join_unnamed(self):
people = self.tables.people
values = Values(