summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-30 18:01:58 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-04 09:26:43 -0400
commit3b4d62f4f72e8dfad7f38db192a6a90a8551608c (patch)
treed0334c4bb52f803bd7dad661f2e6a12e25f5880c
parent4e603e23755f31278f27a45449120a8dea470a45 (diff)
downloadsqlalchemy-3b4d62f4f72e8dfad7f38db192a6a90a8551608c.tar.gz
pep484 - sql.selectable
the pep484 task becomes more intense as there is mounting pressure to come up with a consistency in how data moves from end-user to instance variable. current thinking is coming into: 1. there are _typing._XYZArgument objects that represent "what the user sent" 2. there's the roles, which represent a kind of "filter" for different kinds of objects. These are mostly important as the argument we pass to coerce(). 3. there's the thing that coerce() returns, which should be what the construct uses as its internal representation of the thing. This is _typing._XYZElement. but there's some controversy over whether or not we should pass actual ClauseElements around by their role or not. I think we shouldn't at the moment, but this makes the "role-ness" of something a little less portable. Like, we have to set DMLTableRole for TableClause, Join, and Alias, but then also we have to repeat those three types in order to set up _DMLTableElement. Other change introduced here, there was a deannotate=True for the left/right of a sql.join(). All tests pass without that. I'd rather not have that there as if we have a join(A, B) where A, B are mapped classes, we want them inside of the _annotations. The rationale seems to be performance, but this performance can be illustrated to be on the compile side which we hope is cached in the normal case. CTEs now accommodate for text selects including recursive. Get typing to accommodate "util.preloaded" cleanly; add "preloaded" as a real module. This seemed like we would have needed pep562 `__getattr__()` but we don't, just set names in globals() as we import them. References: #6810 Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0
-rw-r--r--lib/sqlalchemy/exc.py9
-rw-r--r--lib/sqlalchemy/orm/query.py7
-rw-r--r--lib/sqlalchemy/sql/_dml_constructors.py11
-rw-r--r--lib/sqlalchemy/sql/_elements_constructors.py4
-rw-r--r--lib/sqlalchemy/sql/_selectable_constructors.py122
-rw-r--r--lib/sqlalchemy/sql/_typing.py135
-rw-r--r--lib/sqlalchemy/sql/annotation.py43
-rw-r--r--lib/sqlalchemy/sql/base.py23
-rw-r--r--lib/sqlalchemy/sql/cache_key.py9
-rw-r--r--lib/sqlalchemy/sql/coercions.py158
-rw-r--r--lib/sqlalchemy/sql/compiler.py29
-rw-r--r--lib/sqlalchemy/sql/dml.py49
-rw-r--r--lib/sqlalchemy/sql/elements.py85
-rw-r--r--lib/sqlalchemy/sql/functions.py246
-rw-r--r--lib/sqlalchemy/sql/lambdas.py194
-rw-r--r--lib/sqlalchemy/sql/roles.py8
-rw-r--r--lib/sqlalchemy/sql/schema.py44
-rw-r--r--lib/sqlalchemy/sql/selectable.py1868
-rw-r--r--lib/sqlalchemy/sql/traversals.py2
-rw-r--r--lib/sqlalchemy/sql/type_api.py5
-rw-r--r--lib/sqlalchemy/sql/util.py231
-rw-r--r--lib/sqlalchemy/sql/visitors.py143
-rw-r--r--lib/sqlalchemy/testing/engines.py3
-rw-r--r--lib/sqlalchemy/testing/fixtures.py2
-rw-r--r--lib/sqlalchemy/util/__init__.py5
-rw-r--r--lib/sqlalchemy/util/_collections.py3
-rw-r--r--lib/sqlalchemy/util/preloaded.py (renamed from lib/sqlalchemy/util/_preloaded.py)18
-rw-r--r--pyproject.toml12
-rw-r--r--test/aaa_profiling/test_orm.py75
-rw-r--r--test/base/test_utils.py4
-rw-r--r--test/profiles.txt8
-rw-r--r--test/sql/test_compare.py2
-rw-r--r--test/sql/test_compiler.py13
-rw-r--r--test/sql/test_cte.py31
-rw-r--r--test/sql/test_selectable.py18
-rw-r--r--test/sql/test_text.py43
-rw-r--r--test/sql/test_values.py25
37 files changed, 2568 insertions, 1119 deletions
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
index 8f4b963eb..9db6f3f52 100644
--- a/lib/sqlalchemy/exc.py
+++ b/lib/sqlalchemy/exc.py
@@ -23,8 +23,8 @@ from typing import Tuple
from typing import Type
from typing import Union
-from .util import _preloaded
from .util import compat
+from .util import preloaded as _preloaded
if typing.TYPE_CHECKING:
from .engine.interfaces import _AnyExecuteParams
@@ -345,6 +345,8 @@ class MultipleResultsFound(InvalidRequestError):
class NoReferenceError(InvalidRequestError):
"""Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
+ table_name: str
+
class AwaitRequired(InvalidRequestError):
"""Error raised by the async greenlet spawn if no async operation
@@ -501,10 +503,7 @@ class StatementError(SQLAlchemyError):
@_preloaded.preload_module("sqlalchemy.sql.util")
def _sql_message(self) -> str:
- if typing.TYPE_CHECKING:
- from .sql import util
- else:
- util = _preloaded.preloaded.sql_util
+ util = _preloaded.sql_util
details = [self._message()]
if self.statement:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 18a14012f..b9ced44d5 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -28,6 +28,8 @@ from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import exc as orm_exc
@@ -78,6 +80,8 @@ from ..sql.selectable import SelectBase
from ..sql.selectable import SelectStatementGrouping
from ..sql.visitors import InternalTraversal
+if TYPE_CHECKING:
+ from ..sql.selectable import _SetupJoinsElement
__all__ = ["Query", "QueryContext"]
@@ -134,7 +138,8 @@ class Query(
_correlate = ()
_auto_correlate = True
_from_obj = ()
- _setup_joins = ()
+ _setup_joins: Tuple[_SetupJoinsElement, ...] = ()
+
_label_style = LABEL_STYLE_LEGACY_ORM
_memoized_select_entities = ()
diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py
index 835819bac..926e5257b 100644
--- a/lib/sqlalchemy/sql/_dml_constructors.py
+++ b/lib/sqlalchemy/sql/_dml_constructors.py
@@ -7,12 +7,17 @@
from __future__ import annotations
+from typing import TYPE_CHECKING
+
from .dml import Delete
from .dml import Insert
from .dml import Update
+if TYPE_CHECKING:
+ from ._typing import _DMLTableArgument
+
-def insert(table):
+def insert(table: _DMLTableArgument) -> Insert:
"""Construct an :class:`_expression.Insert` object.
E.g.::
@@ -82,7 +87,7 @@ def insert(table):
return Insert(table)
-def update(table):
+def update(table: _DMLTableArgument) -> Update:
r"""Construct an :class:`_expression.Update` object.
E.g.::
@@ -122,7 +127,7 @@ def update(table):
return Update(table)
-def delete(table):
+def delete(table: _DMLTableArgument) -> Delete:
r"""Construct :class:`_expression.Delete` object.
E.g.::
diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py
index fc925a8b3..ea21e01c6 100644
--- a/lib/sqlalchemy/sql/_elements_constructors.py
+++ b/lib/sqlalchemy/sql/_elements_constructors.py
@@ -345,8 +345,8 @@ def between(
:meth:`_expression.ColumnElement.between`
"""
- expr = coercions.expect(roles.ExpressionElementRole, expr)
- return expr.between(lower_bound, upper_bound, symmetric=symmetric)
+ col_expr = coercions.expect(roles.ExpressionElementRole, expr)
+ return col_expr.between(lower_bound, upper_bound, symmetric=symmetric)
def outparam(
diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py
index a17ee4ce8..7896c02c2 100644
--- a/lib/sqlalchemy/sql/_selectable_constructors.py
+++ b/lib/sqlalchemy/sql/_selectable_constructors.py
@@ -9,6 +9,8 @@ from __future__ import annotations
from typing import Any
from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from . import coercions
from . import roles
@@ -17,64 +19,65 @@ from .elements import ColumnClause
from .selectable import Alias
from .selectable import CompoundSelect
from .selectable import Exists
+from .selectable import FromClause
from .selectable import Join
from .selectable import Lateral
+from .selectable import LateralFromClause
+from .selectable import NamedFromClause
from .selectable import Select
from .selectable import TableClause
from .selectable import TableSample
from .selectable import Values
+if TYPE_CHECKING:
+ from ._typing import _ColumnsClauseArgument
+ from ._typing import _FromClauseArgument
+ from ._typing import _OnClauseArgument
+ from ._typing import _SelectStatementForCompoundArgument
+ from .functions import Function
+ from .selectable import CTE
+ from .selectable import HasCTE
+ from .selectable import ScalarSelect
+ from .selectable import SelectBase
-def alias(selectable, name=None, flat=False):
- """Return an :class:`_expression.Alias` object.
- An :class:`_expression.Alias` represents any
- :class:`_expression.FromClause`
- with an alternate name assigned within SQL, typically using the ``AS``
- clause when generated, e.g. ``SELECT * FROM table AS aliasname``.
+def alias(
+ selectable: FromClause, name: Optional[str] = None, flat: bool = False
+) -> NamedFromClause:
+ """Return a named alias of the given :class:`.FromClause`.
+
+ For :class:`.Table` and :class:`.Join` objects, the return type is the
+ :class:`_expression.Alias` object. Other kinds of :class:`.NamedFromClause`
+ objects may be returned for other kinds of :class:`.FromClause` objects.
+
+ The named alias represents any :class:`_expression.FromClause` with an
+ alternate name assigned within SQL, typically using the ``AS`` clause when
+ generated, e.g. ``SELECT * FROM table AS aliasname``.
- Similar functionality is available via the
+ Equivalent functionality is available via the
:meth:`_expression.FromClause.alias`
- method available on all :class:`_expression.FromClause` subclasses.
- In terms of
- a SELECT object as generated from the :func:`_expression.select`
- function, the :meth:`_expression.SelectBase.alias` method returns an
- :class:`_expression.Alias` or similar object which represents a named,
- parenthesized subquery.
-
- When an :class:`_expression.Alias` is created from a
- :class:`_schema.Table` object,
- this has the effect of the table being rendered
- as ``tablename AS aliasname`` in a SELECT statement.
-
- For :func:`_expression.select` objects, the effect is that of
- creating a named subquery, i.e. ``(select ...) AS aliasname``.
-
- The ``name`` parameter is optional, and provides the name
- to use in the rendered SQL. If blank, an "anonymous" name
- will be deterministically generated at compile time.
- Deterministic means the name is guaranteed to be unique against
- other constructs used in the same statement, and will also be the
- same name for each successive compilation of the same statement
- object.
+ method available on all :class:`_expression.FromClause` objects.
:param selectable: any :class:`_expression.FromClause` subclass,
such as a table, select statement, etc.
:param name: string name to be assigned as the alias.
- If ``None``, a name will be deterministically generated
- at compile time.
+ If ``None``, a name will be deterministically generated at compile
+ time. Deterministic means the name is guaranteed to be unique against
+ other constructs used in the same statement, and will also be the same
+ name for each successive compilation of the same statement object.
:param flat: Will be passed through to if the given selectable
is an instance of :class:`_expression.Join` - see
- :meth:`_expression.Join.alias`
- for details.
+ :meth:`_expression.Join.alias` for details.
"""
return Alias._factory(selectable, name=name, flat=flat)
-def cte(selectable, name=None, recursive=False):
+def cte(
+ selectable: HasCTE, name: Optional[str] = None, recursive: bool = False
+) -> CTE:
r"""Return a new :class:`_expression.CTE`,
or Common Table Expression instance.
@@ -86,7 +89,7 @@ def cte(selectable, name=None, recursive=False):
)
-def except_(*selects):
+def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return an ``EXCEPT`` of multiple selectables.
The returned object is an instance of
@@ -99,7 +102,9 @@ def except_(*selects):
return CompoundSelect._create_except(*selects)
-def except_all(*selects):
+def except_all(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``EXCEPT ALL`` of multiple selectables.
The returned object is an instance of
@@ -112,7 +117,11 @@ def except_all(*selects):
return CompoundSelect._create_except_all(*selects)
-def exists(__argument=None):
+def exists(
+ __argument: Optional[
+ Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ ] = None,
+) -> Exists:
"""Construct a new :class:`_expression.Exists` construct.
The :func:`_sql.exists` can be invoked by itself to produce an
@@ -153,7 +162,7 @@ def exists(__argument=None):
return Exists(__argument)
-def intersect(*selects):
+def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return an ``INTERSECT`` of multiple selectables.
The returned object is an instance of
@@ -166,7 +175,9 @@ def intersect(*selects):
return CompoundSelect._create_intersect(*selects)
-def intersect_all(*selects):
+def intersect_all(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``INTERSECT ALL`` of multiple selectables.
The returned object is an instance of
@@ -180,7 +191,13 @@ def intersect_all(*selects):
return CompoundSelect._create_intersect_all(*selects)
-def join(left, right, onclause=None, isouter=False, full=False):
+def join(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+) -> Join:
"""Produce a :class:`_expression.Join` object, given two
:class:`_expression.FromClause`
expressions.
@@ -232,7 +249,10 @@ def join(left, right, onclause=None, isouter=False, full=False):
return Join(left, right, onclause, isouter, full)
-def lateral(selectable, name=None):
+def lateral(
+ selectable: Union[SelectBase, _FromClauseArgument],
+ name: Optional[str] = None,
+) -> LateralFromClause:
"""Return a :class:`_expression.Lateral` object.
:class:`_expression.Lateral` is an :class:`_expression.Alias`
@@ -255,7 +275,12 @@ def lateral(selectable, name=None):
return Lateral._factory(selectable, name=name)
-def outerjoin(left, right, onclause=None, full=False):
+def outerjoin(
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ full: bool = False,
+) -> Join:
"""Return an ``OUTER JOIN`` clause element.
The returned object is an instance of :class:`_expression.Join`.
@@ -349,7 +374,12 @@ def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause:
return TableClause(name, *columns, **kw)
-def tablesample(selectable, sampling, name=None, seed=None):
+def tablesample(
+ selectable: _FromClauseArgument,
+ sampling: Union[float, Function[Any]],
+ name: Optional[str] = None,
+ seed: Optional[roles.ExpressionElementRole[Any]] = None,
+) -> TableSample:
"""Return a :class:`_expression.TableSample` object.
:class:`_expression.TableSample` is an :class:`_expression.Alias`
@@ -395,7 +425,7 @@ def tablesample(selectable, sampling, name=None, seed=None):
return TableSample._factory(selectable, sampling, name=name, seed=seed)
-def union(*selects, **kwargs):
+def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return a ``UNION`` of multiple selectables.
The returned object is an instance of
@@ -412,10 +442,10 @@ def union(*selects, **kwargs):
:func:`select`.
"""
- return CompoundSelect._create_union(*selects, **kwargs)
+ return CompoundSelect._create_union(*selects)
-def union_all(*selects):
+def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
r"""Return a ``UNION ALL`` of multiple selectables.
The returned object is an instance of
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index a5da87802..0a72a93c5 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -1,7 +1,7 @@
from __future__ import annotations
+import operator
from typing import Any
-from typing import Iterable
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -24,9 +24,16 @@ if TYPE_CHECKING:
from .roles import FromClauseRole
from .schema import DefaultGenerator
from .schema import Sequence
+ from .selectable import Alias
from .selectable import FromClause
+ from .selectable import Join
from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
+ from .selectable import Select
+ from .selectable import SelectBase
+ from .selectable import Subquery
from .selectable import TableClause
+ from .sqltypes import TableValueType
from .sqltypes import TupleType
from .type_api import TypeEngine
from ..util.typing import TypeGuard
@@ -47,6 +54,14 @@ class _HasClauseElement(Protocol):
# the coercions system is responsible for converting from XYZArgument to
# XYZElement.
+_TextCoercedExpressionArgument = Union[
+ str,
+ "TextClause",
+ "ColumnElement[_T]",
+ _HasClauseElement,
+ roles.ExpressionElementRole[_T],
+]
+
_ColumnsClauseArgument = Union[
Literal["*", 1],
roles.ColumnsClauseRole,
@@ -54,8 +69,31 @@ _ColumnsClauseArgument = Union[
Inspectable[_HasClauseElement],
_HasClauseElement,
]
+"""open-ended SELECT columns clause argument.
+
+Includes column expressions, tables, ORM mapped entities, a few literal values.
+
+This type is used for lists of columns / entities to be returned in result
+sets; select(...), insert().returning(...), etc.
+
+
+"""
+
+_ColumnExpressionArgument = Union[
+ "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
+]
+"""narrower "column expression" argument.
+
+This type is used for all the other "column" kinds of expressions that
+typically represent a single SQL column expression, not a set of columns the
+way a table or ORM entity does.
+
+This includes ColumnElement, or ORM-mapped attributes that will have a
+`__clause_element__()` method, it also has the ExpressionElementRole
+overall which brings in the TextClause object also.
+
+"""
-_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]]
_FromClauseArgument = Union[
roles.FromClauseRole,
@@ -63,28 +101,99 @@ _FromClauseArgument = Union[
Inspectable[_HasClauseElement],
_HasClauseElement,
]
+"""A FROM clause, like we would send to select().select_from().
-_ColumnExpressionArgument = Union[
- "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T]
+Also accommodates ORM entities and related constructs.
+
+"""
+
+_JoinTargetArgument = Union[_FromClauseArgument, roles.JoinTargetRole]
+"""target for join() builds on _FromClauseArgument to include additional
+join target roles such as those which come from the ORM.
+
+"""
+
+_OnClauseArgument = Union[_ColumnExpressionArgument[Any], roles.OnClauseRole]
+"""target for an ON clause, includes additional roles such as those which
+come from the ORM.
+
+"""
+
+_SelectStatementForCompoundArgument = Union[
+ "SelectBase", roles.CompoundElementRole
+]
+"""SELECT statement acceptable by ``union()`` and other SQL set operations"""
+
+_DMLColumnArgument = Union[
+ str, "ColumnClause[Any]", _HasClauseElement, roles.DMLColumnRole
]
+"""A DML column expression. This is a "key" inside of insert().values(),
+update().values(), and related.
+
+These are usually strings or SQL table columns.
+
+There's also edge cases like JSON expression assignment, which we would want
+the DMLColumnRole to be able to accommodate.
-_DMLColumnArgument = Union[str, "ColumnClause[Any]", _HasClauseElement]
+"""
+
+
+_DMLTableArgument = Union[
+ "TableClause",
+ "Join",
+ "Alias",
+ Type[Any],
+ Inspectable[_HasClauseElement],
+ _HasClauseElement,
+]
_PropagateAttrsType = util.immutabledict[str, Any]
_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
+if TYPE_CHECKING:
-def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]:
- return t.named_with_column
+ def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]:
+ ...
+ def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
+ ...
-def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
- return c._is_column_element
+ def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
+ ...
+ def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]:
+ ...
-def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
- return c._is_text_clause
+ def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
+ ...
+
+ def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
+ ...
+
+ def is_select_base(t: ReturnsRows) -> TypeGuard[SelectBase]:
+ ...
+
+ def is_select_statement(t: ReturnsRows) -> TypeGuard[Select]:
+ ...
+
+ def is_table(t: FromClause) -> TypeGuard[TableClause]:
+ ...
+
+ def is_subquery(t: FromClause) -> TypeGuard[Subquery]:
+ ...
+
+else:
+ is_named_from_clause = operator.attrgetter("named_with_column")
+ is_column_element = operator.attrgetter("_is_column_element")
+ is_text_clause = operator.attrgetter("_is_text_clause")
+ is_from_clause = operator.attrgetter("_is_from_clause")
+ is_tuple_type = operator.attrgetter("_is_tuple_type")
+ is_table_value_type = operator.attrgetter("_is_table_value")
+ is_select_base = operator.attrgetter("_is_select_base")
+ is_select_statement = operator.attrgetter("_is_select_statement")
+ is_table = operator.attrgetter("_is_table")
+ is_subquery = operator.attrgetter("_is_subquery")
def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]:
@@ -95,9 +204,5 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
return hasattr(s, "quote")
-def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
- return t._is_tuple_type
-
-
def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
return hasattr(s, "__clause_element__")
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index f1919d1d3..fa36c09fc 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -59,7 +59,9 @@ class SupportsAnnotations(ExternallyTraversible):
_is_immutable: bool
- def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations:
+ def _annotate(
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
raise NotImplementedError()
@overload
@@ -105,11 +107,6 @@ class SupportsAnnotations(ExternallyTraversible):
)
-SelfSupportsCloneAnnotations = TypeVar(
- "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations"
-)
-
-
class SupportsCloneAnnotations(SupportsAnnotations):
if not typing.TYPE_CHECKING:
__slots__ = ()
@@ -119,8 +116,8 @@ class SupportsCloneAnnotations(SupportsAnnotations):
]
def _annotate(
- self: SelfSupportsCloneAnnotations, values: _AnnotationDict
- ) -> SelfSupportsCloneAnnotations:
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -132,8 +129,8 @@ class SupportsCloneAnnotations(SupportsAnnotations):
return new
def _with_annotations(
- self: SelfSupportsCloneAnnotations, values: _AnnotationDict
- ) -> SelfSupportsCloneAnnotations:
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
@@ -184,11 +181,6 @@ class SupportsCloneAnnotations(SupportsAnnotations):
return self
-SelfSupportsWrappingAnnotations = TypeVar(
- "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations"
-)
-
-
class SupportsWrappingAnnotations(SupportsAnnotations):
__slots__ = ()
@@ -200,19 +192,23 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
def entity_namespace(self) -> _EntityNamespace:
...
- def _annotate(self, values: _AnnotationDict) -> Annotated:
+ def _annotate(
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
"""
- return Annotated._as_annotated_instance(self, values)
+ return Annotated._as_annotated_instance(self, values) # type: ignore
- def _with_annotations(self, values: _AnnotationDict) -> Annotated:
+ def _with_annotations(
+ self: SelfSupportsAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsAnnotations:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
"""
- return Annotated._as_annotated_instance(self, values)
+ return Annotated._as_annotated_instance(self, values) # type: ignore
@overload
def _deannotate(
@@ -306,16 +302,17 @@ class Annotated(SupportsAnnotations):
self: SelfAnnotated, values: _AnnotationDict
) -> SelfAnnotated:
_values = self._annotations.union(values)
- return self._with_annotations(_values)
+ new: SelfAnnotated = self._with_annotations(_values) # type: ignore
+ return new
def _with_annotations(
- self: SelfAnnotated, values: util.immutabledict[str, Any]
- ) -> SelfAnnotated:
+ self: SelfAnnotated, values: _AnnotationDict
+ ) -> SupportsAnnotations:
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotations_cache_key", None)
clone.__dict__.pop("_generate_cache_key", None)
- clone._annotations = values
+ clone._annotations = util.immutabledict(values)
return clone
@overload
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 19e4c13d2..6b25d8fcd 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -63,12 +63,14 @@ if TYPE_CHECKING:
from . import elements
from . import type_api
from ._typing import _ColumnsClauseArgument
- from ._typing import _SelectIterable
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
+ from .elements import TextClause
+ from .selectable import _JoinTargetElement
+ from .selectable import _SelectIterable
from .selectable import FromClause
from ..engine import Connection
from ..engine import Result
@@ -167,7 +169,11 @@ class SingletonConstant(Immutable):
cls._singleton = obj
-def _from_objects(*elements: ColumnElement[Any]) -> Iterator[FromClause]:
+def _from_objects(
+ *elements: Union[
+ ColumnElement[Any], FromClause, TextClause, _JoinTargetElement
+ ]
+) -> Iterator[FromClause]:
return itertools.chain.from_iterable(
[element._from_objects for element in elements]
)
@@ -255,6 +261,11 @@ def _expand_cloned(elements):
predecessors.
"""
+ # TODO: cython candidate
+ # and/or change approach: in
+ # https://gerrit.sqlalchemy.org/c/sqlalchemy/sqlalchemy/+/3712 we propose
+ # getting rid of _cloned_set.
+ # turning this into chain.from_iterable adds all kinds of callcount
return itertools.chain(*[x._cloned_set for x in elements])
@@ -1559,6 +1570,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
was moved onto the :class:`_expression.ColumnCollection` itself.
"""
+ # TODO: cython candidate
+
+ # don't dig around if the column is locally present
+ if column in self._colset:
+ return column
def embedded(expanded_proxy_set, target_set):
for t in target_set.difference(expanded_proxy_set):
@@ -1568,9 +1584,6 @@ class ColumnCollection(Generic[_COLKEY, _COL]):
return False
return True
- # don't dig around if the column is locally present
- if column in self._colset:
- return column
col, intersect = None, None
target_set = column.proxy_set
cols = [c for (k, c) in self._collection]
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index 19a232c56..1f8b9c19e 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -54,6 +54,11 @@ class CacheConst(enum.Enum):
NO_CACHE = CacheConst.NO_CACHE
+_CacheKeyTraversalType = Union[
+ "_TraverseInternalsType", Literal[CacheConst.NO_CACHE], Literal[None]
+]
+
+
class CacheTraverseTarget(enum.Enum):
CACHE_IN_PLACE = 0
CALL_GEN_CACHE_KEY = 1
@@ -89,9 +94,7 @@ class HasCacheKey:
__slots__ = ()
- _cache_key_traversal: Union[
- _TraverseInternalsType, Literal[CacheConst.NO_CACHE], Literal[None]
- ] = NO_CACHE
+ _cache_key_traversal: _CacheKeyTraversalType = NO_CACHE
_is_has_cache_key = True
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index ccc8fba8d..4c71ca38b 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -12,7 +12,6 @@ import numbers
import re
import typing
from typing import Any
-from typing import Any as TODO_Any
from typing import Callable
from typing import Dict
from typing import List
@@ -20,7 +19,9 @@ from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import operators
from . import roles
@@ -32,6 +33,7 @@ from .visitors import Visitable
from .. import exc
from .. import inspection
from .. import util
+from ..util.typing import Literal
if not typing.TYPE_CHECKING:
elements = None
@@ -46,12 +48,26 @@ if typing.TYPE_CHECKING:
from . import schema
from . import selectable
from . import traversals
+ from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnsClauseArgument
+ from ._typing import _DMLTableArgument
+ from ._typing import _FromClauseArgument
+ from .dml import _DMLTableElement
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import DQLDMLClauseElement
from .elements import SQLCoreOperations
-
+ from .schema import Column
+ from .selectable import _ColumnsClauseElement
+ from .selectable import _JoinTargetElement
+ from .selectable import _JoinTargetProtocol
+ from .selectable import _OnClauseElement
+ from .selectable import FromClause
+ from .selectable import HasCTE
+ from .selectable import SelectBase
+ from .selectable import Subquery
+ from .visitors import _TraverseCallableType
_SR = TypeVar("_SR", bound=roles.SQLRole)
_F = TypeVar("_F", bound=Callable[..., Any])
@@ -143,10 +159,6 @@ def _expression_collection_was_a_list(attrname, fnname, args):
def expect(
role: Type[roles.TruncatedLabelRole],
element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
**kw: Any,
) -> str:
...
@@ -154,12 +166,30 @@ def expect(
@overload
def expect(
- role: Type[roles.ExpressionElementRole[_T]],
+ role: Type[roles.StatementOptionRole],
element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
+ **kw: Any,
+) -> DQLDMLClauseElement:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.DDLReferredColumnRole],
+ element: Any,
+ **kw: Any,
+) -> Column[Any]:
+ ...
+
+
+@overload
+def expect(
+ role: Union[
+ Type[roles.ExpressionElementRole[Any]],
+ Type[roles.LimitOffsetRole],
+ Type[roles.WhereHavingRole],
+ ],
+ element: _ColumnExpressionArgument[_T],
**kw: Any,
) -> ColumnElement[_T]:
...
@@ -167,40 +197,89 @@ def expect(
@overload
def expect(
- role: Type[roles.DMLTableRole],
+ role: Union[
+ Type[roles.ExpressionElementRole[Any]],
+ Type[roles.LimitOffsetRole],
+ Type[roles.WhereHavingRole],
+ ],
element: Any,
+ **kw: Any,
+) -> ColumnElement[Any]:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.DMLTableRole],
+ element: _DMLTableArgument,
+ **kw: Any,
+) -> _DMLTableElement:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.HasCTERole],
+ element: HasCTE,
+ **kw: Any,
+) -> HasCTE:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.SelectStatementRole],
+ element: SelectBase,
+ **kw: Any,
+) -> SelectBase:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.FromClauseRole],
+ element: _FromClauseArgument,
+ **kw: Any,
+) -> FromClause:
+ ...
+
+
+@overload
+def expect(
+ role: Type[roles.FromClauseRole],
+ element: SelectBase,
*,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
+ explicit_subquery: Literal[True] = ...,
**kw: Any,
-) -> roles.DMLTableRole:
+) -> Subquery:
...
@overload
def expect(
role: Type[roles.ColumnsClauseRole],
- element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
+ element: _ColumnsClauseArgument,
+ **kw: Any,
+) -> _ColumnsClauseElement:
+ ...
+
+
+@overload
+def expect(
+ role: Union[Type[roles.JoinTargetRole], Type[roles.OnClauseRole]],
+ element: _JoinTargetProtocol,
**kw: Any,
-) -> roles.ColumnsClauseRole:
+) -> _JoinTargetProtocol:
...
+# catchall for not-yet-implemented overloads
@overload
def expect(
role: Type[_SR],
element: Any,
- *,
- apply_propagate_attrs: Optional[ClauseElement] = None,
- argname: Optional[str] = None,
- post_inspect: bool = False,
**kw: Any,
-) -> TODO_Any:
+) -> Any:
...
@@ -212,7 +291,7 @@ def expect(
argname: Optional[str] = None,
post_inspect: bool = False,
**kw: Any,
-) -> TODO_Any:
+) -> Any:
if (
role.allows_lambda
# note callable() will not invoke a __getattr__() method, whereas
@@ -329,7 +408,8 @@ def expect_col_expression_collection(role, expressions):
strname = resolved = expr
else:
cols: List[ColumnClause[Any]] = []
- visitors.traverse(resolved, {}, {"column": cols.append})
+ col_append: _TraverseCallableType[ColumnClause[Any]] = cols.append
+ visitors.traverse(resolved, {}, {"column": col_append})
if cols:
column = cols[0]
add_element = column if column is not None else strname
@@ -432,7 +512,7 @@ class _ColumnCoercions(RoleImpl):
original_element = element
if not getattr(resolved, "is_clause_element", False):
self._raise_for_expected(original_element, argname, resolved)
- elif resolved._is_select_statement:
+ elif resolved._is_select_base:
self._warn_for_scalar_subquery_coercion()
return resolved.scalar_subquery()
elif resolved._is_from_clause and isinstance(
@@ -670,7 +750,7 @@ class InElementImpl(RoleImpl):
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
- and resolved.element._is_select_statement
+ and resolved.element._is_select_base
):
self._warn_for_implicit_coercion(resolved)
return self._post_coercion(resolved.element, **kw)
@@ -722,7 +802,7 @@ class InElementImpl(RoleImpl):
self._raise_for_expected(element, **kw)
def _post_coercion(self, element, expr, operator, **kw):
- if element._is_select_statement:
+ if element._is_select_base:
# for IN, we are doing scalar_subquery() coercion without
# a warning
return element.scalar_subquery()
@@ -1085,7 +1165,7 @@ class JoinTargetImpl(RoleImpl):
# #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
# were set to False.
return element
- elif legacy and resolved._is_select_statement:
+ elif legacy and resolved._is_select_base:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT "
"constructs into FROM clauses is deprecated; please call "
@@ -1114,7 +1194,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
allow_select: bool = True,
**kw: Any,
) -> Any:
- if resolved._is_select_statement:
+ if resolved._is_select_base:
if explicit_subquery:
return resolved.subquery()
elif allow_select:
@@ -1150,7 +1230,7 @@ class StrictFromClauseImpl(FromClauseImpl):
allow_select: bool = False,
**kw: Any,
) -> Any:
- if resolved._is_select_statement and allow_select:
+ if resolved._is_select_base and allow_select:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT constructs "
"into FROM clauses is deprecated; please call .subquery() "
@@ -1195,7 +1275,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
- and resolved.element._is_select_statement
+ and resolved.element._is_select_base
):
return resolved.element
else:
@@ -1235,3 +1315,9 @@ for name in dir(roles):
if name in globals():
impl = globals()[name](cls)
_impl_lookup[cls] = impl
+
+if not TYPE_CHECKING:
+ ee_impl = _impl_lookup[roles.ExpressionElementRole]
+
+ for py_type in (int, bool, str, float):
+ _impl_lookup[roles.ExpressionElementRole[py_type]] = ee_impl
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b7f6d11f6..6ecfbf986 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -313,12 +313,12 @@ EXTRACT_MAP = {
}
COMPOUND_KEYWORDS = {
- selectable.CompoundSelect.UNION: "UNION",
- selectable.CompoundSelect.UNION_ALL: "UNION ALL",
- selectable.CompoundSelect.EXCEPT: "EXCEPT",
- selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
- selectable.CompoundSelect.INTERSECT: "INTERSECT",
- selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
+ selectable._CompoundSelectKeyword.UNION: "UNION",
+ selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL",
+ selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT",
+ selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL",
+ selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT",
+ selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL",
}
@@ -1468,6 +1468,10 @@ class SQLCompiler(Compiled):
self.post_compile_params = frozenset()
for key in expanded_state.parameter_expansion:
bind = self.binds.pop(key)
+
+ if TYPE_CHECKING:
+ assert bind.value is not None
+
self.bind_names.pop(bind)
for value, expanded_key in zip(
bind.value, expanded_state.parameter_expansion[key]
@@ -3089,12 +3093,7 @@ class SQLCompiler(Compiled):
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
if cte.recursive:
- if isinstance(cte.element, selectable.Select):
- col_source = cte.element
- elif isinstance(cte.element, selectable.CompoundSelect):
- col_source = cte.element.selects[0]
- else:
- assert False, "cte should only be against SelectBase"
+ col_source = cte.element
# TODO: can we get at the .columns_plus_names collection
# that is already (or will be?) generated for the SELECT
@@ -3315,7 +3314,9 @@ class SQLCompiler(Compiled):
for elem in chunk
)
- if isinstance(element.name, elements._truncated_label):
+ if element._unnamed:
+ name = None
+ elif isinstance(element.name, elements._truncated_label):
name = self._truncated_identifier("values", element.name)
else:
name = element.name
@@ -3980,7 +3981,7 @@ class SQLCompiler(Compiled):
clause = " ".join(
prefix._compiler_dispatch(self, **kw)
for prefix, dialect_name in prefixes
- if dialect_name is None or dialect_name == self.dialect.name
+ if dialect_name in (None, "*") or dialect_name == self.dialect.name
)
if clause:
clause += " "
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 0c9056aee..8a3a1b38f 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -45,11 +45,14 @@ from .base import Executable
from .base import HasCompileState
from .elements import BooleanClauseList
from .elements import ClauseElement
+from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Null
+from .selectable import Alias
from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
+from .selectable import Join
from .selectable import ReturnsRows
from .selectable import TableClause
from .sqltypes import NullType
@@ -59,15 +62,15 @@ from .. import util
from ..util.typing import TypeGuard
if TYPE_CHECKING:
-
+ from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnsClauseArgument
from ._typing import _DMLColumnArgument
+ from ._typing import _DMLTableArgument
from ._typing import _FromClauseArgument
- from ._typing import _HasClauseElement
- from ._typing import _SelectIterable
from .base import ReadOnlyColumnCollection
from .compiler import SQLCompiler
- from .elements import ColumnClause
+ from .selectable import _ColumnsClauseElement
+ from .selectable import _SelectIterable
from .selectable import Select
def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]:
@@ -85,7 +88,8 @@ else:
isinsert = operator.attrgetter("isinsert")
-_DMLColumnElement = Union[str, "ColumnClause[Any]"]
+_DMLColumnElement = Union[str, ColumnClause[Any]]
+_DMLTableElement = Union[TableClause, Alias, Join]
class DMLState(CompileState):
@@ -132,7 +136,7 @@ class DMLState(CompileState):
]
@property
- def dml_table(self) -> roles.DMLTableRole:
+ def dml_table(self) -> _DMLTableElement:
return self.statement.table
if TYPE_CHECKING:
@@ -322,17 +326,17 @@ class UpdateBase(
__visit_name__ = "update_base"
_hints: util.immutabledict[
- Tuple[roles.DMLTableRole, str], str
+ Tuple[_DMLTableElement, str], str
] = util.EMPTY_DICT
named_with_column = False
- table: roles.DMLTableRole
+ table: _DMLTableElement
_return_defaults = False
_return_defaults_columns: Optional[
- Tuple[roles.ColumnsClauseRole, ...]
+ Tuple[_ColumnsClauseElement, ...]
] = None
- _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
+ _returning: Tuple[_ColumnsClauseElement, ...] = ()
is_dml = True
@@ -483,7 +487,7 @@ class UpdateBase(
def with_hint(
self: SelfUpdateBase,
text: str,
- selectable: Optional[roles.DMLTableRole] = None,
+ selectable: Optional[_DMLTableArgument] = None,
dialect_name: str = "*",
) -> SelfUpdateBase:
"""Add a table hint for a single table to this
@@ -517,7 +521,8 @@ class UpdateBase(
"""
if selectable is None:
selectable = self.table
-
+ else:
+ selectable = coercions.expect(roles.DMLTableRole, selectable)
self._hints = self._hints.union({(selectable, dialect_name): text})
return self
@@ -636,9 +641,9 @@ class ValuesBase(UpdateBase):
_select_names: Optional[List[str]] = None
_inline: bool = False
- _returning: Tuple[roles.ColumnsClauseRole, ...] = ()
+ _returning: Tuple[_ColumnsClauseElement, ...] = ()
- def __init__(self, table: _FromClauseArgument):
+ def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
@@ -970,7 +975,7 @@ class Insert(ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table: roles.FromClauseRole):
+ def __init__(self, table: _DMLTableArgument):
super(Insert, self).__init__(table)
@_generative
@@ -1066,12 +1071,12 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
class DMLWhereBase:
- table: roles.DMLTableRole
+ table: _DMLTableElement
_where_criteria: Tuple[ColumnElement[Any], ...] = ()
@_generative
def where(
- self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any]
+ self: SelfDMLWhereBase, *whereclause: _ColumnExpressionArgument[bool]
) -> SelfDMLWhereBase:
"""Return a new construct with the given expression(s) added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -1104,7 +1109,9 @@ class DMLWhereBase:
"""
for criterion in whereclause:
- where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ where_criteria: ColumnElement[Any] = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
self._where_criteria += (where_criteria,)
return self
@@ -1119,7 +1126,7 @@ class DMLWhereBase:
return self.where(*criteria)
- def _filter_by_zero(self) -> roles.DMLTableRole:
+ def _filter_by_zero(self) -> _DMLTableElement:
return self.table
def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase:
@@ -1189,7 +1196,7 @@ class Update(DMLWhereBase, ValuesBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table: roles.FromClauseRole):
+ def __init__(self, table: _DMLTableArgument):
super(Update, self).__init__(table)
@_generative
@@ -1279,7 +1286,7 @@ class Delete(DMLWhereBase, UpdateBase):
+ HasCTE._has_ctes_traverse_internals
)
- def __init__(self, table: roles.FromClauseRole):
+ def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index c735085f8..aec29d1b2 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -26,6 +26,7 @@ from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
@@ -77,8 +78,8 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _PropagateAttrsType
- from ._typing import _SelectIterable
from ._typing import _TypeEngineArgument
+ from .cache_key import _CacheKeyTraversalType
from .cache_key import CacheKey
from .compiler import Compiled
from .compiler import SQLCompiler
@@ -88,6 +89,7 @@ if typing.TYPE_CHECKING:
from .schema import DefaultGenerator
from .schema import FetchedValue
from .schema import ForeignKey
+ from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import NamedFromClause
from .selectable import ReturnsRows
@@ -96,6 +98,7 @@ if typing.TYPE_CHECKING:
from .sqltypes import Boolean
from .sqltypes import TupleType
from .type_api import TypeEngine
+ from .visitors import _CloneCallableType
from .visitors import _TraverseInternalsType
from ..engine import Connection
from ..engine import Dialect
@@ -310,6 +313,7 @@ class ClauseElement(
_is_text_clause = False
_is_from_container = False
_is_select_container = False
+ _is_select_base = False
_is_select_statement = False
_is_bind_parameter = False
_is_clause_list = False
@@ -321,7 +325,7 @@ class ClauseElement(
def _order_by_label_element(self) -> Optional[Label[Any]]:
return None
- _cache_key_traversal = None
+ _cache_key_traversal: _CacheKeyTraversalType = None
negation_clause: ColumnElement[bool]
@@ -528,7 +532,7 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def self_group(self, against=None):
+ def self_group(self, against: Optional[OperatorType] = None) -> Any:
"""Apply a 'grouping' to this :class:`_expression.ClauseElement`.
This method is overridden by subclasses to return a "grouping"
@@ -637,9 +641,9 @@ class ClauseElement(
return self._negate()
def _negate(self) -> ClauseElement:
- return UnaryExpression(
- self.self_group(against=operators.inv), operator=operators.inv
- )
+ grouped = self.self_group(against=operators.inv)
+ assert isinstance(grouped, ColumnElement)
+ return UnaryExpression(grouped, operator=operators.inv)
def __bool__(self):
raise TypeError("Boolean value of this clause is not defined")
@@ -1290,12 +1294,6 @@ class ColumnElement(
@overload
def self_group(
- self: ColumnElement[bool], against: Optional[OperatorType] = None
- ) -> ColumnElement[bool]:
- ...
-
- @overload
- def self_group(
self: ColumnElement[Any], against: Optional[OperatorType] = None
) -> ColumnElement[Any]:
...
@@ -1764,6 +1762,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
key: str
type: TypeEngine[_T]
+ value: Optional[_T]
_is_crud = False
_is_bind_parameter = True
@@ -1883,7 +1882,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
return cloned
@property
- def effective_value(self):
+ def effective_value(self) -> Optional[_T]:
"""Return the value of this bound parameter,
taking into account if the ``callable`` parameter
was set.
@@ -1893,11 +1892,12 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
"""
if self.callable:
- return self.callable()
+ # TODO: set up protocol for bind parameter callable
+ return self.callable() # type: ignore
else:
return self.value
- def render_literal_execute(self):
+ def render_literal_execute(self) -> BindParameter[_T]:
"""Produce a copy of this bound parameter that will enable the
:paramref:`_sql.BindParameter.literal_execute` flag.
@@ -2513,8 +2513,10 @@ class ClauseList(
self.operator = operator
self.group = group
self.group_contents = group_contents
+ clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses
if _flatten_sub_clauses:
- clauses = util.flatten_iterator(clauses)
+ clauses_iterator = util.flatten_iterator(clauses_iterator)
+
self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
@@ -2523,31 +2525,35 @@ class ClauseList(
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
).self_group(against=self.operator)
- for clause in clauses
+ for clause in clauses_iterator
]
else:
self.clauses = [
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
)
- for clause in clauses
+ for clause in clauses_iterator
]
self._is_implicitly_boolean = operators.is_boolean(self.operator)
@classmethod
- def _construct_raw(cls, operator, clauses=None):
+ def _construct_raw(
+ cls,
+ operator: OperatorType,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
+ ) -> ClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
self._is_implicitly_boolean = False
return self
- def __iter__(self):
+ def __iter__(self) -> Iterator[ColumnElement[Any]]:
return iter(self.clauses)
- def __len__(self):
+ def __len__(self) -> int:
return len(self.clauses)
@property
@@ -2708,10 +2714,10 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
def _construct_raw(
cls,
operator: OperatorType,
- clauses: Optional[List[ColumnElement[Any]]] = None,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
) -> BooleanClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
@@ -2781,7 +2787,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
sqltypes = util.preloaded.sql_sqltypes
if types is None:
- init_clauses = [
+ init_clauses: List[ColumnElement[Any]] = [
coercions.expect(roles.ExpressionElementRole, c)
for c in clauses
]
@@ -2908,7 +2914,7 @@ class Case(ColumnElement[_T]):
]
if whenlist:
- type_ = list(whenlist[-1])[-1].type
+ type_ = whenlist[-1][-1].type
else:
type_ = None
@@ -3098,6 +3104,8 @@ class _label_reference(ColumnElement[_T]):
("element", InternalTraversal.dp_clauseelement)
]
+ element: ColumnElement[_T]
+
def __init__(self, element: ColumnElement[_T]):
self.element = element
@@ -3212,7 +3220,9 @@ class UnaryExpression(ColumnElement[_T]):
cls,
expr: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
- col_expr = coercions.expect(roles.ExpressionElementRole, expr)
+ col_expr: ColumnElement[_T] = coercions.expect(
+ roles.ExpressionElementRole, expr
+ )
return UnaryExpression(
col_expr,
operator=operators.distinct_op,
@@ -3265,7 +3275,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_any(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3281,7 +3291,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_all(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3374,6 +3384,9 @@ class BinaryExpression(ColumnElement[_T]):
modifiers: Optional[Mapping[str, Any]]
+ left: ColumnElement[Any]
+ right: Union[ColumnElement[Any], ClauseList]
+
def __init__(
self,
left: ColumnElement[Any],
@@ -4147,7 +4160,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
def foreign_keys(self):
return self.element.foreign_keys
- def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ def _copy_internals(
+ self,
+ *,
+ clone: _CloneCallableType = _clone,
+ anonymize_labels: bool = False,
+ **kw: Any,
+ ) -> None:
self._reset_memoizations()
self._element = clone(self._element, **kw)
if anonymize_labels:
@@ -4447,7 +4466,9 @@ class TableValuedColumn(NamedColumn[_T]):
self.key = self.name = scalar_alias.name
self.type = type_
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
self.scalar_alias = clone(self.scalar_alias, **kw)
self.key = self.name = self.scalar_alias.name
@@ -4467,7 +4488,7 @@ class CollationClause(ColumnElement[str]):
def _create_collation_expression(
cls, expression: _ColumnExpressionArgument[str], collation: str
) -> BinaryExpression[str]:
- expr = coercions.expect(roles.ExpressionElementRole, expression)
+ expr = coercions.expect(roles.ExpressionElementRole[str], expression)
return BinaryExpression(
expr,
CollationClause(collation),
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 3bca8b502..db4bb5837 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -11,10 +11,15 @@
from __future__ import annotations
+import datetime
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Mapping
from typing import Optional
from typing import overload
-from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -24,7 +29,9 @@ from . import operators
from . import roles
from . import schema
from . import sqltypes
+from . import type_api
from . import util as sqlutil
+from ._typing import is_table_value_type
from .base import _entity_namespace
from .base import ColumnCollection
from .base import Executable
@@ -46,16 +53,21 @@ from .elements import WithinGroup
from .selectable import FromClause
from .selectable import Select
from .selectable import TableValuedAlias
+from .sqltypes import _N
+from .sqltypes import TableValueType
from .type_api import TypeEngine
from .visitors import InternalTraversal
from .. import util
+
if TYPE_CHECKING:
from ._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
-_registry = util.defaultdict(dict)
+_registry: util.defaultdict[
+ str, Dict[str, Type[Function[Any]]]
+] = util.defaultdict(dict)
def register_function(identifier, fn, package="_default"):
@@ -103,11 +115,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
("_table_value_type", InternalTraversal.dp_has_cache_key),
]
- packagenames = ()
+ packagenames: Tuple[str, ...] = ()
_has_args = False
_with_ordinality = False
- _table_value_type = None
+ _table_value_type: Optional[TableValueType] = None
+
+ # some attributes that are defined between both ColumnElement and
+ # FromClause are set to Any here to avoid typing errors
+ primary_key: Any
+ _is_clone_of: Any
+
+ clause_expr: Grouping[Any]
def __init__(self, *clauses: Any):
r"""Construct a :class:`.FunctionElement`.
@@ -135,9 +154,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
for c in clauses
]
self._has_args = self._has_args or bool(args)
- self.clause_expr = ClauseList(
- operator=operators.comma_op, group_contents=True, *args
- ).self_group()
+ self.clause_expr = Grouping(
+ ClauseList(operator=operators.comma_op, group_contents=True, *args)
+ )
_non_anon_label = None
@@ -263,9 +282,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
expr += (with_ordinality,)
new_func._with_ordinality = True
- new_func.type = new_func._table_value_type = sqltypes.TableValueType(
- *expr
- )
+ new_func.type = new_func._table_value_type = TableValueType(*expr)
return new_func.alias(name=name, joins_implicitly=joins_implicitly)
@@ -332,7 +349,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
@property
def _all_selected_columns(self):
- if self.type._is_table_value:
+ if is_table_value_type(self.type):
cols = self.type._elements
else:
cols = [self.label(None)]
@@ -344,12 +361,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
return self.columns
@HasMemoized.memoized_attribute
- def clauses(self):
+ def clauses(self) -> ClauseList:
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
"""
- return self.clause_expr.element
+ return cast(ClauseList, self.clause_expr.element)
def over(self, partition_by=None, order_by=None, rows=None, range_=None):
"""Produce an OVER clause against this function.
@@ -647,7 +664,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
return _entity_namespace(self.clause_expr)
-class FunctionAsBinary(BinaryExpression):
+class FunctionAsBinary(BinaryExpression[Any]):
_traverse_internals = [
("sql_function", InternalTraversal.dp_clauseelement),
("left_index", InternalTraversal.dp_plain_obj),
@@ -655,10 +672,16 @@ class FunctionAsBinary(BinaryExpression):
("modifiers", InternalTraversal.dp_plain_dict),
]
+ sql_function: FunctionElement[Any]
+ left_index: int
+ right_index: int
+
def _gen_cache_key(self, anon_map, bindparams):
return ColumnElement._gen_cache_key(self, anon_map, bindparams)
- def __init__(self, fn, left_index, right_index):
+ def __init__(
+ self, fn: FunctionElement[Any], left_index: int, right_index: int
+ ):
self.sql_function = fn
self.left_index = left_index
self.right_index = right_index
@@ -670,23 +693,30 @@ class FunctionAsBinary(BinaryExpression):
self.modifiers = {}
@property
- def left(self):
+ def left_expr(self) -> ColumnElement[Any]:
return self.sql_function.clauses.clauses[self.left_index - 1]
- @left.setter
- def left(self, value):
+ @left_expr.setter
+ def left_expr(self, value: ColumnElement[Any]) -> None:
self.sql_function.clauses.clauses[self.left_index - 1] = value
@property
- def right(self):
+ def right_expr(self) -> ColumnElement[Any]:
return self.sql_function.clauses.clauses[self.right_index - 1]
- @right.setter
- def right(self, value):
+ @right_expr.setter
+ def right_expr(self, value: ColumnElement[Any]) -> None:
self.sql_function.clauses.clauses[self.right_index - 1] = value
+ if not TYPE_CHECKING:
+ # mypy can't accommodate @property to replace an instance
+ # variable
+
+ left = left_expr
+ right = right_expr
+
-class ScalarFunctionColumn(NamedColumn):
+class ScalarFunctionColumn(NamedColumn[_T]):
__visit_name__ = "scalar_function_column"
_traverse_internals = [
@@ -698,10 +728,18 @@ class ScalarFunctionColumn(NamedColumn):
is_literal = False
table = None
- def __init__(self, fn, name, type_=None):
+ def __init__(
+ self,
+ fn: FunctionElement[_T],
+ name: str,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ ):
self.fn = fn
self.name = name
- self.type = sqltypes.to_instance(type_)
+
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
class _FunctionGenerator:
@@ -789,7 +827,7 @@ class _FunctionGenerator:
# passthru __ attributes; fixes pydoc
if name.startswith("__"):
try:
- return self.__dict__[name]
+ return self.__dict__[name] # type: ignore
except KeyError:
raise AttributeError(name)
@@ -883,8 +921,6 @@ class Function(FunctionElement[_T]):
identifier: str
- packagenames: Sequence[str]
-
type: TypeEngine[_T]
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
@@ -907,7 +943,7 @@ class Function(FunctionElement[_T]):
name: str,
*clauses: Any,
type_: Optional[_TypeEngineArgument[_T]] = None,
- packagenames: Optional[Sequence[str]] = None,
+ packagenames: Optional[Tuple[str, ...]] = None,
):
"""Construct a :class:`.Function`.
@@ -918,7 +954,9 @@ class Function(FunctionElement[_T]):
self.packagenames = packagenames or ()
self.name = name
- self.type = sqltypes.to_instance(type_)
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
FunctionElement.__init__(self, *clauses)
@@ -934,7 +972,7 @@ class Function(FunctionElement[_T]):
)
-class GenericFunction(Function):
+class GenericFunction(Function[_T]):
"""Define a 'generic' function.
A generic function is a pre-established :class:`.Function`
@@ -957,7 +995,7 @@ class GenericFunction(Function):
from sqlalchemy.types import DateTime
class as_utc(GenericFunction):
- type = DateTime
+ type = DateTime()
inherit_cache = True
print(select(func.as_utc()))
@@ -971,7 +1009,7 @@ class GenericFunction(Function):
"time"::
class as_utc(GenericFunction):
- type = DateTime
+ type = DateTime()
package = "time"
inherit_cache = True
@@ -987,7 +1025,7 @@ class GenericFunction(Function):
the usage of ``name`` as the rendered name::
class GeoBuffer(GenericFunction):
- type = Geometry
+ type = Geometry()
package = "geo"
name = "ST_Buffer"
identifier = "buffer"
@@ -1006,7 +1044,7 @@ class GenericFunction(Function):
from sqlalchemy.sql import quoted_name
class GeoBuffer(GenericFunction):
- type = Geometry
+ type = Geometry()
package = "geo"
name = quoted_name("ST_Buffer", True)
identifier = "buffer"
@@ -1028,6 +1066,8 @@ class GenericFunction(Function):
coerce_arguments = True
inherit_cache = True
+ _register: bool
+
name = "GenericFunction"
def __init_subclass__(cls) -> None:
@@ -1036,7 +1076,9 @@ class GenericFunction(Function):
super().__init_subclass__()
@classmethod
- def _register_generic_function(cls, clsname, clsdict):
+ def _register_generic_function(
+ cls, clsname: str, clsdict: Mapping[str, Any]
+ ) -> None:
cls.name = name = clsdict.get("name", clsname)
cls.identifier = identifier = clsdict.get("identifier", name)
package = clsdict.get("package", "_default")
@@ -1068,11 +1110,14 @@ class GenericFunction(Function):
]
self._has_args = self._has_args or bool(parsed_args)
self.packagenames = ()
- self.clause_expr = ClauseList(
- operator=operators.comma_op, group_contents=True, *parsed_args
- ).self_group()
- self.type = sqltypes.to_instance(
+ self.clause_expr = Grouping(
+ ClauseList(
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ )
+ )
+
+ self.type = type_api.to_instance( # type: ignore
kwargs.pop("type_", None) or getattr(self, "type", None)
)
@@ -1081,7 +1126,7 @@ register_function("cast", Cast)
register_function("extract", Extract)
-class next_value(GenericFunction):
+class next_value(GenericFunction[int]):
"""Represent the 'next value', given a :class:`.Sequence`
as its single argument.
@@ -1103,7 +1148,7 @@ class next_value(GenericFunction):
seq, schema.Sequence
), "next_value() accepts a Sequence object as input."
self.sequence = seq
- self.type = sqltypes.to_instance(
+ self.type = sqltypes.to_instance( # type: ignore
seq.data_type or getattr(self, "type", None)
)
@@ -1118,7 +1163,7 @@ class next_value(GenericFunction):
return []
-class AnsiFunction(GenericFunction):
+class AnsiFunction(GenericFunction[_T]):
"""Define a function in "ansi" format, which doesn't render parenthesis."""
inherit_cache = True
@@ -1127,13 +1172,13 @@ class AnsiFunction(GenericFunction):
GenericFunction.__init__(self, *args, **kwargs)
-class ReturnTypeFromArgs(GenericFunction):
+class ReturnTypeFromArgs(GenericFunction[_T]):
"""Define a function whose return type is the same as its arguments."""
inherit_cache = True
def __init__(self, *args, **kwargs):
- args = [
+ fn_args = [
coercions.expect(
roles.ExpressionElementRole,
c,
@@ -1142,35 +1187,35 @@ class ReturnTypeFromArgs(GenericFunction):
)
for c in args
]
- kwargs.setdefault("type_", _type_from_args(args))
- kwargs["_parsed_args"] = args
- super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
+ kwargs.setdefault("type_", _type_from_args(fn_args))
+ kwargs["_parsed_args"] = fn_args
+ super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs)
-class coalesce(ReturnTypeFromArgs):
+class coalesce(ReturnTypeFromArgs[_T]):
_has_args = True
inherit_cache = True
-class max(ReturnTypeFromArgs): # noqa A001
+class max(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL MAX() aggregate function."""
inherit_cache = True
-class min(ReturnTypeFromArgs): # noqa A001
+class min(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL MIN() aggregate function."""
inherit_cache = True
-class sum(ReturnTypeFromArgs): # noqa A001
+class sum(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL SUM() aggregate function."""
inherit_cache = True
-class now(GenericFunction):
+class now(GenericFunction[datetime.datetime]):
"""The SQL now() datetime function.
SQLAlchemy dialects will usually render this particular function
@@ -1178,11 +1223,11 @@ class now(GenericFunction):
"""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class concat(GenericFunction):
+class concat(GenericFunction[str]):
"""The SQL CONCAT() function, which concatenates strings.
E.g.::
@@ -1200,28 +1245,30 @@ class concat(GenericFunction):
"""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class char_length(GenericFunction):
+class char_length(GenericFunction[int]):
"""The CHAR_LENGTH() SQL function."""
- type = sqltypes.Integer
+ type = sqltypes.Integer()
inherit_cache = True
- def __init__(self, arg, **kwargs):
- GenericFunction.__init__(self, arg, **kwargs)
+ def __init__(self, arg, **kw):
+ # slight hack to limit to just one positional argument
+ # not sure why this one function has this special treatment
+ super().__init__(arg, **kw)
-class random(GenericFunction):
+class random(GenericFunction[float]):
"""The RANDOM() SQL function."""
_has_args = True
inherit_cache = True
-class count(GenericFunction):
+class count(GenericFunction[int]):
r"""The ANSI COUNT aggregate function. With no arguments,
emits COUNT \*.
@@ -1242,7 +1289,7 @@ class count(GenericFunction):
"""
- type = sqltypes.Integer
+ type = sqltypes.Integer()
inherit_cache = True
def __init__(self, expression=None, **kwargs):
@@ -1251,70 +1298,70 @@ class count(GenericFunction):
super(count, self).__init__(expression, **kwargs)
-class current_date(AnsiFunction):
+class current_date(AnsiFunction[datetime.date]):
"""The CURRENT_DATE() SQL function."""
- type = sqltypes.Date
+ type = sqltypes.Date()
inherit_cache = True
-class current_time(AnsiFunction):
+class current_time(AnsiFunction[datetime.time]):
"""The CURRENT_TIME() SQL function."""
- type = sqltypes.Time
+ type = sqltypes.Time()
inherit_cache = True
-class current_timestamp(AnsiFunction):
+class current_timestamp(AnsiFunction[datetime.datetime]):
"""The CURRENT_TIMESTAMP() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class current_user(AnsiFunction):
+class current_user(AnsiFunction[str]):
"""The CURRENT_USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class localtime(AnsiFunction):
+class localtime(AnsiFunction[datetime.datetime]):
"""The localtime() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class localtimestamp(AnsiFunction):
+class localtimestamp(AnsiFunction[datetime.datetime]):
"""The localtimestamp() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class session_user(AnsiFunction):
+class session_user(AnsiFunction[str]):
"""The SESSION_USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class sysdate(AnsiFunction):
+class sysdate(AnsiFunction[datetime.datetime]):
"""The SYSDATE() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class user(AnsiFunction):
+class user(AnsiFunction[str]):
"""The USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class array_agg(GenericFunction):
+class array_agg(GenericFunction[_T]):
"""Support for the ARRAY_AGG function.
The ``func.array_agg(expr)`` construct returns an expression of
@@ -1334,11 +1381,10 @@ class array_agg(GenericFunction):
"""
- type = sqltypes.ARRAY
inherit_cache = True
def __init__(self, *args, **kwargs):
- args = [
+ fn_args = [
coercions.expect(
roles.ExpressionElementRole, c, apply_propagate_attrs=self
)
@@ -1348,16 +1394,16 @@ class array_agg(GenericFunction):
default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
if "type_" not in kwargs:
- type_from_args = _type_from_args(args)
+ type_from_args = _type_from_args(fn_args)
if isinstance(type_from_args, sqltypes.ARRAY):
kwargs["type_"] = type_from_args
else:
kwargs["type_"] = default_array_type(type_from_args)
- kwargs["_parsed_args"] = args
- super(array_agg, self).__init__(*args, **kwargs)
+ kwargs["_parsed_args"] = fn_args
+ super(array_agg, self).__init__(*fn_args, **kwargs)
-class OrderedSetAgg(GenericFunction):
+class OrderedSetAgg(GenericFunction[_T]):
"""Define a function where the return type is based on the sort
expression type as defined by the expression passed to the
:meth:`.FunctionElement.within_group` method."""
@@ -1366,7 +1412,7 @@ class OrderedSetAgg(GenericFunction):
inherit_cache = True
def within_group_type(self, within_group):
- func_clauses = self.clause_expr.element
+ func_clauses = cast(ClauseList, self.clause_expr.element)
order_by = sqlutil.unwrap_order_by(within_group.order_by)
if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
return sqltypes.ARRAY(order_by[0].type)
@@ -1374,7 +1420,7 @@ class OrderedSetAgg(GenericFunction):
return order_by[0].type
-class mode(OrderedSetAgg):
+class mode(OrderedSetAgg[_T]):
"""Implement the ``mode`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1389,7 +1435,7 @@ class mode(OrderedSetAgg):
inherit_cache = True
-class percentile_cont(OrderedSetAgg):
+class percentile_cont(OrderedSetAgg[_T]):
"""Implement the ``percentile_cont`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1407,7 +1453,7 @@ class percentile_cont(OrderedSetAgg):
inherit_cache = True
-class percentile_disc(OrderedSetAgg):
+class percentile_disc(OrderedSetAgg[_T]):
"""Implement the ``percentile_disc`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1425,7 +1471,7 @@ class percentile_disc(OrderedSetAgg):
inherit_cache = True
-class rank(GenericFunction):
+class rank(GenericFunction[int]):
"""Implement the ``rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1441,7 +1487,7 @@ class rank(GenericFunction):
inherit_cache = True
-class dense_rank(GenericFunction):
+class dense_rank(GenericFunction[int]):
"""Implement the ``dense_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1457,7 +1503,7 @@ class dense_rank(GenericFunction):
inherit_cache = True
-class percent_rank(GenericFunction):
+class percent_rank(GenericFunction[_N]):
"""Implement the ``percent_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1469,11 +1515,11 @@ class percent_rank(GenericFunction):
"""
- type = sqltypes.Numeric()
+ type: sqltypes.Numeric[_N] = sqltypes.Numeric()
inherit_cache = True
-class cume_dist(GenericFunction):
+class cume_dist(GenericFunction[_N]):
"""Implement the ``cume_dist`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1485,11 +1531,11 @@ class cume_dist(GenericFunction):
"""
- type = sqltypes.Numeric()
+ type: sqltypes.Numeric[_N] = sqltypes.Numeric()
inherit_cache = True
-class cube(GenericFunction):
+class cube(GenericFunction[_T]):
r"""Implement the ``CUBE`` grouping operation.
This function is used as part of the GROUP BY of a statement,
@@ -1506,7 +1552,7 @@ class cube(GenericFunction):
inherit_cache = True
-class rollup(GenericFunction):
+class rollup(GenericFunction[_T]):
r"""Implement the ``ROLLUP`` grouping operation.
This function is used as part of the GROUP BY of a statement,
@@ -1523,7 +1569,7 @@ class rollup(GenericFunction):
inherit_cache = True
-class grouping_sets(GenericFunction):
+class grouping_sets(GenericFunction[_T]):
r"""Implement the ``GROUPING SETS`` grouping operation.
This function is used as part of the GROUP BY of a statement,
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
index 9d011ef53..da15c305f 100644
--- a/lib/sqlalchemy/sql/lambdas.py
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -12,6 +12,18 @@ import inspect
import itertools
import operator
import types
+from types import CodeType
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Iterable
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
import weakref
from . import cache_key as _cache_key
@@ -19,37 +31,62 @@ from . import coercions
from . import elements
from . import roles
from . import schema
-from . import traversals
from . import type_api
from . import visitors
from .base import _clone
+from .base import Executable
from .base import Options
+from .cache_key import CacheConst
from .operators import ColumnOperators
from .. import exc
from .. import inspection
from .. import util
+from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import Self
-_closure_per_cache_key = util.LRUCache(1000)
+if TYPE_CHECKING:
+ from .cache_key import CacheConst
+ from .cache_key import NO_CACHE
+ from .elements import BindParameter
+ from .elements import ClauseElement
+ from .roles import SQLRole
+ from .visitors import _CloneCallableType
+
+_LambdaCacheType = MutableMapping[
+ Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"]
+]
+_BoundParameterGetter = Callable[..., Any]
+
+_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
+
+
+class _LambdaType(Protocol):
+ __code__: CodeType
+ __closure__: Iterable[Tuple[Any, Any]]
+
+ def __call__(self, *arg: Any, **kw: Any) -> ClauseElement:
+ ...
class LambdaOptions(Options):
enable_tracking = True
track_closure_variables = True
- track_on = None
+ track_on: Optional[object] = None
global_track_bound_values = True
track_bound_values = True
- lambda_cache = None
+ lambda_cache: Optional[_LambdaCacheType] = None
def lambda_stmt(
- lmb,
- enable_tracking=True,
- track_closure_variables=True,
- track_on=None,
- global_track_bound_values=True,
- track_bound_values=True,
- lambda_cache=None,
-):
+ lmb: _LambdaType,
+ enable_tracking: bool = True,
+ track_closure_variables: bool = True,
+ track_on: Optional[object] = None,
+ global_track_bound_values: bool = True,
+ track_bound_values: bool = True,
+ lambda_cache: Optional[_LambdaCacheType] = None,
+) -> StatementLambdaElement:
"""Produce a SQL statement that is cached as a lambda.
The Python code object within the lambda is scanned for both Python
@@ -142,15 +179,28 @@ class LambdaElement(elements.ClauseElement):
("_resolved", visitors.InternalTraversal.dp_clauseelement)
]
- _transforms = ()
+ _transforms: Tuple[_CloneCallableType, ...] = ()
- parent_lambda = None
+ _resolved_bindparams: List[BindParameter[Any]]
+ parent_lambda: Optional[StatementLambdaElement] = None
+ closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
+ role: Type[SQLRole]
+ _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
+ fn: _LambdaType
+ tracker_key: Tuple[CodeType, ...]
def __repr__(self):
- return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
+ return "%s(%r)" % (
+ self.__class__.__name__,
+ self.fn.__code__,
+ )
def __init__(
- self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
+ self,
+ fn: _LambdaType,
+ role: Type[SQLRole],
+ opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
):
self.fn = fn
self.role = role
@@ -182,6 +232,7 @@ class LambdaElement(elements.ClauseElement):
opts,
)
+ bindparams: List[BindParameter[Any]]
self._resolved_bindparams = bindparams = []
if self.parent_lambda is not None:
@@ -189,8 +240,10 @@ class LambdaElement(elements.ClauseElement):
else:
parent_closure_cache_key = ()
+ cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
+
if parent_closure_cache_key is not _cache_key.NO_CACHE:
- anon_map = traversals.anon_map()
+ anon_map = visitors.anon_map()
cache_key = tuple(
[
getter(closure, opts, anon_map, bindparams)
@@ -241,7 +294,7 @@ class LambdaElement(elements.ClauseElement):
if self.parent_lambda is not None:
bindparams[:0] = self.parent_lambda._resolved_bindparams
- lambda_element = self
+ lambda_element: Optional[LambdaElement] = self
while lambda_element is not None:
rec = lambda_element._rec
if rec.bindparam_trackers:
@@ -289,17 +342,21 @@ class LambdaElement(elements.ClauseElement):
def _setup_binds_for_tracked_expr(self, expr):
bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
- def replace(thing):
- if isinstance(thing, elements.BindParameter):
+ def replace(
+ element: Optional[visitors.ExternallyTraversible], **kw: Any
+ ) -> Optional[visitors.ExternallyTraversible]:
+ if isinstance(element, elements.BindParameter):
- if thing.key in bindparam_lookup:
- bind = bindparam_lookup[thing.key]
- if thing.expanding:
+ if element.key in bindparam_lookup:
+ bind = bindparam_lookup[element.key]
+ if element.expanding:
bind.expanding = True
- bind.expand_op = thing.expand_op
- bind.type = thing.type
+ bind.expand_op = element.expand_op
+ bind.type = element.type
return bind
+ return None
+
if self._rec.is_sequence:
expr = [
visitors.replacement_traverse(sub_expr, {}, replace)
@@ -311,8 +368,11 @@ class LambdaElement(elements.ClauseElement):
return expr
def _copy_internals(
- self, clone=_clone, deferred_copy_internals=None, **kw
- ):
+ self: Self,
+ clone: _CloneCallableType = _clone,
+ deferred_copy_internals: Optional[_CloneCallableType] = None,
+ **kw: Any,
+ ) -> None:
# TODO: this needs A LOT of tests
self._resolved = clone(
self._resolved,
@@ -340,9 +400,15 @@ class LambdaElement(elements.ClauseElement):
) + self.closure_cache_key
parent = self.parent_lambda
+
while parent is not None:
+ assert parent.closure_cache_key is not CacheConst.NO_CACHE
+ parent_closure_cache_key: Tuple[
+ Any, ...
+ ] = parent.closure_cache_key
+
cache_key = (
- (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+ (parent.fn.__code__,) + parent_closure_cache_key + cache_key
)
parent = parent.parent_lambda
@@ -351,7 +417,7 @@ class LambdaElement(elements.ClauseElement):
bindparams.extend(self._resolved_bindparams)
return cache_key
- def _invoke_user_fn(self, fn, *arg):
+ def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement:
return fn()
@@ -365,7 +431,13 @@ class DeferredLambdaElement(LambdaElement):
"""
- def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
+ def __init__(
+ self,
+ fn: _LambdaType,
+ role: Type[roles.SQLRole],
+ opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
+ lambda_args: Tuple[Any, ...] = (),
+ ):
self.lambda_args = lambda_args
super(DeferredLambdaElement, self).__init__(fn, role, opts)
@@ -373,6 +445,7 @@ class DeferredLambdaElement(LambdaElement):
return fn(*self.lambda_args)
def _resolve_with_args(self, *lambda_args):
+ assert isinstance(self._rec, AnalyzedFunction)
tracker_fn = self._rec.tracker_instrumented_fn
expr = tracker_fn(*lambda_args)
@@ -506,6 +579,8 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
def _execute_on_connection(
self, connection, distilled_params, execution_options
):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, ClauseElement)
if self._rec.expected_expr.supports_execution:
return connection._execute_clauseelement(
self, distilled_params, execution_options
@@ -515,14 +590,20 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
@property
def _with_options(self):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, Executable)
return self._rec.expected_expr._with_options
@property
def _effective_plugin_target(self):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, Executable)
return self._rec.expected_expr._effective_plugin_target
@property
def _execution_options(self):
+ if TYPE_CHECKING:
+ assert isinstance(self._rec.expected_expr, Executable)
return self._rec.expected_expr._execution_options
def spoil(self):
@@ -583,9 +664,14 @@ class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
class LinkedLambdaElement(StatementLambdaElement):
"""Represent subsequent links of a :class:`.StatementLambdaElement`."""
- role = None
+ parent_lambda: StatementLambdaElement
- def __init__(self, fn, parent_lambda, opts):
+ def __init__(
+ self,
+ fn: _LambdaType,
+ parent_lambda: StatementLambdaElement,
+ opts: Union[Type[LambdaOptions], LambdaOptions],
+ ):
self.opts = opts
self.fn = fn
self.parent_lambda = parent_lambda
@@ -606,7 +692,9 @@ class AnalyzedCode:
"closure_trackers",
"build_py_wrappers",
)
- _fns = weakref.WeakKeyDictionary()
+ _fns: weakref.WeakKeyDictionary[
+ CodeType, AnalyzedCode
+ ] = weakref.WeakKeyDictionary()
@classmethod
def get(cls, fn, lambda_element, lambda_kw, **kw):
@@ -615,6 +703,8 @@ class AnalyzedCode:
return cls._fns[fn.__code__]
except KeyError:
pass
+
+ analyzed: AnalyzedCode
cls._fns[fn.__code__] = analyzed = AnalyzedCode(
fn, lambda_element, lambda_kw, **kw
)
@@ -947,14 +1037,18 @@ class AnalyzedCode:
class NonAnalyzedFunction:
__slots__ = ("expr",)
- closure_bindparams = None
- bindparam_trackers = None
+ closure_bindparams: Optional[List[BindParameter[Any]]] = None
+ bindparam_trackers: Optional[List[_BoundParameterGetter]] = None
+
+ is_sequence = False
+
+ expr: ClauseElement
- def __init__(self, expr):
+ def __init__(self, expr: ClauseElement):
self.expr = expr
@property
- def expected_expr(self):
+ def expected_expr(self) -> ClauseElement:
return self.expr
@@ -972,6 +1066,10 @@ class AnalyzedFunction:
"closure_bindparams",
)
+ closure_bindparams: Optional[List[BindParameter[Any]]]
+ expected_expr: Union[ClauseElement, List[ClauseElement]]
+ bindparam_trackers: Optional[List[_BoundParameterGetter]]
+
def __init__(
self,
analyzed_code,
@@ -1071,19 +1169,25 @@ class AnalyzedFunction:
if parent_lambda is None:
if isinstance(expr, collections_abc.Sequence):
self.expected_expr = [
- coercions.expect(
- lambda_element.role,
- sub_expr,
- apply_propagate_attrs=apply_propagate_attrs,
+ cast(
+ "ClauseElement",
+ coercions.expect(
+ lambda_element.role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ ),
)
for sub_expr in expr
]
self.is_sequence = True
else:
- self.expected_expr = coercions.expect(
- lambda_element.role,
- expr,
- apply_propagate_attrs=apply_propagate_attrs,
+ self.expected_expr = cast(
+ "ClauseElement",
+ coercions.expect(
+ lambda_element.role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ ),
)
self.is_sequence = False
else:
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 86725f86f..577d868fd 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -19,7 +19,6 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _PropagateAttrsType
- from ._typing import _SelectIterable
from .base import _EntityNamespace
from .base import ColumnCollection
from .base import ReadOnlyColumnCollection
@@ -28,6 +27,7 @@ if TYPE_CHECKING:
from .elements import ColumnElement
from .elements import Label
from .elements import NamedColumn
+ from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import Subquery
@@ -164,6 +164,12 @@ class WhereHavingRole(OnClauseRole):
class ExpressionElementRole(Generic[_T], SQLRole):
+ # note when using generics for ExpressionElementRole,
+ # the generic type needs to be in
+ # sqlalchemy.sql.coercions._impl_lookup mapping also.
+ # these are set up for basic types like int, bool, str, float
+ # right now
+
__slots__ = ()
_role_name = "SQL expression element"
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index cbd0c77f4..883439ca5 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -36,6 +36,7 @@ import operator
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
from typing import Iterator
from typing import List
@@ -68,6 +69,7 @@ from .elements import SQLCoreOperations
from .elements import TextClause
from .selectable import TableClause
from .type_api import to_instance
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .. import event
from .. import exc
@@ -131,21 +133,33 @@ def _get_table_key(name: str, schema: Optional[str]) -> str:
# this should really be in sql/util.py but we'd have to
# break an import cycle
-def _copy_expression(expression, source_table, target_table):
+def _copy_expression(
+ expression: ColumnElement[Any],
+ source_table: Optional[Table],
+ target_table: Optional[Table],
+) -> ColumnElement[Any]:
if source_table is None or target_table is None:
return expression
- def replace(col):
+ fixed_source_table = source_table
+ fixed_target_table = target_table
+
+ def replace(
+ element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
if (
- isinstance(col, Column)
- and col.table is source_table
- and col.key in source_table.c
+ isinstance(element, Column)
+ and element.table is fixed_source_table
+ and element.key in fixed_source_table.c
):
- return target_table.c[col.key]
+ return fixed_target_table.c[element.key]
else:
return None
- return visitors.replacement_traverse(expression, {}, replace)
+ return cast(
+ ColumnElement[Any],
+ visitors.replacement_traverse(expression, {}, replace),
+ )
@inspection._self_inspects
@@ -911,8 +925,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
def _reset_exported(self):
pass
- @property
- def _autoincrement_column(self):
+ @util.ro_non_memoized_property
+ def _autoincrement_column(self) -> Optional[Column[Any]]:
return self.primary_key._autoincrement_column
@property
@@ -2308,6 +2322,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
parent: Column[Any]
+ _table_column: Optional[Column[Any]]
+
def __init__(
self,
column: Union[str, Column[Any], SQLCoreOperations[Any]],
@@ -4290,11 +4306,11 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
self._columns.extend(columns)
- PrimaryKeyConstraint._autoincrement_column._reset(self)
+ PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore
self._set_parent_with_dispatch(self.table)
def _replace(self, col):
- PrimaryKeyConstraint._autoincrement_column._reset(self)
+ PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore
self._columns.replace(col)
self.dispatch._sa_event_column_added_to_pk_constraint(self, col)
@@ -4308,8 +4324,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
else:
return list(self._columns)
- @util.memoized_property
- def _autoincrement_column(self):
+ @util.ro_memoized_property
+ def _autoincrement_column(self) -> Optional[Column[Any]]:
def _validate_autoinc(col, autoinc_true):
if col.type._type_affinity is None or not issubclass(
col.type._type_affinity, type_api.INTEGERTYPE._type_affinity
@@ -4350,6 +4366,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
"ignore_fk",
) and _validate_autoinc(col, False):
return col
+ else:
+ return None
else:
autoinc = None
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 4f6e3795e..6504449f1 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -17,16 +17,26 @@ import collections
from enum import Enum
import itertools
import typing
+from typing import AbstractSet
from typing import Any as TODO_Any
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
from typing import Iterable
+from typing import Iterator
from typing import List
from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import cache_key
from . import coercions
@@ -37,6 +47,9 @@ from . import type_api
from . import visitors
from ._typing import _ColumnsClauseArgument
from ._typing import is_column_element
+from ._typing import is_select_statement
+from ._typing import is_subquery
+from ._typing import is_table
from .annotation import Annotated
from .annotation import SupportsCloneAnnotations
from .base import _clone
@@ -68,32 +81,80 @@ from .elements import ColumnClause
from .elements import ColumnElement
from .elements import DQLDMLClauseElement
from .elements import GroupedElement
-from .elements import Grouping
from .elements import literal_column
from .elements import TableValuedColumn
from .elements import UnaryExpression
+from .operators import OperatorType
+from .visitors import _TraverseInternalsType
from .visitors import InternalTraversal
from .visitors import prefix_anon_map
from .. import exc
from .. import util
+from ..util import HasMemoized_ro_memoized_attribute
+from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import Self
and_ = BooleanClauseList.and_
_T = TypeVar("_T", bound=Any)
if TYPE_CHECKING:
- from ._typing import _SelectIterable
+ from ._typing import _ColumnExpressionArgument
+ from ._typing import _FromClauseArgument
+ from ._typing import _JoinTargetArgument
+ from ._typing import _OnClauseArgument
+ from ._typing import _SelectStatementForCompoundArgument
+ from ._typing import _TextCoercedExpressionArgument
+ from ._typing import _TypeEngineArgument
+ from .base import _AmbiguousTableNameMap
+ from .base import ExecutableOption
from .base import ReadOnlyColumnCollection
+ from .cache_key import _CacheKeyTraversalType
+ from .compiler import SQLCompiler
+ from .dml import Delete
+ from .dml import Insert
+ from .dml import Update
from .elements import NamedColumn
+ from .elements import TextClause
+ from .functions import Function
+ from .schema import Column
from .schema import ForeignKey
- from .schema import PrimaryKeyConstraint
+ from .schema import ForeignKeyConstraint
+ from .type_api import TypeEngine
+ from .util import ClauseAdapter
+ from .visitors import _CloneCallableType
-class _OffsetLimitParam(BindParameter):
+_ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"]
+
+
+class _JoinTargetProtocol(Protocol):
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ ...
+
+
+_JoinTargetElement = Union["FromClause", _JoinTargetProtocol]
+_OnClauseElement = Union["ColumnElement[bool]", _JoinTargetProtocol]
+
+
+_SetupJoinsElement = Tuple[
+ _JoinTargetElement,
+ Optional[_OnClauseElement],
+ Optional["FromClause"],
+ Dict[str, Any],
+]
+
+
+_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]]
+
+
+class _OffsetLimitParam(BindParameter[int]):
inherit_cache = True
@property
- def _limit_offset_value(self):
+ def _limit_offset_value(self) -> Optional[int]:
return self.effective_value
@@ -114,11 +175,12 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
# sub-elements of returns_rows
_is_from_clause = False
+ _is_select_base = False
_is_select_statement = False
_is_lateral = False
@property
- def selectable(self):
+ def selectable(self) -> ReturnsRows:
return self
@util.non_memoized_property
@@ -133,8 +195,28 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
"""
raise NotImplementedError()
+ def is_derived_from(self, fromclause: FromClause) -> bool:
+ """Return ``True`` if this :class:`.ReturnsRows` is
+ 'derived' from the given :class:`.FromClause`.
+
+ An example would be an Alias of a Table is derived from that Table.
+
+ """
+ raise NotImplementedError()
+
+ def _generate_fromclause_column_proxies(
+ self, fromclause: FromClause
+ ) -> None:
+ """Populate columns into an :class:`.AliasedReturnsRows` object."""
+
+ raise NotImplementedError()
+
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
+ """reset internal collections for an incoming column being added."""
+ raise NotImplementedError()
+
@property
- def exported_columns(self):
+ def exported_columns(self) -> ReadOnlyColumnCollection[Any, Any]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.ReturnsRows`.
@@ -160,6 +242,9 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
raise NotImplementedError()
+SelfSelectable = TypeVar("SelfSelectable", bound="Selectable")
+
+
class Selectable(ReturnsRows):
"""Mark a class as being selectable."""
@@ -167,10 +252,10 @@ class Selectable(ReturnsRows):
is_selectable = True
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
raise NotImplementedError()
- def lateral(self, name=None):
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a LATERAL alias of this :class:`_expression.Selectable`.
The return value is the :class:`_expression.Lateral` construct also
@@ -192,15 +277,21 @@ class Selectable(ReturnsRows):
"functionality is available via the sqlalchemy.sql.visitors module.",
)
@util.preload_module("sqlalchemy.sql.util")
- def replace_selectable(self, old, alias):
+ def replace_selectable(
+ self: SelfSelectable, old: FromClause, alias: Alias
+ ) -> SelfSelectable:
"""Replace all occurrences of :class:`_expression.FromClause`
'old' with the given :class:`_expression.Alias`
object, returning a copy of this :class:`_expression.FromClause`.
"""
- return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self)
+ return util.preloaded.sql_util.ClauseAdapter(alias).traverse( # type: ignore # noqa E501
+ self
+ )
- def corresponding_column(self, column, require_embedded=False):
+ def corresponding_column(
+ self, column: ColumnElement[Any], require_embedded: bool = False
+ ) -> Optional[ColumnElement[Any]]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from the
:attr:`_expression.Selectable.exported_columns`
@@ -242,19 +333,23 @@ SelfHasPrefixes = typing.TypeVar("SelfHasPrefixes", bound="HasPrefixes")
class HasPrefixes:
- _prefixes = ()
+ _prefixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = ()
- _has_prefixes_traverse_internals = [
+ _has_prefixes_traverse_internals: _TraverseInternalsType = [
("_prefixes", InternalTraversal.dp_prefix_sequence)
]
@_generative
@_document_text_coercion(
- "expr",
+ "prefixes",
":meth:`_expression.HasPrefixes.prefix_with`",
- ":paramref:`.HasPrefixes.prefix_with.*expr`",
+ ":paramref:`.HasPrefixes.prefix_with.*prefixes`",
)
- def prefix_with(self: SelfHasPrefixes, *expr, **kw) -> SelfHasPrefixes:
+ def prefix_with(
+ self: SelfHasPrefixes,
+ *prefixes: _TextCoercedExpressionArgument[Any],
+ dialect: str = "*",
+ ) -> SelfHasPrefixes:
r"""Add one or more expressions following the statement keyword, i.e.
SELECT, INSERT, UPDATE, or DELETE. Generative.
@@ -272,49 +367,44 @@ class HasPrefixes:
Multiple prefixes can be specified by multiple calls
to :meth:`_expression.HasPrefixes.prefix_with`.
- :param \*expr: textual or :class:`_expression.ClauseElement`
+ :param \*prefixes: textual or :class:`_expression.ClauseElement`
construct which
will be rendered following the INSERT, UPDATE, or DELETE
keyword.
- :param \**kw: A single keyword 'dialect' is accepted. This is an
- optional string dialect name which will
+ :param dialect: optional string dialect name which will
limit rendering of this prefix to only that dialect.
"""
- dialect = kw.pop("dialect", None)
- if kw:
- raise exc.ArgumentError(
- "Unsupported argument(s): %s" % ",".join(kw)
- )
- self._setup_prefixes(expr, dialect)
- return self
-
- def _setup_prefixes(self, prefixes, dialect=None):
self._prefixes = self._prefixes + tuple(
[
(coercions.expect(roles.StatementOptionRole, p), dialect)
for p in prefixes
]
)
+ return self
SelfHasSuffixes = typing.TypeVar("SelfHasSuffixes", bound="HasSuffixes")
class HasSuffixes:
- _suffixes = ()
+ _suffixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = ()
- _has_suffixes_traverse_internals = [
+ _has_suffixes_traverse_internals: _TraverseInternalsType = [
("_suffixes", InternalTraversal.dp_prefix_sequence)
]
@_generative
@_document_text_coercion(
- "expr",
+ "suffixes",
":meth:`_expression.HasSuffixes.suffix_with`",
- ":paramref:`.HasSuffixes.suffix_with.*expr`",
+ ":paramref:`.HasSuffixes.suffix_with.*suffixes`",
)
- def suffix_with(self: SelfHasSuffixes, *expr, **kw) -> SelfHasSuffixes:
+ def suffix_with(
+ self: SelfHasSuffixes,
+ *suffixes: _TextCoercedExpressionArgument[Any],
+ dialect: str = "*",
+ ) -> SelfHasSuffixes:
r"""Add one or more expressions following the statement as a whole.
This is used to support backend-specific suffix keywords on
@@ -328,44 +418,39 @@ class HasSuffixes:
Multiple suffixes can be specified by multiple calls
to :meth:`_expression.HasSuffixes.suffix_with`.
- :param \*expr: textual or :class:`_expression.ClauseElement`
+ :param \*suffixes: textual or :class:`_expression.ClauseElement`
construct which
will be rendered following the target clause.
- :param \**kw: A single keyword 'dialect' is accepted. This is an
- optional string dialect name which will
+ :param dialect: Optional string dialect name which will
limit rendering of this suffix to only that dialect.
"""
- dialect = kw.pop("dialect", None)
- if kw:
- raise exc.ArgumentError(
- "Unsupported argument(s): %s" % ",".join(kw)
- )
- self._setup_suffixes(expr, dialect)
- return self
-
- def _setup_suffixes(self, suffixes, dialect=None):
self._suffixes = self._suffixes + tuple(
[
(coercions.expect(roles.StatementOptionRole, p), dialect)
for p in suffixes
]
)
+ return self
SelfHasHints = typing.TypeVar("SelfHasHints", bound="HasHints")
class HasHints:
- _hints = util.immutabledict()
- _statement_hints = ()
+ _hints: util.immutabledict[
+ Tuple[FromClause, str], str
+ ] = util.immutabledict()
+ _statement_hints: Tuple[Tuple[str, str], ...] = ()
- _has_hints_traverse_internals = [
+ _has_hints_traverse_internals: _TraverseInternalsType = [
("_statement_hints", InternalTraversal.dp_statement_hint_list),
("_hints", InternalTraversal.dp_table_hint_list),
]
- def with_statement_hint(self, text, dialect_name="*"):
+ def with_statement_hint(
+ self: SelfHasHints, text: str, dialect_name: str = "*"
+ ) -> SelfHasHints:
"""Add a statement hint to this :class:`_expression.Select` or
other selectable object.
@@ -389,11 +474,14 @@ class HasHints:
MySQL optimizer hints
"""
- return self.with_hint(None, text, dialect_name)
+ return self._with_hint(None, text, dialect_name)
@_generative
def with_hint(
- self: SelfHasHints, selectable, text, dialect_name="*"
+ self: SelfHasHints,
+ selectable: _FromClauseArgument,
+ text: str,
+ dialect_name: str = "*",
) -> SelfHasHints:
r"""Add an indexing or other executional context hint for the given
selectable to this :class:`_expression.Select` or other selectable
@@ -429,6 +517,15 @@ class HasHints:
:meth:`_expression.Select.with_statement_hint`
"""
+
+ return self._with_hint(selectable, text, dialect_name)
+
+ def _with_hint(
+ self: SelfHasHints,
+ selectable: Optional[_FromClauseArgument],
+ text: str,
+ dialect_name: str,
+ ) -> SelfHasHints:
if selectable is None:
self._statement_hints += ((dialect_name, text),)
else:
@@ -443,6 +540,9 @@ class HasHints:
return self
+SelfFromClause = TypeVar("SelfFromClause", bound="FromClause")
+
+
class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
@@ -473,6 +573,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_is_clone_of: Optional[FromClause]
+ _columns: ColumnCollection[Any, Any]
+
schema: Optional[str] = None
"""Define the 'schema' attribute for this :class:`_expression.FromClause`.
@@ -488,7 +590,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_use_schema_map = False
- def select(self) -> "Select":
+ def select(self) -> Select:
r"""Return a SELECT of this :class:`_expression.FromClause`.
@@ -504,7 +606,13 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return Select(self)
- def join(self, right, onclause=None, isouter=False, full=False):
+ def join(
+ self,
+ right: _FromClauseArgument,
+ onclause: Optional[_ColumnExpressionArgument[bool]] = None,
+ isouter: bool = False,
+ full: bool = False,
+ ) -> Join:
"""Return a :class:`_expression.Join` from this
:class:`_expression.FromClause`
to another :class:`FromClause`.
@@ -550,7 +658,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return Join(self, right, onclause, isouter, full)
- def outerjoin(self, right, onclause=None, full=False):
+ def outerjoin(
+ self,
+ right: _FromClauseArgument,
+ onclause: Optional[_ColumnExpressionArgument[bool]] = None,
+ full: bool = False,
+ ) -> Join:
"""Return a :class:`_expression.Join` from this
:class:`_expression.FromClause`
to another :class:`FromClause`, with the "isouter" flag set to
@@ -596,7 +709,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return Join(self, right, onclause, True, full)
- def alias(self, name=None, flat=False):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromClause:
"""Return an alias of this :class:`_expression.FromClause`.
E.g.::
@@ -617,35 +732,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return Alias._construct(self, name)
- @util.preload_module("sqlalchemy.sql.sqltypes")
- def table_valued(self):
- """Return a :class:`_sql.TableValuedColumn` object for this
- :class:`_expression.FromClause`.
-
- A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that
- represents a complete row in a table. Support for this construct is
- backend dependent, and is supported in various forms by backends
- such as PostgreSQL, Oracle and SQL Server.
-
- E.g.::
-
- >>> from sqlalchemy import select, column, func, table
- >>> a = table("a", column("id"), column("x"), column("y"))
- >>> stmt = select(func.row_to_json(a.table_valued()))
- >>> print(stmt)
- SELECT row_to_json(a) AS row_to_json_1
- FROM a
-
- .. versionadded:: 1.4.0b2
-
- .. seealso::
-
- :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
-
- """
- return TableValuedColumn(self, type_api.TABLEVALUE)
-
- def tablesample(self, sampling, name=None, seed=None):
+ def tablesample(
+ self,
+ sampling: Union[float, Function[Any]],
+ name: Optional[str] = None,
+ seed: Optional[roles.ExpressionElementRole[Any]] = None,
+ ) -> TableSample:
"""Return a TABLESAMPLE alias of this :class:`_expression.FromClause`.
The return value is the :class:`_expression.TableSample`
@@ -661,7 +753,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return TableSample._construct(self, sampling, name, seed)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
"""Return ``True`` if this :class:`_expression.FromClause` is
'derived' from the given ``FromClause``.
@@ -673,7 +765,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
# contained elements.
return fromclause in self._cloned_set
- def _is_lexical_equivalent(self, other):
+ def _is_lexical_equivalent(self, other: FromClause) -> bool:
"""Return ``True`` if this :class:`_expression.FromClause` and
the other represent the same lexical identity.
@@ -681,9 +773,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
if they are the same via annotation identity.
"""
- return self._cloned_set.intersection(other._cloned_set)
+ return bool(self._cloned_set.intersection(other._cloned_set))
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def description(self) -> str:
"""A brief description of this :class:`_expression.FromClause`.
@@ -692,13 +784,15 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return getattr(self, "name", self.__class__.__name__ + " object")
- def _generate_fromclause_column_proxies(self, fromclause):
+ def _generate_fromclause_column_proxies(
+ self, fromclause: FromClause
+ ) -> None:
fromclause._columns._populate_separate_keys(
col._make_proxy(fromclause) for col in self.c
)
@property
- def exported_columns(self):
+ def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`.
@@ -796,7 +890,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._populate_column_collection()
return self.foreign_keys
- def _reset_column_collection(self):
+ def _reset_column_collection(self) -> None:
"""Reset the attributes linked to the ``FromClause.c`` attribute.
This collection is separate from all the other memoized things
@@ -817,7 +911,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
def _select_iterable(self) -> _SelectIterable:
return self.c
- def _init_collections(self):
+ def _init_collections(self) -> None:
assert "_columns" not in self.__dict__
assert "primary_key" not in self.__dict__
assert "foreign_keys" not in self.__dict__
@@ -827,10 +921,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self.foreign_keys = set() # type: ignore
@property
- def _cols_populated(self):
+ def _cols_populated(self) -> bool:
return "_columns" in self.__dict__
- def _populate_column_collection(self):
+ def _populate_column_collection(self) -> None:
"""Called on subclasses to establish the .c collection.
Each implementation has a different way of establishing
@@ -838,7 +932,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
"""Given a column added to the .c collection of an underlying
selectable, produce the local version of that column, assuming this
selectable ultimately should proxy this column.
@@ -865,15 +959,60 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
self._reset_column_collection()
- def _anonymous_fromclause(self, name=None, flat=False):
+ def _anonymous_fromclause(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromClause:
return self.alias(name=name)
+ if TYPE_CHECKING:
+
+ def self_group(
+ self: Self, against: Optional[OperatorType] = None
+ ) -> Union[FromGrouping, Self]:
+ ...
+
class NamedFromClause(FromClause):
+ """A :class:`.FromClause` that has a name.
+
+ Examples include tables, subqueries, CTEs, aliased tables.
+
+ .. versionadded:: 2.0
+
+ """
+
named_with_column = True
name: str
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def table_valued(self) -> TableValuedColumn[Any]:
+ """Return a :class:`_sql.TableValuedColumn` object for this
+ :class:`_expression.FromClause`.
+
+ A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that
+ represents a complete row in a table. Support for this construct is
+ backend dependent, and is supported in various forms by backends
+ such as PostgreSQL, Oracle and SQL Server.
+
+ E.g.::
+
+ >>> from sqlalchemy import select, column, func, table
+ >>> a = table("a", column("id"), column("x"), column("y"))
+ >>> stmt = select(func.row_to_json(a.table_valued()))
+ >>> print(stmt)
+ SELECT row_to_json(a) AS row_to_json_1
+ FROM a
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ """
+ return TableValuedColumn(self, type_api.TABLEVALUE)
+
class SelectLabelStyle(Enum):
"""Label style constants that may be passed to
@@ -992,7 +1131,7 @@ class Join(roles.DMLTableRole, FromClause):
__visit_name__ = "join"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("left", InternalTraversal.dp_clauseelement),
("right", InternalTraversal.dp_clauseelement),
("onclause", InternalTraversal.dp_clauseelement),
@@ -1002,7 +1141,20 @@ class Join(roles.DMLTableRole, FromClause):
_is_join = True
- def __init__(self, left, right, onclause=None, isouter=False, full=False):
+ left: FromClause
+ right: FromClause
+ onclause: Optional[ColumnElement[bool]]
+ isouter: bool
+ full: bool
+
+ def __init__(
+ self,
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+ ):
"""Construct a new :class:`_expression.Join`.
The usual entrypoint here is the :func:`_expression.join`
@@ -1010,11 +1162,23 @@ class Join(roles.DMLTableRole, FromClause):
:class:`_expression.FromClause` object.
"""
+
+ # when deannotate was removed here, callcounts went up for ORM
+ # compilation of eager joins, since there were more comparisons of
+ # annotated objects. test_orm.py -> test_fetch_results
+ # was therefore changed to show a more real-world use case, where the
+ # compilation is cached; there's no change in post-cache callcounts.
+ # callcounts for a single compilation in that particular test
+ # that includes about eight joins about 1100 extra fn calls, from
+ # 29200 -> 30373
+
self.left = coercions.expect(
- roles.FromClauseRole, left, deannotate=True
+ roles.FromClauseRole,
+ left,
)
self.right = coercions.expect(
- roles.FromClauseRole, right, deannotate=True
+ roles.FromClauseRole,
+ right,
).self_group()
if onclause is None:
@@ -1029,7 +1193,7 @@ class Join(roles.DMLTableRole, FromClause):
self.isouter = isouter
self.full = full
- @property
+ @util.ro_non_memoized_property
def description(self) -> str:
return "Join object on %s(%d) and %s(%d)" % (
self.left.description,
@@ -1038,7 +1202,7 @@ class Join(roles.DMLTableRole, FromClause):
id(self.right),
)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
return (
# use hash() to ensure direct comparison to annotated works
# as well
@@ -1047,7 +1211,10 @@ class Join(roles.DMLTableRole, FromClause):
or self.right.is_derived_from(fromclause)
)
- def self_group(self, against=None):
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> FromGrouping:
+ ...
return FromGrouping(self)
@util.preload_module("sqlalchemy.sql.util")
@@ -1055,7 +1222,7 @@ class Join(roles.DMLTableRole, FromClause):
sqlutil = util.preloaded.sql_util
columns = [c for c in self.left.c] + [c for c in self.right.c]
- self.primary_key.extend(
+ self.primary_key.extend( # type: ignore
sqlutil.reduce_columns(
(c for c in columns if c.primary_key), self.onclause
)
@@ -1063,11 +1230,13 @@ class Join(roles.DMLTableRole, FromClause):
self._columns._populate_separate_keys(
(col._tq_key_label, col) for col in columns
)
- self.foreign_keys.update(
+ self.foreign_keys.update( # type: ignore
itertools.chain(*[col.foreign_keys for col in columns])
)
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
# see Select._copy_internals() for similar concept
# here we pre-clone "left" and "right" so that we can
@@ -1100,12 +1269,14 @@ class Join(roles.DMLTableRole, FromClause):
self._reset_memoizations()
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
super(Join, self)._refresh_for_new_column(column)
self.left._refresh_for_new_column(column)
self.right._refresh_for_new_column(column)
- def _match_primaries(self, left, right):
+ def _match_primaries(
+ self, left: FromClause, right: FromClause
+ ) -> ColumnElement[bool]:
if isinstance(left, Join):
left_right = left.right
else:
@@ -1114,8 +1285,15 @@ class Join(roles.DMLTableRole, FromClause):
@classmethod
def _join_condition(
- cls, a, b, a_subset=None, consider_as_foreign_keys=None
- ):
+ cls,
+ a: FromClause,
+ b: FromClause,
+ *,
+ a_subset: Optional[FromClause] = None,
+ consider_as_foreign_keys: Optional[
+ AbstractSet[ColumnClause[Any]]
+ ] = None,
+ ) -> ColumnElement[bool]:
"""Create a join condition between two tables or selectables.
See sqlalchemy.sql.util.join_condition() for full docs.
@@ -1151,7 +1329,15 @@ class Join(roles.DMLTableRole, FromClause):
return and_(*crit)
@classmethod
- def _can_join(cls, left, right, consider_as_foreign_keys=None):
+ def _can_join(
+ cls,
+ left: FromClause,
+ right: FromClause,
+ *,
+ consider_as_foreign_keys: Optional[
+ AbstractSet[ColumnClause[Any]]
+ ] = None,
+ ) -> bool:
if isinstance(left, Join):
left_right = left.right
else:
@@ -1169,20 +1355,31 @@ class Join(roles.DMLTableRole, FromClause):
@classmethod
@util.preload_module("sqlalchemy.sql.util")
def _joincond_scan_left_right(
- cls, a, a_subset, b, consider_as_foreign_keys
- ):
+ cls,
+ a: FromClause,
+ a_subset: Optional[FromClause],
+ b: FromClause,
+ consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]],
+ ) -> collections.defaultdict[
+ Optional[ForeignKeyConstraint],
+ List[Tuple[ColumnClause[Any], ColumnClause[Any]]],
+ ]:
sql_util = util.preloaded.sql_util
a = coercions.expect(roles.FromClauseRole, a)
b = coercions.expect(roles.FromClauseRole, b)
- constraints = collections.defaultdict(list)
+ constraints: collections.defaultdict[
+ Optional[ForeignKeyConstraint],
+ List[Tuple[ColumnClause[Any], ColumnClause[Any]]],
+ ] = collections.defaultdict(list)
for left in (a_subset, a):
if left is None:
continue
for fk in sorted(
- b.foreign_keys, key=lambda fk: fk.parent._creation_order
+ b.foreign_keys,
+ key=lambda fk: fk.parent._creation_order, # type: ignore
):
if (
consider_as_foreign_keys is not None
@@ -1202,7 +1399,8 @@ class Join(roles.DMLTableRole, FromClause):
constraints[fk.constraint].append((col, fk.parent))
if left is not b:
for fk in sorted(
- left.foreign_keys, key=lambda fk: fk.parent._creation_order
+ left.foreign_keys,
+ key=lambda fk: fk.parent._creation_order, # type: ignore
):
if (
consider_as_foreign_keys is not None
@@ -1309,7 +1507,8 @@ class Join(roles.DMLTableRole, FromClause):
@util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
- return [self] + self.left._from_objects + self.right._from_objects
+ self_list: List[FromClause] = [self]
+ return self_list + self.left._from_objects + self.right._from_objects
class NoInit:
@@ -1327,6 +1526,14 @@ class NoInit:
)
+class LateralFromClause(NamedFromClause):
+ """mark a FROM clause as being able to render directly as LATERAL"""
+
+
+_SelfAliasedReturnsRows = TypeVar(
+ "_SelfAliasedReturnsRows", bound="AliasedReturnsRows"
+)
+
# FromClause ->
# AliasedReturnsRows
# -> Alias only for FromClause
@@ -1335,6 +1542,8 @@ class NoInit:
# -> Lateral -> FromClause, but we accept SelectBase
# w/ non-deprecated coercion
# -> TableSample -> only for FromClause
+
+
class AliasedReturnsRows(NoInit, NamedFromClause):
"""Base class of aliases against tables, subqueries, and other
selectables."""
@@ -1343,24 +1552,21 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
_supports_derived_columns = False
- element: ClauseElement
+ element: ReturnsRows
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
]
@classmethod
- def _construct(cls, *arg, **kw):
+ def _construct(
+ cls: Type[_SelfAliasedReturnsRows], *arg: Any, **kw: Any
+ ) -> _SelfAliasedReturnsRows:
obj = cls.__new__(cls)
obj._init(*arg, **kw)
return obj
- @classmethod
- def _factory(cls, returnsrows, name=None):
- """Base factory method. Subclasses need to provide this."""
- raise NotImplementedError()
-
def _init(self, selectable, name=None):
self.element = coercions.expect(
roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self
@@ -1378,11 +1584,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
name = _anonymous_label.safe_construct(id(self), name or "anon")
self.name = name
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
super(AliasedReturnsRows, self)._refresh_for_new_column(column)
self.element._refresh_for_new_column(column)
- @property
+ def _populate_column_collection(self):
+ self.element._generate_fromclause_column_proxies(self)
+
+ @util.ro_non_memoized_property
def description(self) -> str:
name = self.name
if isinstance(name, _anonymous_label):
@@ -1395,15 +1604,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
"""Legacy for dialects that are referring to Alias.original."""
return self.element
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
if fromclause in self._cloned_set:
return True
return self.element.is_derived_from(fromclause)
- def _populate_column_collection(self):
- self.element._generate_fromclause_column_proxies(self)
-
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
existing_element = self.element
super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw)
@@ -1420,7 +1628,11 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
return [self]
-class Alias(roles.DMLTableRole, AliasedReturnsRows):
+class FromClauseAlias(AliasedReturnsRows):
+ element: FromClause
+
+
+class Alias(roles.DMLTableRole, FromClauseAlias):
"""Represents an table or selectable alias (AS).
Represents an alias, as typically applied to any table or
@@ -1445,13 +1657,18 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
element: FromClause
@classmethod
- def _factory(cls, selectable, name=None, flat=False):
+ def _factory(
+ cls,
+ selectable: FromClause,
+ name: Optional[str] = None,
+ flat: bool = False,
+ ) -> NamedFromClause:
return coercions.expect(
roles.FromClauseRole, selectable, allow_select=True
).alias(name=name, flat=flat)
-class TableValuedAlias(Alias):
+class TableValuedAlias(LateralFromClause, Alias):
"""An alias against a "table valued" SQL function.
This construct provides for a SQL function that returns columns
@@ -1480,7 +1697,7 @@ class TableValuedAlias(Alias):
_render_derived_w_types = False
joins_implicitly = False
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
("_tableval_type", InternalTraversal.dp_type),
@@ -1526,7 +1743,9 @@ class TableValuedAlias(Alias):
return TableValuedColumn(self, self._tableval_type)
- def alias(self, name=None):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> TableValuedAlias:
"""Return a new alias of this :class:`_sql.TableValuedAlias`.
This creates a distinct FROM object that will be distinguished
@@ -1547,7 +1766,7 @@ class TableValuedAlias(Alias):
return tva
- def lateral(self, name=None):
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a new :class:`_sql.TableValuedAlias` with the lateral flag set,
so that it renders as LATERAL.
@@ -1619,7 +1838,7 @@ class TableValuedAlias(Alias):
return new_alias
-class Lateral(AliasedReturnsRows):
+class Lateral(FromClauseAlias, LateralFromClause):
"""Represent a LATERAL subquery.
This object is constructed from the :func:`_expression.lateral` module
@@ -1644,13 +1863,17 @@ class Lateral(AliasedReturnsRows):
inherit_cache = True
@classmethod
- def _factory(cls, selectable, name=None):
+ def _factory(
+ cls,
+ selectable: Union[SelectBase, _FromClauseArgument],
+ name: Optional[str] = None,
+ ) -> LateralFromClause:
return coercions.expect(
roles.FromClauseRole, selectable, explicit_subquery=True
).lateral(name=name)
-class TableSample(AliasedReturnsRows):
+class TableSample(FromClauseAlias):
"""Represent a TABLESAMPLE clause.
This object is constructed from the :func:`_expression.tablesample` module
@@ -1668,13 +1891,22 @@ class TableSample(AliasedReturnsRows):
__visit_name__ = "tablesample"
- _traverse_internals = AliasedReturnsRows._traverse_internals + [
- ("sampling", InternalTraversal.dp_clauseelement),
- ("seed", InternalTraversal.dp_clauseelement),
- ]
+ _traverse_internals: _TraverseInternalsType = (
+ AliasedReturnsRows._traverse_internals
+ + [
+ ("sampling", InternalTraversal.dp_clauseelement),
+ ("seed", InternalTraversal.dp_clauseelement),
+ ]
+ )
@classmethod
- def _factory(cls, selectable, sampling, name=None, seed=None):
+ def _factory(
+ cls,
+ selectable: _FromClauseArgument,
+ sampling: Union[float, Function[Any]],
+ name: Optional[str] = None,
+ seed: Optional[roles.ExpressionElementRole[Any]] = None,
+ ) -> TableSample:
return coercions.expect(roles.FromClauseRole, selectable).tablesample(
sampling, name=name, seed=seed
)
@@ -1721,7 +1953,7 @@ class CTE(
__visit_name__ = "cte"
- _traverse_internals = (
+ _traverse_internals: _TraverseInternalsType = (
AliasedReturnsRows._traverse_internals
+ [
("_cte_alias", InternalTraversal.dp_clauseelement),
@@ -1736,7 +1968,12 @@ class CTE(
element: HasCTE
@classmethod
- def _factory(cls, selectable, name=None, recursive=False):
+ def _factory(
+ cls,
+ selectable: HasCTE,
+ name: Optional[str] = None,
+ recursive: bool = False,
+ ) -> CTE:
r"""Return a new :class:`_expression.CTE`,
or Common Table Expression instance.
@@ -1775,7 +2012,9 @@ class CTE(
else:
self.element._generate_fromclause_column_proxies(self)
- def alias(self, name=None, flat=False):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromClause:
"""Return an :class:`_expression.Alias` of this
:class:`_expression.CTE`.
@@ -1814,6 +2053,10 @@ class CTE(
:meth:`_sql.HasCTE.cte` - examples of calling styles
"""
+ assert is_select_statement(
+ self.element
+ ), f"CTE element f{self.element} does not support union()"
+
return CTE._construct(
self.element.union(*other),
name=self.name,
@@ -1839,6 +2082,11 @@ class CTE(
:meth:`_sql.HasCTE.cte` - examples of calling styles
"""
+
+ assert is_select_statement(
+ self.element
+ ), f"CTE element f{self.element} does not support union_all()"
+
return CTE._construct(
self.element.union_all(*other),
name=self.name,
@@ -1865,23 +2113,229 @@ class _CTEOpts(NamedTuple):
nesting: bool
-class HasCTE(roles.HasCTERole, ClauseElement):
+class _ColumnsPlusNames(NamedTuple):
+ required_label_name: Optional[str]
+ """
+ string label name, if non-None, must be rendered as a
+ label, i.e. "AS <name>"
+ """
+
+ proxy_key: Optional[str]
+ """
+ proxy_key that is to be part of the result map for this
+ col. this is also the key in a fromclause.c or
+ select.selected_columns collection
+ """
+
+ fallback_label_name: Optional[str]
+ """
+ name that can be used to render an "AS <name>" when
+ we have to render a label even though
+ required_label_name was not given
+ """
+
+ column: Union[ColumnElement[Any], TextClause]
+ """
+ the ColumnElement itself
+ """
+
+ repeated: bool
+ """
+ True if this is a duplicate of a previous column
+ in the list of columns
+ """
+
+
+class SelectsRows(ReturnsRows):
+ """Sub-base of ReturnsRows for elements that deliver rows
+ directly, namely SELECT and INSERT/UPDATE/DELETE..RETURNING"""
+
+ _label_style: SelectLabelStyle = LABEL_STYLE_NONE
+
+ def _generate_columns_plus_names(
+ self, anon_for_dupe_key: bool
+ ) -> List[_ColumnsPlusNames]:
+ """Generate column names as rendered in a SELECT statement by
+ the compiler.
+
+ This is distinct from the _column_naming_convention generator that's
+ intended for population of .c collections and similar, which has
+ different rules. the collection returned here calls upon the
+ _column_naming_convention as well.
+
+ """
+ cols = self._all_selected_columns
+
+ key_naming_convention = SelectState._column_naming_convention(
+ self._label_style
+ )
+
+ names = {}
+
+ result: List[_ColumnsPlusNames] = []
+ result_append = result.append
+
+ table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ label_style_none = self._label_style is LABEL_STYLE_NONE
+
+ # a counter used for "dedupe" labels, which have double underscores
+ # in them and are never referred by name; they only act
+ # as positional placeholders. they need only be unique within
+ # the single columns clause they're rendered within (required by
+ # some dbs such as mysql). So their anon identity is tracked against
+ # a fixed counter rather than hash() identity.
+ dedupe_hash = 1
+
+ for c in cols:
+ repeated = False
+
+ if not c._render_label_in_columns_clause:
+ effective_name = (
+ required_label_name
+ ) = fallback_label_name = None
+ elif label_style_none:
+ if TYPE_CHECKING:
+ assert is_column_element(c)
+
+ effective_name = required_label_name = None
+ fallback_label_name = c._non_anon_label or c._anon_name_label
+ else:
+ if TYPE_CHECKING:
+ assert is_column_element(c)
+
+ if table_qualified:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = c._tq_label
+ else:
+ effective_name = fallback_label_name = c._non_anon_label
+ required_label_name = None
+
+ if effective_name is None:
+ # it seems like this could be _proxy_key and we would
+ # not need _expression_label but it isn't
+ # giving us a clue when to use anon_label instead
+ expr_label = c._expression_label
+ if expr_label is None:
+ repeated = c._anon_name_label in names
+ names[c._anon_name_label] = c
+ effective_name = required_label_name = None
+
+ if repeated:
+ # here, "required_label_name" is sent as
+ # "None" and "fallback_label_name" is sent.
+ if table_qualified:
+ fallback_label_name = (
+ c._dedupe_anon_tq_label_idx(dedupe_hash)
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._dedupe_anon_label_idx(
+ dedupe_hash
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._anon_name_label
+ else:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = expr_label
+
+ if effective_name is not None:
+ if TYPE_CHECKING:
+ assert is_column_element(c)
+
+ if effective_name in names:
+ # when looking to see if names[name] is the same column as
+ # c, use hash(), so that an annotated version of the column
+ # is seen as the same as the non-annotated
+ if hash(names[effective_name]) != hash(c):
+
+ # different column under the same name. apply
+ # disambiguating label
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_tq_label
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_name_label
+
+ if anon_for_dupe_key and required_label_name in names:
+ # here, c._anon_tq_label is definitely unique to
+ # that column identity (or annotated version), so
+ # this should always be true.
+ # this is also an infrequent codepath because
+ # you need two levels of duplication to be here
+ assert hash(names[required_label_name]) == hash(c)
+
+ # the column under the disambiguating label is
+ # already present. apply the "dedupe" label to
+ # subsequent occurrences of the column so that the
+ # original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[required_label_name] = c
+ elif anon_for_dupe_key:
+ # same column under the same name. apply the "dedupe"
+ # label so that the original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[effective_name] = c
+
+ result_append(
+ _ColumnsPlusNames(
+ required_label_name,
+ key_naming_convention(c),
+ fallback_label_name,
+ c,
+ repeated,
+ )
+ )
+
+ return result
+
+
+class HasCTE(roles.HasCTERole, SelectsRows):
"""Mixin that declares a class to include CTE support.
.. versionadded:: 1.1
"""
- _has_ctes_traverse_internals = [
+ _has_ctes_traverse_internals: _TraverseInternalsType = [
("_independent_ctes", InternalTraversal.dp_clauseelement_list),
("_independent_ctes_opts", InternalTraversal.dp_plain_obj),
]
- _independent_ctes = ()
- _independent_ctes_opts = ()
+ _independent_ctes: Tuple[CTE, ...] = ()
+ _independent_ctes_opts: Tuple[_CTEOpts, ...] = ()
@_generative
- def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE:
+ def add_cte(
+ self: SelfHasCTE, *ctes: CTE, nest_here: bool = False
+ ) -> SelfHasCTE:
r"""Add one or more :class:`_sql.CTE` constructs to this statement.
This method will associate the given :class:`_sql.CTE` constructs with
@@ -1985,7 +2439,12 @@ class HasCTE(roles.HasCTERole, ClauseElement):
self._independent_ctes_opts += (opt,)
return self
- def cte(self, name=None, recursive=False, nesting=False):
+ def cte(
+ self,
+ name: Optional[str] = None,
+ recursive: bool = False,
+ nesting: bool = False,
+ ) -> CTE:
r"""Return a new :class:`_expression.CTE`,
or Common Table Expression instance.
@@ -2293,10 +2752,12 @@ class Subquery(AliasedReturnsRows):
inherit_cache = True
- element: Select
+ element: SelectBase
@classmethod
- def _factory(cls, selectable, name=None):
+ def _factory(
+ cls, selectable: SelectBase, name: Optional[str] = None
+ ) -> Subquery:
"""Return a :class:`.Subquery` object."""
return coercions.expect(
roles.SelectStatementRole, selectable
@@ -2335,11 +2796,13 @@ class Subquery(AliasedReturnsRows):
class FromGrouping(GroupedElement, FromClause):
"""Represent a grouping of a FROM clause"""
- _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_clauseelement)
+ ]
element: FromClause
- def __init__(self, element):
+ def __init__(self, element: FromClause):
self.element = coercions.expect(roles.FromClauseRole, element)
def _init_collections(self):
@@ -2361,11 +2824,13 @@ class FromGrouping(GroupedElement, FromClause):
def foreign_keys(self):
return self.element.foreign_keys
- def is_derived_from(self, element):
- return self.element.is_derived_from(element)
+ def is_derived_from(self, fromclause: FromClause) -> bool:
+ return self.element.is_derived_from(fromclause)
- def alias(self, **kw):
- return FromGrouping(self.element.alias(**kw))
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> NamedFromGrouping:
+ return NamedFromGrouping(self.element.alias(name=name, flat=flat))
def _anonymous_fromclause(self, **kw):
return FromGrouping(self.element._anonymous_fromclause(**kw))
@@ -2385,6 +2850,16 @@ class FromGrouping(GroupedElement, FromClause):
self.element = state["element"]
+class NamedFromGrouping(FromGrouping, NamedFromClause):
+ """represent a grouping of a named FROM clause
+
+ .. versionadded:: 2.0
+
+ """
+
+ inherit_cache = True
+
+
class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
"""Represents a minimal "table" construct.
@@ -2417,7 +2892,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
__visit_name__ = "table"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
(
"columns",
InternalTraversal.dp_fromclause_canonical_column_collection,
@@ -2434,15 +2909,17 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
doesn't support having a primary key or column
-level defaults, so implicit returning doesn't apply."""
- _autoincrement_column = None
- """No PK or default support so no autoincrement column."""
+ @util.ro_memoized_property
+ def _autoincrement_column(self) -> Optional[ColumnClause[Any]]:
+ """No PK or default support so no autoincrement column."""
+ return None
- def __init__(self, name, *columns, **kw):
+ def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any):
super(TableClause, self).__init__()
self.name = name
self._columns = DedupeColumnCollection()
- self.primary_key = ColumnSet()
- self.foreign_keys = set()
+ self.primary_key = ColumnSet() # type: ignore
+ self.foreign_keys = set() # type: ignore
for c in columns:
self.append_column(c)
@@ -2466,23 +2943,23 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
...
- def __str__(self):
+ def __str__(self) -> str:
if self.schema is not None:
return self.schema + "." + self.name
else:
return self.name
- def _refresh_for_new_column(self, column):
+ def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
pass
- def _init_collections(self):
+ def _init_collections(self) -> None:
pass
- @util.memoized_property
+ @util.ro_memoized_property
def description(self) -> str:
return self.name
- def append_column(self, c, **kw):
+ def append_column(self, c: ColumnClause[Any]) -> None:
existing = c.table
if existing is not None and existing is not self:
raise exc.ArgumentError(
@@ -2494,7 +2971,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
c.table = self
@util.preload_module("sqlalchemy.sql.dml")
- def insert(self):
+ def insert(self) -> Insert:
"""Generate an :func:`_expression.insert` construct against this
:class:`_expression.TableClause`.
@@ -2505,10 +2982,11 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
See :func:`_expression.insert` for argument and usage information.
"""
+
return util.preloaded.sql_dml.Insert(self)
@util.preload_module("sqlalchemy.sql.dml")
- def update(self):
+ def update(self) -> Update:
"""Generate an :func:`_expression.update` construct against this
:class:`_expression.TableClause`.
@@ -2524,7 +3002,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
)
@util.preload_module("sqlalchemy.sql.dml")
- def delete(self):
+ def delete(self) -> Delete:
"""Generate a :func:`_expression.delete` construct against this
:class:`_expression.TableClause`.
@@ -2543,13 +3021,18 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
class ForUpdateArg(ClauseElement):
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("of", InternalTraversal.dp_clauseelement_list),
("nowait", InternalTraversal.dp_boolean),
("read", InternalTraversal.dp_boolean),
("skip_locked", InternalTraversal.dp_boolean),
]
+ of: Optional[Sequence[ClauseElement]]
+ nowait: bool
+ read: bool
+ skip_locked: bool
+
@classmethod
def _from_argument(cls, with_for_update):
if isinstance(with_for_update, ForUpdateArg):
@@ -2606,7 +3089,7 @@ class ForUpdateArg(ClauseElement):
SelfValues = typing.TypeVar("SelfValues", bound="Values")
-class Values(Generative, NamedFromClause):
+class Values(Generative, LateralFromClause):
"""Represent a ``VALUES`` construct that can be used as a FROM element
in a statement.
@@ -2619,28 +3102,42 @@ class Values(Generative, NamedFromClause):
__visit_name__ = "values"
- _data = ()
+ _data: Tuple[List[Tuple[Any, ...]], ...] = ()
- _traverse_internals = [
+ _unnamed: bool
+ _traverse_internals: _TraverseInternalsType = [
("_column_args", InternalTraversal.dp_clauseelement_list),
("_data", InternalTraversal.dp_dml_multi_values),
("name", InternalTraversal.dp_string),
("literal_binds", InternalTraversal.dp_boolean),
]
- def __init__(self, *columns, name=None, literal_binds=False):
+ def __init__(
+ self,
+ *columns: ColumnClause[Any],
+ name: Optional[str] = None,
+ literal_binds: bool = False,
+ ):
super(Values, self).__init__()
self._column_args = columns
- self.name = name
+ if name is None:
+ self._unnamed = True
+ self.name = _anonymous_label.safe_construct(id(self), "anon")
+ else:
+ self._unnamed = False
+ self.name = name
self.literal_binds = literal_binds
- self.named_with_column = self.name is not None
+ self.named_with_column = not self._unnamed
@property
def _column_types(self):
return [col.type for col in self._column_args]
@_generative
- def alias(self: SelfValues, name, **kw) -> SelfValues:
+ def alias(
+ self: SelfValues, name: Optional[str] = None, flat: bool = False
+ ) -> SelfValues:
+
"""Return a new :class:`_expression.Values`
construct that is a copy of this
one with the given name.
@@ -2655,12 +3152,20 @@ class Values(Generative, NamedFromClause):
:func:`_expression.alias`
"""
- self.name = name
- self.named_with_column = self.name is not None
+ non_none_name: str
+
+ if name is None:
+ non_none_name = _anonymous_label.safe_construct(id(self), "anon")
+ else:
+ non_none_name = name
+
+ self.name = non_none_name
+ self.named_with_column = True
+ self._unnamed = False
return self
@_generative
- def lateral(self: SelfValues, name=None) -> SelfValues:
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a new :class:`_expression.Values` with the lateral flag set,
so that
it renders as LATERAL.
@@ -2670,13 +3175,20 @@ class Values(Generative, NamedFromClause):
:func:`_expression.lateral`
"""
+ non_none_name: str
+
+ if name is None:
+ non_none_name = self.name
+ else:
+ non_none_name = name
+
self._is_lateral = True
- if name is not None:
- self.name = name
+ self.name = non_none_name
+ self._unnamed = False
return self
@_generative
- def data(self: SelfValues, values) -> SelfValues:
+ def data(self: SelfValues, values: List[Tuple[Any, ...]]) -> SelfValues:
"""Return a new :class:`_expression.Values` construct,
adding the given data
to the data list.
@@ -2694,7 +3206,7 @@ class Values(Generative, NamedFromClause):
self._data += (values,)
return self
- def _populate_column_collection(self):
+ def _populate_column_collection(self) -> None:
for c in self._column_args:
self._columns.add(c)
c.table = self
@@ -2727,32 +3239,16 @@ class SelectBase(
"""
- _is_select_statement = True
+ _is_select_base = True
is_select = True
- def _generate_fromclause_column_proxies(
- self, fromclause: FromClause
- ) -> None:
- raise NotImplementedError()
+ _label_style: SelectLabelStyle = LABEL_STYLE_NONE
def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None:
self._reset_memoizations()
- def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
- ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
- raise NotImplementedError()
-
- def set_label_style(
- self: SelfSelectBase, label_style: SelectLabelStyle
- ) -> SelfSelectBase:
- raise NotImplementedError()
-
- def get_label_style(self) -> SelectLabelStyle:
- raise NotImplementedError()
-
- @property
- def selected_columns(self):
+ @util.ro_non_memoized_property
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set.
@@ -2797,7 +3293,7 @@ class SelectBase(
raise NotImplementedError()
@property
- def exported_columns(self):
+ def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`, not including
@@ -2819,7 +3315,7 @@ class SelectBase(
"""
- return self.selected_columns
+ return self.selected_columns.as_readonly()
@util.deprecated_property(
"1.4",
@@ -2841,6 +3337,26 @@ class SelectBase(
def columns(self):
return self.c
+ def get_label_style(self) -> SelectLabelStyle:
+ """
+ Retrieve the current label style.
+
+ Implemented by subclasses.
+
+ """
+ raise NotImplementedError()
+
+ def set_label_style(
+ self: SelfSelectBase, style: SelectLabelStyle
+ ) -> SelfSelectBase:
+ """Return a new selectable with the specified label style.
+
+ Implemented by subclasses.
+
+ """
+
+ raise NotImplementedError()
+
@util.deprecated(
"1.4",
"The :meth:`_expression.SelectBase.select` method is deprecated "
@@ -2857,6 +3373,9 @@ class SelectBase(
def _implicit_subquery(self):
return self.subquery()
+ def _scalar_type(self) -> TypeEngine[Any]:
+ raise NotImplementedError()
+
@util.deprecated(
"1.4",
"The :meth:`_expression.SelectBase.as_scalar` "
@@ -2926,7 +3445,7 @@ class SelectBase(
"""
return self.scalar_subquery().label(name)
- def lateral(self, name=None):
+ def lateral(self, name: Optional[str] = None) -> LateralFromClause:
"""Return a LATERAL alias of this :class:`_expression.Selectable`.
The return value is the :class:`_expression.Lateral` construct also
@@ -2941,11 +3460,7 @@ class SelectBase(
"""
return Lateral._factory(self, name)
- @util.ro_non_memoized_property
- def _from_objects(self) -> List[FromClause]:
- return [self]
-
- def subquery(self, name=None):
+ def subquery(self, name: Optional[str] = None) -> Subquery:
"""Return a subquery of this :class:`_expression.SelectBase`.
A subquery is from a SQL perspective a parenthesized, named
@@ -2995,7 +3510,9 @@ class SelectBase(
raise NotImplementedError()
- def alias(self, name=None, flat=False):
+ def alias(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> Subquery:
"""Return a named subquery against this
:class:`_expression.SelectBase`.
@@ -3023,7 +3540,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
"""
__visit_name__ = "select_statement_grouping"
- _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_clauseelement)
+ ]
_is_select_container = True
@@ -3053,13 +3572,14 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
def select_statement(self):
return self.element
- def self_group(self, against=None):
+ def self_group(self: Self, against: Optional[OperatorType] = None) -> Self:
+ ...
return self
- def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
- ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
- return self.element._generate_columns_plus_names(anon_for_dupe_key)
+ # def _generate_columns_plus_names(
+ # self, anon_for_dupe_key: bool
+ # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
+ # return self.element._generate_columns_plus_names(anon_for_dupe_key)
def _generate_fromclause_column_proxies(
self, subquery: FromClause
@@ -3070,8 +3590,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
def _all_selected_columns(self) -> _SelectIterable:
return self.element._all_selected_columns
- @property
- def selected_columns(self):
+ @util.ro_non_memoized_property
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
the embedded SELECT statement returns in its result set, not including
@@ -3112,25 +3632,30 @@ class GenerativeSelect(SelectBase):
"""
- _order_by_clauses = ()
- _group_by_clauses = ()
- _limit_clause = None
- _offset_clause = None
- _fetch_clause = None
- _fetch_clause_options = None
- _for_update_arg = None
+ _order_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _group_by_clauses: Tuple[ColumnElement[Any], ...] = ()
+ _limit_clause: Optional[ColumnElement[Any]] = None
+ _offset_clause: Optional[ColumnElement[Any]] = None
+ _fetch_clause: Optional[ColumnElement[Any]] = None
+ _fetch_clause_options: Optional[Dict[str, bool]] = None
+ _for_update_arg: Optional[ForUpdateArg] = None
- def __init__(self, _label_style=LABEL_STYLE_DEFAULT):
+ def __init__(self, _label_style: SelectLabelStyle = LABEL_STYLE_DEFAULT):
self._label_style = _label_style
@_generative
def with_for_update(
self: SelfGenerativeSelect,
- nowait=False,
- read=False,
- of=None,
- skip_locked=False,
- key_share=False,
+ nowait: bool = False,
+ read: bool = False,
+ of: Optional[
+ Union[
+ _ColumnExpressionArgument[Any],
+ Sequence[_ColumnExpressionArgument[Any]],
+ ]
+ ] = None,
+ skip_locked: bool = False,
+ key_share: bool = False,
) -> SelfGenerativeSelect:
"""Specify a ``FOR UPDATE`` clause for this
:class:`_expression.GenerativeSelect`.
@@ -3241,20 +3766,25 @@ class GenerativeSelect(SelectBase):
return self
@property
- def _group_by_clause(self):
+ def _group_by_clause(self) -> ClauseList:
"""ClauseList access to group_by_clauses for legacy dialects"""
return ClauseList._construct_raw(
operators.comma_op, self._group_by_clauses
)
@property
- def _order_by_clause(self):
+ def _order_by_clause(self) -> ClauseList:
"""ClauseList access to order_by_clauses for legacy dialects"""
return ClauseList._construct_raw(
operators.comma_op, self._order_by_clauses
)
- def _offset_or_limit_clause(self, element, name=None, type_=None):
+ def _offset_or_limit_clause(
+ self,
+ element: Union[int, _ColumnExpressionArgument[Any]],
+ name: Optional[str] = None,
+ type_: Optional[_TypeEngineArgument[int]] = None,
+ ) -> ColumnElement[Any]:
"""Convert the given value to an "offset or limit" clause.
This handles incoming integers and converts to an expression; if
@@ -3265,7 +3795,21 @@ class GenerativeSelect(SelectBase):
roles.LimitOffsetRole, element, name=name, type_=type_
)
- def _offset_or_limit_clause_asint(self, clause, attrname):
+ @overload
+ def _offset_or_limit_clause_asint(
+ self, clause: ColumnElement[Any], attrname: str
+ ) -> NoReturn:
+ ...
+
+ @overload
+ def _offset_or_limit_clause_asint(
+ self, clause: Optional[_OffsetLimitParam], attrname: str
+ ) -> Optional[int]:
+ ...
+
+ def _offset_or_limit_clause_asint(
+ self, clause: Optional[ColumnElement[Any]], attrname: str
+ ) -> Union[NoReturn, Optional[int]]:
"""Convert the "offset or limit" clause of a select construct to an
integer.
@@ -3286,7 +3830,7 @@ class GenerativeSelect(SelectBase):
return util.asint(value)
@property
- def _limit(self):
+ def _limit(self) -> Optional[int]:
"""Get an integer value for the limit. This should only be used
by code that cannot support a limit as a BindParameter or
other custom clause as it will throw an exception if the limit
@@ -3295,14 +3839,14 @@ class GenerativeSelect(SelectBase):
"""
return self._offset_or_limit_clause_asint(self._limit_clause, "limit")
- def _simple_int_clause(self, clause):
+ def _simple_int_clause(self, clause: ClauseElement) -> bool:
"""True if the clause is a simple integer, False
if it is not present or is a SQL expression.
"""
return isinstance(clause, _OffsetLimitParam)
@property
- def _offset(self):
+ def _offset(self) -> Optional[int]:
"""Get an integer value for the offset. This should only be used
by code that cannot support an offset as a BindParameter or
other custom clause as it will throw an exception if the
@@ -3314,7 +3858,7 @@ class GenerativeSelect(SelectBase):
)
@property
- def _has_row_limiting_clause(self):
+ def _has_row_limiting_clause(self) -> bool:
return (
self._limit_clause is not None
or self._offset_clause is not None
@@ -3322,7 +3866,10 @@ class GenerativeSelect(SelectBase):
)
@_generative
- def limit(self: SelfGenerativeSelect, limit) -> SelfGenerativeSelect:
+ def limit(
+ self: SelfGenerativeSelect,
+ limit: Union[int, _ColumnExpressionArgument[int]],
+ ) -> SelfGenerativeSelect:
"""Return a new selectable with the given LIMIT criterion
applied.
@@ -3356,7 +3903,10 @@ class GenerativeSelect(SelectBase):
@_generative
def fetch(
- self: SelfGenerativeSelect, count, with_ties=False, percent=False
+ self: SelfGenerativeSelect,
+ count: Union[int, _ColumnExpressionArgument[int]],
+ with_ties: bool = False,
+ percent: bool = False,
) -> SelfGenerativeSelect:
"""Return a new selectable with the given FETCH FIRST criterion
applied.
@@ -3408,7 +3958,10 @@ class GenerativeSelect(SelectBase):
return self
@_generative
- def offset(self: SelfGenerativeSelect, offset) -> SelfGenerativeSelect:
+ def offset(
+ self: SelfGenerativeSelect,
+ offset: Union[int, _ColumnExpressionArgument[int]],
+ ) -> SelfGenerativeSelect:
"""Return a new selectable with the given OFFSET criterion
applied.
@@ -3438,7 +3991,11 @@ class GenerativeSelect(SelectBase):
@_generative
@util.preload_module("sqlalchemy.sql.util")
- def slice(self: SelfGenerativeSelect, start, stop) -> SelfGenerativeSelect:
+ def slice(
+ self: SelfGenerativeSelect,
+ start: int,
+ stop: int,
+ ) -> SelfGenerativeSelect:
"""Apply LIMIT / OFFSET to this statement based on a slice.
The start and stop indices behave like the argument to Python's
@@ -3485,7 +4042,9 @@ class GenerativeSelect(SelectBase):
return self
@_generative
- def order_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect:
+ def order_by(
+ self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of ORDER BY
criteria applied.
@@ -3522,7 +4081,9 @@ class GenerativeSelect(SelectBase):
return self
@_generative
- def group_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect:
+ def group_by(
+ self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of GROUP BY
criterion applied.
@@ -3567,6 +4128,15 @@ class CompoundSelectState(CompileState):
return d, d, d
+class _CompoundSelectKeyword(Enum):
+ UNION = "UNION"
+ UNION_ALL = "UNION ALL"
+ EXCEPT = "EXCEPT"
+ EXCEPT_ALL = "EXCEPT ALL"
+ INTERSECT = "INTERSECT"
+ INTERSECT_ALL = "INTERSECT ALL"
+
+
class CompoundSelect(HasCompileState, GenerativeSelect):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
SELECT-based set operations.
@@ -3590,7 +4160,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
__visit_name__ = "compound_select"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("selects", InternalTraversal.dp_clauseelement_list),
("_limit_clause", InternalTraversal.dp_clauseelement),
("_offset_clause", InternalTraversal.dp_clauseelement),
@@ -3602,17 +4172,16 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
("keyword", InternalTraversal.dp_string),
] + SupportsCloneAnnotations._clone_annotations_traverse_internals
- UNION = util.symbol("UNION")
- UNION_ALL = util.symbol("UNION ALL")
- EXCEPT = util.symbol("EXCEPT")
- EXCEPT_ALL = util.symbol("EXCEPT ALL")
- INTERSECT = util.symbol("INTERSECT")
- INTERSECT_ALL = util.symbol("INTERSECT ALL")
+ selects: List[SelectBase]
_is_from_container = True
_auto_correlate = False
- def __init__(self, keyword, *selects):
+ def __init__(
+ self,
+ keyword: _CompoundSelectKeyword,
+ *selects: _SelectStatementForCompoundArgument,
+ ):
self.keyword = keyword
self.selects = [
coercions.expect(roles.CompoundElementRole, s).self_group(
@@ -3624,36 +4193,50 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
GenerativeSelect.__init__(self)
@classmethod
- def _create_union(cls, *selects, **kwargs):
- return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs)
+ def _create_union(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.UNION, *selects)
@classmethod
- def _create_union_all(cls, *selects):
- return CompoundSelect(CompoundSelect.UNION_ALL, *selects)
+ def _create_union_all(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects)
@classmethod
- def _create_except(cls, *selects):
- return CompoundSelect(CompoundSelect.EXCEPT, *selects)
+ def _create_except(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects)
@classmethod
- def _create_except_all(cls, *selects):
- return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects)
+ def _create_except_all(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects)
@classmethod
- def _create_intersect(cls, *selects):
- return CompoundSelect(CompoundSelect.INTERSECT, *selects)
+ def _create_intersect(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects)
@classmethod
- def _create_intersect_all(cls, *selects):
- return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects)
+ def _create_intersect_all(
+ cls, *selects: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
+ return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects)
- def _scalar_type(self):
+ def _scalar_type(self) -> TypeEngine[Any]:
return self.selects[0]._scalar_type()
- def self_group(self, against=None):
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> GroupedElement:
return SelectStatementGrouping(self)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
for s in self.selects:
if s.is_derived_from(fromclause):
return True
@@ -3675,7 +4258,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return self
- def _generate_fromclause_column_proxies(self, subquery):
+ def _generate_fromclause_column_proxies(
+ self, subquery: FromClause
+ ) -> None:
# this is a slightly hacky thing - the union exports a
# column that resembles just that of the *first* selectable.
@@ -3716,8 +4301,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
def _all_selected_columns(self) -> _SelectIterable:
return self.selects[0]._all_selected_columns
- @property
- def selected_columns(self):
+ @util.ro_non_memoized_property
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -3739,6 +4324,11 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return self.selects[0].selected_columns
+# backwards compat
+for elem in _CompoundSelectKeyword:
+ setattr(CompoundSelect, elem.name, elem)
+
+
@CompileState.plugin_for("default", "select")
class SelectState(util.MemoizedSlots, CompileState):
__slots__ = (
@@ -3758,10 +4348,12 @@ class SelectState(util.MemoizedSlots, CompileState):
if TYPE_CHECKING:
@classmethod
- def get_plugin_class(cls, statement: Select) -> SelectState:
+ def get_plugin_class(cls, statement: Executable) -> Type[SelectState]:
...
- def __init__(self, statement, compiler, **kw):
+ def __init__(
+ self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any
+ ):
self.statement = statement
self.from_clauses = statement._from_obj
@@ -3778,14 +4370,16 @@ class SelectState(util.MemoizedSlots, CompileState):
self.columns_plus_names = statement._generate_columns_plus_names(True)
@classmethod
- def _plugin_not_implemented(cls):
+ def _plugin_not_implemented(cls) -> NoReturn:
raise NotImplementedError(
"The default SELECT construct without plugins does not "
"implement this method."
)
@classmethod
- def get_column_descriptions(cls, statement):
+ def get_column_descriptions(
+ cls, statement: Select
+ ) -> List[Dict[str, Any]]:
return [
{
"name": name,
@@ -3798,11 +4392,13 @@ class SelectState(util.MemoizedSlots, CompileState):
]
@classmethod
- def from_statement(cls, statement, from_statement):
+ def from_statement(
+ cls, statement: Select, from_statement: ReturnsRows
+ ) -> Any:
cls._plugin_not_implemented()
@classmethod
- def get_columns_clause_froms(cls, statement):
+ def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]:
return cls._normalize_froms(
itertools.chain.from_iterable(
element._from_objects for element in statement._raw_columns
@@ -3810,7 +4406,9 @@ class SelectState(util.MemoizedSlots, CompileState):
)
@classmethod
- def _column_naming_convention(cls, label_style):
+ def _column_naming_convention(
+ cls, label_style: SelectLabelStyle
+ ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]:
table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL
dedupe = label_style is not LABEL_STYLE_NONE
@@ -3850,7 +4448,8 @@ class SelectState(util.MemoizedSlots, CompileState):
return go
- def _get_froms(self, statement):
+ def _get_froms(self, statement: Select) -> List[FromClause]:
+ ambiguous_table_name_map: _AmbiguousTableNameMap
self._ambiguous_table_name_map = ambiguous_table_name_map = {}
return self._normalize_froms(
@@ -3876,10 +4475,10 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def _normalize_froms(
cls,
- iterable_of_froms,
- check_statement=None,
- ambiguous_table_name_map=None,
- ):
+ iterable_of_froms: Iterable[FromClause],
+ check_statement: Optional[Select] = None,
+ ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+ ) -> List[FromClause]:
"""given an iterable of things to select FROM, reduce them to what
would actually render in the FROM clause of a SELECT.
@@ -3888,12 +4487,12 @@ class SelectState(util.MemoizedSlots, CompileState):
etc.
"""
- seen = set()
- froms = []
+ seen: Set[FromClause] = set()
+ froms: List[FromClause] = []
for item in iterable_of_froms:
- if item._is_subquery and item.element is check_statement:
+ if is_subquery(item) and item.element is check_statement:
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
)
@@ -3923,7 +4522,7 @@ class SelectState(util.MemoizedSlots, CompileState):
)
for item in froms
for fr in item._from_objects
- if fr._is_table
+ if is_table(fr)
and fr.schema
and fr.name not in ambiguous_table_name_map
)
@@ -3931,8 +4530,10 @@ class SelectState(util.MemoizedSlots, CompileState):
return froms
def _get_display_froms(
- self, explicit_correlate_froms=None, implicit_correlate_froms=None
- ):
+ self,
+ explicit_correlate_froms: Optional[Sequence[FromClause]] = None,
+ implicit_correlate_froms: Optional[Sequence[FromClause]] = None,
+ ) -> List[FromClause]:
"""Return the full list of 'from' clauses to be displayed.
Takes into account a set of existing froms which may be
@@ -3998,25 +4599,33 @@ class SelectState(util.MemoizedSlots, CompileState):
return froms
- def _memoized_attr__label_resolve_dict(self):
- with_cols = dict(
- (c._tq_label or c.key, c)
+ def _memoized_attr__label_resolve_dict(
+ self,
+ ) -> Tuple[
+ Dict[str, ColumnElement[Any]],
+ Dict[str, ColumnElement[Any]],
+ Dict[str, ColumnElement[Any]],
+ ]:
+ with_cols: Dict[str, ColumnElement[Any]] = dict(
+ (c._tq_label or c.key, c) # type: ignore
for c in self.statement._all_selected_columns
if c._allow_label_resolve
)
- only_froms = dict(
- (c.key, c)
+ only_froms: Dict[str, ColumnElement[Any]] = dict(
+ (c.key, c) # type: ignore
for c in _select_iterables(self.froms)
if c._allow_label_resolve
)
- only_cols = with_cols.copy()
+ only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy()
for key, value in only_froms.items():
with_cols.setdefault(key, value)
return with_cols, only_froms, only_cols
@classmethod
- def determine_last_joined_entity(cls, stmt):
+ def determine_last_joined_entity(
+ cls, stmt: Select
+ ) -> Optional[_JoinTargetElement]:
if stmt._setup_joins:
return stmt._setup_joins[-1][0]
else:
@@ -4026,8 +4635,16 @@ class SelectState(util.MemoizedSlots, CompileState):
def all_selected_columns(cls, statement: Select) -> _SelectIterable:
return [c for c in _select_iterables(statement._raw_columns)]
- def _setup_joins(self, args, raw_columns):
+ def _setup_joins(
+ self,
+ args: Tuple[_SetupJoinsElement, ...],
+ raw_columns: List[_ColumnsClauseElement],
+ ) -> None:
for (right, onclause, left, flags) in args:
+ if TYPE_CHECKING:
+ if onclause is not None:
+ assert isinstance(onclause, ColumnElement)
+
isouter = flags["isouter"]
full = flags["full"]
@@ -4043,6 +4660,16 @@ class SelectState(util.MemoizedSlots, CompileState):
left
)
+ # these assertions can be made here, as if the right/onclause
+ # contained ORM elements, the select() statement would have been
+ # upgraded to an ORM select, and this method would not be called;
+ # orm.context.ORMSelectCompileState._join() would be
+ # used instead.
+ if TYPE_CHECKING:
+ assert isinstance(right, FromClause)
+ if onclause is not None:
+ assert isinstance(onclause, ColumnElement)
+
if replace_from_obj_index is not None:
# splice into an existing element in the
# self._from_obj list
@@ -4062,15 +4689,19 @@ class SelectState(util.MemoizedSlots, CompileState):
+ self.from_clauses[replace_from_obj_index + 1 :]
)
else:
-
+ assert left is not None
self.from_clauses = self.from_clauses + (
Join(left, right, onclause, isouter=isouter, full=full),
)
@util.preload_module("sqlalchemy.sql.util")
def _join_determine_implicit_left_side(
- self, raw_columns, left, right, onclause
- ):
+ self,
+ raw_columns: List[_ColumnsClauseElement],
+ left: Optional[FromClause],
+ right: _JoinTargetElement,
+ onclause: Optional[ColumnElement[Any]],
+ ) -> Tuple[Optional[FromClause], Optional[int]]:
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -4079,13 +4710,13 @@ class SelectState(util.MemoizedSlots, CompileState):
sql_util = util.preloaded.sql_util
- replace_from_obj_index = None
+ replace_from_obj_index: Optional[int] = None
from_clauses = self.from_clauses
if from_clauses:
- indexes = sql_util.find_left_clause_to_join_from(
+ indexes: List[int] = sql_util.find_left_clause_to_join_from(
from_clauses, right, onclause
)
@@ -4138,15 +4769,17 @@ class SelectState(util.MemoizedSlots, CompileState):
return left, replace_from_obj_index
@util.preload_module("sqlalchemy.sql.util")
- def _join_place_explicit_left_side(self, left):
- replace_from_obj_index = None
+ def _join_place_explicit_left_side(
+ self, left: FromClause
+ ) -> Optional[int]:
+ replace_from_obj_index: Optional[int] = None
sql_util = util.preloaded.sql_util
from_clauses = list(self.statement._iterate_from_elements())
if from_clauses:
- indexes = sql_util.find_left_clause_that_matches_given(
+ indexes: List[int] = sql_util.find_left_clause_that_matches_given(
self.from_clauses, left
)
else:
@@ -4171,7 +4804,13 @@ class SelectState(util.MemoizedSlots, CompileState):
class _SelectFromElements:
- def _iterate_from_elements(self):
+ __slots__ = ()
+
+ _raw_columns: List[_ColumnsClauseElement]
+ _where_criteria: Tuple[ColumnElement[Any], ...]
+ _from_obj: Tuple[FromClause, ...]
+
+ def _iterate_from_elements(self) -> Iterator[FromClause]:
# note this does not include elements
# in _setup_joins
@@ -4195,28 +4834,58 @@ class _SelectFromElements:
yield element
+Self_MemoizedSelectEntities = TypeVar("Self_MemoizedSelectEntities", bound=Any)
+
+
class _MemoizedSelectEntities(
cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
):
+ """represents partial state from a Select object, for the case
+ where Select.columns() has redefined the set of columns/entities the
+ statement will be SELECTing from. This object represents
+ the entities from the SELECT before that transformation was applied,
+ so that transformations that were made in terms of the SELECT at that
+ time, such as join() as well as options(), can access the correct context.
+
+ In previous SQLAlchemy versions, this wasn't needed because these
+ constructs calculated everything up front, like when you called join()
+ or options(), it did everything to figure out how that would translate
+ into specific SQL constructs that would be ready to send directly to the
+ SQL compiler when needed. But as of
+ 1.4, all of that stuff is done in the compilation phase, during the
+ "compile state" portion of the process, so that the work can all be
+ cached. So it needs to be able to resolve joins/options2 based on what
+ the list of entities was when those methods were called.
+
+
+ """
+
__visit_name__ = "memoized_select_entities"
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("_raw_columns", InternalTraversal.dp_clauseelement_list),
("_setup_joins", InternalTraversal.dp_setup_join_tuple),
("_with_options", InternalTraversal.dp_executable_options),
]
+ _is_clone_of: Optional[ClauseElement]
+ _raw_columns: List[_ColumnsClauseElement]
+ _setup_joins: Tuple[_SetupJoinsElement, ...]
+ _with_options: Tuple[ExecutableOption, ...]
+
_annotations = util.EMPTY_DICT
- def _clone(self, **kw):
+ def _clone(
+ self: Self_MemoizedSelectEntities, **kw: Any
+ ) -> Self_MemoizedSelectEntities:
c = self.__class__.__new__(self.__class__)
c.__dict__ = {k: v for k, v in self.__dict__.items()}
c._is_clone_of = self.__dict__.get("_is_clone_of", self)
- return c
+ return c # type: ignore
@classmethod
- def _generate_for_statement(cls, select_stmt):
+ def _generate_for_statement(cls, select_stmt: Select) -> None:
if select_stmt._setup_joins or select_stmt._with_options:
self = _MemoizedSelectEntities()
self._raw_columns = select_stmt._raw_columns
@@ -4224,12 +4893,10 @@ class _MemoizedSelectEntities(
self._with_options = select_stmt._with_options
select_stmt._memoized_select_entities += (self,)
- select_stmt._raw_columns = (
- select_stmt._setup_joins
- ) = select_stmt._with_options = ()
+ select_stmt._raw_columns = []
+ select_stmt._setup_joins = select_stmt._with_options = ()
-# TODO: use pep-673 when feasible
SelfSelect = typing.TypeVar("SelfSelect", bound="Select")
@@ -4258,9 +4925,11 @@ class Select(
__visit_name__ = "select"
- _setup_joins: Tuple[TODO_Any, ...] = ()
+ _setup_joins: Tuple[_SetupJoinsElement, ...] = ()
_memoized_select_entities: Tuple[TODO_Any, ...] = ()
+ _raw_columns: List[_ColumnsClauseElement]
+
_distinct = False
_distinct_on: Tuple[ColumnElement[Any], ...] = ()
_correlate: Tuple[FromClause, ...] = ()
@@ -4269,12 +4938,12 @@ class Select(
_having_criteria: Tuple[ColumnElement[Any], ...] = ()
_from_obj: Tuple[FromClause, ...] = ()
_auto_correlate = True
-
+ _is_select_statement = True
_compile_options: CacheableOptions = (
SelectState.default_select_compile_options
)
- _traverse_internals = (
+ _traverse_internals: _TraverseInternalsType = (
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
(
@@ -4306,12 +4975,14 @@ class Select(
+ Executable._executable_traverse_internals
)
- _cache_key_traversal = _traverse_internals + [
+ _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [
("_compile_options", InternalTraversal.dp_has_cache_key)
]
+ _compile_state_factory: Type[SelectState]
+
@classmethod
- def _create_raw_select(cls, **kw) -> "Select":
+ def _create_raw_select(cls, **kw: Any) -> Select:
"""Create a :class:`.Select` using raw ``__new__`` with no coercions.
Used internally to build up :class:`.Select` constructs with
@@ -4330,6 +5001,12 @@ class Select(
:func:`_sql.select` function.
"""
+ things = [
+ coercions.expect(
+ roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
+ )
+ for ent in entities
+ ]
self._raw_columns = [
coercions.expect(
@@ -4340,7 +5017,7 @@ class Select(
GenerativeSelect.__init__(self)
- def _scalar_type(self):
+ def _scalar_type(self) -> TypeEngine[Any]:
elem = self._raw_columns[0]
cols = list(elem._select_iterable)
return cols[0].type
@@ -4446,7 +5123,12 @@ class Select(
@_generative
def join(
- self: SelfSelect, target, onclause=None, *, isouter=False, full=False
+ self: SelfSelect,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ isouter: bool = False,
+ full: bool = False,
) -> SelfSelect:
r"""Create a SQL JOIN against this :class:`_expression.Select`
object's criterion
@@ -4505,17 +5187,32 @@ class Select(
:meth:`_expression.Select.outerjoin`
""" # noqa: E501
- target = coercions.expect(
+ join_target = coercions.expect(
roles.JoinTargetRole, target, apply_propagate_attrs=self
)
if onclause is not None:
- onclause = coercions.expect(roles.OnClauseRole, onclause)
+ onclause_element = coercions.expect(roles.OnClauseRole, onclause)
+ else:
+ onclause_element = None
+
self._setup_joins += (
- (target, onclause, None, {"isouter": isouter, "full": full}),
+ (
+ join_target,
+ onclause_element,
+ None,
+ {"isouter": isouter, "full": full},
+ ),
)
return self
- def outerjoin_from(self, from_, target, onclause=None, *, full=False):
+ def outerjoin_from(
+ self: SelfSelect,
+ from_: _FromClauseArgument,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ full: bool = False,
+ ) -> SelfSelect:
r"""Create a SQL LEFT OUTER JOIN against this :class:`_expression.Select`
object's criterion
and apply generatively, returning the newly resulting
@@ -4531,12 +5228,12 @@ class Select(
@_generative
def join_from(
self: SelfSelect,
- from_,
- target,
- onclause=None,
+ from_: _FromClauseArgument,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
*,
- isouter=False,
- full=False,
+ isouter: bool = False,
+ full: bool = False,
) -> SelfSelect:
r"""Create a SQL JOIN against this :class:`_expression.Select`
object's criterion
@@ -4586,18 +5283,31 @@ class Select(
from_ = coercions.expect(
roles.FromClauseRole, from_, apply_propagate_attrs=self
)
- target = coercions.expect(
+ join_target = coercions.expect(
roles.JoinTargetRole, target, apply_propagate_attrs=self
)
if onclause is not None:
- onclause = coercions.expect(roles.OnClauseRole, onclause)
+ onclause_element = coercions.expect(roles.OnClauseRole, onclause)
+ else:
+ onclause_element = None
self._setup_joins += (
- (target, onclause, from_, {"isouter": isouter, "full": full}),
+ (
+ join_target,
+ onclause_element,
+ from_,
+ {"isouter": isouter, "full": full},
+ ),
)
return self
- def outerjoin(self, target, onclause=None, *, full=False):
+ def outerjoin(
+ self: SelfSelect,
+ target: _JoinTargetArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ *,
+ full: bool = False,
+ ) -> SelfSelect:
"""Create a left outer join.
Parameters are the same as that of :meth:`_expression.Select.join`.
@@ -4634,7 +5344,7 @@ class Select(
"""
return self.join(target, onclause=onclause, isouter=True, full=full)
- def get_final_froms(self):
+ def get_final_froms(self) -> Sequence[FromClause]:
"""Compute the final displayed list of :class:`_expression.FromClause`
elements.
@@ -4671,6 +5381,7 @@ class Select(
:attr:`_sql.Select.columns_clause_froms`
"""
+
return self._compile_state_factory(self, None)._get_display_froms()
@util.deprecated_property(
@@ -4678,7 +5389,7 @@ class Select(
"The :attr:`_expression.Select.froms` attribute is moved to "
"the :meth:`_expression.Select.get_final_froms` method.",
)
- def froms(self):
+ def froms(self) -> Sequence[FromClause]:
"""Return the displayed list of :class:`_expression.FromClause`
elements.
@@ -4687,7 +5398,7 @@ class Select(
return self.get_final_froms()
@property
- def columns_clause_froms(self):
+ def columns_clause_froms(self) -> List[FromClause]:
"""Return the set of :class:`_expression.FromClause` objects implied
by the columns clause of this SELECT statement.
@@ -4720,7 +5431,7 @@ class Select(
return iter(self._all_selected_columns)
- def is_derived_from(self, fromclause):
+ def is_derived_from(self, fromclause: FromClause) -> bool:
if self in fromclause._cloned_set:
return True
@@ -4729,7 +5440,9 @@ class Select(
return True
return False
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
# Select() object has been cloned and probably adapted by the
# given clone function. Apply the cloning function to internal
# objects
@@ -4786,13 +5499,15 @@ class Select(
def get_children(self, **kwargs):
return itertools.chain(
super(Select, self).get_children(
- omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
+ omit_attrs=("_from_obj", "_correlate", "_correlate_except")
),
self._iterate_from_elements(),
)
@_generative
- def add_columns(self: SelfSelect, *columns) -> SelfSelect:
+ def add_columns(
+ self: SelfSelect, *columns: _ColumnsClauseArgument
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given column expressions added to its columns clause.
@@ -4816,7 +5531,9 @@ class Select(
]
return self
- def _set_entities(self, entities):
+ def _set_entities(
+ self, entities: Iterable[_ColumnsClauseArgument]
+ ) -> None:
self._raw_columns = [
coercions.expect(
roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
@@ -4830,7 +5547,7 @@ class Select(
"be removed in a future release. Please use "
":meth:`_expression.Select.add_columns`",
)
- def column(self, column):
+ def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given column expression added to its columns clause.
@@ -4847,7 +5564,9 @@ class Select(
return self.add_columns(column)
@util.preload_module("sqlalchemy.sql.util")
- def reduce_columns(self, only_synonyms=True):
+ def reduce_columns(
+ self: SelfSelect, only_synonyms: bool = True
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with redundantly
named, equivalently-valued columns removed from the columns clause.
@@ -4880,7 +5599,9 @@ class Select(
@_generative
def with_only_columns(
- self: SelfSelect, *columns, maintain_column_froms=False
+ self: SelfSelect,
+ *columns: _ColumnsClauseArgument,
+ maintain_column_froms: bool = False,
) -> SelfSelect:
r"""Return a new :func:`_expression.select` construct with its columns
clause replaced with the given columns.
@@ -4941,7 +5662,9 @@ class Select(
self._assert_no_memoizations()
if maintain_column_froms:
- self.select_from.non_generative(self, *self.columns_clause_froms)
+ self.select_from.non_generative( # type: ignore
+ self, *self.columns_clause_froms
+ )
# then memoize the FROMs etc.
_MemoizedSelectEntities._generate_for_statement(self)
@@ -4974,7 +5697,9 @@ class Select(
_whereclause = whereclause
@_generative
- def where(self: SelfSelect, *whereclause) -> SelfSelect:
+ def where(
+ self: SelfSelect, *whereclause: _ColumnExpressionArgument[bool]
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given expression added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -4984,24 +5709,33 @@ class Select(
assert isinstance(self._where_criteria, tuple)
for criterion in whereclause:
- where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ where_criteria: ColumnElement[Any] = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
self._where_criteria += (where_criteria,)
return self
@_generative
- def having(self: SelfSelect, having) -> SelfSelect:
+ def having(
+ self: SelfSelect, *having: _ColumnExpressionArgument[bool]
+ ) -> SelfSelect:
"""Return a new :func:`_expression.select` construct with
the given expression added to
its HAVING clause, joined to the existing clause via AND, if any.
"""
- self._having_criteria += (
- coercions.expect(roles.WhereHavingRole, having),
- )
+
+ for criterion in having:
+ having_criteria = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
+ self._having_criteria += (having_criteria,)
return self
@_generative
- def distinct(self: SelfSelect, *expr) -> SelfSelect:
+ def distinct(
+ self: SelfSelect, *expr: _ColumnExpressionArgument[Any]
+ ) -> SelfSelect:
r"""Return a new :func:`_expression.select` construct which
will apply DISTINCT to its columns clause.
@@ -5023,7 +5757,9 @@ class Select(
return self
@_generative
- def select_from(self: SelfSelect, *froms) -> SelfSelect:
+ def select_from(
+ self: SelfSelect, *froms: _FromClauseArgument
+ ) -> SelfSelect:
r"""Return a new :func:`_expression.select` construct with the
given FROM expression(s)
merged into its list of FROM objects.
@@ -5067,7 +5803,10 @@ class Select(
return self
@_generative
- def correlate(self: SelfSelect, *fromclauses) -> SelfSelect:
+ def correlate(
+ self: SelfSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfSelect:
r"""Return a new :class:`_expression.Select`
which will correlate the given FROM
clauses to that of an enclosing :class:`_expression.Select`.
@@ -5106,10 +5845,10 @@ class Select(
none of its FROM entries, and all will render unconditionally
in the local FROM clause.
- :param \*fromclauses: a list of one or more
- :class:`_expression.FromClause`
- constructs, or other compatible constructs (i.e. ORM-mapped
- classes) to become part of the correlate collection.
+ :param \*fromclauses: one or more :class:`.FromClause` or other
+ FROM-compatible construct such as an ORM mapped entity to become part
+ of the correlate collection; alternatively pass a single value
+ ``None`` to remove all existing correlations.
.. seealso::
@@ -5119,8 +5858,16 @@ class Select(
"""
+ # tests failing when we try to change how these
+ # arguments are passed
+
self._auto_correlate = False
- if fromclauses and fromclauses[0] in {None, False}:
+ if not fromclauses or fromclauses[0] in {None, False}:
+ if len(fromclauses) > 1:
+ raise exc.ArgumentError(
+ "additional FROM objects not accepted when "
+ "passing None/False to correlate()"
+ )
self._correlate = ()
else:
self._correlate = self._correlate + tuple(
@@ -5129,7 +5876,10 @@ class Select(
return self
@_generative
- def correlate_except(self: SelfSelect, *fromclauses) -> SelfSelect:
+ def correlate_except(
+ self: SelfSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfSelect:
r"""Return a new :class:`_expression.Select`
which will omit the given FROM
clauses from the auto-correlation process.
@@ -5141,9 +5891,9 @@ class Select(
all other FROM elements remain subject to normal auto-correlation
behaviors.
- If ``None`` is passed, the :class:`_expression.Select`
- object will correlate
- all of its FROM entries.
+ If ``None`` is passed, or no arguments are passed,
+ the :class:`_expression.Select` object will correlate all of its
+ FROM entries.
:param \*fromclauses: a list of one or more
:class:`_expression.FromClause`
@@ -5159,16 +5909,22 @@ class Select(
"""
self._auto_correlate = False
- if fromclauses and fromclauses[0] in {None, False}:
+ if not fromclauses or fromclauses[0] in {None, False}:
+ if len(fromclauses) > 1:
+ raise exc.ArgumentError(
+ "additional FROM objects not accepted when "
+ "passing None/False to correlate_except()"
+ )
self._correlate_except = ()
else:
self._correlate_except = (self._correlate_except or ()) + tuple(
coercions.expect(roles.FromClauseRole, f) for f in fromclauses
)
+
return self
- @HasMemoized.memoized_attribute
- def selected_columns(self):
+ @HasMemoized_ro_memoized_attribute
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -5214,18 +5970,22 @@ class Select(
# generates the actual names used in the SELECT string. that
# method is more complex because it also renders columns that are
# fully ambiguous, e.g. same column more than once.
- conv = SelectState._column_naming_convention(self._label_style)
+ conv = cast(
+ "Callable[[Any], str]",
+ SelectState._column_naming_convention(self._label_style),
+ )
- return ColumnCollection(
+ cc: ColumnCollection[str, ColumnElement[Any]] = ColumnCollection(
[
(conv(c), c)
for c in self._all_selected_columns
if is_column_element(c)
]
- ).as_readonly()
+ )
+ return cc.as_readonly()
@HasMemoized.memoized_attribute
- def _all_selected_columns(self) -> Sequence[ColumnElement[Any]]:
+ def _all_selected_columns(self) -> _SelectIterable:
meth = SelectState.get_plugin_class(self).all_selected_columns
return list(meth(self))
@@ -5234,173 +5994,9 @@ class Select(
self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
return self
- def _generate_columns_plus_names(
- self, anon_for_dupe_key: bool
- ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
- """Generate column names as rendered in a SELECT statement by
- the compiler.
-
- This is distinct from the _column_naming_convention generator that's
- intended for population of .c collections and similar, which has
- different rules. the collection returned here calls upon the
- _column_naming_convention as well.
-
- """
- cols = self._all_selected_columns
-
- key_naming_convention = SelectState._column_naming_convention(
- self._label_style
- )
-
- names = {}
-
- result = []
- result_append = result.append
-
- table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
- label_style_none = self._label_style is LABEL_STYLE_NONE
-
- # a counter used for "dedupe" labels, which have double underscores
- # in them and are never referred by name; they only act
- # as positional placeholders. they need only be unique within
- # the single columns clause they're rendered within (required by
- # some dbs such as mysql). So their anon identity is tracked against
- # a fixed counter rather than hash() identity.
- dedupe_hash = 1
-
- for c in cols:
- repeated = False
-
- if not c._render_label_in_columns_clause:
- effective_name = (
- required_label_name
- ) = fallback_label_name = None
- elif label_style_none:
- effective_name = required_label_name = None
- fallback_label_name = c._non_anon_label or c._anon_name_label
- else:
- if table_qualified:
- required_label_name = (
- effective_name
- ) = fallback_label_name = c._tq_label
- else:
- effective_name = fallback_label_name = c._non_anon_label
- required_label_name = None
-
- if effective_name is None:
- # it seems like this could be _proxy_key and we would
- # not need _expression_label but it isn't
- # giving us a clue when to use anon_label instead
- expr_label = c._expression_label
- if expr_label is None:
- repeated = c._anon_name_label in names
- names[c._anon_name_label] = c
- effective_name = required_label_name = None
-
- if repeated:
- # here, "required_label_name" is sent as
- # "None" and "fallback_label_name" is sent.
- if table_qualified:
- fallback_label_name = (
- c._dedupe_anon_tq_label_idx(dedupe_hash)
- )
- dedupe_hash += 1
- else:
- fallback_label_name = c._dedupe_anon_label_idx(
- dedupe_hash
- )
- dedupe_hash += 1
- else:
- fallback_label_name = c._anon_name_label
- else:
- required_label_name = (
- effective_name
- ) = fallback_label_name = expr_label
-
- if effective_name is not None:
- if effective_name in names:
- # when looking to see if names[name] is the same column as
- # c, use hash(), so that an annotated version of the column
- # is seen as the same as the non-annotated
- if hash(names[effective_name]) != hash(c):
-
- # different column under the same name. apply
- # disambiguating label
- if table_qualified:
- required_label_name = (
- fallback_label_name
- ) = c._anon_tq_label
- else:
- required_label_name = (
- fallback_label_name
- ) = c._anon_name_label
-
- if anon_for_dupe_key and required_label_name in names:
- # here, c._anon_tq_label is definitely unique to
- # that column identity (or annotated version), so
- # this should always be true.
- # this is also an infrequent codepath because
- # you need two levels of duplication to be here
- assert hash(names[required_label_name]) == hash(c)
-
- # the column under the disambiguating label is
- # already present. apply the "dedupe" label to
- # subsequent occurrences of the column so that the
- # original stays non-ambiguous
- if table_qualified:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
- dedupe_hash += 1
- else:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_label_idx(dedupe_hash)
- dedupe_hash += 1
- repeated = True
- else:
- names[required_label_name] = c
- elif anon_for_dupe_key:
- # same column under the same name. apply the "dedupe"
- # label so that the original stays non-ambiguous
- if table_qualified:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
- dedupe_hash += 1
- else:
- required_label_name = (
- fallback_label_name
- ) = c._dedupe_anon_label_idx(dedupe_hash)
- dedupe_hash += 1
- repeated = True
- else:
- names[effective_name] = c
-
- result_append(
- (
- # string label name, if non-None, must be rendered as a
- # label, i.e. "AS <name>"
- required_label_name,
- # proxy_key that is to be part of the result map for this
- # col. this is also the key in a fromclause.c or
- # select.selected_columns collection
- key_naming_convention(c),
- # name that can be used to render an "AS <name>" when
- # we have to render a label even though
- # required_label_name was not given
- fallback_label_name,
- # the ColumnElement itself
- c,
- # True if this is a duplicate of a previous column
- # in the list of columns
- repeated,
- )
- )
-
- return result
-
- def _generate_fromclause_column_proxies(self, subquery):
+ def _generate_fromclause_column_proxies(
+ self, subquery: FromClause
+ ) -> None:
"""Generate column proxies to place in the exported ``.c``
collection of a subquery."""
@@ -5418,7 +6014,7 @@ class Select(
c,
repeated,
) in (self._generate_columns_plus_names(False))
- if not c._is_text_clause
+ if is_column_element(c)
]
subquery._columns._populate_separate_keys(prox)
@@ -5428,7 +6024,10 @@ class Select(
self._order_by_clause.clauses
)
- def self_group(self, against=None):
+ def self_group(
+ self: Self, against: Optional[OperatorType] = None
+ ) -> Union[SelectStatementGrouping, Self]:
+ ...
"""Return a 'grouping' construct as per the
:class:`_expression.ClauseElement` specification.
@@ -5445,7 +6044,9 @@ class Select(
else:
return SelectStatementGrouping(self)
- def union(self, *other, **kwargs):
+ def union(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``UNION`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5460,9 +6061,11 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_union(self, *other, **kwargs)
+ return CompoundSelect._create_union(self, *other)
- def union_all(self, *other, **kwargs):
+ def union_all(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``UNION ALL`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5477,9 +6080,11 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_union_all(self, *other, **kwargs)
+ return CompoundSelect._create_union_all(self, *other)
- def except_(self, *other, **kwargs):
+ def except_(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``EXCEPT`` of this select() construct against
the given selectable provided as positional arguments.
@@ -5490,13 +6095,12 @@ class Select(
multiple elements are now accepted.
- :param \**kwargs: keyword arguments are forwarded to the constructor
- for the newly created :class:`_sql.CompoundSelect` object.
-
"""
- return CompoundSelect._create_except(self, *other, **kwargs)
+ return CompoundSelect._create_except(self, *other)
- def except_all(self, *other, **kwargs):
+ def except_all(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``EXCEPT ALL`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5507,13 +6111,12 @@ class Select(
multiple elements are now accepted.
- :param \**kwargs: keyword arguments are forwarded to the constructor
- for the newly created :class:`_sql.CompoundSelect` object.
-
"""
- return CompoundSelect._create_except_all(self, *other, **kwargs)
+ return CompoundSelect._create_except_all(self, *other)
- def intersect(self, *other, **kwargs):
+ def intersect(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``INTERSECT`` of this select() construct against
the given selectables provided as positional arguments.
@@ -5528,9 +6131,11 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_intersect(self, *other, **kwargs)
+ return CompoundSelect._create_intersect(self, *other)
- def intersect_all(self, *other, **kwargs):
+ def intersect_all(
+ self, *other: _SelectStatementForCompoundArgument
+ ) -> CompoundSelect:
r"""Return a SQL ``INTERSECT ALL`` of this select() construct
against the given selectables provided as positional arguments.
@@ -5545,13 +6150,17 @@ class Select(
for the newly created :class:`_sql.CompoundSelect` object.
"""
- return CompoundSelect._create_intersect_all(self, *other, **kwargs)
+ return CompoundSelect._create_intersect_all(self, *other)
-SelfScalarSelect = typing.TypeVar("SelfScalarSelect", bound="ScalarSelect")
+SelfScalarSelect = typing.TypeVar(
+ "SelfScalarSelect", bound="ScalarSelect[Any]"
+)
-class ScalarSelect(roles.InElementRole, Generative, Grouping):
+class ScalarSelect(
+ roles.InElementRole, Generative, GroupedElement, ColumnElement[_T]
+):
"""Represent a scalar subquery.
@@ -5570,15 +6179,33 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
"""
- _from_objects = []
+ _traverse_internals: _TraverseInternalsType = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ _from_objects: List[FromClause] = []
_is_from_container = True
- _is_implicitly_boolean = False
+ if not TYPE_CHECKING:
+ _is_implicitly_boolean = False
inherit_cache = True
+ element: SelectBase
+
def __init__(self, element):
self.element = element
self.type = element._scalar_type()
+ def __getattr__(self, attr):
+ return getattr(self.element, attr)
+
+ def __getstate__(self):
+ return {"element": self.element, "type": self.type}
+
+ def __setstate__(self, state):
+ self.element = state["element"]
+ self.type = state["type"]
+
@property
def columns(self):
raise exc.InvalidRequestError(
@@ -5590,19 +6217,39 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
c = columns
@_generative
- def where(self: SelfScalarSelect, crit) -> SelfScalarSelect:
+ def where(
+ self: SelfScalarSelect, crit: _ColumnExpressionArgument[bool]
+ ) -> SelfScalarSelect:
"""Apply a WHERE clause to the SELECT statement referred to
by this :class:`_expression.ScalarSelect`.
"""
- self.element = self.element.where(crit)
+ self.element = cast(Select, self.element).where(crit)
return self
- def self_group(self, **kwargs):
+ @overload
+ def self_group(
+ self: ScalarSelect[Any], against: Optional[OperatorType] = None
+ ) -> ScalarSelect[Any]:
+ ...
+
+ @overload
+ def self_group(
+ self: ColumnElement[Any], against: Optional[OperatorType] = None
+ ) -> ColumnElement[Any]:
+ ...
+
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> ColumnElement[Any]:
+
return self
@_generative
- def correlate(self: SelfScalarSelect, *fromclauses) -> SelfScalarSelect:
+ def correlate(
+ self: SelfScalarSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfScalarSelect:
r"""Return a new :class:`_expression.ScalarSelect`
which will correlate the given FROM
clauses to that of an enclosing :class:`_expression.Select`.
@@ -5631,12 +6278,13 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
"""
- self.element = self.element.correlate(*fromclauses)
+ self.element = cast(Select, self.element).correlate(*fromclauses)
return self
@_generative
def correlate_except(
- self: SelfScalarSelect, *fromclauses
+ self: SelfScalarSelect,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
) -> SelfScalarSelect:
r"""Return a new :class:`_expression.ScalarSelect`
which will omit the given FROM
@@ -5668,11 +6316,16 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
"""
- self.element = self.element.correlate_except(*fromclauses)
+ self.element = cast(Select, self.element).correlate_except(
+ *fromclauses
+ )
return self
-class Exists(UnaryExpression[_T]):
+SelfExists = TypeVar("SelfExists", bound="Exists")
+
+
+class Exists(UnaryExpression[bool]):
"""Represent an ``EXISTS`` clause.
See :func:`_sql.exists` for a description of usage.
@@ -5682,10 +6335,14 @@ class Exists(UnaryExpression[_T]):
"""
- _from_objects = ()
inherit_cache = True
- def __init__(self, __argument=None):
+ def __init__(
+ self,
+ __argument: Optional[
+ Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ ] = None,
+ ):
if __argument is None:
s = Select(literal_column("*")).scalar_subquery()
elif isinstance(__argument, (SelectBase, ScalarSelect)):
@@ -5701,12 +6358,16 @@ class Exists(UnaryExpression[_T]):
wraps_column_expression=True,
)
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ return []
+
def _regroup(self, fn):
element = self.element._ungroup()
element = fn(element)
return element.self_group(against=operators.exists)
- def select(self) -> "Select":
+ def select(self) -> Select:
r"""Return a SELECT of this :class:`_expression.Exists`.
e.g.::
@@ -5726,7 +6387,10 @@ class Exists(UnaryExpression[_T]):
return Select(self)
- def correlate(self, *fromclause):
+ def correlate(
+ self: SelfExists,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfExists:
"""Apply correlation to the subquery noted by this :class:`_sql.Exists`.
.. seealso::
@@ -5736,11 +6400,14 @@ class Exists(UnaryExpression[_T]):
"""
e = self._clone()
e.element = self._regroup(
- lambda element: element.correlate(*fromclause)
+ lambda element: element.correlate(*fromclauses)
)
return e
- def correlate_except(self, *fromclause):
+ def correlate_except(
+ self: SelfExists,
+ *fromclauses: Union[Literal[None, False], _FromClauseArgument],
+ ) -> SelfExists:
"""Apply correlation to the subquery noted by this :class:`_sql.Exists`.
.. seealso::
@@ -5751,11 +6418,11 @@ class Exists(UnaryExpression[_T]):
e = self._clone()
e.element = self._regroup(
- lambda element: element.correlate_except(*fromclause)
+ lambda element: element.correlate_except(*fromclauses)
)
return e
- def select_from(self, *froms):
+ def select_from(self: SelfExists, *froms: FromClause) -> SelfExists:
"""Return a new :class:`_expression.Exists` construct,
applying the given
expression to the :meth:`_expression.Select.select_from`
@@ -5772,7 +6439,9 @@ class Exists(UnaryExpression[_T]):
e.element = self._regroup(lambda element: element.select_from(*froms))
return e
- def where(self, *clause):
+ def where(
+ self: SelfExists, *clause: _ColumnExpressionArgument[bool]
+ ) -> SelfExists:
"""Return a new :func:`_expression.exists` construct with the
given expression added to
its WHERE clause, joined to the existing clause via AND, if any.
@@ -5824,7 +6493,7 @@ class TextualSelect(SelectBase):
_label_style = LABEL_STYLE_NONE
- _traverse_internals = [
+ _traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement),
("column_args", InternalTraversal.dp_clauseelement_list),
] + SupportsCloneAnnotations._clone_annotations_traverse_internals
@@ -5842,8 +6511,8 @@ class TextualSelect(SelectBase):
]
self.positional = positional
- @HasMemoized.memoized_attribute
- def selected_columns(self):
+ @HasMemoized_ro_memoized_attribute
+ def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -5868,6 +6537,13 @@ class TextualSelect(SelectBase):
(c.key, c) for c in self.column_args
).as_readonly()
+ # def _generate_columns_plus_names(
+ # self, anon_for_dupe_key: bool
+ # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
+ # return Select._generate_columns_plus_names(
+ # self, anon_for_dupe_key=anon_for_dupe_key
+ # )
+
@util.non_memoized_property
def _all_selected_columns(self) -> _SelectIterable:
return self.column_args
@@ -5880,7 +6556,9 @@ class TextualSelect(SelectBase):
@_generative
def bindparams(
- self: SelfTextualSelect, *binds, **bind_as_values
+ self: SelfTextualSelect,
+ *binds: BindParameter[Any],
+ **bind_as_values: Any,
) -> SelfTextualSelect:
self.element = self.element.bindparams(*binds, **bind_as_values)
return self
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 1f3d50876..c3653c264 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -228,7 +228,7 @@ class HasCopyInternals(HasTraverseInternals):
raise NotImplementedError()
def _copy_internals(
- self, omit_attrs: Iterable[str] = (), **kw: Any
+ self, *, omit_attrs: Iterable[str] = (), **kw: Any
) -> None:
"""Reassign internal elements to be clones of themselves.
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 9a934a50b..82adf4a4f 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -1008,10 +1008,7 @@ class TypeEngine(Visitable, Generic[_T]):
@util.preload_module("sqlalchemy.engine.default")
def _default_dialect(self) -> Dialect:
- if TYPE_CHECKING:
- from ..engine import default
- else:
- default = util.preloaded.engine_default
+ default = util.preloaded.engine_default
# dmypy / mypy seems to sporadically keep thinking this line is
# returning Any, which seems to be caused by the @deprecated_params
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index cdce49f7b..80711c4b5 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -13,10 +13,20 @@ from __future__ import annotations
from collections import deque
from itertools import chain
import typing
+from typing import AbstractSet
from typing import Any
+from typing import Callable
from typing import cast
+from typing import Dict
from typing import Iterator
+from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from . import coercions
from . import operators
@@ -49,11 +59,22 @@ from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
+from .visitors import _ET
from .. import exc
from .. import util
+from ..util.typing import Literal
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _ColumnExpressionArgument
+ from ._typing import _TypeEngineArgument
from .roles import FromClauseRole
+ from .selectable import _JoinTargetElement
+ from .selectable import _OnClauseElement
+ from .selectable import Selectable
+ from .visitors import _TraverseCallableType
+ from .visitors import ExternallyTraversible
+ from .visitors import ExternalTraversal
from ..engine.interfaces import _AnyExecuteParams
from ..engine.interfaces import _AnyMultiExecuteParams
from ..engine.interfaces import _AnySingleExecuteParams
@@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from):
return liberal_idx
-def find_left_clause_to_join_from(clauses, join_to, onclause):
+def find_left_clause_to_join_from(
+ clauses: Sequence[FromClause],
+ join_to: _JoinTargetElement,
+ onclause: Optional[ColumnElement[Any]],
+) -> List[int]:
"""Given a list of FROM clauses, a selectable,
and optional ON clause, return a list of integer indexes from the
clauses list indicating the clauses that can be joined from.
@@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
for i, f in enumerate(clauses):
for s in selectables.difference([f]):
if resolve_ambiguity:
+ assert cols_in_onclause is not None
if set(f.c).union(s.c).issuperset(cols_in_onclause):
idx.append(i)
break
@@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause):
# onclause was given and none of them resolved, so assume
# all indexes can match
if not idx and onclause is not None:
- return range(len(clauses))
+ return list(range(len(clauses)))
else:
return idx
@@ -247,7 +273,7 @@ def visit_binary_product(fn, expr):
a binary comparison is passed as pairs.
"""
- stack = []
+ stack: List[ClauseElement] = []
def visit(element):
if isinstance(element, ScalarSelect):
@@ -272,21 +298,22 @@ def visit_binary_product(fn, expr):
yield e
list(visit(expr))
- visit = None # remove gc cycles
+ visit = None # type: ignore # remove gc cycles
def find_tables(
- clause,
- check_columns=False,
- include_aliases=False,
- include_joins=False,
- include_selects=False,
- include_crud=False,
-):
+ clause: ClauseElement,
+ *,
+ check_columns: bool = False,
+ include_aliases: bool = False,
+ include_joins: bool = False,
+ include_selects: bool = False,
+ include_crud: bool = False,
+) -> List[TableClause]:
"""locate Table objects within the given expression."""
- tables = []
- _visitors = {}
+ tables: List[TableClause] = []
+ _visitors: Dict[str, _TraverseCallableType[Any]] = {}
if include_selects:
_visitors["select"] = _visitors["compound_select"] = tables.append
@@ -335,7 +362,7 @@ def unwrap_order_by(clause):
t = stack.popleft()
if isinstance(t, ColumnElement) and (
not isinstance(t, UnaryExpression)
- or not operators.is_ordering_modifier(t.modifier)
+ or not operators.is_ordering_modifier(t.modifier) # type: ignore
):
if isinstance(t, Label) and not isinstance(
t.element, ScalarSelect
@@ -365,9 +392,14 @@ def unwrap_order_by(clause):
def unwrap_label_reference(element):
- def replace(elem):
- if isinstance(elem, (_label_reference, _textual_label_reference)):
- return elem.element
+ def replace(
+ element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
+ if isinstance(element, _label_reference):
+ return element.element
+ elif isinstance(element, _textual_label_reference):
+ assert False, "can't unwrap a textual label reference"
+ return None
return visitors.replacement_traverse(element, {}, replace)
@@ -407,7 +439,7 @@ def clause_is_present(clause, search):
return False
-def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]:
+def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
if isinstance(clause, Join):
for t in tables_from_leftmost(clause.left):
yield t
@@ -509,6 +541,8 @@ class _repr_base:
__slots__ = ("max_chars",)
+ max_chars: int
+
def trunc(self, value: Any) -> str:
rep = repr(value)
lenrep = len(rep)
@@ -612,7 +646,7 @@ class _repr_params(_repr_base):
def _repr_multi(
self,
multi_params: _AnyMultiExecuteParams,
- typ,
+ typ: int,
) -> str:
if multi_params:
if isinstance(multi_params[0], list):
@@ -639,7 +673,7 @@ class _repr_params(_repr_base):
def _repr_params(
self,
- params: Optional[_AnySingleExecuteParams],
+ params: _AnySingleExecuteParams,
typ: int,
) -> str:
trunc = self.trunc
@@ -653,9 +687,10 @@ class _repr_params(_repr_base):
)
)
elif typ is self._TUPLE:
+ seq_params = cast("Sequence[Any]", params)
return "(%s%s)" % (
- ", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else "",
+ ", ".join(trunc(value) for value in seq_params),
+ "," if len(seq_params) == 1 else "",
)
else:
return "[%s]" % (", ".join(trunc(value) for value in params))
@@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
-def splice_joins(left, right, stop_on=None):
+def splice_joins(
+ left: Optional[FromClause],
+ right: Optional[FromClause],
+ stop_on: Optional[FromClause] = None,
+) -> Optional[FromClause]:
if left is None:
return right
- stack = [(right, None)]
+ stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
adapter = ClauseAdapter(left)
ret = None
@@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None):
else:
right = adapter.traverse(right)
if prevright is not None:
+ assert right is not None
prevright.left = right
if ret is None:
ret = right
@@ -845,11 +885,14 @@ def criterion_as_pairs(
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
- pairs = []
+ pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
+_CE = TypeVar("_CE", bound="ClauseElement")
+
+
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
@@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
def __init__(
self,
- selectable,
- equivalents=None,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
self.__traverse_options__ = {
"stop_on": [selectable],
@@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
self.adapt_on_names = adapt_on_names
self.adapt_from_selectables = adapt_from_selectables
+ if TYPE_CHECKING:
+
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ # note this specializes the ReplacingExternalTraversal.traverse()
+ # method to state
+ # that we will return the same kind of ExternalTraversal object as
+ # we were given. This is probably not 100% true, such as it's
+ # possible for us to swap out Alias for Table at the top level.
+ # Ideally there could be overloads specific to ColumnElement and
+ # FromClause but Mypy is not accepting those as compatible with
+ # the base ReplacingExternalTraversal
+ @overload
+ def traverse(self, obj: _ET) -> _ET:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
+ ...
+
def _corresponding_column(
self, col, require_embedded, _seen=util.EMPTY_SET
):
@@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
return newcol
@util.preload_module("sqlalchemy.sql.functions")
- def replace(self, col, _include_singleton_constants=False):
+ def replace(
+ self, col: _ET, _include_singleton_constants: bool = False
+ ) -> Optional[_ET]:
functions = util.preloaded.sql_functions
+ # TODO: cython candidate
+
if isinstance(col, FromClause) and not isinstance(
col, functions.FunctionElement
):
@@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
break
else:
return None
- return self.selectable
+ return self.selectable # type: ignore
elif isinstance(col, Alias) and isinstance(
col.element, TableClause
):
@@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# we are an alias of a table and we are not derived from an
# alias of a table (which nonetheless may be the same table
# as ours) so, same thing
- return col
+ return col # type: ignore
else:
# other cases where we are a selectable and the element
# is another join or selectable that contains a table which our
@@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
else:
return None
+ if TYPE_CHECKING:
+ assert isinstance(col, ColumnElement)
+
if self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
else:
- return self._corresponding_column(col, True)
+ return self._corresponding_column(col, True) # type: ignore
+
+
+class _ColumnLookup(Protocol):
+ def __getitem__(
+ self, key: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
+ ...
class ColumnAdapter(ClauseAdapter):
@@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter):
"""
+ columns: _ColumnLookup
+
def __init__(
self,
- selectable,
- equivalents=None,
- adapt_required=False,
- include_fn=None,
- exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False,
- adapt_from_selectables=None,
+ selectable: Selectable,
+ equivalents: Optional[
+ Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]]
+ ] = None,
+ adapt_required: bool = False,
+ include_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
+ adapt_on_names: bool = False,
+ allow_label_resolve: bool = True,
+ anonymize_labels: bool = False,
+ adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
):
ClauseAdapter.__init__(
self,
@@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter):
adapt_from_selectables=adapt_from_selectables,
)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
if self.include_fn or self.exclude_fn:
self.columns = self._IncludeExcludeMapping(self, self.columns)
self.adapt_required = adapt_required
@@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter):
ac = self.__class__.__new__(self.__class__)
ac.__dict__.update(self.__dict__)
ac._wrap = adapter
- ac.columns = util.WeakPopulateDict(ac._locate_col)
+ ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
if ac.include_fn or ac.exclude_fn:
ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
@@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter):
def traverse(self, obj):
return self.columns[obj]
+ def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
+ assert isinstance(visitor, ColumnAdapter)
+
+ return super().chain(visitor)
+
+ if TYPE_CHECKING:
+
+ @property
+ def visitor_iterator(self) -> Iterator[ColumnAdapter]:
+ ...
+
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
@@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter):
return newcol
- def _locate_col(self, col):
+ def _locate_col(
+ self, col: ColumnElement[Any]
+ ) -> Optional[ColumnElement[Any]]:
# both replace and traverse() are overly complicated for what
# we are doing here and we would do better to have an inlined
# version that doesn't build up as much overhead. the issue is that
@@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter):
def __setstate__(self, state):
self.__dict__.update(state)
- self.columns = util.WeakPopulateDict(self._locate_col)
+ self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
-def _offset_or_limit_clause(element, name=None, type_=None):
+def _offset_or_limit_clause(
+ element: Union[int, _ColumnExpressionArgument[int]],
+ name: Optional[str] = None,
+ type_: Optional[_TypeEngineArgument[int]] = None,
+) -> ColumnElement[int]:
"""Convert the given value to an "offset or limit" clause.
This handles incoming integers and converts to an expression; if
@@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None):
)
-def _offset_or_limit_clause_asint_if_possible(clause):
+def _offset_or_limit_clause_asint_if_possible(
+ clause: Optional[Union[int, _ColumnExpressionArgument[int]]]
+) -> Optional[Union[int, _ColumnExpressionArgument[int]]]:
"""Return the offset or limit clause as a simple integer if possible,
else return the clause.
@@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause):
if clause is None:
return None
if hasattr(clause, "_limit_offset_value"):
- value = clause._limit_offset_value
+ value = clause._limit_offset_value # type: ignore
return util.asint(value)
else:
return clause
-def _make_slice(limit_clause, offset_clause, start, stop):
+def _make_slice(
+ limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]],
+ start: int,
+ stop: int,
+) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
"""Compute LIMIT/OFFSET in terms of slice start/end"""
# for calculated limit/offset, try to do the addition of
# values to offset in Python, however if a SQL clause is present
# then the addition has to be on the SQL side.
+
+ # TODO: typing is finding a few gaps in here, see if they can be
+ # closed up
+
if start is not None and stop is not None:
offset_clause = _offset_or_limit_clause_asint_if_possible(
offset_clause
@@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
+ assert offset_clause is not None
offset_clause = _offset_or_limit_clause(offset_clause)
limit_clause = _offset_or_limit_clause(stop - start)
@@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop):
offset_clause = 0
if start != 0:
- offset_clause = offset_clause + start
+ offset_clause = offset_clause + start # type: ignore
if offset_clause == 0:
offset_clause = None
else:
- offset_clause = _offset_or_limit_clause(offset_clause)
+ offset_clause = _offset_or_limit_clause(
+ offset_clause # type: ignore
+ )
- return limit_clause, offset_clause
+ return limit_clause, offset_clause # type: ignore
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 903aae648..081faf1e9 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -28,6 +28,7 @@ from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -37,6 +38,7 @@ from .. import exc
from .. import util
from ..util import langhelpers
from ..util._has_cy import HAS_CYEXTENSION
+from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
@@ -599,8 +601,8 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
raise NotImplementedError()
def _copy_internals(
- self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any
- ) -> Self:
+ self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> None:
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
@@ -615,10 +617,24 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
_ET = TypeVar("_ET", bound=ExternallyTraversible)
+
+
_TraverseCallableType = Callable[[_ET], None]
-_TraverseTransformCallableType = Callable[
- [ExternallyTraversible], Optional[ExternallyTraversible]
-]
+
+
+class _CloneCallableType(Protocol):
+ def __call__(self, element: _ET, **kw: Any) -> _ET:
+ ...
+
+
+class _TraverseTransformCallableType(Protocol):
+ def __call__(
+ self, element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
+ ...
+
+
+_ExtT = TypeVar("_ExtT", bound="ExternalTraversal")
class ExternalTraversal:
@@ -640,7 +656,7 @@ class ExternalTraversal:
return meth(obj, **kw)
def iterate(
- self, obj: ExternallyTraversible
+ self, obj: Optional[ExternallyTraversible]
) -> Iterator[ExternallyTraversible]:
"""Traverse the given expression structure, returning an iterator
of all elements.
@@ -648,7 +664,17 @@ class ExternalTraversal:
"""
return iterate(obj, self.__traverse_options__)
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@@ -671,7 +697,7 @@ class ExternalTraversal:
yield v
v = getattr(v, "_next", None)
- def chain(self, visitor: ExternalTraversal) -> ExternalTraversal:
+ def chain(self: _ExtT, visitor: ExternalTraversal) -> _ExtT:
"""'Chain' an additional ExternalTraversal onto this ExternalTraversal
The chained visitor will receive all visit events after this one.
@@ -701,7 +727,17 @@ class CloningExternalTraversal(ExternalTraversal):
"""
return [self.traverse(x) for x in list_]
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
return cloned_traverse(
@@ -729,14 +765,25 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
"""
return None
+ @overload
+ def traverse(self, obj: Literal[None]) -> None:
+ ...
+
+ @overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
+ ...
+
+ def traverse(
+ self, obj: Optional[ExternallyTraversible]
+ ) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure."""
def replace(
- elem: ExternallyTraversible,
+ element: ExternallyTraversible,
+ **kw: Any,
) -> Optional[ExternallyTraversible]:
for v in self.visitor_iterator:
- e = cast(ReplacingExternalTraversal, v).replace(elem)
+ e = cast(ReplacingExternalTraversal, v).replace(element)
if e is not None:
return e
@@ -754,7 +801,8 @@ ReplacingCloningVisitor = ReplacingExternalTraversal
def iterate(
- obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any] = util.EMPTY_DICT,
) -> Iterator[ExternallyTraversible]:
r"""Traverse the given expression structure, returning an iterator.
@@ -776,6 +824,9 @@ def iterate(
empty in modern usage.
"""
+ if obj is None:
+ return
+
yield obj
children = obj.get_children(**opts)
@@ -790,11 +841,29 @@ def iterate(
stack.append(t.get_children(**opts))
+@overload
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: Literal[None],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def traverse_using(
iterator: Iterable[ExternallyTraversible],
obj: ExternallyTraversible,
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: Optional[ExternallyTraversible],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Visit the given expression structure using the given iterator of
objects.
@@ -826,11 +895,29 @@ def traverse_using(
return obj
+@overload
+def traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Traverse and visit the given expression structure using the default
iterator.
@@ -863,11 +950,29 @@ def traverse(
return traverse_using(iterate(obj, opts), obj, visitors)
+@overload
+def cloned_traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> None:
+ ...
+
+
+@overload
def cloned_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
+ ...
+
+
+def cloned_traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing modifications by
visitors.
@@ -931,11 +1036,29 @@ def cloned_traverse(
return obj
+@overload
+def replacement_traverse(
+ obj: Literal[None],
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> None:
+ ...
+
+
+@overload
def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
replace: _TraverseTransformCallableType,
) -> ExternallyTraversible:
+ ...
+
+
+def replacement_traverse(
+ obj: Optional[ExternallyTraversible],
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index 4496b8ded..f1bf5c0c4 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -306,6 +306,7 @@ def testing_engine(
options=None,
asyncio=False,
transfer_staticpool=False,
+ share_pool=False,
_sqlite_savepoint=False,
):
if asyncio:
@@ -356,6 +357,8 @@ def testing_engine(
if config.db is not None and isinstance(config.db.pool, StaticPool):
use_reaper = False
engine.pool._transfer_from(config.db.pool)
+ elif share_pool:
+ engine.pool = config.db.pool
if scope == "global":
if asyncio:
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 6c6b21fce..8c0120bcc 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -121,6 +121,7 @@ class TestBase:
future=None,
asyncio=False,
transfer_staticpool=False,
+ share_pool=False,
):
if options is None:
options = {}
@@ -130,6 +131,7 @@ class TestBase:
options=options,
asyncio=asyncio,
transfer_staticpool=transfer_staticpool,
+ share_pool=share_pool,
)
yield gen_testing_engine
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 406c8af24..c0c2e7dfb 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -9,7 +9,9 @@
from collections import defaultdict as defaultdict
from functools import partial as partial
from functools import update_wrapper as update_wrapper
+from typing import TYPE_CHECKING
+from . import preloaded as preloaded
from ._collections import coerce_generator_arg as coerce_generator_arg
from ._collections import coerce_to_immutabledict as coerce_to_immutabledict
from ._collections import column_dict as column_dict
@@ -44,8 +46,6 @@ from ._collections import UniqueAppender as UniqueAppender
from ._collections import update_copy as update_copy
from ._collections import WeakPopulateDict as WeakPopulateDict
from ._collections import WeakSequence as WeakSequence
-from ._preloaded import preload_module as preload_module
-from ._preloaded import preloaded as preloaded
from .compat import arm as arm
from .compat import b as b
from .compat import b64decode as b64decode
@@ -148,3 +148,4 @@ from .langhelpers import warn as warn
from .langhelpers import warn_exception as warn_exception
from .langhelpers import warn_limited as warn_limited
from .langhelpers import wrap_callable as wrap_callable
+from .preloaded import preload_module as preload_module
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index bd73bf714..bcb2ad423 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -463,11 +463,12 @@ def update_copy(d, _new=None, **kw):
return d
-def flatten_iterator(x):
+def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]:
"""Given an iterator of which further sub-elements may also be
iterators, flatten the sub-elements into a single iterator.
"""
+ elem: _T
for elem in x:
if not isinstance(elem, str) and hasattr(elem, "__iter__"):
for y in flatten_iterator(elem):
diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/preloaded.py
index 511b93351..c861c83b3 100644
--- a/lib/sqlalchemy/util/_preloaded.py
+++ b/lib/sqlalchemy/util/preloaded.py
@@ -16,10 +16,16 @@ from types import ModuleType
import typing
from typing import Any
from typing import Callable
+from typing import TYPE_CHECKING
from typing import TypeVar
_FN = TypeVar("_FN", bound=Callable[..., Any])
+if TYPE_CHECKING:
+ from sqlalchemy.engine import default as engine_default
+ from sqlalchemy.sql import dml as sql_dml
+ from sqlalchemy.sql import util as sql_util
+
class _ModuleRegistry:
"""Registry of modules to load in a package init file.
@@ -67,7 +73,7 @@ class _ModuleRegistry:
not path or module.startswith(path)
) and key not in self.__dict__:
__import__(module, globals(), locals())
- self.__dict__[key] = sys.modules[module]
+ self.__dict__[key] = globals()[key] = sys.modules[module]
if typing.TYPE_CHECKING:
@@ -75,5 +81,11 @@ class _ModuleRegistry:
...
-preloaded = _ModuleRegistry()
-preload_module = preloaded.preload_module
+_reg = _ModuleRegistry()
+preload_module = _reg.preload_module
+import_prefix = _reg.import_prefix
+
+if TYPE_CHECKING:
+
+ def __getattr__(key: str) -> ModuleType:
+ ...
diff --git a/pyproject.toml b/pyproject.toml
index cc79e8646..012f1bffa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -68,9 +68,6 @@ module = [
"sqlalchemy.ext.mutable",
"sqlalchemy.ext.horizontal_shard",
- "sqlalchemy.sql._selectable_constructors",
- "sqlalchemy.sql._dml_constructors",
-
# TODO for non-strict:
"sqlalchemy.ext.baked",
"sqlalchemy.ext.instrumentation",
@@ -78,11 +75,6 @@ module = [
"sqlalchemy.ext.orderinglist",
"sqlalchemy.ext.serializer",
- "sqlalchemy.sql.selectable", # would be nice as strict
- "sqlalchemy.sql.functions", # would be nice as strict
- "sqlalchemy.sql.lambdas",
- "sqlalchemy.sql.util",
-
# not yet classified:
"sqlalchemy.orm.*",
"sqlalchemy.dialects.*",
@@ -132,10 +124,14 @@ module = [
"sqlalchemy.sql.crud",
"sqlalchemy.sql.ddl", # would be nice as strict
"sqlalchemy.sql.elements", # would be nice as strict
+ "sqlalchemy.sql.functions", # would be nice as strict, requires sqltypes
+ "sqlalchemy.sql.lambdas",
"sqlalchemy.sql.naming",
+ "sqlalchemy.sql.selectable", # would be nice as strict
"sqlalchemy.sql.schema", # would be nice as strict
"sqlalchemy.sql.sqltypes", # would be nice as strict
"sqlalchemy.sql.traversals",
+ "sqlalchemy.sql.util",
"sqlalchemy.util.*",
]
diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py
index 06be9fbd8..e03a8415d 100644
--- a/test/aaa_profiling/test_orm.py
+++ b/test/aaa_profiling/test_orm.py
@@ -835,27 +835,14 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest):
)
s.commit()
- def test_build_query(self):
+ def test_fetch_results_integrated(self, testing_engine):
A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
- sess = fixture_session()
+ # this test has been reworked to use the compiled cache again,
+ # as a real-world scenario.
- @profiling.function_call_count()
- def go():
- for i in range(100):
- q = sess.query(A).options(
- joinedload(A.bs).joinedload(B.cs).joinedload(C.ds),
- joinedload(A.es).joinedload(E.fs),
- defaultload(A.es).joinedload(E.gs),
- )
- q._compile_context()
-
- go()
-
- def test_fetch_results(self):
- A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G")
-
- sess = Session(testing.db)
+ eng = testing_engine(share_pool=True)
+ sess = Session(eng)
q = sess.query(A).options(
joinedload(A.bs).joinedload(B.cs).joinedload(C.ds),
@@ -863,47 +850,27 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest):
defaultload(A.es).joinedload(E.gs),
)
- compile_state = q._compile_state()
+ @profiling.function_call_count()
+ def initial_run():
+ list(q.all())
- from sqlalchemy.orm.context import ORMCompileState
+ initial_run()
+ sess.close()
- @profiling.function_call_count(warmup=1)
- def go():
- for i in range(100):
- # NOTE: this test was broken in
- # 77f1b7d236dba6b1c859bb428ef32d118ec372e6 because we started
- # clearing out the attributes after the first iteration. make
- # sure the attributes are there every time.
- assert compile_state.attributes
- exec_opts = {}
- bind_arguments = {}
- ORMCompileState.orm_pre_session_exec(
- sess,
- compile_state.select_statement,
- {},
- exec_opts,
- bind_arguments,
- is_reentrant_invoke=False,
- )
+ @profiling.function_call_count()
+ def subsequent_run():
+ list(q.all())
- r = sess.connection().execute(
- compile_state.statement,
- execution_options=exec_opts,
- )
+ subsequent_run()
+ sess.close()
- r.context.compiled.compile_state = compile_state
- obj = ORMCompileState.orm_setup_cursor_result(
- sess,
- compile_state.statement,
- {},
- exec_opts,
- {},
- r,
- )
- list(obj.unique())
- sess.close()
+ @profiling.function_call_count()
+ def more_runs():
+ for i in range(100):
+ list(q.all())
- go()
+ more_runs()
+ sess.close()
class JoinConditionTest(NoCache, fixtures.DeclarativeMappedTest):
diff --git a/test/base/test_utils.py b/test/base/test_utils.py
index fc61e39b6..e22340da6 100644
--- a/test/base/test_utils.py
+++ b/test/base/test_utils.py
@@ -25,11 +25,11 @@ from sqlalchemy.testing import mock
from sqlalchemy.testing import ne_
from sqlalchemy.testing.util import gc_collect
from sqlalchemy.testing.util import picklers
-from sqlalchemy.util import _preloaded
from sqlalchemy.util import classproperty
from sqlalchemy.util import compat
from sqlalchemy.util import get_callable_argspec
from sqlalchemy.util import langhelpers
+from sqlalchemy.util import preloaded
from sqlalchemy.util import WeakSequence
from sqlalchemy.util._collections import merge_lists_w_ordering
@@ -3187,7 +3187,7 @@ class TestModuleRegistry(fixtures.TestBase):
for m in ("xml.dom", "wsgiref.simple_server"):
to_restore.append((m, sys.modules.pop(m, None)))
try:
- mr = _preloaded._ModuleRegistry()
+ mr = preloaded._ModuleRegistry()
ret = mr.preload_module(
"xml.dom", "wsgiref.simple_server", "sqlalchemy.sql.util"
diff --git a/test/profiles.txt b/test/profiles.txt
index 31f72bd16..7b4f37734 100644
--- a/test/profiles.txt
+++ b/test/profiles.txt
@@ -196,16 +196,16 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96844
test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102344
-# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query
-
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520615
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522475
# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results
test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 440705
test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 458805
+# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated
+
+test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 30373,1014,96450
+
# TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity
test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 22984
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index dd073d2a5..8d6dc7553 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -208,6 +208,8 @@ class CoreFixtures:
column("q") == column("x"),
column("q") == column("y"),
column("z") == column("x"),
+ (column("z") == column("x")).self_group(),
+ (column("q") == column("x")).self_group(),
column("z") + column("x"),
column("z") - column("x"),
column("x") - column("z"),
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index d20037e92..6ca06dc0e 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -82,7 +82,6 @@ from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import CompilerColumnElement
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.sql.expression import ClauseList
-from sqlalchemy.sql.expression import HasPrefixes
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlalchemy.testing import assert_raises
@@ -270,18 +269,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
"columns",
)
- def test_prefix_constructor(self):
- class Pref(HasPrefixes):
- def _generate(self):
- return self
-
- assert_raises(
- exc.ArgumentError,
- Pref().prefix_with,
- "some prefix",
- not_a_dialect=True,
- )
-
def test_table_select(self):
self.assert_compile(
table1.select(),
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index 686d4928d..d1d01a5c7 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -1186,6 +1186,37 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
dialect="postgresql",
)
+ def test_recursive_dml_syntax(self):
+ orders = table(
+ "orders",
+ column("region"),
+ column("amount"),
+ column("product"),
+ column("quantity"),
+ )
+
+ upsert = (
+ orders.update()
+ .where(orders.c.region == "Region1")
+ .values(amount=1.0, product="Product1", quantity=1)
+ .returning(*(orders.c._all_columns))
+ .cte("upsert", recursive=True)
+ )
+ stmt = select(upsert)
+
+ # This statement probably makes no sense, just want to see that the
+ # column generation aspect needed by RECURSIVE works (new in 2.0)
+ self.assert_compile(
+ stmt,
+ "WITH RECURSIVE upsert(region, amount, product, quantity) "
+ "AS (UPDATE orders SET amount=:param_1, product=:param_2, "
+ "quantity=:param_3 WHERE orders.region = :region_1 "
+ "RETURNING orders.region, orders.amount, orders.product, "
+ "orders.quantity) "
+ "SELECT upsert.region, upsert.amount, upsert.product, "
+ "upsert.quantity FROM upsert",
+ )
+
def test_upsert_from_select(self):
orders = table(
"orders",
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index ca5f43bb6..9fdc51938 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -432,6 +432,24 @@ class SelectableTest(
):
select(stmt.subquery()).compile()
+ def test_correlate_none_arg_error(self):
+ stmt = select(table1)
+ with expect_raises_message(
+ exc.ArgumentError,
+ "additional FROM objects not accepted when passing "
+ "None/False to correlate",
+ ):
+ stmt.correlate(None, table2)
+
+ def test_correlate_except_none_arg_error(self):
+ stmt = select(table1)
+ with expect_raises_message(
+ exc.ArgumentError,
+ "additional FROM objects not accepted when passing "
+ "None/False to correlate_except",
+ ):
+ stmt.correlate_except(None, table2)
+
def test_select_label_grouped_still_corresponds(self):
label = select(table1.c.col1).label("foo")
label2 = label.self_group()
diff --git a/test/sql/test_text.py b/test/sql/test_text.py
index 81b20f86f..0f645a2d2 100644
--- a/test/sql/test_text.py
+++ b/test/sql/test_text.py
@@ -688,6 +688,19 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
mapping = self._mapping(s)
assert x not in mapping
+ def test_subquery_accessors(self):
+ t = self._xy_table_fixture()
+
+ s = text("SELECT x from t").columns(t.c.x)
+
+ self.assert_compile(
+ select(s.scalar_subquery()), "SELECT (SELECT x from t) AS anon_1"
+ )
+ self.assert_compile(
+ select(s.subquery()),
+ "SELECT anon_1.x FROM (SELECT x from t) AS anon_1",
+ )
+
def test_select_label_alt_name_table_alias_column(self):
t = self._xy_table_fixture()
x = t.c.x
@@ -716,6 +729,36 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
"FROM mytable, t WHERE mytable.myid = t.id",
)
+ def test_cte_recursive(self):
+ t = (
+ text("select id, name from user")
+ .columns(id=Integer, name=String)
+ .cte("t", recursive=True)
+ )
+
+ s = select(table1).where(table1.c.myid == t.c.id)
+ self.assert_compile(
+ s,
+ "WITH RECURSIVE t(id, name) AS (select id, name from user) "
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable, t WHERE mytable.myid = t.id",
+ )
+
+ def test_unions(self):
+ s1 = text("select id, name from user where id > 5").columns(
+ id=Integer, name=String
+ )
+ s2 = text("select id, name from user where id < 15").columns(
+ id=Integer, name=String
+ )
+ stmt = union(s1, s2)
+ eq_(stmt.selected_columns.keys(), ["id", "name"])
+ self.assert_compile(
+ stmt,
+ "select id, name from user where id > 5 UNION "
+ "select id, name from user where id < 15",
+ )
+
def test_subquery(self):
t = (
text("select id, name from user")
diff --git a/test/sql/test_values.py b/test/sql/test_values.py
index f5ae9ea53..d14de9aee 100644
--- a/test/sql/test_values.py
+++ b/test/sql/test_values.py
@@ -294,6 +294,31 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
checkparams={},
)
+ def test_anon_alias(self):
+ people = self.tables.people
+ values = (
+ Values(
+ column("bookcase_id", Integer),
+ column("bookcase_owner_id", Integer),
+ )
+ .data([(1, 1), (2, 1), (3, 2), (3, 3)])
+ .alias()
+ )
+ stmt = select(people, values).select_from(
+ people.join(
+ values, values.c.bookcase_owner_id == people.c.people_id
+ )
+ )
+ self.assert_compile(
+ stmt,
+ "SELECT people.people_id, people.age, people.name, "
+ "anon_1.bookcase_id, anon_1.bookcase_owner_id FROM people "
+ "JOIN (VALUES (:param_1, :param_2), (:param_3, :param_4), "
+ "(:param_5, :param_6), (:param_7, :param_8)) AS anon_1 "
+ "(bookcase_id, bookcase_owner_id) "
+ "ON people.people_id = anon_1.bookcase_owner_id",
+ )
+
def test_with_join_unnamed(self):
people = self.tables.people
values = Values(