summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_py_util.py25
-rw-r--r--lib/sqlalchemy/sql/annotation.py305
-rw-r--r--lib/sqlalchemy/sql/base.py17
-rw-r--r--lib/sqlalchemy/sql/cache_key.py354
-rw-r--r--lib/sqlalchemy/sql/coercions.py235
-rw-r--r--lib/sqlalchemy/sql/compiler.py447
-rw-r--r--lib/sqlalchemy/sql/dml.py8
-rw-r--r--lib/sqlalchemy/sql/elements.py34
-rw-r--r--lib/sqlalchemy/sql/functions.py3
-rw-r--r--lib/sqlalchemy/sql/roles.py33
-rw-r--r--lib/sqlalchemy/sql/schema.py6
-rw-r--r--lib/sqlalchemy/sql/selectable.py29
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py2
-rw-r--r--lib/sqlalchemy/sql/traversals.py64
-rw-r--r--lib/sqlalchemy/sql/util.py6
-rw-r--r--lib/sqlalchemy/sql/visitors.py598
16 files changed, 1499 insertions, 667 deletions
diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py
index 96e8f6b2c..9f18b882d 100644
--- a/lib/sqlalchemy/sql/_py_util.py
+++ b/lib/sqlalchemy/sql/_py_util.py
@@ -7,7 +7,16 @@
from __future__ import annotations
+import typing
+from typing import Any
from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from .cache_key import CacheConst
class prefix_anon_map(Dict[str, str]):
@@ -22,16 +31,18 @@ class prefix_anon_map(Dict[str, str]):
"""
- def __missing__(self, key):
+ def __missing__(self, key: str) -> str:
(ident, derived) = key.split(" ", 1)
anonymous_counter = self.get(derived, 1)
- self[derived] = anonymous_counter + 1
+ self[derived] = anonymous_counter + 1 # type: ignore
value = f"{derived}_{anonymous_counter}"
self[key] = value
return value
-class cache_anon_map(Dict[int, str]):
+class cache_anon_map(
+ Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]]
+):
"""A map that creates new keys for missing key access.
Produces an incrementing sequence given a series of unique keys.
@@ -45,11 +56,13 @@ class cache_anon_map(Dict[int, str]):
_index = 0
- def get_anon(self, object_):
+ def get_anon(self, object_: Any) -> Tuple[str, bool]:
idself = id(object_)
if idself in self:
- return self[idself], True
+ s_val = self[idself]
+ assert s_val is not True
+ return s_val, True
else:
# inline of __missing__
self[idself] = id_ = str(self._index)
@@ -57,7 +70,7 @@ class cache_anon_map(Dict[int, str]):
return id_, False
- def __missing__(self, key):
+ def __missing__(self, key: int) -> str:
self[key] = val = str(self._index)
self._index += 1
return val
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index b76393ad6..7afc2de97 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -13,22 +13,77 @@ associations.
from __future__ import annotations
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Mapping
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
+
from . import operators
-from .base import HasCacheKey
-from .traversals import anon_map
+from .cache_key import HasCacheKey
+from .visitors import anon_map
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .. import util
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from .visitors import _TraverseInternalsType
+ from ..util.typing import Self
+
+_AnnotationDict = Mapping[str, Any]
+
+EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT
+
-EMPTY_ANNOTATIONS = util.immutabledict()
+SelfSupportsAnnotations = TypeVar(
+ "SelfSupportsAnnotations", bound="SupportsAnnotations"
+)
-class SupportsAnnotations:
+class SupportsAnnotations(ExternallyTraversible):
__slots__ = ()
- _annotations = EMPTY_ANNOTATIONS
+ _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS
+ proxy_set: Set[SupportsAnnotations]
+ _is_immutable: bool
+
+ def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations:
+ raise NotImplementedError()
+
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
+ raise NotImplementedError()
@util.memoized_property
- def _annotations_cache_key(self):
+ def _annotations_cache_key(self) -> Tuple[Any, ...]:
anon_map_ = anon_map()
return (
"_annotations",
@@ -47,14 +102,22 @@ class SupportsAnnotations:
)
+SelfSupportsCloneAnnotations = TypeVar(
+ "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations"
+)
+
+
class SupportsCloneAnnotations(SupportsAnnotations):
- __slots__ = ()
+ if not typing.TYPE_CHECKING:
+ __slots__ = ()
- _clone_annotations_traverse_internals = [
+ _clone_annotations_traverse_internals: _TraverseInternalsType = [
("_annotations", InternalTraversal.dp_annotations_key)
]
- def _annotate(self, values):
+ def _annotate(
+ self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsCloneAnnotations:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
@@ -65,7 +128,9 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new.__dict__.pop("_generate_cache_key", None)
return new
- def _with_annotations(self, values):
+ def _with_annotations(
+ self: SelfSupportsCloneAnnotations, values: _AnnotationDict
+ ) -> SelfSupportsCloneAnnotations:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
@@ -76,7 +141,27 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new.__dict__.pop("_generate_cache_key", None)
return new
- def _deannotate(self, values=None, clone=False):
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
"""return a copy of this :class:`_expression.ClauseElement`
with annotations
removed.
@@ -96,24 +181,52 @@ class SupportsCloneAnnotations(SupportsAnnotations):
return self
+SelfSupportsWrappingAnnotations = TypeVar(
+ "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations"
+)
+
+
class SupportsWrappingAnnotations(SupportsAnnotations):
__slots__ = ()
- def _annotate(self, values):
+ _constructor: Callable[..., SupportsWrappingAnnotations]
+ entity_namespace: Mapping[str, Any]
+
+ def _annotate(self, values: _AnnotationDict) -> Annotated:
"""return a copy of this ClauseElement with annotations
updated by the given dictionary.
"""
- return Annotated(self, values)
+ return Annotated._as_annotated_instance(self, values)
- def _with_annotations(self, values):
+ def _with_annotations(self, values: _AnnotationDict) -> Annotated:
"""return a copy of this ClauseElement with annotations
replaced by the given dictionary.
"""
- return Annotated(self, values)
-
- def _deannotate(self, values=None, clone=False):
+ return Annotated._as_annotated_instance(self, values)
+
+ @overload
+ def _deannotate(
+ self: SelfSupportsAnnotations,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfSupportsAnnotations:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> SupportsAnnotations:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = False,
+ ) -> SupportsAnnotations:
"""return a copy of this :class:`_expression.ClauseElement`
with annotations
removed.
@@ -129,8 +242,11 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
return self
-class Annotated:
- """clones a SupportsAnnotated and applies an 'annotations' dictionary.
+SelfAnnotated = TypeVar("SelfAnnotated", bound="Annotated")
+
+
+class Annotated(SupportsAnnotations):
+ """clones a SupportsAnnotations and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
@@ -151,21 +267,26 @@ class Annotated:
_is_column_operators = False
- def __new__(cls, *args):
- if not args:
- # clone constructor
- return object.__new__(cls)
- else:
- element, values = args
- # pull appropriate subclass from registry of annotated
- # classes
- try:
- cls = annotated_classes[element.__class__]
- except KeyError:
- cls = _new_annotation_type(element.__class__, cls)
- return object.__new__(cls)
-
- def __init__(self, element, values):
+ @classmethod
+ def _as_annotated_instance(
+ cls, element: SupportsWrappingAnnotations, values: _AnnotationDict
+ ) -> Annotated:
+ try:
+ cls = annotated_classes[element.__class__]
+ except KeyError:
+ cls = _new_annotation_type(element.__class__, cls)
+ return cls(element, values)
+
+ _annotations: util.immutabledict[str, Any]
+ __element: SupportsWrappingAnnotations
+ _hash: int
+
+ def __new__(cls: Type[SelfAnnotated], *args: Any) -> SelfAnnotated:
+ return object.__new__(cls)
+
+ def __init__(
+ self, element: SupportsWrappingAnnotations, values: _AnnotationDict
+ ):
self.__dict__ = element.__dict__.copy()
self.__dict__.pop("_annotations_cache_key", None)
self.__dict__.pop("_generate_cache_key", None)
@@ -173,11 +294,15 @@ class Annotated:
self._annotations = util.immutabledict(values)
self._hash = hash(element)
- def _annotate(self, values):
+ def _annotate(
+ self: SelfAnnotated, values: _AnnotationDict
+ ) -> SelfAnnotated:
_values = self._annotations.union(values)
return self._with_annotations(_values)
- def _with_annotations(self, values):
+ def _with_annotations(
+ self: SelfAnnotated, values: util.immutabledict[str, Any]
+ ) -> SelfAnnotated:
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotations_cache_key", None)
@@ -185,7 +310,27 @@ class Annotated:
clone._annotations = values
return clone
- def _deannotate(self, values=None, clone=True):
+ @overload
+ def _deannotate(
+ self: SelfAnnotated,
+ values: Literal[None] = ...,
+ clone: bool = ...,
+ ) -> SelfAnnotated:
+ ...
+
+ @overload
+ def _deannotate(
+ self,
+ values: Sequence[str] = ...,
+ clone: bool = ...,
+ ) -> Annotated:
+ ...
+
+ def _deannotate(
+ self,
+ values: Optional[Sequence[str]] = None,
+ clone: bool = True,
+ ) -> SupportsAnnotations:
if values is None:
return self.__element
else:
@@ -199,14 +344,18 @@ class Annotated:
)
)
- def _compiler_dispatch(self, visitor, **kw):
- return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
+ if not typing.TYPE_CHECKING:
+ # manually proxy some methods that need extra attention
+ def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any:
+ return self.__element.__class__._compiler_dispatch(
+ self, visitor, **kw
+ )
- @property
- def _constructor(self):
- return self.__element._constructor
+ @property
+ def _constructor(self):
+ return self.__element._constructor
- def _clone(self, **kw):
+ def _clone(self: SelfAnnotated, **kw: Any) -> SelfAnnotated:
clone = self.__element._clone(**kw)
if clone is self.__element:
# detect immutable, don't change anything
@@ -217,22 +366,25 @@ class Annotated:
clone.__dict__.update(self.__dict__)
return self.__class__(clone, self._annotations)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]:
return self.__class__, (self.__element, self._annotations)
- def __hash__(self):
+ def __hash__(self) -> int:
return self._hash
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if self._is_column_operators:
return self.__element.__class__.__eq__(self, other)
else:
return hash(other) == hash(self)
@property
- def entity_namespace(self):
+ def entity_namespace(self) -> Mapping[str, Any]:
if "entity_namespace" in self._annotations:
- return self._annotations["entity_namespace"].entity_namespace
+ return cast(
+ SupportsWrappingAnnotations,
+ self._annotations["entity_namespace"],
+ ).entity_namespace
else:
return self.__element.entity_namespace
@@ -242,12 +394,19 @@ class Annotated:
# so that the resulting objects are pickleable; additionally, other
# decisions can be made up front about the type of object being annotated
# just once per class rather than per-instance.
-annotated_classes = {}
+annotated_classes: Dict[
+ Type[SupportsWrappingAnnotations], Type[Annotated]
+] = {}
+
+_SA = TypeVar("_SA", bound="SupportsAnnotations")
def _deep_annotate(
- element, annotations, exclude=None, detect_subquery_cols=False
-):
+ element: _SA,
+ annotations: _AnnotationDict,
+ exclude: Optional[Sequence[SupportsAnnotations]] = None,
+ detect_subquery_cols: bool = False,
+) -> _SA:
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
@@ -258,9 +417,9 @@ def _deep_annotate(
# annotated objects hack the __hash__() method so if we want to
# uniquely process them we have to use id()
- cloned_ids = {}
+ cloned_ids: Dict[int, SupportsAnnotations] = {}
- def clone(elem, **kw):
+ def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
kw["detect_subquery_cols"] = detect_subquery_cols
id_ = id(elem)
@@ -285,17 +444,20 @@ def _deep_annotate(
return newelem
if element is not None:
- element = clone(element)
- clone = None # remove gc cycles
+ element = cast(_SA, clone(element))
+ clone = None # type: ignore # remove gc cycles
return element
-def _deep_deannotate(element, values=None):
+def _deep_deannotate(
+ element: _SA, values: Optional[Sequence[str]] = None
+) -> _SA:
"""Deep copy the given element, removing annotations."""
- cloned = {}
+ cloned: Dict[Any, SupportsAnnotations] = {}
- def clone(elem, **kw):
+ def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
+ key: Any
if values:
key = id(elem)
else:
@@ -310,12 +472,14 @@ def _deep_deannotate(element, values=None):
return cloned[key]
if element is not None:
- element = clone(element)
- clone = None # remove gc cycles
+ element = cast(_SA, clone(element))
+ clone = None # type: ignore # remove gc cycles
return element
-def _shallow_annotate(element, annotations):
+def _shallow_annotate(
+ element: SupportsAnnotations, annotations: _AnnotationDict
+) -> SupportsAnnotations:
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
@@ -328,7 +492,13 @@ def _shallow_annotate(element, annotations):
return element
-def _new_annotation_type(cls, base_cls):
+def _new_annotation_type(
+ cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated]
+) -> Type[Annotated]:
+ """Generates a new class that subclasses Annotated and proxies a given
+ element type.
+
+ """
if issubclass(cls, Annotated):
return cls
elif cls in annotated_classes:
@@ -342,8 +512,9 @@ def _new_annotation_type(cls, base_cls):
base_cls = annotated_classes[super_]
break
- annotated_classes[cls] = anno_cls = type(
- "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ annotated_classes[cls] = anno_cls = cast(
+ Type[Annotated],
+ type("Annotated%s" % cls.__name__, (base_cls, cls), {}),
)
globals()["Annotated%s" % cls.__name__] = anno_cls
@@ -359,13 +530,15 @@ def _new_annotation_type(cls, base_cls):
# some classes include this even if they have traverse_internals
# e.g. BindParameter, add it if present.
if cls.__dict__.get("inherit_cache", False):
- anno_cls.inherit_cache = True
+ anno_cls.inherit_cache = True # type: ignore
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
return anno_cls
-def _prepare_annotations(target_hierarchy, base_cls):
+def _prepare_annotations(
+ target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated]
+) -> None:
for cls in util.walk_subclasses(target_hierarchy):
_new_annotation_type(cls, base_cls)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index a94590da1..a408a010a 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -19,8 +19,10 @@ from itertools import zip_longest
import operator
import re
import typing
+from typing import MutableMapping
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import TypeVar
from . import roles
@@ -36,14 +38,9 @@ from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
-from ..util._has_cy import HAS_CYEXTENSION
-
-if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_util import prefix_anon_map # noqa
-else:
- from sqlalchemy.cyextension.util import prefix_anon_map # noqa
if typing.TYPE_CHECKING:
+ from .elements import ColumnElement
from ..engine import Connection
from ..engine import Result
from ..engine.interfaces import _CoreMultiExecuteParams
@@ -63,6 +60,8 @@ NO_ARG = util.symbol("NO_ARG")
# symbols, mypy reports: "error: _Fn? not callable"
_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_AmbiguousTableNameMap = MutableMapping[str, str]
+
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -87,6 +86,10 @@ class SingletonConstant(Immutable):
_is_singleton_constant = True
+ _singleton: SingletonConstant
+
+ proxy_set: Set[ColumnElement]
+
def __new__(cls, *arg, **kw):
return cls._singleton
@@ -519,6 +522,8 @@ class CompileState:
plugins = {}
+ _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
+
@classmethod
def create_for_statement(cls, statement, compiler, **kw):
# factory construction.
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index ff659b77d..fca58f98e 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -11,21 +11,41 @@ import enum
from itertools import zip_longest
import typing
from typing import Any
-from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
from typing import NamedTuple
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import Union
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import HasTraversalDispatch
+from .visitors import HasTraverseInternals
from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
from .. import util
from ..inspection import inspect
from ..util import HasMemoized
from ..util.typing import Literal
-
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from .elements import BindParameter
+ from .elements import ClauseElement
+ from .visitors import _TraverseInternalsType
+ from ..engine.base import _CompiledCacheType
+ from ..engine.interfaces import _CoreSingleExecuteParams
+
+
+class _CacheKeyTraversalDispatchType(Protocol):
+ def __call__(
+ s, self: HasCacheKey, visitor: _CacheKeyTraversal
+ ) -> CacheKey:
+ ...
class CacheConst(enum.Enum):
@@ -70,7 +90,9 @@ class HasCacheKey:
__slots__ = ()
- _cache_key_traversal = NO_CACHE
+ _cache_key_traversal: Union[
+ _TraverseInternalsType, Literal[CacheConst.NO_CACHE]
+ ] = NO_CACHE
_is_has_cache_key = True
@@ -83,7 +105,7 @@ class HasCacheKey:
"""
- inherit_cache = None
+ inherit_cache: Optional[bool] = None
"""Indicate if this :class:`.HasCacheKey` instance should make use of the
cache key generation scheme used by its immediate superclass.
@@ -106,8 +128,12 @@ class HasCacheKey:
__slots__ = ()
+ _generated_cache_key_traversal: Any
+
@classmethod
- def _generate_cache_attrs(cls):
+ def _generate_cache_attrs(
+ cls,
+ ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]:
"""generate cache key dispatcher for a new class.
This sets the _generated_cache_key_traversal attribute once called
@@ -121,8 +147,11 @@ class HasCacheKey:
_cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
if _cache_key_traversal is None:
try:
- # this would be HasTraverseInternals
- _cache_key_traversal = cls._traverse_internals
+ # check for _traverse_internals, which is part of
+ # HasTraverseInternals
+ _cache_key_traversal = cast(
+ "Type[HasTraverseInternals]", cls
+ )._traverse_internals
except AttributeError:
cls._generated_cache_key_traversal = NO_CACHE
return NO_CACHE
@@ -138,7 +167,9 @@ class HasCacheKey:
# more complicated, so for the moment this is a little less
# efficient on startup but simpler.
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
else:
_cache_key_traversal = cls.__dict__.get(
@@ -170,11 +201,15 @@ class HasCacheKey:
return NO_CACHE
return _cache_key_traversal_visitor.generate_dispatch(
- cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ cls,
+ _cache_key_traversal,
+ "_generated_cache_key_traversal",
)
@util.preload_module("sqlalchemy.sql.elements")
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key(
+ self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+ ) -> Optional[Tuple[Any, ...]]:
"""return an optional cache key.
The cache key is a tuple which can contain any series of
@@ -202,15 +237,15 @@ class HasCacheKey:
dispatcher: Union[
Literal[CacheConst.NO_CACHE],
- Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"],
+ _CacheKeyTraversalDispatchType,
]
try:
dispatcher = cls.__dict__["_generated_cache_key_traversal"]
except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
+ # traversals.py -> _preconfigure_traversals()
+ # may be used to run these ahead of time, but
+ # is not enabled right now.
# this block will generate any remaining dispatchers.
dispatcher = cls._generate_cache_attrs()
@@ -218,7 +253,7 @@ class HasCacheKey:
anon_map[NO_CACHE] = True
return None
- result = (id_, cls)
+ result: Tuple[Any, ...] = (id_, cls)
# inline of _cache_key_traversal_visitor.run_generated_dispatch()
@@ -268,7 +303,7 @@ class HasCacheKey:
# Columns, this should be long lived. For select()
# statements, not so much, but they usually won't have
# annotations.
- result += self._annotations_cache_key
+ result += self._annotations_cache_key # type: ignore
elif (
meth is InternalTraversal.dp_clauseelement_list
or meth is InternalTraversal.dp_clauseelement_tuple
@@ -290,7 +325,7 @@ class HasCacheKey:
)
return result
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
"""return a cache key.
The cache key is a tuple which can contain any series of
@@ -322,32 +357,40 @@ class HasCacheKey:
"""
- bindparams = []
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = self._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
@classmethod
- def _generate_cache_key_for_object(cls, obj):
- bindparams = []
+ def _generate_cache_key_for_object(
+ cls, obj: HasCacheKey
+ ) -> Optional[CacheKey]:
+ bindparams: List[BindParameter[Any]] = []
_anon_map = anon_map()
key = obj._gen_cache_key(_anon_map, bindparams)
if NO_CACHE in _anon_map:
return None
else:
+ assert key is not None
return CacheKey(key, bindparams)
+class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey):
+ pass
+
+
class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
__slots__ = ()
@HasMemoized.memoized_instancemethod
- def _generate_cache_key(self):
+ def _generate_cache_key(self) -> Optional[CacheKey]:
return HasCacheKey._generate_cache_key(self)
@@ -362,14 +405,22 @@ class CacheKey(NamedTuple):
"""
key: Tuple[Any, ...]
- bindparams: Sequence[BindParameter]
+ bindparams: Sequence[BindParameter[Any]]
- def __hash__(self):
+ # can't set __hash__ attribute because it interferes
+ # with namedtuple
+ # can't use "if not TYPE_CHECKING" because mypy rejects it
+ # inside of a NamedTuple
+ def __hash__(self) -> Optional[int]: # type: ignore
"""CacheKey itself is not hashable - hash the .key portion"""
-
return None
- def to_offline_string(self, statement_cache, statement, parameters):
+ def to_offline_string(
+ self,
+ statement_cache: _CompiledCacheType,
+ statement: ClauseElement,
+ parameters: _CoreSingleExecuteParams,
+ ) -> str:
"""Generate an "offline string" form of this :class:`.CacheKey`
The "offline string" is basically the string SQL for the
@@ -400,21 +451,21 @@ class CacheKey(NamedTuple):
return repr((sql_str, param_tuple))
- def __eq__(self, other):
- return self.key == other.key
+ def __eq__(self, other: Any) -> bool:
+ return bool(self.key == other.key)
@classmethod
- def _diff_tuples(cls, left, right):
+ def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
ck1 = CacheKey(left, [])
ck2 = CacheKey(right, [])
return ck1._diff(ck2)
- def _whats_different(self, other):
+ def _whats_different(self, other: CacheKey) -> Iterator[str]:
k1 = self.key
k2 = other.key
- stack = []
+ stack: List[int] = []
pickup_index = 0
while True:
s1, s2 = k1, k2
@@ -440,11 +491,11 @@ class CacheKey(NamedTuple):
pickup_index = stack.pop(-1)
break
- def _diff(self, other):
+ def _diff(self, other: CacheKey) -> str:
return ", ".join(self._whats_different(other))
- def __str__(self):
- stack = [self.key]
+ def __str__(self) -> str:
+ stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key]
output = []
sentinel = object()
@@ -473,15 +524,15 @@ class CacheKey(NamedTuple):
return "CacheKey(key=%s)" % ("\n".join(output),)
- def _generate_param_dict(self):
+ def _generate_param_dict(self) -> Dict[str, Any]:
"""used for testing"""
- from .compiler import prefix_anon_map
-
_anon_map = prefix_anon_map()
return {b.key % _anon_map: b.effective_value for b in self.bindparams}
- def _apply_params_to_element(self, original_cache_key, target_element):
+ def _apply_params_to_element(
+ self, original_cache_key: CacheKey, target_element: ClauseElement
+ ) -> ClauseElement:
translate = {
k.key: v.value
for k, v in zip(original_cache_key.bindparams, self.bindparams)
@@ -490,7 +541,7 @@ class CacheKey(NamedTuple):
return target_element.params(translate)
-class _CacheKeyTraversal(ExtendedInternalTraversal):
+class _CacheKeyTraversal(HasTraversalDispatch):
# very common elements are inlined into the main _get_cache_key() method
# to produce a dramatic savings in Python function call overhead
@@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
visit_propagate_attrs = PROPAGATE_ATTRS
def visit_with_context_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple((fn.__code__, c_key) for fn, c_key in obj)
- def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_inspectable(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
- def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_string_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(obj)
- def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
obj._gen_cache_key(anon_map, bindparams)
@@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
else obj,
)
- def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_multi_list(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_has_cache_key_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_has_cache_key_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_executable_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_inspectable_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_list(
attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
)
def visit_clauseelement_tuples(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return self.visit_has_cache_key_tuples(
attrname, obj, parent, anon_map, bindparams
)
def visit_fromclause_ordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
return (
@@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_clauseelement_unordered_set(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
cache_keys = [
@@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_named_ddl_element(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, obj.name)
def visit_prefix_sequence(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
@@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_setup_join_tuple(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return tuple(
(
target._gen_cache_key(anon_map, bindparams),
@@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_table_hint_list(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
if not obj:
return ()
@@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
),
)
- def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_plain_dict(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
def visit_dialect_options(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_string_clauseelement_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_string_multi_dict(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_fromclause_canonical_column_collection(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# inlining into the internals of ColumnCollection
return (
attrname,
@@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_unknown_structure(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
anon_map[NO_CACHE] = True
return ()
def visit_dml_ordered_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
return (
attrname,
tuple(
@@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
),
)
- def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ def visit_dml_values(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# in py37 we can assume two dictionaries created in the same
# insert ordering will retain that sorting
return (
@@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal):
)
def visit_dml_multi_values(
- self, attrname, obj, parent, anon_map, bindparams
- ):
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
# multivalues are simply not cacheable right now
anon_map[NO_CACHE] = True
return ()
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index d616417ab..834bfb75d 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -13,6 +13,9 @@ import re
import typing
from typing import Any
from typing import Any as TODO_Any
+from typing import Dict
+from typing import List
+from typing import NoReturn
from typing import Optional
from typing import Type
from typing import TypeVar
@@ -42,6 +45,7 @@ if typing.TYPE_CHECKING:
from . import selectable
from . import traversals
from .elements import ClauseElement
+ from .elements import ColumnClause
_SR = TypeVar("_SR", bound=roles.SQLRole)
_StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
@@ -252,7 +256,7 @@ def expect_col_expression_collection(role, expressions):
if isinstance(resolved, str):
strname = resolved = expr
else:
- cols = []
+ cols: List[ColumnClause[Any]] = []
visitors.traverse(resolved, {}, {"column": cols.append})
if cols:
column = cols[0]
@@ -266,7 +270,7 @@ class RoleImpl:
def _literal_coercion(self, element, **kw):
raise NotImplementedError()
- _post_coercion = None
+ _post_coercion: Any = None
_resolve_literal_only = False
_skip_clauseelement_for_target_match = False
@@ -276,19 +280,24 @@ class RoleImpl:
self._use_inspection = issubclass(role_class, roles.UsesInspection)
def _implicit_coercions(
- self, element, resolved, argname=None, **kw
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
) -> Any:
self._raise_for_expected(element, argname, resolved)
def _raise_for_expected(
self,
- element,
- argname=None,
- resolved=None,
- advice=None,
- code=None,
- err=None,
- ):
+ element: Any,
+ argname: Optional[str] = None,
+ resolved: Optional[Any] = None,
+ advice: Optional[str] = None,
+ code: Optional[str] = None,
+ err: Optional[Exception] = None,
+ **kw: Any,
+ ) -> NoReturn:
if resolved is not None and resolved is not element:
got = "%r object resolved from %r object" % (resolved, element)
else:
@@ -324,22 +333,20 @@ class _StringOnly:
_resolve_literal_only = True
-class _ReturnsStringKey:
+class _ReturnsStringKey(RoleImpl):
__slots__ = ()
- def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, str):
- return original_element
+ def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ if isinstance(element, str):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
-class _ColumnCoercions:
+class _ColumnCoercions(RoleImpl):
__slots__ = ()
def _warn_for_scalar_subquery_coercion(self):
@@ -368,8 +375,12 @@ class _ColumnCoercions:
def _no_text_coercion(
- element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None
-):
+ element: Any,
+ argname: Optional[str] = None,
+ exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError,
+ extra: Optional[str] = None,
+ err: Optional[Exception] = None,
+) -> NoReturn:
raise exc_cls(
"%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
"explicitly declared as text(%(expr)r)"
@@ -381,7 +392,7 @@ def _no_text_coercion(
) from err
-class _NoTextCoercion:
+class _NoTextCoercion(RoleImpl):
__slots__ = ()
def _literal_coercion(self, element, argname=None, **kw):
@@ -393,7 +404,7 @@ class _NoTextCoercion:
self._raise_for_expected(element, argname)
-class _CoerceLiterals:
+class _CoerceLiterals(RoleImpl):
__slots__ = ()
_coerce_consts = False
_coerce_star = False
@@ -440,12 +451,19 @@ class LiteralValueImpl(RoleImpl):
return element
-class _SelectIsNotFrom:
+class _SelectIsNotFrom(RoleImpl):
__slots__ = ()
def _raise_for_expected(
- self, element, argname=None, resolved=None, advice=None, **kw
- ):
+ self,
+ element: Any,
+ argname: Optional[str] = None,
+ resolved: Optional[Any] = None,
+ advice: Optional[str] = None,
+ code: Optional[str] = None,
+ err: Optional[Exception] = None,
+ **kw: Any,
+ ) -> NoReturn:
if (
not advice
and isinstance(element, roles.SelectStatementRole)
@@ -460,26 +478,33 @@ class _SelectIsNotFrom:
else:
code = None
- return super(_SelectIsNotFrom, self)._raise_for_expected(
+ super()._raise_for_expected(
element,
argname=argname,
resolved=resolved,
advice=advice,
code=code,
+ err=err,
**kw,
)
+ # never reached
+ assert False
class HasCacheKeyImpl(RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, traversals.HasCacheKey):
- return original_element
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, HasCacheKey):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
@@ -489,12 +514,16 @@ class ExecutableOptionImpl(RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, ExecutableOption):
- return original_element
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, ExecutableOption):
+ return element
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, **kw):
return element
@@ -560,8 +589,12 @@ class InElementImpl(RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
@@ -573,7 +606,7 @@ class InElementImpl(RoleImpl):
self._warn_for_implicit_coercion(resolved)
return self._post_coercion(resolved.select(), **kw)
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _warn_for_implicit_coercion(self, elem):
util.warn(
@@ -586,12 +619,16 @@ class InElementImpl(RoleImpl):
if isinstance(element, collections_abc.Iterable) and not isinstance(
element, str
):
- non_literal_expressions = {}
+ non_literal_expressions: Dict[
+ Optional[operators.ColumnOperators[Any]],
+ operators.ColumnOperators[Any],
+ ] = {}
element = list(element)
for o in element:
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
self._raise_for_expected(element, **kw)
+
else:
non_literal_expressions[o] = o
elif o is None:
@@ -712,8 +749,12 @@ class GroupByImpl(ByOfImpl, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if isinstance(resolved, roles.StrictFromClauseRole):
return elements.ClauseList(*resolved.c)
else:
@@ -748,12 +789,16 @@ class TruncatedLabelImpl(_StringOnly, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
- if isinstance(original_element, str):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, str):
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, argname=None, **kw):
"""coerce the given value to :class:`._truncated_label`.
@@ -794,7 +839,13 @@ class DDLReferredColumnImpl(DDLConstraintColumnImpl):
class LimitOffsetImpl(RoleImpl):
__slots__ = ()
- def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ def _implicit_coercions(
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved is None:
return None
else:
@@ -814,18 +865,22 @@ class LabeledColumnExprImpl(ExpressionElementImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if isinstance(resolved, roles.ExpressionElementRole):
return resolved.label(None)
else:
new = super(LabeledColumnExprImpl, self)._implicit_coercions(
- original_element, resolved, argname=argname, **kw
+ element, resolved, argname=argname, **kw
)
if isinstance(new, roles.ExpressionElementRole):
return new.label(None)
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
@@ -899,13 +954,17 @@ class StatementImpl(_CoerceLiterals, RoleImpl):
return resolved
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_lambda_element:
return resolved
else:
- return super(StatementImpl, self)._implicit_coercions(
- original_element, resolved, argname=argname, **kw
+ return super()._implicit_coercions(
+ element, resolved, argname=argname, **kw
)
@@ -913,12 +972,16 @@ class SelectStatementImpl(_NoTextCoercion, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_text_clause:
return resolved.columns()
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class HasCTEImpl(ReturnsRowsImpl):
@@ -938,13 +1001,18 @@ class JoinTargetImpl(RoleImpl):
self._raise_for_expected(element, argname)
def _implicit_coercions(
- self, original_element, resolved, argname=None, legacy=False, **kw
- ):
- if isinstance(original_element, roles.JoinTargetRole):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ legacy: bool = False,
+ **kw: Any,
+ ) -> Any:
+ if isinstance(element, roles.JoinTargetRole):
# note that this codepath no longer occurs as of
# #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
# were set to False.
- return original_element
+ return element
elif legacy and resolved._is_select_statement:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT "
@@ -959,7 +1027,7 @@ class JoinTargetImpl(RoleImpl):
# in _ORMJoin->Join
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
@@ -967,13 +1035,13 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
def _implicit_coercions(
self,
- original_element,
- resolved,
- argname=None,
- explicit_subquery=False,
- allow_select=True,
- **kw,
- ):
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ explicit_subquery: bool = False,
+ allow_select: bool = True,
+ **kw: Any,
+ ) -> Any:
if resolved._is_select_statement:
if explicit_subquery:
return resolved.subquery()
@@ -989,7 +1057,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
elif resolved._is_text_clause:
return resolved
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
def _post_coercion(self, element, deannotate=False, **kw):
if deannotate:
@@ -1003,12 +1071,13 @@ class StrictFromClauseImpl(FromClauseImpl):
def _implicit_coercions(
self,
- original_element,
- resolved,
- argname=None,
- allow_select=False,
- **kw,
- ):
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ explicit_subquery: bool = False,
+ allow_select: bool = False,
+ **kw: Any,
+ ) -> Any:
if resolved._is_select_statement and allow_select:
util.warn_deprecated(
"Implicit coercion of SELECT and textual SELECT constructs "
@@ -1019,7 +1088,7 @@ class StrictFromClauseImpl(FromClauseImpl):
)
return resolved._implicit_subquery
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class AnonymizedFromClauseImpl(StrictFromClauseImpl):
@@ -1045,8 +1114,12 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
__slots__ = ()
def _implicit_coercions(
- self, original_element, resolved, argname=None, **kw
- ):
+ self,
+ element: Any,
+ resolved: Any,
+ argname: Optional[str] = None,
+ **kw: Any,
+ ) -> Any:
if resolved._is_from_clause:
if (
isinstance(resolved, selectable.Alias)
@@ -1056,7 +1129,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl):
else:
return resolved.select()
else:
- self._raise_for_expected(original_element, argname, resolved)
+ self._raise_for_expected(element, argname, resolved)
class CompoundElementImpl(_NoTextCoercion, RoleImpl):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 423c3d446..f28dceefc 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -35,14 +35,19 @@ from time import perf_counter
import typing
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
+from typing import FrozenSet
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
from typing import Union
from . import base
@@ -54,19 +59,42 @@ from . import operators
from . import schema
from . import selectable
from . import sqltypes
+from .base import _from_objects
from .base import NO_ARG
-from .base import prefix_anon_map
from .elements import quoted_name
from .schema import Column
+from .sqltypes import TupleType
from .type_api import TypeEngine
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .base import _AmbiguousTableNameMap
+ from .base import CompileState
+ from .cache_key import CacheKey
+ from .elements import BindParameter
+ from .elements import ColumnClause
+ from .elements import Label
+ from .functions import Function
+ from .selectable import Alias
+ from .selectable import AliasedReturnsRows
+ from .selectable import CompoundSelectState
from .selectable import CTE
from .selectable import FromClause
+ from .selectable import NamedFromClause
+ from .selectable import ReturnsRows
+ from .selectable import Select
+ from .selectable import SelectState
+ from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _MutableCoreSingleExecuteParams
+ from ..engine.interfaces import _SchemaTranslateMapType
from ..engine.result import _ProcessorType
_FromHintsType = Dict["FromClause", str]
@@ -236,7 +264,7 @@ OPERATORS = {
operators.nulls_last_op: " NULLS LAST",
}
-FUNCTIONS = {
+FUNCTIONS: Dict[Type[Function], str] = {
functions.coalesce: "coalesce",
functions.current_date: "CURRENT_DATE",
functions.current_time: "CURRENT_TIME",
@@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple):
name: str
"""column name, may be labeled"""
- objects: List[Any]
- """list of objects that should be able to locate this column
+ objects: Tuple[Any, ...]
+ """sequence of objects that should be able to locate this column
in a RowMapping. This is typically string names and aliases
as well as Column objects.
@@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple):
"""
+class _ResultMapAppender(Protocol):
+ def __call__(
+ self,
+ keyname: str,
+ name: str,
+ objects: Sequence[Any],
+ type_: TypeEngine[Any],
+ ) -> None:
+ ...
+
+
# integer indexes into ResultColumnsEntry used by cursor.py.
# some profiling showed integer access faster than named tuple
RM_RENDERED_NAME: Literal[0] = 0
@@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2
RM_TYPE: Literal[3] = 3
+class _BaseCompilerStackEntry(TypedDict):
+ asfrom_froms: Set[FromClause]
+ correlate_froms: Set[FromClause]
+ selectable: ReturnsRows
+
+
+class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
+ compile_state: CompileState
+ need_result_map_for_nested: bool
+ need_result_map_for_compound: bool
+ select_0: ReturnsRows
+ insert_from_select: Select
+
+
class ExpandedState(NamedTuple):
statement: str
additional_parameters: _CoreSingleExecuteParams
@@ -427,21 +480,23 @@ class Compiled:
defaults.
"""
- _cached_metadata = None
+ _cached_metadata: Optional[CursorResultMetaData] = None
_result_columns: Optional[List[ResultColumnsEntry]] = None
- schema_translate_map = None
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None
- execution_options = util.EMPTY_DICT
+ execution_options: _ExecuteOptions = util.EMPTY_DICT
"""
Execution options propagated from the statement. In some cases,
sub-elements of the statement can modify these.
"""
- _annotations = util.EMPTY_DICT
+ preparer: IdentifierPreparer
+
+ _annotations: _AnnotationDict = util.EMPTY_DICT
- compile_state = None
+ compile_state: Optional[CompileState] = None
"""Optional :class:`.CompileState` object that maintains additional
state used by the compiler.
@@ -457,9 +512,21 @@ class Compiled:
"""
- cache_key = None
+ cache_key: Optional[CacheKey] = None
+ """The :class:`.CacheKey` that was generated ahead of creating this
+ :class:`.Compiled` object.
+
+ This is used for routines that need access to the original
+ :class:`.CacheKey` instance generated when the :class:`.Compiled`
+ instance was first cached, typically in order to reconcile
+ the original list of :class:`.BindParameter` objects with a
+ per-statement list that's generated on each call.
+
+ """
_gen_time: float
+ """Generation time of this :class:`.Compiled`, used for reporting
+ cache stats."""
def __init__(
self,
@@ -543,7 +610,11 @@ class Compiled:
return self.string or ""
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
"""Return the bind params for this compiled object.
:param params: a dict of string/object pairs whose values will
@@ -646,6 +717,17 @@ class SQLCompiler(Compiled):
isplaintext: bool = False
+ binds: Dict[str, BindParameter[Any]]
+ """a dictionary of bind parameter keys to BindParameter instances."""
+
+ bind_names: Dict[BindParameter[Any], str]
+ """a dictionary of BindParameter instances to "compiled" names
+ that are actually present in the generated SQL"""
+
+ stack: List[_CompilerStackEntry]
+ """major statements such as SELECT, INSERT, UPDATE, DELETE are
+ tracked in this stack using an entry format."""
+
result_columns: List[ResultColumnsEntry]
"""relates label names in the final SQL to a tuple of local
column/label name, ColumnElement object (if any) and
@@ -709,7 +791,7 @@ class SQLCompiler(Compiled):
"""
- insert_single_values_expr = None
+ insert_single_values_expr: Optional[str] = None
"""When an INSERT is compiled with a single set of parameters inside
a VALUES expression, the string is assigned here, where it can be
used for insert batching schemes to rewrite the VALUES expression.
@@ -718,19 +800,19 @@ class SQLCompiler(Compiled):
"""
- literal_execute_params = frozenset()
+ literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as literal values at statement
execution time.
"""
- post_compile_params = frozenset()
+ post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
"""bindparameter objects that are rendered as bound parameter placeholders
at statement execution time.
"""
- escaped_bind_names = util.EMPTY_DICT
+ escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
"""Late escaping of bound parameter names that has to be converted
to the original name when looking in the parameter dictionary.
@@ -744,14 +826,25 @@ class SQLCompiler(Compiled):
"""if True, and this in insert, use cursor.lastrowid to populate
result.inserted_primary_key. """
- _cache_key_bind_match = None
+ _cache_key_bind_match: Optional[
+ Tuple[
+ Dict[
+ BindParameter[Any],
+ List[BindParameter[Any]],
+ ],
+ Dict[
+ str,
+ BindParameter[Any],
+ ],
+ ]
+ ] = None
"""a mapping that will relate the BindParameter object we compile
to those that are part of the extracted collection of parameters
in the cache key, if we were given a cache key.
"""
- positiontup: Optional[Sequence[str]] = None
+ positiontup: Optional[List[str]] = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
@@ -768,6 +861,19 @@ class SQLCompiler(Compiled):
inline: bool = False
+ ctes: Optional[MutableMapping[CTE, str]]
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ ctes_by_level_name: Dict[Tuple[int, str], CTE]
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name, cte_opts)]
+ level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
+
+ ctes_recursive: bool
+ cte_positional: Dict[CTE, List[str]]
+
def __init__(
self,
dialect,
@@ -804,10 +910,9 @@ class SQLCompiler(Compiled):
self.cache_key = cache_key
if cache_key:
- self._cache_key_bind_match = ckbm = {
- b.key: b for b in cache_key[1]
- }
- ckbm.update({b: [b] for b in cache_key[1]})
+ cksm = {b.key: b for b in cache_key[1]}
+ ckbm = {b: [b] for b in cache_key[1]}
+ self._cache_key_bind_match = (ckbm, cksm)
# compile INSERT/UPDATE defaults/sequences to expect executemany
# style execution, which may mean no pre-execute of defaults,
@@ -911,14 +1016,14 @@ class SQLCompiler(Compiled):
@property
def prefetch(self):
- return list(self.insert_prefetch + self.update_prefetch)
+ return list(self.insert_prefetch) + list(self.update_prefetch)
@util.memoized_property
def _global_attributes(self):
return {}
@util.memoized_instancemethod
- def _init_cte_state(self) -> None:
+ def _init_cte_state(self) -> MutableMapping[CTE, str]:
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
@@ -926,21 +1031,22 @@ class SQLCompiler(Compiled):
"""
# collect CTEs to tack on top of a SELECT
# To store the query to print - Dict[cte, text_query]
- self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ ctes: MutableMapping[CTE, str] = util.OrderedDict()
+ self.ctes = ctes
# Detect same CTE references - Dict[(level, name), cte]
# Level is required for supporting nesting
- self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
+ self.ctes_by_level_name = {}
# To retrieve key/level in ctes_by_level_name -
# Dict[cte_reference, (level, cte_name, cte_opts)]
- self.level_name_by_cte: Dict[
- CTE, Tuple[int, str, selectable._CTEOpts]
- ] = {}
+ self.level_name_by_cte = {}
- self.ctes_recursive: bool = False
+ self.ctes_recursive = False
if self.positional:
- self.cte_positional: Dict[CTE, List[str]] = {}
+ self.cte_positional = {}
+
+ return ctes
@contextlib.contextmanager
def _nested_result(self):
@@ -985,7 +1091,7 @@ class SQLCompiler(Compiled):
if not bindparam.type._is_tuple_type
else tuple(
elem_type._cached_bind_processor(self.dialect)
- for elem_type in bindparam.type.types
+ for elem_type in cast(TupleType, bindparam.type).types
),
)
for bindparam in self.bind_names
@@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled):
def construct_params(
self,
- params=None,
- _group_number=None,
- _check=True,
- extracted_parameters=None,
- ):
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ _group_number: Optional[int] = None,
+ _check: bool = True,
+ ) -> _MutableCoreSingleExecuteParams:
"""return a dictionary of bind parameter keys and values"""
has_escaped_names = bool(self.escaped_bind_names)
@@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled):
# way. The parameters present in self.bind_names may be clones of
# these original cache key params in the case of DML but the .key
# will be guaranteed to match.
- try:
- orig_extracted = self.cache_key[1]
- except TypeError as err:
+ if self.cache_key is None:
raise exc.CompileError(
"This compiled object has no original cache key; "
"can't pass extracted_parameters to construct_params"
- ) from err
+ )
+ else:
+ orig_extracted = self.cache_key[1]
- ckbm = self._cache_key_bind_match
+ ckbm_tuple = self._cache_key_bind_match
+ assert ckbm_tuple is not None
+ ckbm, _ = ckbm_tuple
resolved_extracted = {
bind: extracted
for b, extracted in zip(orig_extracted, extracted_parameters)
@@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled):
if bindparam.type._is_tuple_type:
inputsizes[bindparam] = [
- lookup_type(typ) for typ in bindparam.type.types
+ lookup_type(typ)
+ for typ in cast(TupleType, bindparam.type).types
]
else:
inputsizes[bindparam] = lookup_type(bindparam.type)
@@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled):
def _process_parameters_for_postcompile(
self,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_MutableCoreSingleExecuteParams] = None,
_populate_self: bool = False,
) -> ExpandedState:
"""handle special post compile parameters.
@@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled):
parameters = self.construct_params()
expanded_parameters = {}
+ positiontup: Optional[List[str]]
+
if self.positional:
positiontup = []
else:
positiontup = None
processors = self._bind_processors
+ single_processors = cast("Mapping[str, _ProcessorType]", processors)
+ tuple_processors = cast(
+ "Mapping[str, Sequence[_ProcessorType]]", processors
+ )
- new_processors = {}
+ new_processors: Dict[str, _ProcessorType] = {}
if self.positional and self._numeric_binds:
# I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled):
"the 'numeric' paramstyle at this time."
)
- replacement_expressions = {}
- to_update_sets = {}
+ replacement_expressions: Dict[str, Any] = {}
+ to_update_sets: Dict[str, Any] = {}
# notes:
# *unescaped* parameter names in:
@@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled):
# *escaped* parameter names in:
# construct_params(), replacement_expressions
- for name in (
- self.positiontup if self.positional else self.bind_names.values()
- ):
+ if self.positional and self.positiontup is not None:
+ names: Iterable[str] = self.positiontup
+ else:
+ names = self.bind_names.values()
+
+ for name in names:
escaped_name = (
self.escaped_bind_names.get(name, name)
if self.escaped_bind_names
@@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled):
if parameter in self.post_compile_params:
if escaped_name in replacement_expressions:
to_update = to_update_sets[escaped_name]
+ values = None
else:
# we are removing the parameter from parameters
# because it is a list value, which is not expected by
@@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled):
if not parameter.literal_execute:
parameters.update(to_update)
if parameter.type._is_tuple_type:
+ assert values is not None
new_processors.update(
(
"%s_%s_%s" % (name, i, j),
- processors[name][j - 1],
+ tuple_processors[name][j - 1],
)
for i, tuple_element in enumerate(values, 1)
- for j, value in enumerate(tuple_element, 1)
- if name in processors
- and processors[name][j - 1] is not None
+ for j, _ in enumerate(tuple_element, 1)
+ if name in tuple_processors
+ and tuple_processors[name][j - 1] is not None
)
else:
new_processors.update(
- (key, processors[name])
- for key, value in to_update
- if name in processors
+ (key, single_processors[name])
+ for key, _ in to_update
+ if name in single_processors
)
- if self.positional:
- positiontup.extend(name for name, value in to_update)
+ if positiontup is not None:
+ positiontup.extend(name for name, _ in to_update)
expanded_parameters[name] = [
- expand_key for expand_key, value in to_update
+ expand_key for expand_key, _ in to_update
]
- elif self.positional:
+ elif positiontup is not None:
positiontup.append(name)
def process_expanding(m):
@@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled):
# special use cases.
self.string = expanded_state.statement
self._bind_processors.update(expanded_state.processors)
- self.positiontup = expanded_state.positiontup
+ self.positiontup = list(expanded_state.positiontup or ())
self.post_compile_params = frozenset()
for key in expanded_state.parameter_expansion:
bind = self.binds.pop(key)
@@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled):
self._result_columns
)
+ _key_getters_for_crud_column: Tuple[
+ Callable[[Union[str, Column[Any]]], str],
+ Callable[[Column[Any]], str],
+ Callable[[Column[Any]], str],
+ ]
+
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
getter = self._key_getters_for_crud_column[2]
@@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled):
@util.memoized_property
@util.preload_module("sqlalchemy.engine.result")
def _inserted_primary_key_from_returning_getter(self):
- result = util.preloaded.engine_result
+ if typing.TYPE_CHECKING:
+ from ..engine import result
+ else:
+ result = util.preloaded.engine_result
param_key_getter = self._within_exec_param_key_getter
table = self.statement.table
- ret = {col: idx for idx, col in enumerate(self.returning)}
+ returning = self.returning
+ assert returning is not None
+ ret = {col: idx for idx, col in enumerate(returning)}
- getters = [
- (operator.itemgetter(ret[col]), True)
- if col in ret
- else (
- operator.methodcaller("get", param_key_getter(col), None),
- False,
- )
- for col in table.primary_key
- ]
+ getters = cast(
+ "List[Tuple[Callable[[Any], Any], bool]]",
+ [
+ (operator.itemgetter(ret[col]), True)
+ if col in ret
+ else (
+ operator.methodcaller("get", param_key_getter(col), None),
+ False,
+ )
+ for col in table.primary_key
+ ],
+ )
row_fn = result.result_tuple([col.key for col in table.primary_key])
@@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled):
self, element, within_columns_clause=False, **kwargs
):
if self.stack and self.dialect.supports_simple_order_by_label:
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ raise exc.CompileError(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ) from ke
(
with_cols,
@@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled):
# compiling the element outside of the context of a SELECT
return self.process(element._text_clause)
- compile_state = self.stack[-1]["compile_state"]
+ try:
+ compile_state = cast(
+ "Union[SelectState, CompoundSelectState]",
+ self.stack[-1]["compile_state"],
+ )
+ except KeyError as ke:
+ coercions._no_text_coercion(
+ element.element,
+ extra=(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ),
+ exc_cls=exc.CompileError,
+ err=ke,
+ )
+
with_cols, only_froms, only_cols = compile_state._label_resolve_dict
try:
if within_columns_clause:
@@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled):
def visit_column(
self,
- column,
- add_to_result_map=None,
- include_table=True,
- result_map_targets=(),
- ambiguous_table_name_map=None,
- **kwargs,
- ):
+ column: ColumnClause[Any],
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ include_table: bool = True,
+ result_map_targets: Tuple[Any, ...] = (),
+ ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
+ **kwargs: Any,
+ ) -> str:
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled):
)
else:
schema_prefix = ""
- tablename = table.name
+
+ tablename = cast("NamedFromClause", table).name
if (
not effective_schema
@@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"correlate_froms": set(),
"asfrom_froms": set(),
"selectable": taf,
@@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled):
compiled_col = self.visit_column(element, **kw)
return "(%s).%s" % (compiled_fn, compiled_col)
- def visit_function(self, func, add_to_result_map=None, **kwargs):
+ def visit_function(
+ self,
+ func: Function,
+ add_to_result_map: Optional[_ResultMapAppender] = None,
+ **kwargs: Any,
+ ) -> str:
if add_to_result_map is not None:
add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+
+ text: str
+
if disp:
text = disp(func, **kwargs)
else:
@@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled):
if compound_stmt._independent_ctes:
self._dispatch_independent_ctes(compound_stmt, kwargs)
- keyword = self.compound_keywords.get(cs.keyword)
+ keyword = self.compound_keywords[cs.keyword]
text = (" " + keyword + " ").join(
(
@@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled):
# a different set of parameter values. here, we accommodate for
# parameters that may have been cloned both before and after the cache
# key was been generated.
- ckbm = self._cache_key_bind_match
- if ckbm:
+ ckbm_tuple = self._cache_key_bind_match
+
+ if ckbm_tuple:
+ ckbm, cksm = ckbm_tuple
for bp in bindparam._cloned_set:
- if bp.key in ckbm:
- cb = ckbm[bp.key]
+ if bp.key in cksm:
+ cb = cksm[bp.key]
ckbm[cb].append(bindparam)
if bindparam.isoutparam:
@@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled):
if positional_names is not None:
positional_names.append(name)
else:
- self.positiontup.append(name)
+ self.positiontup.append(name) # type: ignore[union-attr]
elif not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
@@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled):
name = new_name
if escaped_from:
- if not self.escaped_bind_names:
- self.escaped_bind_names = {}
- self.escaped_bind_names[escaped_from] = name
+ self.escaped_bind_names = self.escaped_bind_names.union(
+ {escaped_from: name}
+ )
if post_compile:
return "__[POSTCOMPILE_%s]" % name
@@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled):
cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
**kwargs: Any,
) -> Optional[str]:
- self._init_cte_state()
+ self_ctes = self._init_cte_state()
+ assert self_ctes is self.ctes
kwargs["visiting_cte"] = cte
@@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled):
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
- del self.ctes[existing_cte]
+ del self_ctes[existing_cte]
existing_cte_reference_cte = existing_cte._get_reference_cte()
@@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled):
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
- if not cte_pre_alias_name and cte not in self.ctes:
+ if not cte_pre_alias_name and cte not in self_ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
@@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled):
cte, cte._suffixes, **kwargs
)
- self.ctes[cte] = text
+ self_ctes[cte] = text
if asfrom:
if from_linter:
from_linter.froms[cte] = cte_name
if not is_new_cte and embedded_in_current_named_cte:
- return self.preparer.format_alias(cte, cte_name)
+ return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501
if cte_pre_alias_name:
text = self.preparer.format_alias(cte, cte_pre_alias_name)
@@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
+ return None
+
def visit_table_valued_alias(self, element, **kw):
if element._is_lateral:
return self.visit_lateral(element, **kw)
@@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled):
self,
keyname: str,
name: str,
- objects: List[Any],
+ objects: Tuple[Any, ...],
type_: TypeEngine[Any],
) -> None:
if keyname is None or keyname == "*":
@@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled):
def get_statement_hint_text(self, hint_texts):
return " ".join(hint_texts)
- _default_stack_entry = util.immutabledict(
- [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
- )
+ _default_stack_entry: _CompilerStackEntry
+
+ if not typing.TYPE_CHECKING:
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(
self, select_stmt, asfrom, lateral=False, **kw
@@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled):
)
return froms
- translate_select_structure = None
+ translate_select_structure: Any = None
"""if not ``None``, should be a callable which accepts ``(select_stmt,
**kw)`` and returns a select object. this is used for structural changes
mostly to accommodate for LIMIT/OFFSET schemes
@@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled):
)
self._result_columns = [
- (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ ResultColumnsEntry(
+ key, name, tuple(translate.get(o, o) for o in obj), type_
+ )
for key, name, obj, type_ in self._result_columns
]
@@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled):
implicit_correlate_froms=asfrom_froms,
)
- new_correlate_froms = set(selectable._from_objects(*froms))
+ new_correlate_froms = set(_from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
- new_entry = {
+ new_entry: _CompilerStackEntry = {
"asfrom_froms": new_correlate_froms,
"correlate_froms": all_correlate_froms,
"selectable": select,
@@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled):
text += " \nWHERE " + t
if warn_linting:
+ assert from_linter is not None
from_linter.warn()
if select._group_by_clauses:
@@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled):
if not self.ctes:
return ""
+ ctes: MutableMapping[CTE, str]
+
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
@@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled):
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
+ assert self.positiontup is not None
self.positiontup = (
- sum([self.cte_positional[cte] for cte in ctes], [])
+ list(
+ itertools.chain.from_iterable(
+ self.cte_positional[cte] for cte in ctes
+ )
+ )
+ self.positiontup
)
+
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
@@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled):
if is_multitable:
# main table might be a JOIN
- main_froms = set(selectable._from_objects(update_stmt.table))
+ main_froms = set(_from_objects(update_stmt.table))
render_extra_froms = [
f for f in extra_froms if f not in main_froms
]
@@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled):
def type_compiler(self):
return self.dialect.type_compiler
- def construct_params(self, params=None, extracted_parameters=None):
+ def construct_params(
+ self,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+ ) -> Optional[_MutableCoreSingleExecuteParams]:
return None
def visit_ddl(self, ddl, **kwargs):
@@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler):
return get_col_spec(**kw)
+class _SchemaForObjectCallable(Protocol):
+ def __call__(self, obj: Any) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
@@ -5209,7 +5404,13 @@ class IdentifierPreparer:
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- schema_for_object = operator.attrgetter("schema")
+ initial_quote: str
+
+ final_quote: str
+
+ _strings: MutableMapping[str, str]
+
+ schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
"""Return the .schema attribute for an object.
For the default IdentifierPreparer, the schema for an object is always
@@ -5297,7 +5498,7 @@ class IdentifierPreparer:
return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
- def _escape_identifier(self, value):
+ def _escape_identifier(self, value: str) -> str:
"""Escape an identifier.
Subclasses should override this to provide database-dependent
@@ -5309,7 +5510,7 @@ class IdentifierPreparer:
value = value.replace("%", "%%")
return value
- def _unescape_identifier(self, value):
+ def _unescape_identifier(self, value: str) -> str:
"""Canonicalize an escaped identifier.
Subclasses should override this to provide database-dependent
@@ -5336,7 +5537,7 @@ class IdentifierPreparer:
)
return element
- def quote_identifier(self, value):
+ def quote_identifier(self, value: str) -> str:
"""Quote an identifier.
Subclasses should override this to provide database-dependent
@@ -5349,7 +5550,7 @@ class IdentifierPreparer:
+ self.final_quote
)
- def _requires_quotes(self, value):
+ def _requires_quotes(self, value: str) -> bool:
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (
@@ -5364,7 +5565,7 @@ class IdentifierPreparer:
not taking case convention into account."""
return not self.legal_characters.match(str(value))
- def quote_schema(self, schema, force=None):
+ def quote_schema(self, schema: str, force: Any = None) -> str:
"""Conditionally quote a schema name.
@@ -5403,7 +5604,7 @@ class IdentifierPreparer:
return self.quote(schema)
- def quote(self, ident, force=None):
+ def quote(self, ident: str, force: Any = None) -> str:
"""Conditionally quote an identifier.
The identifier is quoted if it is a reserved word, contains
@@ -5474,11 +5675,19 @@ class IdentifierPreparer:
name = self.quote_schema(effective_schema) + "." + name
return name
- def format_label(self, label, name=None):
+ def format_label(
+ self, label: Label[Any], name: Optional[str] = None
+ ) -> str:
return self.quote(name or label.name)
- def format_alias(self, alias, name=None):
- return self.quote(name or alias.name)
+ def format_alias(
+ self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
+ ) -> str:
+ if name is None:
+ assert alias is not None
+ return self.quote(alias.name)
+ else:
+ return self.quote(name)
def format_savepoint(self, savepoint, name=None):
# Running the savepoint name through quoting is unnecessary
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 5aded307b..96e90b0ea 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -13,6 +13,10 @@ from __future__ import annotations
import collections.abc as collections_abc
import typing
+from typing import Any
+from typing import List
+from typing import MutableMapping
+from typing import Optional
from . import coercions
from . import roles
@@ -40,8 +44,8 @@ from .. import util
class DMLState(CompileState):
_no_parameters = True
- _dict_parameters = None
- _multi_parameters = None
+ _dict_parameters: Optional[MutableMapping[str, Any]] = None
+ _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None
_ordered_values = None
_parameter_ordering = None
_has_multi_parameters = False
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 168da17cc..08d632afd 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -18,7 +18,9 @@ import re
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import Generic
+from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
@@ -47,6 +49,7 @@ from .coercions import _document_text_coercion # noqa
from .operators import ColumnOperators
from .traversals import HasCopyInternals
from .visitors import cloned_traverse
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .visitors import traverse
from .visitors import Visitable
@@ -68,6 +71,8 @@ if typing.TYPE_CHECKING:
from ..engine import Connection
from ..engine import Dialect
from ..engine import Engine
+ from ..engine.base import _CompiledCacheType
+ from ..engine.base import _SchemaTranslateMapType
_NUMERIC = Union[complex, "Decimal"]
@@ -238,6 +243,7 @@ class ClauseElement(
SupportsWrappingAnnotations,
MemoizedHasCacheKey,
HasCopyInternals,
+ ExternallyTraversible,
CompilerElement,
):
"""Base class for elements of a programmatically constructed SQL
@@ -398,7 +404,9 @@ class ClauseElement(
"""
return self._replace_params(True, optionaldict, kwargs)
- def params(self, *optionaldict, **kwargs):
+ def params(
+ self, *optionaldict: Dict[str, Any], **kwargs: Any
+ ) -> ClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
replaced.
@@ -415,7 +423,12 @@ class ClauseElement(
"""
return self._replace_params(False, optionaldict, kwargs)
- def _replace_params(self, unique, optionaldict, kwargs):
+ def _replace_params(
+ self,
+ unique: bool,
+ optionaldict: Optional[Dict[str, Any]],
+ kwargs: Dict[str, Any],
+ ) -> ClauseElement:
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
@@ -487,12 +500,12 @@ class ClauseElement(
def _compile_w_cache(
self,
- dialect,
- compiled_cache=None,
- column_keys=None,
- for_executemany=False,
- schema_translate_map=None,
- **kw,
+ dialect: Dialect,
+ compiled_cache: Optional[_CompiledCacheType] = None,
+ column_keys: Optional[List[str]] = None,
+ for_executemany: bool = False,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ **kw: Any,
):
if compiled_cache is not None and dialect._supports_statement_cache:
elem_cache_key = self._generate_cache_key()
@@ -1383,7 +1396,7 @@ class ColumnElement(
"""
return Cast(self, type_)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[_T]:
"""Produce a column label, i.e. ``<columnname> AS <name>``.
This is a shortcut to the :func:`_expression.label` function.
@@ -1608,6 +1621,9 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
("value", InternalTraversal.dp_plain_obj),
]
+ key: str
+ type: TypeEngine
+
_is_crud = False
_is_bind_parameter = True
_key_is_anon = False
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index eb3d17ee4..6e5eec127 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,6 +12,7 @@
from __future__ import annotations
from typing import Any
+from typing import Sequence
from typing import TypeVar
from . import annotation
@@ -839,6 +840,8 @@ class Function(FunctionElement):
identifier: str
+ packagenames: Sequence[str]
+
type: TypeEngine = sqltypes.NULLTYPE
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 64bd4b951..1a7a5f4d4 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -7,14 +7,22 @@
from __future__ import annotations
import typing
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
-from sqlalchemy.util.langhelpers import TypingOnly
from .. import util
-
+from ..util import TypingOnly
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
+ from .base import ColumnCollection
from .elements import ClauseElement
+ from .elements import Label
from .selectable import FromClause
+ from .selectable import Subquery
class SQLRole:
@@ -35,7 +43,7 @@ class SQLRole:
class UsesInspection:
__slots__ = ()
- _post_inspect = None
+ _post_inspect: Literal[None] = None
uses_inspection = True
@@ -96,7 +104,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
_role_name = "Column expression or FROM clause"
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> Sequence[ColumnsClauseRole]:
raise NotImplementedError()
@@ -150,6 +158,9 @@ class ExpressionElementRole(SQLRole):
__slots__ = ()
_role_name = "SQL expression element"
+ def label(self, name: Optional[str]) -> Label[Any]:
+ raise NotImplementedError()
+
class ConstExprRole(ExpressionElementRole):
__slots__ = ()
@@ -187,7 +198,7 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
_is_subquery = False
@property
- def _hide_froms(self):
+ def _hide_froms(self) -> Iterable[FromClause]:
raise NotImplementedError()
@@ -195,8 +206,10 @@ class StrictFromClauseRole(FromClauseRole):
__slots__ = ()
# does not allow text() or select() objects
+ c: ColumnCollection
+
@property
- def description(self):
+ def description(self) -> str:
raise NotImplementedError()
@@ -204,7 +217,9 @@ class AnonymizedFromClauseRole(StrictFromClauseRole):
__slots__ = ()
# calls .alias() as a post processor
- def _anonymous_fromclause(self, name=None, flat=False):
+ def _anonymous_fromclause(
+ self, name: Optional[str] = None, flat: bool = False
+ ) -> FromClause:
raise NotImplementedError()
@@ -220,14 +235,14 @@ class StatementRole(SQLRole):
__slots__ = ()
_role_name = "Executable SQL or text() construct"
- _propagate_attrs = util.immutabledict()
+ _propagate_attrs: Mapping[str, Any] = util.immutabledict()
class SelectStatementRole(StatementRole, ReturnsRowsRole):
__slots__ = ()
_role_name = "SELECT construct or equivalent text() construct"
- def subquery(self):
+ def subquery(self) -> Subquery:
raise NotImplementedError(
"All SelectStatementRole objects should implement a "
".subquery() method."
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index c270e1564..33e300bf6 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -51,7 +51,7 @@ from . import visitors
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
-from .base import SchemaEventTarget
+from .base import SchemaEventTarget as SchemaEventTarget
from .coercions import _document_text_coercion
from .elements import ClauseElement
from .elements import ColumnClause
@@ -2676,6 +2676,10 @@ class DefaultGenerator(Executable, SchemaItem):
def __init__(self, for_update=False):
self.for_update = for_update
+ @util.memoized_property
+ def is_callable(self):
+ raise NotImplementedError()
+
def _set_parent(self, column, **kw):
self.column = column
if self.for_update:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index e5c2bef68..09befb078 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -53,7 +53,6 @@ from .base import Generative
from .base import HasCompileState
from .base import HasMemoized
from .base import Immutable
-from .base import prefix_anon_map
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import BindParameter
@@ -69,10 +68,10 @@ from .elements import literal_column
from .elements import TableValuedColumn
from .elements import UnaryExpression
from .visitors import InternalTraversal
+from .visitors import prefix_anon_map
from .. import exc
from .. import util
-
and_ = BooleanClauseList.and_
_T = TypeVar("_T", bound=Any)
@@ -855,6 +854,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.alias(name=name)
+class NamedFromClause(FromClause):
+ named_with_column = True
+
+ name: str
+
+
class SelectLabelStyle(Enum):
"""Label style constants that may be passed to
:meth:`_sql.Select.set_label_style`."""
@@ -1317,15 +1322,16 @@ class NoInit:
# -> Lateral -> FromClause, but we accept SelectBase
# w/ non-deprecated coercion
# -> TableSample -> only for FromClause
-class AliasedReturnsRows(NoInit, FromClause):
+class AliasedReturnsRows(NoInit, NamedFromClause):
"""Base class of aliases against tables, subqueries, and other
selectables."""
_is_from_container = True
- named_with_column = True
_supports_derived_columns = False
+ element: ClauseElement
+
_traverse_internals = [
("element", InternalTraversal.dp_clauseelement),
("name", InternalTraversal.dp_anon_name),
@@ -1423,6 +1429,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows):
inherit_cache = True
+ element: FromClause
+
@classmethod
def _factory(cls, selectable, name=None, flat=False):
return coercions.expect(
@@ -1689,6 +1697,8 @@ class CTE(
+ HasSuffixes._has_suffixes_traverse_internals
)
+ element: HasCTE
+
@classmethod
def _factory(cls, selectable, name=None, recursive=False):
r"""Return a new :class:`_expression.CTE`,
@@ -1819,7 +1829,7 @@ class _CTEOpts(NamedTuple):
nesting: bool
-class HasCTE(roles.HasCTERole):
+class HasCTE(roles.HasCTERole, ClauseElement):
"""Mixin that declares a class to include CTE support.
.. versionadded:: 1.1
@@ -2247,6 +2257,8 @@ class Subquery(AliasedReturnsRows):
inherit_cache = True
+ element: Select
+
@classmethod
def _factory(cls, selectable, name=None):
"""Return a :class:`.Subquery` object."""
@@ -2331,7 +2343,7 @@ class FromGrouping(GroupedElement, FromClause):
self.element = state["element"]
-class TableClause(roles.DMLTableRole, Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
"""Represents a minimal "table" construct.
This is a lightweight table object that has only a name, a
@@ -2371,8 +2383,6 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause):
("name", InternalTraversal.dp_string),
]
- named_with_column = True
-
_is_table = True
implicit_returning = False
@@ -2542,7 +2552,7 @@ class ForUpdateArg(ClauseElement):
SelfValues = typing.TypeVar("SelfValues", bound="Values")
-class Values(Generative, FromClause):
+class Values(Generative, NamedFromClause):
"""Represent a ``VALUES`` construct that can be used as a FROM element
in a statement.
@@ -2553,7 +2563,6 @@ class Values(Generative, FromClause):
"""
- named_with_column = True
__visit_name__ = "values"
_data = ()
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 7d21f1262..b2b1d9bc2 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -35,13 +35,13 @@ from .elements import _NONE_NAME
from .elements import quoted_name
from .elements import Slice
from .elements import TypeCoerce as type_coerce # noqa
-from .traversals import InternalTraversal
from .type_api import Emulated
from .type_api import NativeForEmulated # noqa
from .type_api import to_instance
from .type_api import TypeDecorator
from .type_api import TypeEngine
from .type_api import Variant # noqa
+from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import inspection
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 4fa23d370..cf9487f93 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -15,7 +15,10 @@ import operator
import typing
from typing import Any
from typing import Callable
+from typing import Deque
from typing import Dict
+from typing import Set
+from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -23,9 +26,9 @@ from . import operators
from .cache_key import HasCacheKey
from .visitors import _TraverseInternalsType
from .visitors import anon_map
-from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
+from .visitors import HasTraversalDispatch
from .visitors import HasTraverseInternals
-from .visitors import InternalTraversal
from .. import util
from ..util import langhelpers
@@ -35,6 +38,7 @@ COMPARE_SUCCEEDED = True
def compare(obj1, obj2, **kw):
+ strategy: TraversalComparatorStrategy
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
else:
@@ -45,16 +49,18 @@ def compare(obj1, obj2, **kw):
def _preconfigure_traversals(target_hierarchy):
for cls in util.walk_subclasses(target_hierarchy):
- if hasattr(cls, "_traverse_internals"):
- cls._generate_cache_attrs()
+ if hasattr(cls, "_generate_cache_attrs") and hasattr(
+ cls, "_traverse_internals"
+ ):
+ cls._generate_cache_attrs() # type: ignore
_copy_internals.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_copy_internals_traversal",
)
_get_children.generate_dispatch(
- cls,
- cls._traverse_internals,
+ cls, # type: ignore
+ cls._traverse_internals, # type: ignore
"_generated_get_children_traversal",
)
@@ -125,54 +131,58 @@ class HasShallowCopy(HasTraverseInternals):
meth_text = f"def {method_name}(self, d):\n{code}\n"
return langhelpers._exec_code_in_env(meth_text, {}, method_name)
- def _shallow_from_dict(self, d: Dict) -> None:
+ def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
cls = self.__class__
+ shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
try:
shallow_from_dict = cls.__dict__[
"_generated_shallow_from_dict_traversal"
]
except KeyError:
- shallow_from_dict = (
- cls._generated_shallow_from_dict_traversal # type: ignore
- ) = self._generate_shallow_from_dict(
+ shallow_from_dict = self._generate_shallow_from_dict(
cls._traverse_internals,
"_generated_shallow_from_dict_traversal",
)
+ cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa E501
+
shallow_from_dict(self, d)
def _shallow_to_dict(self) -> Dict[str, Any]:
cls = self.__class__
+ shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
+
try:
shallow_to_dict = cls.__dict__[
"_generated_shallow_to_dict_traversal"
]
except KeyError:
- shallow_to_dict = (
- cls._generated_shallow_to_dict_traversal # type: ignore
- ) = self._generate_shallow_to_dict(
+ shallow_to_dict = self._generate_shallow_to_dict(
cls._traverse_internals, "_generated_shallow_to_dict_traversal"
)
+ cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa E501
return shallow_to_dict(self)
- def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy):
+ def _shallow_copy_to(
+ self: SelfHasShallowCopy, other: SelfHasShallowCopy
+ ) -> None:
cls = self.__class__
+ shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None]
try:
shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
except KeyError:
- shallow_copy = (
- cls._generated_shallow_copy_traversal # type: ignore
- ) = self._generate_shallow_copy(
+ shallow_copy = self._generate_shallow_copy(
cls._traverse_internals, "_generated_shallow_copy_traversal"
)
+ cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
shallow_copy(self, other)
- def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy:
+ def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy:
"""Create a shallow copy"""
c = self.__class__.__new__(self.__class__)
self._shallow_copy_to(c)
@@ -246,7 +256,7 @@ class HasCopyInternals(HasTraverseInternals):
setattr(self, attrname, result)
-class _CopyInternalsTraversal(InternalTraversal):
+class _CopyInternalsTraversal(HasTraversalDispatch):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
@@ -381,7 +391,7 @@ def _flatten_clauseelement(element):
return element
-class _GetChildrenTraversal(InternalTraversal):
+class _GetChildrenTraversal(HasTraversalDispatch):
"""Generate a _children_traversal internal traversal dispatch for classes
with a _traverse_internals collection."""
@@ -463,13 +473,13 @@ def _resolve_name_for_compare(element, name, anon_map, **kw):
return name
-class TraversalComparatorStrategy(
- ExtendedInternalTraversal, util.MemoizedSlots
-):
+class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
__slots__ = "stack", "cache", "anon_map"
def __init__(self):
- self.stack = deque()
+ self.stack: Deque[
+ Tuple[ExternallyTraversible, ExternallyTraversible]
+ ] = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
@@ -653,7 +663,7 @@ class TraversalComparatorStrategy(
if seq1 is None:
return seq2 is None
- completed = set()
+ completed: Set[object] = set()
for clause in seq1:
for other_clause in set(seq2).difference(completed):
if self.compare_inner(clause, other_clause, **kw):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index e0248adf0..5114a2431 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -21,9 +21,9 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
-from .annotation import _deep_annotate # noqa
-from .annotation import _deep_deannotate # noqa
-from .annotation import _shallow_annotate # noqa
+from .annotation import _deep_annotate as _deep_annotate
+from .annotation import _deep_deannotate as _deep_deannotate
+from .annotation import _shallow_annotate as _shallow_annotate
from .base import _expand_cloned
from .base import _from_objects
from .base import ColumnSet
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 111ecd32e..0c41e440e 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -7,43 +7,46 @@
"""Visitor/traversal interface and library functions.
-SQLAlchemy schema and expression constructs rely on a Python-centric
-version of the classic "visitor" pattern as the primary way in which
-they apply functionality. The most common use of this pattern
-is statement compilation, where individual expression classes match
-up to rendering methods that produce a string result. Beyond this,
-the visitor system is also used to inspect expressions for various
-information and patterns, as well as for the purposes of applying
-transformations to expressions.
-
-Examples of how the visit system is used can be seen in the source code
-of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler``
-modules. Some background on clause adaption is also at
-https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
"""
from __future__ import annotations
from collections import deque
+from enum import Enum
import itertools
import operator
import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import ClassVar
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
+from typing import Optional
from typing import Tuple
+from typing import Type
+from typing import TypeVar
+from typing import Union
from .. import exc
from .. import util
from ..util import langhelpers
-from ..util import symbol
from ..util._has_cy import HAS_CYEXTENSION
-from ..util.langhelpers import _symbol
+from ..util.typing import Protocol
+from ..util.typing import Self
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_util import cache_anon_map as anon_map # noqa
+ from ._py_util import prefix_anon_map as prefix_anon_map
+ from ._py_util import cache_anon_map as anon_map
else:
- from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
+ from sqlalchemy.cyextension.util import prefix_anon_map as prefix_anon_map
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map
+
__all__ = [
"iterate",
@@ -54,57 +57,23 @@ __all__ = [
"Visitable",
"ExternalTraversal",
"InternalTraversal",
+ "anon_map",
]
-_TraverseInternalsType = List[Tuple[str, _symbol]]
-
-
-class HasTraverseInternals:
- """base for classes that have a "traverse internals" element,
- which defines all kinds of ways of traversing the elements of an object.
-
- """
-
- __slots__ = ()
-
- _traverse_internals: _TraverseInternalsType
-
- @util.preload_module("sqlalchemy.sql.traversals")
- def get_children(self, omit_attrs=(), **kw):
- r"""Return immediate child :class:`.visitors.Visitable`
- elements of this :class:`.visitors.Visitable`.
-
- This is used for visit traversal.
-
- \**kw may contain flags that change the collection that is
- returned, for example to return a subset of items in order to
- cut down on larger traversals, or to return child items from a
- different context (such as schema-level collections instead of
- clause-level).
-
- """
-
- traversals = util.preloaded.sql_traversals
-
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return []
- dispatch = traversals._get_children.run_generated_dispatch
- return itertools.chain.from_iterable(
- meth(obj, **kw)
- for attrname, obj, meth in dispatch(
- self, traverse_internals, "_generated_get_children_traversal"
- )
- if attrname not in omit_attrs and obj is not None
- )
+class _CompilerDispatchType(Protocol):
+ def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any:
+ ...
class Visitable:
"""Base class for visitable objects.
+ :class:`.Visitable` is used to implement the SQL compiler dispatch
+ functions. Other forms of traversal such as for cache key generation
+ are implemented separately using the :class:`.HasTraverseInternals`
+ interface.
+
.. versionchanged:: 2.0 The :class:`.Visitable` class was named
:class:`.Traversible` in the 1.4 series; the name is changed back
to :class:`.Visitable` in 2.0 which is what it was prior to 1.4.
@@ -117,32 +86,20 @@ class Visitable:
__visit_name__: str
+ _original_compiler_dispatch: _CompilerDispatchType
+
+ if typing.TYPE_CHECKING:
+
+ def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str:
+ ...
+
def __init_subclass__(cls) -> None:
if "__visit_name__" in cls.__dict__:
cls._generate_compiler_dispatch()
super().__init_subclass__()
@classmethod
- def _generate_compiler_dispatch(cls):
- """Assign dispatch attributes to various kinds of
- "visitable" classes.
-
- Attributes include:
-
- * The ``_compiler_dispatch`` method, corresponding to
- ``__visit_name__``. This is called "external traversal" because the
- caller of each visit() method is responsible for sub-traversing the
- inner elements of each object. This is appropriate for string
- compilers and other traversals that need to call upon the inner
- elements in a specific pattern.
-
- * internal traversal collections ``_children_traversal``,
- ``_cache_key_traversal``, ``_copy_internals_traversal``, generated
- from an optional ``_traverse_internals`` collection of symbols which
- comes from the :class:`.InternalTraversal` list of symbols. This is
- called "internal traversal".
-
- """
+ def _generate_compiler_dispatch(cls) -> None:
visit_name = cls.__visit_name__
if "_compiler_dispatch" in cls.__dict__:
@@ -161,7 +118,9 @@ class Visitable:
name = "visit_%s" % visit_name
getter = operator.attrgetter(name)
- def _compiler_dispatch(self, visitor, **kw):
+ def _compiler_dispatch(
+ self: Visitable, visitor: Any, **kw: Any
+ ) -> str:
"""Look for an attribute named "visit_<visit_name>" on the
visitor, and call it with the same kw params.
@@ -169,105 +128,20 @@ class Visitable:
try:
meth = getter(visitor)
except AttributeError as err:
- return visitor.visit_unsupported_compilation(self, err, **kw)
+ return visitor.visit_unsupported_compilation(self, err, **kw) # type: ignore # noqa E501
else:
- return meth(self, **kw)
+ return meth(self, **kw) # type: ignore # noqa E501
- cls._compiler_dispatch = (
+ cls._compiler_dispatch = ( # type: ignore
cls._original_compiler_dispatch
) = _compiler_dispatch
- def __class_getitem__(cls, key):
+ def __class_getitem__(cls, key: str) -> Any:
# allow generic classes in py3.9+
return cls
-class _HasTraversalDispatch:
- r"""Define infrastructure for the :class:`.InternalTraversal` class.
-
- .. versionadded:: 2.0
-
- """
-
- __slots__ = ()
-
- def __init_subclass__(cls) -> None:
- cls._generate_traversal_dispatch()
- super().__init_subclass__()
-
- def dispatch(self, visit_symbol):
- """Given a method from :class:`._HasTraversalDispatch`, return the
- corresponding method on a subclass.
-
- """
- name = self._dispatch_lookup[visit_symbol]
- return getattr(self, name, None)
-
- def run_generated_dispatch(
- self, target, internal_dispatch, generate_dispatcher_name
- ):
- try:
- dispatcher = target.__class__.__dict__[generate_dispatcher_name]
- except KeyError:
- # most of the dispatchers are generated up front
- # in sqlalchemy/sql/__init__.py ->
- # traversals.py-> _preconfigure_traversals().
- # this block will generate any remaining dispatchers.
- dispatcher = self.generate_dispatch(
- target.__class__, internal_dispatch, generate_dispatcher_name
- )
- return dispatcher(target, self)
-
- def generate_dispatch(
- self, target_cls, internal_dispatch, generate_dispatcher_name
- ):
- dispatcher = self._generate_dispatcher(
- internal_dispatch, generate_dispatcher_name
- )
- # assert isinstance(target_cls, type)
- setattr(target_cls, generate_dispatcher_name, dispatcher)
- return dispatcher
-
- @classmethod
- def _generate_traversal_dispatch(cls):
- lookup = {}
- clsdict = cls.__dict__
- for key, sym in clsdict.items():
- if key.startswith("dp_"):
- visit_key = key.replace("dp_", "visit_")
- sym_name = sym.name
- assert sym_name not in lookup, sym_name
- lookup[sym] = lookup[sym_name] = visit_key
- if hasattr(cls, "_dispatch_lookup"):
- lookup.update(cls._dispatch_lookup)
- cls._dispatch_lookup = lookup
-
- def _generate_dispatcher(self, internal_dispatch, method_name):
- names = []
- for attrname, visit_sym in internal_dispatch:
- meth = self.dispatch(visit_sym)
- if meth:
- visit_name = ExtendedInternalTraversal._dispatch_lookup[
- visit_sym
- ]
- names.append((attrname, visit_name))
-
- code = (
- (" return [\n")
- + (
- ", \n".join(
- " (%r, self.%s, visitor.%s)"
- % (attrname, attrname, visit_name)
- for attrname, visit_name in names
- )
- )
- + ("\n ]\n")
- )
- meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
- return langhelpers._exec_code_in_env(meth_text, {}, method_name)
-
-
-class InternalTraversal(_HasTraversalDispatch):
+class InternalTraversal(Enum):
r"""Defines visitor symbols used for internal traversal.
The :class:`.InternalTraversal` class is used in two ways. One is that
@@ -306,18 +180,16 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- __slots__ = ()
-
- dp_has_cache_key = symbol("HC")
+ dp_has_cache_key = "HC"
"""Visit a :class:`.HasCacheKey` object."""
- dp_has_cache_key_list = symbol("HL")
+ dp_has_cache_key_list = "HL"
"""Visit a list of :class:`.HasCacheKey` objects."""
- dp_clauseelement = symbol("CE")
+ dp_clauseelement = "CE"
"""Visit a :class:`_expression.ClauseElement` object."""
- dp_fromclause_canonical_column_collection = symbol("FC")
+ dp_fromclause_canonical_column_collection = "FC"
"""Visit a :class:`_expression.FromClause` object in the context of the
``columns`` attribute.
@@ -329,30 +201,30 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_clauseelement_tuples = symbol("CTS")
+ dp_clauseelement_tuples = "CTS"
"""Visit a list of tuples which contain :class:`_expression.ClauseElement`
objects.
"""
- dp_clauseelement_list = symbol("CL")
+ dp_clauseelement_list = "CL"
"""Visit a list of :class:`_expression.ClauseElement` objects.
"""
- dp_clauseelement_tuple = symbol("CT")
+ dp_clauseelement_tuple = "CT"
"""Visit a tuple of :class:`_expression.ClauseElement` objects.
"""
- dp_executable_options = symbol("EO")
+ dp_executable_options = "EO"
- dp_with_context_options = symbol("WC")
+ dp_with_context_options = "WC"
- dp_fromclause_ordered_set = symbol("CO")
+ dp_fromclause_ordered_set = "CO"
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
- dp_string = symbol("S")
+ dp_string = "S"
"""Visit a plain string value.
Examples include table and column names, bound parameter keys, special
@@ -363,10 +235,10 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_string_list = symbol("SL")
+ dp_string_list = "SL"
"""Visit a list of strings."""
- dp_anon_name = symbol("AN")
+ dp_anon_name = "AN"
"""Visit a potentially "anonymized" string value.
The string value is considered to be significant for cache key
@@ -374,7 +246,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_boolean = symbol("B")
+ dp_boolean = "B"
"""Visit a boolean value.
The boolean value is considered to be significant for cache key
@@ -382,7 +254,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_operator = symbol("O")
+ dp_operator = "O"
"""Visit an operator.
The operator is a function from the :mod:`sqlalchemy.sql.operators`
@@ -393,7 +265,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_type = symbol("T")
+ dp_type = "T"
"""Visit a :class:`.TypeEngine` object
The type object is considered to be significant for cache key
@@ -401,7 +273,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_plain_dict = symbol("PD")
+ dp_plain_dict = "PD"
"""Visit a dictionary with string keys.
The keys of the dictionary should be strings, the values should
@@ -410,22 +282,22 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_dialect_options = symbol("DO")
+ dp_dialect_options = "DO"
"""Visit a dialect options structure."""
- dp_string_clauseelement_dict = symbol("CD")
+ dp_string_clauseelement_dict = "CD"
"""Visit a dictionary of string keys to :class:`_expression.ClauseElement`
objects.
"""
- dp_string_multi_dict = symbol("MD")
+ dp_string_multi_dict = "MD"
"""Visit a dictionary of string keys to values which may either be
plain immutable/hashable or :class:`.HasCacheKey` objects.
"""
- dp_annotations_key = symbol("AK")
+ dp_annotations_key = "AK"
"""Visit the _annotations_cache_key element.
This is a dictionary of additional information about a ClauseElement
@@ -436,7 +308,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_plain_obj = symbol("PO")
+ dp_plain_obj = "PO"
"""Visit a plain python object.
The value should be immutable and hashable, such as an integer.
@@ -444,7 +316,7 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_named_ddl_element = symbol("DD")
+ dp_named_ddl_element = "DD"
"""Visit a simple named DDL element.
The current object used by this method is the :class:`.Sequence`.
@@ -454,57 +326,56 @@ class InternalTraversal(_HasTraversalDispatch):
"""
- dp_prefix_sequence = symbol("PS")
+ dp_prefix_sequence = "PS"
"""Visit the sequence represented by :class:`_expression.HasPrefixes`
or :class:`_expression.HasSuffixes`.
"""
- dp_table_hint_list = symbol("TH")
+ dp_table_hint_list = "TH"
"""Visit the ``_hints`` collection of a :class:`_expression.Select`
object.
"""
- dp_setup_join_tuple = symbol("SJ")
+ dp_setup_join_tuple = "SJ"
- dp_memoized_select_entities = symbol("ME")
+ dp_memoized_select_entities = "ME"
- dp_statement_hint_list = symbol("SH")
+ dp_statement_hint_list = "SH"
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`
object.
"""
- dp_unknown_structure = symbol("UK")
+ dp_unknown_structure = "UK"
"""Visit an unknown structure.
"""
- dp_dml_ordered_values = symbol("DML_OV")
+ dp_dml_ordered_values = "DML_OV"
"""Visit the values() ordered tuple list of an
:class:`_expression.Update` object."""
- dp_dml_values = symbol("DML_V")
+ dp_dml_values = "DML_V"
"""Visit the values() dictionary of a :class:`.ValuesBase`
(e.g. Insert or Update) object.
"""
- dp_dml_multi_values = symbol("DML_MV")
+ dp_dml_multi_values = "DML_MV"
"""Visit the values() multi-valued list of dictionaries of an
:class:`_expression.Insert` object.
"""
- dp_propagate_attrs = symbol("PA")
+ dp_propagate_attrs = "PA"
"""Visit the propagate attrs dict. This hardcodes to the particular
elements we care about right now."""
-
-class ExtendedInternalTraversal(InternalTraversal):
- """Defines additional symbols that are useful in caching applications.
+ """Symbols that follow are additional symbols that are useful in
+ caching applications.
Traversals for :class:`_expression.ClauseElement` objects only need to use
those symbols present in :class:`.InternalTraversal`. However, for
@@ -513,9 +384,7 @@ class ExtendedInternalTraversal(InternalTraversal):
"""
- __slots__ = ()
-
- dp_ignore = symbol("IG")
+ dp_ignore = "IG"
"""Specify an object that should be ignored entirely.
This currently applies function call argument caching where some
@@ -523,29 +392,235 @@ class ExtendedInternalTraversal(InternalTraversal):
"""
- dp_inspectable = symbol("IS")
+ dp_inspectable = "IS"
"""Visit an inspectable object where the return value is a
:class:`.HasCacheKey` object."""
- dp_multi = symbol("M")
+ dp_multi = "M"
"""Visit an object that may be a :class:`.HasCacheKey` or may be a
plain hashable object."""
- dp_multi_list = symbol("MT")
+ dp_multi_list = "MT"
"""Visit a tuple containing elements that may be :class:`.HasCacheKey` or
may be a plain hashable object."""
- dp_has_cache_key_tuples = symbol("HT")
+ dp_has_cache_key_tuples = "HT"
"""Visit a list of tuples which contain :class:`.HasCacheKey`
objects.
"""
- dp_inspectable_list = symbol("IL")
+ dp_inspectable_list = "IL"
"""Visit a list of inspectable objects which upon inspection are
HasCacheKey objects."""
+_TraverseInternalsType = List[Tuple[str, InternalTraversal]]
+"""a structure that defines how a HasTraverseInternals should be
+traversed.
+
+This structure consists of a list of (attributename, internaltraversal)
+tuples, where the "attributename" refers to the name of an attribute on an
+instance of the HasTraverseInternals object, and "internaltraversal" refers
+to an :class:`.InternalTraversal` enumeration symbol defining what kind
+of data this attribute stores, which indicates to the traverser how it should
+be handled.
+
+"""
+
+
+class HasTraverseInternals:
+ """base for classes that have a "traverse internals" element,
+ which defines all kinds of ways of traversing the elements of an object.
+
+ Compared to :class:`.Visitable`, which relies upon an external visitor to
+ define how the object is travered (i.e. the :class:`.SQLCompiler`), the
+ :class:`.HasTraverseInternals` interface allows classes to define their own
+ traversal, that is, what attributes are accessed and in what order.
+
+ """
+
+ __slots__ = ()
+
+ _traverse_internals: _TraverseInternalsType
+
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(
+ self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Iterable[HasTraverseInternals]:
+ r"""Return immediate child :class:`.visitors.HasTraverseInternals`
+ elements of this :class:`.visitors.HasTraverseInternals`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
+
+class _InternalTraversalDispatchType(Protocol):
+ def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any:
+ ...
+
+
+class HasTraversalDispatch:
+ r"""Define infrastructure for classes that perform internal traversals
+
+ .. versionadded:: 2.0
+
+ """
+
+ __slots__ = ()
+
+ _dispatch_lookup: ClassVar[Dict[Union[InternalTraversal, str], str]] = {}
+
+ def dispatch(self, visit_symbol: InternalTraversal) -> Callable[..., Any]:
+ """Given a method from :class:`.HasTraversalDispatch`, return the
+ corresponding method on a subclass.
+
+ """
+ name = _dispatch_lookup[visit_symbol]
+ return getattr(self, name, None) # type: ignore
+
+ def run_generated_dispatch(
+ self,
+ target: object,
+ internal_dispatch: _TraverseInternalsType,
+ generate_dispatcher_name: str,
+ ) -> Any:
+ dispatcher: _InternalTraversalDispatchType
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ # traversals.py -> _preconfigure_traversals()
+ # may be used to run these ahead of time, but
+ # is not enabled right now.
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.generate_dispatch(
+ target.__class__, internal_dispatch, generate_dispatcher_name
+ )
+ return dispatcher(target, self)
+
+ def generate_dispatch(
+ self,
+ target_cls: Type[object],
+ internal_dispatch: _TraverseInternalsType,
+ generate_dispatcher_name: str,
+ ) -> _InternalTraversalDispatchType:
+ dispatcher = self._generate_dispatcher(
+ internal_dispatch, generate_dispatcher_name
+ )
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
+ return dispatcher
+
+ def _generate_dispatcher(
+ self, internal_dispatch: _TraverseInternalsType, method_name: str
+ ) -> _InternalTraversalDispatchType:
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = self.dispatch(visit_sym)
+ if meth:
+ visit_name = _dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ return cast(
+ _InternalTraversalDispatchType,
+ langhelpers._exec_code_in_env(meth_text, {}, method_name),
+ )
+
+
+ExtendedInternalTraversal = InternalTraversal
+
+
+def _generate_traversal_dispatch() -> None:
+ lookup = _dispatch_lookup
+
+ for sym in InternalTraversal:
+ key = sym.name
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.value
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+
+
+_dispatch_lookup = HasTraversalDispatch._dispatch_lookup
+_generate_traversal_dispatch()
+
+
+class ExternallyTraversible(HasTraverseInternals, Visitable):
+ __slots__ = ()
+
+ _annotations: Collection[Any] = ()
+
+ if typing.TYPE_CHECKING:
+
+ def get_children(
+ self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Iterable[ExternallyTraversible]:
+ ...
+
+ def _clone(self: Self, **kw: Any) -> Self:
+ """clone this element"""
+ raise NotImplementedError()
+
+ def _copy_internals(
+ self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ ) -> Self:
+ """Reassign internal elements to be clones of themselves.
+
+ Called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy.
+
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
+ """
+ raise NotImplementedError()
+
+
+_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_TraverseCallableType = Callable[[_ET], None]
+_TraverseTransformCallableType = Callable[
+ [ExternallyTraversible], Optional[ExternallyTraversible]
+]
+
+
class ExternalTraversal:
"""Base class for visitor objects which can traverse externally using
the :func:`.visitors.traverse` function.
@@ -555,7 +630,8 @@ class ExternalTraversal:
"""
- __traverse_options__ = {}
+ __traverse_options__: Dict[str, Any] = {}
+ _next: Optional[ExternalTraversal]
def traverse_single(self, obj: Visitable, **kw: Any) -> Any:
for v in self.visitor_iterator:
@@ -563,20 +639,22 @@ class ExternalTraversal:
if meth:
return meth(obj, **kw)
- def iterate(self, obj):
+ def iterate(
+ self, obj: ExternallyTraversible
+ ) -> Iterator[ExternallyTraversible]:
"""Traverse the given expression structure, returning an iterator
of all elements.
"""
return iterate(obj, self.__traverse_options__)
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@util.memoized_property
- def _visitor_dict(self):
+ def _visitor_dict(self) -> Dict[str, _TraverseCallableType[Any]]:
visitors = {}
for name in dir(self):
@@ -585,16 +663,16 @@ class ExternalTraversal:
return visitors
@property
- def visitor_iterator(self):
+ def visitor_iterator(self) -> Iterator[ExternalTraversal]:
"""Iterate through this visitor and each 'chained' visitor."""
- v = self
+ v: Optional[ExternalTraversal] = self
while v:
yield v
v = getattr(v, "_next", None)
- def chain(self, visitor):
- """'Chain' an additional ClauseVisitor onto this ClauseVisitor.
+ def chain(self, visitor: ExternalTraversal) -> ExternalTraversal:
+ """'Chain' an additional ExternalTraversal onto this ExternalTraversal
The chained visitor will receive all visit events after this one.
@@ -614,14 +692,16 @@ class CloningExternalTraversal(ExternalTraversal):
"""
- def copy_and_process(self, list_):
+ def copy_and_process(
+ self, list_: List[ExternallyTraversible]
+ ) -> List[ExternallyTraversible]:
"""Apply cloned traversal to the given list of elements, and return
the new list.
"""
return [self.traverse(x) for x in list_]
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
return cloned_traverse(
@@ -638,7 +718,9 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
"""
- def replace(self, elem):
+ def replace(
+ self, elem: ExternallyTraversible
+ ) -> Optional[ExternallyTraversible]:
"""Receive pre-copied elements during a cloning traversal.
If the method returns a new element, the element is used
@@ -647,15 +729,19 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
"""
return None
- def traverse(self, obj):
+ def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
"""Traverse and visit the given expression structure."""
- def replace(elem):
+ def replace(
+ elem: ExternallyTraversible,
+ ) -> Optional[ExternallyTraversible]:
for v in self.visitor_iterator:
- e = v.replace(elem)
+ e = cast(ReplacingExternalTraversal, v).replace(elem)
if e is not None:
return e
+ return None
+
return replacement_traverse(obj, self.__traverse_options__, replace)
@@ -667,7 +753,9 @@ CloningVisitor = CloningExternalTraversal
ReplacingCloningVisitor = ReplacingExternalTraversal
-def iterate(obj, opts=util.immutabledict()):
+def iterate(
+ obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT
+) -> Iterator[ExternallyTraversible]:
r"""Traverse the given expression structure, returning an iterator.
Traversal is configured to be breadth-first.
@@ -702,7 +790,11 @@ def iterate(obj, opts=util.immutabledict()):
stack.append(t.get_children(**opts))
-def traverse_using(iterator, obj, visitors):
+def traverse_using(
+ iterator: Iterable[ExternallyTraversible],
+ obj: ExternallyTraversible,
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
"""Visit the given expression structure using the given iterator of
objects.
@@ -734,7 +826,11 @@ def traverse_using(iterator, obj, visitors):
return obj
-def traverse(obj, opts, visitors):
+def traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseCallableType[Any]],
+) -> ExternallyTraversible:
"""Traverse and visit the given expression structure using the default
iterator.
@@ -767,7 +863,11 @@ def traverse(obj, opts, visitors):
return traverse_using(iterate(obj, opts), obj, visitors)
-def cloned_traverse(obj, opts, visitors):
+def cloned_traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ visitors: Mapping[str, _TraverseTransformCallableType],
+) -> ExternallyTraversible:
"""Clone the given expression structure, allowing modifications by
visitors.
@@ -794,20 +894,24 @@ def cloned_traverse(obj, opts, visitors):
"""
- cloned = {}
+ cloned: Dict[int, ExternallyTraversible] = {}
stop_on = set(opts.get("stop_on", []))
- def deferred_copy_internals(obj):
+ def deferred_copy_internals(
+ obj: ExternallyTraversible,
+ ) -> ExternallyTraversible:
return cloned_traverse(obj, opts, visitors)
- def clone(elem, **kw):
+ def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
if "replace" in kw:
- newelem = kw["replace"](elem)
+ newelem = cast(
+ Optional[ExternallyTraversible], kw["replace"](elem)
+ )
if newelem is not None:
cloned[id(elem)] = newelem
return newelem
@@ -823,11 +927,15 @@ def cloned_traverse(obj, opts, visitors):
obj = clone(
obj, deferred_copy_internals=deferred_copy_internals, **opts
)
- clone = None # remove gc cycles
+ clone = None # type: ignore[assignment] # remove gc cycles
return obj
-def replacement_traverse(obj, opts, replace):
+def replacement_traverse(
+ obj: ExternallyTraversible,
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType,
+) -> ExternallyTraversible:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
@@ -854,10 +962,12 @@ def replacement_traverse(obj, opts, replace):
cloned = {}
stop_on = {id(x) for x in opts.get("stop_on", [])}
- def deferred_copy_internals(obj):
+ def deferred_copy_internals(
+ obj: ExternallyTraversible,
+ ) -> ExternallyTraversible:
return replacement_traverse(obj, opts, replace)
- def clone(elem, **kw):
+ def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible:
if (
id(elem) in stop_on
or "no_replacement_traverse" in elem._annotations
@@ -888,5 +998,5 @@ def replacement_traverse(obj, opts, replace):
obj = clone(
obj, deferred_copy_internals=deferred_copy_internals, **opts
)
- clone = None # remove gc cycles
+ clone = None # type: ignore[assignment] # remove gc cycles
return obj