diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/_py_util.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 305 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/cache_key.py | 354 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 235 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 447 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 34 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 33 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 29 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 64 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 598 |
16 files changed, 1499 insertions, 667 deletions
diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py index 96e8f6b2c..9f18b882d 100644 --- a/lib/sqlalchemy/sql/_py_util.py +++ b/lib/sqlalchemy/sql/_py_util.py @@ -7,7 +7,16 @@ from __future__ import annotations +import typing +from typing import Any from typing import Dict +from typing import Tuple +from typing import Union + +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .cache_key import CacheConst class prefix_anon_map(Dict[str, str]): @@ -22,16 +31,18 @@ class prefix_anon_map(Dict[str, str]): """ - def __missing__(self, key): + def __missing__(self, key: str) -> str: (ident, derived) = key.split(" ", 1) anonymous_counter = self.get(derived, 1) - self[derived] = anonymous_counter + 1 + self[derived] = anonymous_counter + 1 # type: ignore value = f"{derived}_{anonymous_counter}" self[key] = value return value -class cache_anon_map(Dict[int, str]): +class cache_anon_map( + Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]] +): """A map that creates new keys for missing key access. Produces an incrementing sequence given a series of unique keys. @@ -45,11 +56,13 @@ class cache_anon_map(Dict[int, str]): _index = 0 - def get_anon(self, object_): + def get_anon(self, object_: Any) -> Tuple[str, bool]: idself = id(object_) if idself in self: - return self[idself], True + s_val = self[idself] + assert s_val is not True + return s_val, True else: # inline of __missing__ self[idself] = id_ = str(self._index) @@ -57,7 +70,7 @@ class cache_anon_map(Dict[int, str]): return id_, False - def __missing__(self, key): + def __missing__(self, key: int) -> str: self[key] = val = str(self._index) self._index += 1 return val diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index b76393ad6..7afc2de97 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -13,22 +13,77 @@ associations. from __future__ import annotations +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar + from . import operators -from .base import HasCacheKey -from .traversals import anon_map +from .cache_key import HasCacheKey +from .visitors import anon_map +from .visitors import ExternallyTraversible from .visitors import InternalTraversal from .. import util +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .visitors import _TraverseInternalsType + from ..util.typing import Self + +_AnnotationDict = Mapping[str, Any] + +EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT + -EMPTY_ANNOTATIONS = util.immutabledict() +SelfSupportsAnnotations = TypeVar( + "SelfSupportsAnnotations", bound="SupportsAnnotations" +) -class SupportsAnnotations: +class SupportsAnnotations(ExternallyTraversible): __slots__ = () - _annotations = EMPTY_ANNOTATIONS + _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS + proxy_set: Set[SupportsAnnotations] + _is_immutable: bool + + def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations: + raise NotImplementedError() + + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + raise NotImplementedError() @util.memoized_property - def _annotations_cache_key(self): + def _annotations_cache_key(self) -> Tuple[Any, ...]: anon_map_ = anon_map() return ( "_annotations", @@ -47,14 +102,22 @@ class SupportsAnnotations: ) +SelfSupportsCloneAnnotations = TypeVar( + "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations" +) + + class SupportsCloneAnnotations(SupportsAnnotations): - __slots__ = () + if not typing.TYPE_CHECKING: + __slots__ = () - _clone_annotations_traverse_internals = [ + _clone_annotations_traverse_internals: _TraverseInternalsType = [ ("_annotations", InternalTraversal.dp_annotations_key) ] - def _annotate(self, values): + def _annotate( + self: SelfSupportsCloneAnnotations, values: _AnnotationDict + ) -> SelfSupportsCloneAnnotations: """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -65,7 +128,9 @@ class SupportsCloneAnnotations(SupportsAnnotations): new.__dict__.pop("_generate_cache_key", None) return new - def _with_annotations(self, values): + def _with_annotations( + self: SelfSupportsCloneAnnotations, values: _AnnotationDict + ) -> SelfSupportsCloneAnnotations: """return a copy of this ClauseElement with annotations replaced by the given dictionary. @@ -76,7 +141,27 @@ class SupportsCloneAnnotations(SupportsAnnotations): new.__dict__.pop("_generate_cache_key", None) return new - def _deannotate(self, values=None, clone=False): + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: """return a copy of this :class:`_expression.ClauseElement` with annotations removed. @@ -96,24 +181,52 @@ class SupportsCloneAnnotations(SupportsAnnotations): return self +SelfSupportsWrappingAnnotations = TypeVar( + "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations" +) + + class SupportsWrappingAnnotations(SupportsAnnotations): __slots__ = () - def _annotate(self, values): + _constructor: Callable[..., SupportsWrappingAnnotations] + entity_namespace: Mapping[str, Any] + + def _annotate(self, values: _AnnotationDict) -> Annotated: """return a copy of this ClauseElement with annotations updated by the given dictionary. """ - return Annotated(self, values) + return Annotated._as_annotated_instance(self, values) - def _with_annotations(self, values): + def _with_annotations(self, values: _AnnotationDict) -> Annotated: """return a copy of this ClauseElement with annotations replaced by the given dictionary. """ - return Annotated(self, values) - - def _deannotate(self, values=None, clone=False): + return Annotated._as_annotated_instance(self, values) + + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: """return a copy of this :class:`_expression.ClauseElement` with annotations removed. @@ -129,8 +242,11 @@ class SupportsWrappingAnnotations(SupportsAnnotations): return self -class Annotated: - """clones a SupportsAnnotated and applies an 'annotations' dictionary. +SelfAnnotated = TypeVar("SelfAnnotated", bound="Annotated") + + +class Annotated(SupportsAnnotations): + """clones a SupportsAnnotations and applies an 'annotations' dictionary. Unlike regular clones, this clone also mimics __hash__() and __cmp__() of the original element so that it takes its place @@ -151,21 +267,26 @@ class Annotated: _is_column_operators = False - def __new__(cls, *args): - if not args: - # clone constructor - return object.__new__(cls) - else: - element, values = args - # pull appropriate subclass from registry of annotated - # classes - try: - cls = annotated_classes[element.__class__] - except KeyError: - cls = _new_annotation_type(element.__class__, cls) - return object.__new__(cls) - - def __init__(self, element, values): + @classmethod + def _as_annotated_instance( + cls, element: SupportsWrappingAnnotations, values: _AnnotationDict + ) -> Annotated: + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = _new_annotation_type(element.__class__, cls) + return cls(element, values) + + _annotations: util.immutabledict[str, Any] + __element: SupportsWrappingAnnotations + _hash: int + + def __new__(cls: Type[SelfAnnotated], *args: Any) -> SelfAnnotated: + return object.__new__(cls) + + def __init__( + self, element: SupportsWrappingAnnotations, values: _AnnotationDict + ): self.__dict__ = element.__dict__.copy() self.__dict__.pop("_annotations_cache_key", None) self.__dict__.pop("_generate_cache_key", None) @@ -173,11 +294,15 @@ class Annotated: self._annotations = util.immutabledict(values) self._hash = hash(element) - def _annotate(self, values): + def _annotate( + self: SelfAnnotated, values: _AnnotationDict + ) -> SelfAnnotated: _values = self._annotations.union(values) return self._with_annotations(_values) - def _with_annotations(self, values): + def _with_annotations( + self: SelfAnnotated, values: util.immutabledict[str, Any] + ) -> SelfAnnotated: clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() clone.__dict__.pop("_annotations_cache_key", None) @@ -185,7 +310,27 @@ class Annotated: clone._annotations = values return clone - def _deannotate(self, values=None, clone=True): + @overload + def _deannotate( + self: SelfAnnotated, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfAnnotated: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> Annotated: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = True, + ) -> SupportsAnnotations: if values is None: return self.__element else: @@ -199,14 +344,18 @@ class Annotated: ) ) - def _compiler_dispatch(self, visitor, **kw): - return self.__element.__class__._compiler_dispatch(self, visitor, **kw) + if not typing.TYPE_CHECKING: + # manually proxy some methods that need extra attention + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any: + return self.__element.__class__._compiler_dispatch( + self, visitor, **kw + ) - @property - def _constructor(self): - return self.__element._constructor + @property + def _constructor(self): + return self.__element._constructor - def _clone(self, **kw): + def _clone(self: SelfAnnotated, **kw: Any) -> SelfAnnotated: clone = self.__element._clone(**kw) if clone is self.__element: # detect immutable, don't change anything @@ -217,22 +366,25 @@ class Annotated: clone.__dict__.update(self.__dict__) return self.__class__(clone, self._annotations) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]: return self.__class__, (self.__element, self._annotations) - def __hash__(self): + def __hash__(self) -> int: return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self._is_column_operators: return self.__element.__class__.__eq__(self, other) else: return hash(other) == hash(self) @property - def entity_namespace(self): + def entity_namespace(self) -> Mapping[str, Any]: if "entity_namespace" in self._annotations: - return self._annotations["entity_namespace"].entity_namespace + return cast( + SupportsWrappingAnnotations, + self._annotations["entity_namespace"], + ).entity_namespace else: return self.__element.entity_namespace @@ -242,12 +394,19 @@ class Annotated: # so that the resulting objects are pickleable; additionally, other # decisions can be made up front about the type of object being annotated # just once per class rather than per-instance. -annotated_classes = {} +annotated_classes: Dict[ + Type[SupportsWrappingAnnotations], Type[Annotated] +] = {} + +_SA = TypeVar("_SA", bound="SupportsAnnotations") def _deep_annotate( - element, annotations, exclude=None, detect_subquery_cols=False -): + element: _SA, + annotations: _AnnotationDict, + exclude: Optional[Sequence[SupportsAnnotations]] = None, + detect_subquery_cols: bool = False, +) -> _SA: """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. @@ -258,9 +417,9 @@ def _deep_annotate( # annotated objects hack the __hash__() method so if we want to # uniquely process them we have to use id() - cloned_ids = {} + cloned_ids: Dict[int, SupportsAnnotations] = {} - def clone(elem, **kw): + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: kw["detect_subquery_cols"] = detect_subquery_cols id_ = id(elem) @@ -285,17 +444,20 @@ def _deep_annotate( return newelem if element is not None: - element = clone(element) - clone = None # remove gc cycles + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles return element -def _deep_deannotate(element, values=None): +def _deep_deannotate( + element: _SA, values: Optional[Sequence[str]] = None +) -> _SA: """Deep copy the given element, removing annotations.""" - cloned = {} + cloned: Dict[Any, SupportsAnnotations] = {} - def clone(elem, **kw): + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: + key: Any if values: key = id(elem) else: @@ -310,12 +472,14 @@ def _deep_deannotate(element, values=None): return cloned[key] if element is not None: - element = clone(element) - clone = None # remove gc cycles + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles return element -def _shallow_annotate(element, annotations): +def _shallow_annotate( + element: SupportsAnnotations, annotations: _AnnotationDict +) -> SupportsAnnotations: """Annotate the given ClauseElement and copy its internals so that internal objects refer to the new annotated object. @@ -328,7 +492,13 @@ def _shallow_annotate(element, annotations): return element -def _new_annotation_type(cls, base_cls): +def _new_annotation_type( + cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated] +) -> Type[Annotated]: + """Generates a new class that subclasses Annotated and proxies a given + element type. + + """ if issubclass(cls, Annotated): return cls elif cls in annotated_classes: @@ -342,8 +512,9 @@ def _new_annotation_type(cls, base_cls): base_cls = annotated_classes[super_] break - annotated_classes[cls] = anno_cls = type( - "Annotated%s" % cls.__name__, (base_cls, cls), {} + annotated_classes[cls] = anno_cls = cast( + Type[Annotated], + type("Annotated%s" % cls.__name__, (base_cls, cls), {}), ) globals()["Annotated%s" % cls.__name__] = anno_cls @@ -359,13 +530,15 @@ def _new_annotation_type(cls, base_cls): # some classes include this even if they have traverse_internals # e.g. BindParameter, add it if present. if cls.__dict__.get("inherit_cache", False): - anno_cls.inherit_cache = True + anno_cls.inherit_cache = True # type: ignore anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) return anno_cls -def _prepare_annotations(target_hierarchy, base_cls): +def _prepare_annotations( + target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated] +) -> None: for cls in util.walk_subclasses(target_hierarchy): _new_annotation_type(cls, base_cls) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a94590da1..a408a010a 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -19,8 +19,10 @@ from itertools import zip_longest import operator import re import typing +from typing import MutableMapping from typing import Optional from typing import Sequence +from typing import Set from typing import TypeVar from . import roles @@ -36,14 +38,9 @@ from .. import util from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing -from ..util._has_cy import HAS_CYEXTENSION - -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_util import prefix_anon_map # noqa -else: - from sqlalchemy.cyextension.util import prefix_anon_map # noqa if typing.TYPE_CHECKING: + from .elements import ColumnElement from ..engine import Connection from ..engine import Result from ..engine.interfaces import _CoreMultiExecuteParams @@ -63,6 +60,8 @@ NO_ARG = util.symbol("NO_ARG") # symbols, mypy reports: "error: _Fn? not callable" _Fn = typing.TypeVar("_Fn", bound=typing.Callable) +_AmbiguousTableNameMap = MutableMapping[str, str] + class Immutable: """mark a ClauseElement as 'immutable' when expressions are cloned.""" @@ -87,6 +86,10 @@ class SingletonConstant(Immutable): _is_singleton_constant = True + _singleton: SingletonConstant + + proxy_set: Set[ColumnElement] + def __new__(cls, *arg, **kw): return cls._singleton @@ -519,6 +522,8 @@ class CompileState: plugins = {} + _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] + @classmethod def create_for_statement(cls, statement, compiler, **kw): # factory construction. diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index ff659b77d..fca58f98e 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -11,21 +11,41 @@ import enum from itertools import zip_longest import typing from typing import Any -from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from typing import Union from .visitors import anon_map -from .visitors import ExtendedInternalTraversal +from .visitors import HasTraversalDispatch +from .visitors import HasTraverseInternals from .visitors import InternalTraversal +from .visitors import prefix_anon_map from .. import util from ..inspection import inspect from ..util import HasMemoized from ..util.typing import Literal - +from ..util.typing import Protocol if typing.TYPE_CHECKING: from .elements import BindParameter + from .elements import ClauseElement + from .visitors import _TraverseInternalsType + from ..engine.base import _CompiledCacheType + from ..engine.interfaces import _CoreSingleExecuteParams + + +class _CacheKeyTraversalDispatchType(Protocol): + def __call__( + s, self: HasCacheKey, visitor: _CacheKeyTraversal + ) -> CacheKey: + ... class CacheConst(enum.Enum): @@ -70,7 +90,9 @@ class HasCacheKey: __slots__ = () - _cache_key_traversal = NO_CACHE + _cache_key_traversal: Union[ + _TraverseInternalsType, Literal[CacheConst.NO_CACHE] + ] = NO_CACHE _is_has_cache_key = True @@ -83,7 +105,7 @@ class HasCacheKey: """ - inherit_cache = None + inherit_cache: Optional[bool] = None """Indicate if this :class:`.HasCacheKey` instance should make use of the cache key generation scheme used by its immediate superclass. @@ -106,8 +128,12 @@ class HasCacheKey: __slots__ = () + _generated_cache_key_traversal: Any + @classmethod - def _generate_cache_attrs(cls): + def _generate_cache_attrs( + cls, + ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]: """generate cache key dispatcher for a new class. This sets the _generated_cache_key_traversal attribute once called @@ -121,8 +147,11 @@ class HasCacheKey: _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) if _cache_key_traversal is None: try: - # this would be HasTraverseInternals - _cache_key_traversal = cls._traverse_internals + # check for _traverse_internals, which is part of + # HasTraverseInternals + _cache_key_traversal = cast( + "Type[HasTraverseInternals]", cls + )._traverse_internals except AttributeError: cls._generated_cache_key_traversal = NO_CACHE return NO_CACHE @@ -138,7 +167,9 @@ class HasCacheKey: # more complicated, so for the moment this is a little less # efficient on startup but simpler. return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) else: _cache_key_traversal = cls.__dict__.get( @@ -170,11 +201,15 @@ class HasCacheKey: return NO_CACHE return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", ) @util.preload_module("sqlalchemy.sql.elements") - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Optional[Tuple[Any, ...]]: """return an optional cache key. The cache key is a tuple which can contain any series of @@ -202,15 +237,15 @@ class HasCacheKey: dispatcher: Union[ Literal[CacheConst.NO_CACHE], - Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"], + _CacheKeyTraversalDispatchType, ] try: dispatcher = cls.__dict__["_generated_cache_key_traversal"] except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. # this block will generate any remaining dispatchers. dispatcher = cls._generate_cache_attrs() @@ -218,7 +253,7 @@ class HasCacheKey: anon_map[NO_CACHE] = True return None - result = (id_, cls) + result: Tuple[Any, ...] = (id_, cls) # inline of _cache_key_traversal_visitor.run_generated_dispatch() @@ -268,7 +303,7 @@ class HasCacheKey: # Columns, this should be long lived. For select() # statements, not so much, but they usually won't have # annotations. - result += self._annotations_cache_key + result += self._annotations_cache_key # type: ignore elif ( meth is InternalTraversal.dp_clauseelement_list or meth is InternalTraversal.dp_clauseelement_tuple @@ -290,7 +325,7 @@ class HasCacheKey: ) return result - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: """return a cache key. The cache key is a tuple which can contain any series of @@ -322,32 +357,40 @@ class HasCacheKey: """ - bindparams = [] + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = self._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) @classmethod - def _generate_cache_key_for_object(cls, obj): - bindparams = [] + def _generate_cache_key_for_object( + cls, obj: HasCacheKey + ) -> Optional[CacheKey]: + bindparams: List[BindParameter[Any]] = [] _anon_map = anon_map() key = obj._gen_cache_key(_anon_map, bindparams) if NO_CACHE in _anon_map: return None else: + assert key is not None return CacheKey(key, bindparams) +class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey): + pass + + class MemoizedHasCacheKey(HasCacheKey, HasMemoized): __slots__ = () @HasMemoized.memoized_instancemethod - def _generate_cache_key(self): + def _generate_cache_key(self) -> Optional[CacheKey]: return HasCacheKey._generate_cache_key(self) @@ -362,14 +405,22 @@ class CacheKey(NamedTuple): """ key: Tuple[Any, ...] - bindparams: Sequence[BindParameter] + bindparams: Sequence[BindParameter[Any]] - def __hash__(self): + # can't set __hash__ attribute because it interferes + # with namedtuple + # can't use "if not TYPE_CHECKING" because mypy rejects it + # inside of a NamedTuple + def __hash__(self) -> Optional[int]: # type: ignore """CacheKey itself is not hashable - hash the .key portion""" - return None - def to_offline_string(self, statement_cache, statement, parameters): + def to_offline_string( + self, + statement_cache: _CompiledCacheType, + statement: ClauseElement, + parameters: _CoreSingleExecuteParams, + ) -> str: """Generate an "offline string" form of this :class:`.CacheKey` The "offline string" is basically the string SQL for the @@ -400,21 +451,21 @@ class CacheKey(NamedTuple): return repr((sql_str, param_tuple)) - def __eq__(self, other): - return self.key == other.key + def __eq__(self, other: Any) -> bool: + return bool(self.key == other.key) @classmethod - def _diff_tuples(cls, left, right): + def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str: ck1 = CacheKey(left, []) ck2 = CacheKey(right, []) return ck1._diff(ck2) - def _whats_different(self, other): + def _whats_different(self, other: CacheKey) -> Iterator[str]: k1 = self.key k2 = other.key - stack = [] + stack: List[int] = [] pickup_index = 0 while True: s1, s2 = k1, k2 @@ -440,11 +491,11 @@ class CacheKey(NamedTuple): pickup_index = stack.pop(-1) break - def _diff(self, other): + def _diff(self, other: CacheKey) -> str: return ", ".join(self._whats_different(other)) - def __str__(self): - stack = [self.key] + def __str__(self) -> str: + stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key] output = [] sentinel = object() @@ -473,15 +524,15 @@ class CacheKey(NamedTuple): return "CacheKey(key=%s)" % ("\n".join(output),) - def _generate_param_dict(self): + def _generate_param_dict(self) -> Dict[str, Any]: """used for testing""" - from .compiler import prefix_anon_map - _anon_map = prefix_anon_map() return {b.key % _anon_map: b.effective_value for b in self.bindparams} - def _apply_params_to_element(self, original_cache_key, target_element): + def _apply_params_to_element( + self, original_cache_key: CacheKey, target_element: ClauseElement + ) -> ClauseElement: translate = { k.key: v.value for k, v in zip(original_cache_key.bindparams, self.bindparams) @@ -490,7 +541,7 @@ class CacheKey(NamedTuple): return target_element.params(translate) -class _CacheKeyTraversal(ExtendedInternalTraversal): +class _CacheKeyTraversal(HasTraversalDispatch): # very common elements are inlined into the main _get_cache_key() method # to produce a dramatic savings in Python function call overhead @@ -512,17 +563,43 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): visit_propagate_attrs = PROPAGATE_ATTRS def visit_with_context_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple((fn.__code__, c_key) for fn, c_key in obj) - def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + def visit_inspectable( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) - def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_string_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple(obj) - def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, obj._gen_cache_key(anon_map, bindparams) @@ -530,7 +607,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): else obj, ) - def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + def visit_multi_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -542,8 +626,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -558,8 +647,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_has_cache_key_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -568,8 +662,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_executable_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -582,22 +681,37 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_inspectable_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_list( attrname, [inspect(o) for o in obj], parent, anon_map, bindparams ) def visit_clauseelement_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return self.visit_has_cache_key_tuples( attrname, obj, parent, anon_map, bindparams ) def visit_fromclause_ordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () return ( @@ -606,8 +720,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_clauseelement_unordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () cache_keys = [ @@ -621,13 +740,23 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_named_ddl_element( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, obj.name) def visit_prefix_sequence( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -642,8 +771,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_setup_join_tuple( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return tuple( ( target._gen_cache_key(anon_map, bindparams), @@ -659,8 +793,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_table_hint_list( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: if not obj: return () @@ -678,12 +817,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + def visit_plain_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) def visit_dialect_options( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -701,8 +852,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_clauseelement_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -712,8 +868,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_string_multi_dict( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -728,8 +889,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_fromclause_canonical_column_collection( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # inlining into the internals of ColumnCollection return ( attrname, @@ -740,14 +906,24 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_unknown_structure( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: anon_map[NO_CACHE] = True return () def visit_dml_ordered_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: return ( attrname, tuple( @@ -761,7 +937,14 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ), ) - def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + def visit_dml_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # in py37 we can assume two dictionaries created in the same # insert ordering will retain that sorting return ( @@ -778,8 +961,13 @@ class _CacheKeyTraversal(ExtendedInternalTraversal): ) def visit_dml_multi_values( - self, attrname, obj, parent, anon_map, bindparams - ): + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: # multivalues are simply not cacheable right now anon_map[NO_CACHE] = True return () diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index d616417ab..834bfb75d 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -13,6 +13,9 @@ import re import typing from typing import Any from typing import Any as TODO_Any +from typing import Dict +from typing import List +from typing import NoReturn from typing import Optional from typing import Type from typing import TypeVar @@ -42,6 +45,7 @@ if typing.TYPE_CHECKING: from . import selectable from . import traversals from .elements import ClauseElement + from .elements import ColumnClause _SR = TypeVar("_SR", bound=roles.SQLRole) _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) @@ -252,7 +256,7 @@ def expect_col_expression_collection(role, expressions): if isinstance(resolved, str): strname = resolved = expr else: - cols = [] + cols: List[ColumnClause[Any]] = [] visitors.traverse(resolved, {}, {"column": cols.append}) if cols: column = cols[0] @@ -266,7 +270,7 @@ class RoleImpl: def _literal_coercion(self, element, **kw): raise NotImplementedError() - _post_coercion = None + _post_coercion: Any = None _resolve_literal_only = False _skip_clauseelement_for_target_match = False @@ -276,19 +280,24 @@ class RoleImpl: self._use_inspection = issubclass(role_class, roles.UsesInspection) def _implicit_coercions( - self, element, resolved, argname=None, **kw + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, ) -> Any: self._raise_for_expected(element, argname, resolved) def _raise_for_expected( self, - element, - argname=None, - resolved=None, - advice=None, - code=None, - err=None, - ): + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: if resolved is not None and resolved is not element: got = "%r object resolved from %r object" % (resolved, element) else: @@ -324,22 +333,20 @@ class _StringOnly: _resolve_literal_only = True -class _ReturnsStringKey: +class _ReturnsStringKey(RoleImpl): __slots__ = () - def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, str): - return original_element + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if isinstance(element, str): + return element else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, **kw): return element -class _ColumnCoercions: +class _ColumnCoercions(RoleImpl): __slots__ = () def _warn_for_scalar_subquery_coercion(self): @@ -368,8 +375,12 @@ class _ColumnCoercions: def _no_text_coercion( - element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None -): + element: Any, + argname: Optional[str] = None, + exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError, + extra: Optional[str] = None, + err: Optional[Exception] = None, +) -> NoReturn: raise exc_cls( "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " "explicitly declared as text(%(expr)r)" @@ -381,7 +392,7 @@ def _no_text_coercion( ) from err -class _NoTextCoercion: +class _NoTextCoercion(RoleImpl): __slots__ = () def _literal_coercion(self, element, argname=None, **kw): @@ -393,7 +404,7 @@ class _NoTextCoercion: self._raise_for_expected(element, argname) -class _CoerceLiterals: +class _CoerceLiterals(RoleImpl): __slots__ = () _coerce_consts = False _coerce_star = False @@ -440,12 +451,19 @@ class LiteralValueImpl(RoleImpl): return element -class _SelectIsNotFrom: +class _SelectIsNotFrom(RoleImpl): __slots__ = () def _raise_for_expected( - self, element, argname=None, resolved=None, advice=None, **kw - ): + self, + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: if ( not advice and isinstance(element, roles.SelectStatementRole) @@ -460,26 +478,33 @@ class _SelectIsNotFrom: else: code = None - return super(_SelectIsNotFrom, self)._raise_for_expected( + super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, code=code, + err=err, **kw, ) + # never reached + assert False class HasCacheKeyImpl(RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, traversals.HasCacheKey): - return original_element + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, HasCacheKey): + return element else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, **kw): return element @@ -489,12 +514,16 @@ class ExecutableOptionImpl(RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, ExecutableOption): - return original_element + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, ExecutableOption): + return element else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, **kw): return element @@ -560,8 +589,12 @@ class InElementImpl(RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) @@ -573,7 +606,7 @@ class InElementImpl(RoleImpl): self._warn_for_implicit_coercion(resolved) return self._post_coercion(resolved.select(), **kw) else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _warn_for_implicit_coercion(self, elem): util.warn( @@ -586,12 +619,16 @@ class InElementImpl(RoleImpl): if isinstance(element, collections_abc.Iterable) and not isinstance( element, str ): - non_literal_expressions = {} + non_literal_expressions: Dict[ + Optional[operators.ColumnOperators[Any]], + operators.ColumnOperators[Any], + ] = {} element = list(element) for o in element: if not _is_literal(o): if not isinstance(o, operators.ColumnOperators): self._raise_for_expected(element, **kw) + else: non_literal_expressions[o] = o elif o is None: @@ -712,8 +749,12 @@ class GroupByImpl(ByOfImpl, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if isinstance(resolved, roles.StrictFromClauseRole): return elements.ClauseList(*resolved.c) else: @@ -748,12 +789,16 @@ class TruncatedLabelImpl(_StringOnly, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): - if isinstance(original_element, str): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, str): return resolved else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _literal_coercion(self, element, argname=None, **kw): """coerce the given value to :class:`._truncated_label`. @@ -794,7 +839,13 @@ class DDLReferredColumnImpl(DDLConstraintColumnImpl): class LimitOffsetImpl(RoleImpl): __slots__ = () - def _implicit_coercions(self, element, resolved, argname=None, **kw): + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved is None: return None else: @@ -814,18 +865,22 @@ class LabeledColumnExprImpl(ExpressionElementImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if isinstance(resolved, roles.ExpressionElementRole): return resolved.label(None) else: new = super(LabeledColumnExprImpl, self)._implicit_coercions( - original_element, resolved, argname=argname, **kw + element, resolved, argname=argname, **kw ) if isinstance(new, roles.ExpressionElementRole): return new.label(None) else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): @@ -899,13 +954,17 @@ class StatementImpl(_CoerceLiterals, RoleImpl): return resolved def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_lambda_element: return resolved else: - return super(StatementImpl, self)._implicit_coercions( - original_element, resolved, argname=argname, **kw + return super()._implicit_coercions( + element, resolved, argname=argname, **kw ) @@ -913,12 +972,16 @@ class SelectStatementImpl(_NoTextCoercion, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_text_clause: return resolved.columns() else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class HasCTEImpl(ReturnsRowsImpl): @@ -938,13 +1001,18 @@ class JoinTargetImpl(RoleImpl): self._raise_for_expected(element, argname) def _implicit_coercions( - self, original_element, resolved, argname=None, legacy=False, **kw - ): - if isinstance(original_element, roles.JoinTargetRole): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + legacy: bool = False, + **kw: Any, + ) -> Any: + if isinstance(element, roles.JoinTargetRole): # note that this codepath no longer occurs as of # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match # were set to False. - return original_element + return element elif legacy and resolved._is_select_statement: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT " @@ -959,7 +1027,7 @@ class JoinTargetImpl(RoleImpl): # in _ORMJoin->Join return resolved else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): @@ -967,13 +1035,13 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): def _implicit_coercions( self, - original_element, - resolved, - argname=None, - explicit_subquery=False, - allow_select=True, - **kw, - ): + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = True, + **kw: Any, + ) -> Any: if resolved._is_select_statement: if explicit_subquery: return resolved.subquery() @@ -989,7 +1057,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): elif resolved._is_text_clause: return resolved else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) def _post_coercion(self, element, deannotate=False, **kw): if deannotate: @@ -1003,12 +1071,13 @@ class StrictFromClauseImpl(FromClauseImpl): def _implicit_coercions( self, - original_element, - resolved, - argname=None, - allow_select=False, - **kw, - ): + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = False, + **kw: Any, + ) -> Any: if resolved._is_select_statement and allow_select: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT constructs " @@ -1019,7 +1088,7 @@ class StrictFromClauseImpl(FromClauseImpl): ) return resolved._implicit_subquery else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class AnonymizedFromClauseImpl(StrictFromClauseImpl): @@ -1045,8 +1114,12 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl): __slots__ = () def _implicit_coercions( - self, original_element, resolved, argname=None, **kw - ): + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) @@ -1056,7 +1129,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl): else: return resolved.select() else: - self._raise_for_expected(original_element, argname, resolved) + self._raise_for_expected(element, argname, resolved) class CompoundElementImpl(_NoTextCoercion, RoleImpl): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 423c3d446..f28dceefc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -35,14 +35,19 @@ from time import perf_counter import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict +from typing import FrozenSet +from typing import Iterable from typing import List from typing import Mapping from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type from typing import Union from . import base @@ -54,19 +59,42 @@ from . import operators from . import schema from . import selectable from . import sqltypes +from .base import _from_objects from .base import NO_ARG -from .base import prefix_anon_map from .elements import quoted_name from .schema import Column +from .sqltypes import TupleType from .type_api import TypeEngine +from .visitors import prefix_anon_map from .. import exc from .. import util from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypedDict if typing.TYPE_CHECKING: + from .annotation import _AnnotationDict + from .base import _AmbiguousTableNameMap + from .base import CompileState + from .cache_key import CacheKey + from .elements import BindParameter + from .elements import ColumnClause + from .elements import Label + from .functions import Function + from .selectable import Alias + from .selectable import AliasedReturnsRows + from .selectable import CompoundSelectState from .selectable import CTE from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectState + from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _MutableCoreSingleExecuteParams + from ..engine.interfaces import _SchemaTranslateMapType from ..engine.result import _ProcessorType _FromHintsType = Dict["FromClause", str] @@ -236,7 +264,7 @@ OPERATORS = { operators.nulls_last_op: " NULLS LAST", } -FUNCTIONS = { +FUNCTIONS: Dict[Type[Function], str] = { functions.coalesce: "coalesce", functions.current_date: "CURRENT_DATE", functions.current_time: "CURRENT_TIME", @@ -298,8 +326,8 @@ class ResultColumnsEntry(NamedTuple): name: str """column name, may be labeled""" - objects: List[Any] - """list of objects that should be able to locate this column + objects: Tuple[Any, ...] + """sequence of objects that should be able to locate this column in a RowMapping. This is typically string names and aliases as well as Column objects. @@ -313,6 +341,17 @@ class ResultColumnsEntry(NamedTuple): """ +class _ResultMapAppender(Protocol): + def __call__( + self, + keyname: str, + name: str, + objects: Sequence[Any], + type_: TypeEngine[Any], + ) -> None: + ... + + # integer indexes into ResultColumnsEntry used by cursor.py. # some profiling showed integer access faster than named tuple RM_RENDERED_NAME: Literal[0] = 0 @@ -321,6 +360,20 @@ RM_OBJECTS: Literal[2] = 2 RM_TYPE: Literal[3] = 3 +class _BaseCompilerStackEntry(TypedDict): + asfrom_froms: Set[FromClause] + correlate_froms: Set[FromClause] + selectable: ReturnsRows + + +class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): + compile_state: CompileState + need_result_map_for_nested: bool + need_result_map_for_compound: bool + select_0: ReturnsRows + insert_from_select: Select + + class ExpandedState(NamedTuple): statement: str additional_parameters: _CoreSingleExecuteParams @@ -427,21 +480,23 @@ class Compiled: defaults. """ - _cached_metadata = None + _cached_metadata: Optional[CursorResultMetaData] = None _result_columns: Optional[List[ResultColumnsEntry]] = None - schema_translate_map = None + schema_translate_map: Optional[_SchemaTranslateMapType] = None - execution_options = util.EMPTY_DICT + execution_options: _ExecuteOptions = util.EMPTY_DICT """ Execution options propagated from the statement. In some cases, sub-elements of the statement can modify these. """ - _annotations = util.EMPTY_DICT + preparer: IdentifierPreparer + + _annotations: _AnnotationDict = util.EMPTY_DICT - compile_state = None + compile_state: Optional[CompileState] = None """Optional :class:`.CompileState` object that maintains additional state used by the compiler. @@ -457,9 +512,21 @@ class Compiled: """ - cache_key = None + cache_key: Optional[CacheKey] = None + """The :class:`.CacheKey` that was generated ahead of creating this + :class:`.Compiled` object. + + This is used for routines that need access to the original + :class:`.CacheKey` instance generated when the :class:`.Compiled` + instance was first cached, typically in order to reconcile + the original list of :class:`.BindParameter` objects with a + per-statement list that's generated on each call. + + """ _gen_time: float + """Generation time of this :class:`.Compiled`, used for reporting + cache stats.""" def __init__( self, @@ -543,7 +610,11 @@ class Compiled: return self.string or "" - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: """Return the bind params for this compiled object. :param params: a dict of string/object pairs whose values will @@ -646,6 +717,17 @@ class SQLCompiler(Compiled): isplaintext: bool = False + binds: Dict[str, BindParameter[Any]] + """a dictionary of bind parameter keys to BindParameter instances.""" + + bind_names: Dict[BindParameter[Any], str] + """a dictionary of BindParameter instances to "compiled" names + that are actually present in the generated SQL""" + + stack: List[_CompilerStackEntry] + """major statements such as SELECT, INSERT, UPDATE, DELETE are + tracked in this stack using an entry format.""" + result_columns: List[ResultColumnsEntry] """relates label names in the final SQL to a tuple of local column/label name, ColumnElement object (if any) and @@ -709,7 +791,7 @@ class SQLCompiler(Compiled): """ - insert_single_values_expr = None + insert_single_values_expr: Optional[str] = None """When an INSERT is compiled with a single set of parameters inside a VALUES expression, the string is assigned here, where it can be used for insert batching schemes to rewrite the VALUES expression. @@ -718,19 +800,19 @@ class SQLCompiler(Compiled): """ - literal_execute_params = frozenset() + literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as literal values at statement execution time. """ - post_compile_params = frozenset() + post_compile_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as bound parameter placeholders at statement execution time. """ - escaped_bind_names = util.EMPTY_DICT + escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT """Late escaping of bound parameter names that has to be converted to the original name when looking in the parameter dictionary. @@ -744,14 +826,25 @@ class SQLCompiler(Compiled): """if True, and this in insert, use cursor.lastrowid to populate result.inserted_primary_key. """ - _cache_key_bind_match = None + _cache_key_bind_match: Optional[ + Tuple[ + Dict[ + BindParameter[Any], + List[BindParameter[Any]], + ], + Dict[ + str, + BindParameter[Any], + ], + ] + ] = None """a mapping that will relate the BindParameter object we compile to those that are part of the extracted collection of parameters in the cache key, if we were given a cache key. """ - positiontup: Optional[Sequence[str]] = None + positiontup: Optional[List[str]] = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -768,6 +861,19 @@ class SQLCompiler(Compiled): inline: bool = False + ctes: Optional[MutableMapping[CTE, str]] + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + ctes_by_level_name: Dict[Tuple[int, str], CTE] + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] + + ctes_recursive: bool + cte_positional: Dict[CTE, List[str]] + def __init__( self, dialect, @@ -804,10 +910,9 @@ class SQLCompiler(Compiled): self.cache_key = cache_key if cache_key: - self._cache_key_bind_match = ckbm = { - b.key: b for b in cache_key[1] - } - ckbm.update({b: [b] for b in cache_key[1]}) + cksm = {b.key: b for b in cache_key[1]} + ckbm = {b: [b] for b in cache_key[1]} + self._cache_key_bind_match = (ckbm, cksm) # compile INSERT/UPDATE defaults/sequences to expect executemany # style execution, which may mean no pre-execute of defaults, @@ -911,14 +1016,14 @@ class SQLCompiler(Compiled): @property def prefetch(self): - return list(self.insert_prefetch + self.update_prefetch) + return list(self.insert_prefetch) + list(self.update_prefetch) @util.memoized_property def _global_attributes(self): return {} @util.memoized_instancemethod - def _init_cte_state(self) -> None: + def _init_cte_state(self) -> MutableMapping[CTE, str]: """Initialize collections related to CTEs only if a CTE is located, to save on the overhead of these collections otherwise. @@ -926,21 +1031,22 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # To store the query to print - Dict[cte, text_query] - self.ctes: MutableMapping[CTE, str] = util.OrderedDict() + ctes: MutableMapping[CTE, str] = util.OrderedDict() + self.ctes = ctes # Detect same CTE references - Dict[(level, name), cte] # Level is required for supporting nesting - self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {} + self.ctes_by_level_name = {} # To retrieve key/level in ctes_by_level_name - # Dict[cte_reference, (level, cte_name, cte_opts)] - self.level_name_by_cte: Dict[ - CTE, Tuple[int, str, selectable._CTEOpts] - ] = {} + self.level_name_by_cte = {} - self.ctes_recursive: bool = False + self.ctes_recursive = False if self.positional: - self.cte_positional: Dict[CTE, List[str]] = {} + self.cte_positional = {} + + return ctes @contextlib.contextmanager def _nested_result(self): @@ -985,7 +1091,7 @@ class SQLCompiler(Compiled): if not bindparam.type._is_tuple_type else tuple( elem_type._cached_bind_processor(self.dialect) - for elem_type in bindparam.type.types + for elem_type in cast(TupleType, bindparam.type).types ), ) for bindparam in self.bind_names @@ -1002,11 +1108,11 @@ class SQLCompiler(Compiled): def construct_params( self, - params=None, - _group_number=None, - _check=True, - extracted_parameters=None, - ): + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + _group_number: Optional[int] = None, + _check: bool = True, + ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" has_escaped_names = bool(self.escaped_bind_names) @@ -1018,15 +1124,17 @@ class SQLCompiler(Compiled): # way. The parameters present in self.bind_names may be clones of # these original cache key params in the case of DML but the .key # will be guaranteed to match. - try: - orig_extracted = self.cache_key[1] - except TypeError as err: + if self.cache_key is None: raise exc.CompileError( "This compiled object has no original cache key; " "can't pass extracted_parameters to construct_params" - ) from err + ) + else: + orig_extracted = self.cache_key[1] - ckbm = self._cache_key_bind_match + ckbm_tuple = self._cache_key_bind_match + assert ckbm_tuple is not None + ckbm, _ = ckbm_tuple resolved_extracted = { bind: extracted for b, extracted in zip(orig_extracted, extracted_parameters) @@ -1142,7 +1250,8 @@ class SQLCompiler(Compiled): if bindparam.type._is_tuple_type: inputsizes[bindparam] = [ - lookup_type(typ) for typ in bindparam.type.types + lookup_type(typ) + for typ in cast(TupleType, bindparam.type).types ] else: inputsizes[bindparam] = lookup_type(bindparam.type) @@ -1164,7 +1273,7 @@ class SQLCompiler(Compiled): def _process_parameters_for_postcompile( self, - parameters: Optional[_CoreSingleExecuteParams] = None, + parameters: Optional[_MutableCoreSingleExecuteParams] = None, _populate_self: bool = False, ) -> ExpandedState: """handle special post compile parameters. @@ -1183,14 +1292,20 @@ class SQLCompiler(Compiled): parameters = self.construct_params() expanded_parameters = {} + positiontup: Optional[List[str]] + if self.positional: positiontup = [] else: positiontup = None processors = self._bind_processors + single_processors = cast("Mapping[str, _ProcessorType]", processors) + tuple_processors = cast( + "Mapping[str, Sequence[_ProcessorType]]", processors + ) - new_processors = {} + new_processors: Dict[str, _ProcessorType] = {} if self.positional and self._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric'. @@ -1203,8 +1318,8 @@ class SQLCompiler(Compiled): "the 'numeric' paramstyle at this time." ) - replacement_expressions = {} - to_update_sets = {} + replacement_expressions: Dict[str, Any] = {} + to_update_sets: Dict[str, Any] = {} # notes: # *unescaped* parameter names in: @@ -1213,9 +1328,12 @@ class SQLCompiler(Compiled): # *escaped* parameter names in: # construct_params(), replacement_expressions - for name in ( - self.positiontup if self.positional else self.bind_names.values() - ): + if self.positional and self.positiontup is not None: + names: Iterable[str] = self.positiontup + else: + names = self.bind_names.values() + + for name in names: escaped_name = ( self.escaped_bind_names.get(name, name) if self.escaped_bind_names @@ -1236,6 +1354,7 @@ class SQLCompiler(Compiled): if parameter in self.post_compile_params: if escaped_name in replacement_expressions: to_update = to_update_sets[escaped_name] + values = None else: # we are removing the parameter from parameters # because it is a list value, which is not expected by @@ -1256,28 +1375,29 @@ class SQLCompiler(Compiled): if not parameter.literal_execute: parameters.update(to_update) if parameter.type._is_tuple_type: + assert values is not None new_processors.update( ( "%s_%s_%s" % (name, i, j), - processors[name][j - 1], + tuple_processors[name][j - 1], ) for i, tuple_element in enumerate(values, 1) - for j, value in enumerate(tuple_element, 1) - if name in processors - and processors[name][j - 1] is not None + for j, _ in enumerate(tuple_element, 1) + if name in tuple_processors + and tuple_processors[name][j - 1] is not None ) else: new_processors.update( - (key, processors[name]) - for key, value in to_update - if name in processors + (key, single_processors[name]) + for key, _ in to_update + if name in single_processors ) - if self.positional: - positiontup.extend(name for name, value in to_update) + if positiontup is not None: + positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ - expand_key for expand_key, value in to_update + expand_key for expand_key, _ in to_update ] - elif self.positional: + elif positiontup is not None: positiontup.append(name) def process_expanding(m): @@ -1315,7 +1435,7 @@ class SQLCompiler(Compiled): # special use cases. self.string = expanded_state.statement self._bind_processors.update(expanded_state.processors) - self.positiontup = expanded_state.positiontup + self.positiontup = list(expanded_state.positiontup or ()) self.post_compile_params = frozenset() for key in expanded_state.parameter_expansion: bind = self.binds.pop(key) @@ -1338,6 +1458,12 @@ class SQLCompiler(Compiled): self._result_columns ) + _key_getters_for_crud_column: Tuple[ + Callable[[Union[str, Column[Any]]], str], + Callable[[Column[Any]], str], + Callable[[Column[Any]], str], + ] + @util.memoized_property def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._key_getters_for_crud_column[2] @@ -1398,22 +1524,30 @@ class SQLCompiler(Compiled): @util.memoized_property @util.preload_module("sqlalchemy.engine.result") def _inserted_primary_key_from_returning_getter(self): - result = util.preloaded.engine_result + if typing.TYPE_CHECKING: + from ..engine import result + else: + result = util.preloaded.engine_result param_key_getter = self._within_exec_param_key_getter table = self.statement.table - ret = {col: idx for idx, col in enumerate(self.returning)} + returning = self.returning + assert returning is not None + ret = {col: idx for idx, col in enumerate(returning)} - getters = [ - (operator.itemgetter(ret[col]), True) - if col in ret - else ( - operator.methodcaller("get", param_key_getter(col), None), - False, - ) - for col in table.primary_key - ] + getters = cast( + "List[Tuple[Callable[[Any], Any], bool]]", + [ + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) + for col in table.primary_key + ], + ) row_fn = result.result_tuple([col.key for col in table.primary_key]) @@ -1444,7 +1578,16 @@ class SQLCompiler(Compiled): self, element, within_columns_clause=False, **kwargs ): if self.stack and self.dialect.supports_simple_order_by_label: - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + raise exc.CompileError( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ) from ke ( with_cols, @@ -1485,7 +1628,22 @@ class SQLCompiler(Compiled): # compiling the element outside of the context of a SELECT return self.process(element._text_clause) - compile_state = self.stack[-1]["compile_state"] + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=ke, + ) + with_cols, only_froms, only_cols = compile_state._label_resolve_dict try: if within_columns_clause: @@ -1568,13 +1726,13 @@ class SQLCompiler(Compiled): def visit_column( self, - column, - add_to_result_map=None, - include_table=True, - result_map_targets=(), - ambiguous_table_name_map=None, - **kwargs, - ): + column: ColumnClause[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + include_table: bool = True, + result_map_targets: Tuple[Any, ...] = (), + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + **kwargs: Any, + ) -> str: name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -1608,7 +1766,8 @@ class SQLCompiler(Compiled): ) else: schema_prefix = "" - tablename = table.name + + tablename = cast("NamedFromClause", table).name if ( not effective_schema @@ -1678,7 +1837,7 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - new_entry = { + new_entry: _CompilerStackEntry = { "correlate_froms": set(), "asfrom_froms": set(), "selectable": taf, @@ -1879,11 +2038,19 @@ class SQLCompiler(Compiled): compiled_col = self.visit_column(element, **kw) return "(%s).%s" % (compiled_fn, compiled_col) - def visit_function(self, func, add_to_result_map=None, **kwargs): + def visit_function( + self, + func: Function, + add_to_result_map: Optional[_ResultMapAppender] = None, + **kwargs: Any, + ) -> str: if add_to_result_map is not None: add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + + text: str + if disp: text = disp(func, **kwargs) else: @@ -1964,7 +2131,7 @@ class SQLCompiler(Compiled): if compound_stmt._independent_ctes: self._dispatch_independent_ctes(compound_stmt, kwargs) - keyword = self.compound_keywords.get(cs.keyword) + keyword = self.compound_keywords[cs.keyword] text = (" " + keyword + " ").join( ( @@ -2591,11 +2758,13 @@ class SQLCompiler(Compiled): # a different set of parameter values. here, we accommodate for # parameters that may have been cloned both before and after the cache # key was been generated. - ckbm = self._cache_key_bind_match - if ckbm: + ckbm_tuple = self._cache_key_bind_match + + if ckbm_tuple: + ckbm, cksm = ckbm_tuple for bp in bindparam._cloned_set: - if bp.key in ckbm: - cb = ckbm[bp.key] + if bp.key in cksm: + cb = cksm[bp.key] ckbm[cb].append(bindparam) if bindparam.isoutparam: @@ -2720,7 +2889,7 @@ class SQLCompiler(Compiled): if positional_names is not None: positional_names.append(name) else: - self.positiontup.append(name) + self.positiontup.append(name) # type: ignore[union-attr] elif not escaped_from: if _BIND_TRANSLATE_RE.search(name): @@ -2735,9 +2904,9 @@ class SQLCompiler(Compiled): name = new_name if escaped_from: - if not self.escaped_bind_names: - self.escaped_bind_names = {} - self.escaped_bind_names[escaped_from] = name + self.escaped_bind_names = self.escaped_bind_names.union( + {escaped_from: name} + ) if post_compile: return "__[POSTCOMPILE_%s]" % name @@ -2772,7 +2941,8 @@ class SQLCompiler(Compiled): cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), **kwargs: Any, ) -> Optional[str]: - self._init_cte_state() + self_ctes = self._init_cte_state() + assert self_ctes is self.ctes kwargs["visiting_cte"] = cte @@ -2838,7 +3008,7 @@ class SQLCompiler(Compiled): # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". - del self.ctes[existing_cte] + del self_ctes[existing_cte] existing_cte_reference_cte = existing_cte._get_reference_cte() @@ -2875,7 +3045,7 @@ class SQLCompiler(Compiled): if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) - if not cte_pre_alias_name and cte not in self.ctes: + if not cte_pre_alias_name and cte not in self_ctes: if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) @@ -2942,14 +3112,14 @@ class SQLCompiler(Compiled): cte, cte._suffixes, **kwargs ) - self.ctes[cte] = text + self_ctes[cte] = text if asfrom: if from_linter: from_linter.froms[cte] = cte_name if not is_new_cte and embedded_in_current_named_cte: - return self.preparer.format_alias(cte, cte_name) + return self.preparer.format_alias(cte, cte_name) # type: ignore[no-any-return] # noqa: E501 if cte_pre_alias_name: text = self.preparer.format_alias(cte, cte_pre_alias_name) @@ -2960,6 +3130,8 @@ class SQLCompiler(Compiled): else: return self.preparer.format_alias(cte, cte_name) + return None + def visit_table_valued_alias(self, element, **kw): if element._is_lateral: return self.visit_lateral(element, **kw) @@ -3143,7 +3315,7 @@ class SQLCompiler(Compiled): self, keyname: str, name: str, - objects: List[Any], + objects: Tuple[Any, ...], type_: TypeEngine[Any], ) -> None: if keyname is None or keyname == "*": @@ -3358,9 +3530,12 @@ class SQLCompiler(Compiled): def get_statement_hint_text(self, hint_texts): return " ".join(hint_texts) - _default_stack_entry = util.immutabledict( - [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] - ) + _default_stack_entry: _CompilerStackEntry + + if not typing.TYPE_CHECKING: + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) def _display_froms_for_select( self, select_stmt, asfrom, lateral=False, **kw @@ -3391,7 +3566,7 @@ class SQLCompiler(Compiled): ) return froms - translate_select_structure = None + translate_select_structure: Any = None """if not ``None``, should be a callable which accepts ``(select_stmt, **kw)`` and returns a select object. this is used for structural changes mostly to accommodate for LIMIT/OFFSET schemes @@ -3563,7 +3738,9 @@ class SQLCompiler(Compiled): ) self._result_columns = [ - (key, name, tuple(translate.get(o, o) for o in obj), type_) + ResultColumnsEntry( + key, name, tuple(translate.get(o, o) for o in obj), type_ + ) for key, name, obj, type_ in self._result_columns ] @@ -3660,10 +3837,10 @@ class SQLCompiler(Compiled): implicit_correlate_froms=asfrom_froms, ) - new_correlate_froms = set(selectable._from_objects(*froms)) + new_correlate_froms = set(_from_objects(*froms)) all_correlate_froms = new_correlate_froms.union(correlate_froms) - new_entry = { + new_entry: _CompilerStackEntry = { "asfrom_froms": new_correlate_froms, "correlate_froms": all_correlate_froms, "selectable": select, @@ -3734,6 +3911,7 @@ class SQLCompiler(Compiled): text += " \nWHERE " + t if warn_linting: + assert from_linter is not None from_linter.warn() if select._group_by_clauses: @@ -3781,6 +3959,8 @@ class SQLCompiler(Compiled): if not self.ctes: return "" + ctes: MutableMapping[CTE, str] + if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): @@ -3805,10 +3985,16 @@ class SQLCompiler(Compiled): ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: + assert self.positiontup is not None self.positiontup = ( - sum([self.cte_positional[cte] for cte in ctes], []) + list( + itertools.chain.from_iterable( + self.cte_positional[cte] for cte in ctes + ) + ) + self.positiontup ) + cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -4190,7 +4376,7 @@ class SQLCompiler(Compiled): if is_multitable: # main table might be a JOIN - main_froms = set(selectable._from_objects(update_stmt.table)) + main_froms = set(_from_objects(update_stmt.table)) render_extra_froms = [ f for f in extra_froms if f not in main_froms ] @@ -4506,7 +4692,11 @@ class DDLCompiler(Compiled): def type_compiler(self): return self.dialect.type_compiler - def construct_params(self, params=None, extracted_parameters=None): + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + ) -> Optional[_MutableCoreSingleExecuteParams]: return None def visit_ddl(self, ddl, **kwargs): @@ -5199,6 +5389,11 @@ class StrSQLTypeCompiler(GenericTypeCompiler): return get_col_spec(**kw) +class _SchemaForObjectCallable(Protocol): + def __call__(self, obj: Any) -> str: + ... + + class IdentifierPreparer: """Handle quoting and case-folding of identifiers based on options.""" @@ -5209,7 +5404,13 @@ class IdentifierPreparer: illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = operator.attrgetter("schema") + initial_quote: str + + final_quote: str + + _strings: MutableMapping[str, str] + + schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema") """Return the .schema attribute for an object. For the default IdentifierPreparer, the schema for an object is always @@ -5297,7 +5498,7 @@ class IdentifierPreparer: return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement) - def _escape_identifier(self, value): + def _escape_identifier(self, value: str) -> str: """Escape an identifier. Subclasses should override this to provide database-dependent @@ -5309,7 +5510,7 @@ class IdentifierPreparer: value = value.replace("%", "%%") return value - def _unescape_identifier(self, value): + def _unescape_identifier(self, value: str) -> str: """Canonicalize an escaped identifier. Subclasses should override this to provide database-dependent @@ -5336,7 +5537,7 @@ class IdentifierPreparer: ) return element - def quote_identifier(self, value): + def quote_identifier(self, value: str) -> str: """Quote an identifier. Subclasses should override this to provide database-dependent @@ -5349,7 +5550,7 @@ class IdentifierPreparer: + self.final_quote ) - def _requires_quotes(self, value): + def _requires_quotes(self, value: str) -> bool: """Return True if the given identifier requires quoting.""" lc_value = value.lower() return ( @@ -5364,7 +5565,7 @@ class IdentifierPreparer: not taking case convention into account.""" return not self.legal_characters.match(str(value)) - def quote_schema(self, schema, force=None): + def quote_schema(self, schema: str, force: Any = None) -> str: """Conditionally quote a schema name. @@ -5403,7 +5604,7 @@ class IdentifierPreparer: return self.quote(schema) - def quote(self, ident, force=None): + def quote(self, ident: str, force: Any = None) -> str: """Conditionally quote an identifier. The identifier is quoted if it is a reserved word, contains @@ -5474,11 +5675,19 @@ class IdentifierPreparer: name = self.quote_schema(effective_schema) + "." + name return name - def format_label(self, label, name=None): + def format_label( + self, label: Label[Any], name: Optional[str] = None + ) -> str: return self.quote(name or label.name) - def format_alias(self, alias, name=None): - return self.quote(name or alias.name) + def format_alias( + self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None + ) -> str: + if name is None: + assert alias is not None + return self.quote(alias.name) + else: + return self.quote(name) def format_savepoint(self, savepoint, name=None): # Running the savepoint name through quoting is unnecessary diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 5aded307b..96e90b0ea 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -13,6 +13,10 @@ from __future__ import annotations import collections.abc as collections_abc import typing +from typing import Any +from typing import List +from typing import MutableMapping +from typing import Optional from . import coercions from . import roles @@ -40,8 +44,8 @@ from .. import util class DMLState(CompileState): _no_parameters = True - _dict_parameters = None - _multi_parameters = None + _dict_parameters: Optional[MutableMapping[str, Any]] = None + _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None _ordered_values = None _parameter_ordering = None _has_multi_parameters = False diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 168da17cc..08d632afd 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -18,7 +18,9 @@ import re import typing from typing import Any from typing import Callable +from typing import Dict from typing import Generic +from typing import List from typing import Optional from typing import overload from typing import Sequence @@ -47,6 +49,7 @@ from .coercions import _document_text_coercion # noqa from .operators import ColumnOperators from .traversals import HasCopyInternals from .visitors import cloned_traverse +from .visitors import ExternallyTraversible from .visitors import InternalTraversal from .visitors import traverse from .visitors import Visitable @@ -68,6 +71,8 @@ if typing.TYPE_CHECKING: from ..engine import Connection from ..engine import Dialect from ..engine import Engine + from ..engine.base import _CompiledCacheType + from ..engine.base import _SchemaTranslateMapType _NUMERIC = Union[complex, "Decimal"] @@ -238,6 +243,7 @@ class ClauseElement( SupportsWrappingAnnotations, MemoizedHasCacheKey, HasCopyInternals, + ExternallyTraversible, CompilerElement, ): """Base class for elements of a programmatically constructed SQL @@ -398,7 +404,9 @@ class ClauseElement( """ return self._replace_params(True, optionaldict, kwargs) - def params(self, *optionaldict, **kwargs): + def params( + self, *optionaldict: Dict[str, Any], **kwargs: Any + ) -> ClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -415,7 +423,12 @@ class ClauseElement( """ return self._replace_params(False, optionaldict, kwargs) - def _replace_params(self, unique, optionaldict, kwargs): + def _replace_params( + self, + unique: bool, + optionaldict: Optional[Dict[str, Any]], + kwargs: Dict[str, Any], + ) -> ClauseElement: if len(optionaldict) == 1: kwargs.update(optionaldict[0]) @@ -487,12 +500,12 @@ class ClauseElement( def _compile_w_cache( self, - dialect, - compiled_cache=None, - column_keys=None, - for_executemany=False, - schema_translate_map=None, - **kw, + dialect: Dialect, + compiled_cache: Optional[_CompiledCacheType] = None, + column_keys: Optional[List[str]] = None, + for_executemany: bool = False, + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + **kw: Any, ): if compiled_cache is not None and dialect._supports_statement_cache: elem_cache_key = self._generate_cache_key() @@ -1383,7 +1396,7 @@ class ColumnElement( """ return Cast(self, type_) - def label(self, name): + def label(self, name: Optional[str]) -> Label[_T]: """Produce a column label, i.e. ``<columnname> AS <name>``. This is a shortcut to the :func:`_expression.label` function. @@ -1608,6 +1621,9 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): ("value", InternalTraversal.dp_plain_obj), ] + key: str + type: TypeEngine + _is_crud = False _is_bind_parameter = True _key_is_anon = False diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index eb3d17ee4..6e5eec127 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -12,6 +12,7 @@ from __future__ import annotations from typing import Any +from typing import Sequence from typing import TypeVar from . import annotation @@ -839,6 +840,8 @@ class Function(FunctionElement): identifier: str + packagenames: Sequence[str] + type: TypeEngine = sqltypes.NULLTYPE """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 64bd4b951..1a7a5f4d4 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -7,14 +7,22 @@ from __future__ import annotations import typing +from typing import Any +from typing import Iterable +from typing import Mapping +from typing import Optional +from typing import Sequence -from sqlalchemy.util.langhelpers import TypingOnly from .. import util - +from ..util import TypingOnly +from ..util.typing import Literal if typing.TYPE_CHECKING: + from .base import ColumnCollection from .elements import ClauseElement + from .elements import Label from .selectable import FromClause + from .selectable import Subquery class SQLRole: @@ -35,7 +43,7 @@ class SQLRole: class UsesInspection: __slots__ = () - _post_inspect = None + _post_inspect: Literal[None] = None uses_inspection = True @@ -96,7 +104,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): _role_name = "Column expression or FROM clause" @property - def _select_iterable(self): + def _select_iterable(self) -> Sequence[ColumnsClauseRole]: raise NotImplementedError() @@ -150,6 +158,9 @@ class ExpressionElementRole(SQLRole): __slots__ = () _role_name = "SQL expression element" + def label(self, name: Optional[str]) -> Label[Any]: + raise NotImplementedError() + class ConstExprRole(ExpressionElementRole): __slots__ = () @@ -187,7 +198,7 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole): _is_subquery = False @property - def _hide_froms(self): + def _hide_froms(self) -> Iterable[FromClause]: raise NotImplementedError() @@ -195,8 +206,10 @@ class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects + c: ColumnCollection + @property - def description(self): + def description(self) -> str: raise NotImplementedError() @@ -204,7 +217,9 @@ class AnonymizedFromClauseRole(StrictFromClauseRole): __slots__ = () # calls .alias() as a post processor - def _anonymous_fromclause(self, name=None, flat=False): + def _anonymous_fromclause( + self, name: Optional[str] = None, flat: bool = False + ) -> FromClause: raise NotImplementedError() @@ -220,14 +235,14 @@ class StatementRole(SQLRole): __slots__ = () _role_name = "Executable SQL or text() construct" - _propagate_attrs = util.immutabledict() + _propagate_attrs: Mapping[str, Any] = util.immutabledict() class SelectStatementRole(StatementRole, ReturnsRowsRole): __slots__ = () _role_name = "SELECT construct or equivalent text() construct" - def subquery(self): + def subquery(self) -> Subquery: raise NotImplementedError( "All SelectStatementRole objects should implement a " ".subquery() method." diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index c270e1564..33e300bf6 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -51,7 +51,7 @@ from . import visitors from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import Executable -from .base import SchemaEventTarget +from .base import SchemaEventTarget as SchemaEventTarget from .coercions import _document_text_coercion from .elements import ClauseElement from .elements import ColumnClause @@ -2676,6 +2676,10 @@ class DefaultGenerator(Executable, SchemaItem): def __init__(self, for_update=False): self.for_update = for_update + @util.memoized_property + def is_callable(self): + raise NotImplementedError() + def _set_parent(self, column, **kw): self.column = column if self.for_update: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e5c2bef68..09befb078 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -53,7 +53,6 @@ from .base import Generative from .base import HasCompileState from .base import HasMemoized from .base import Immutable -from .base import prefix_anon_map from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import BindParameter @@ -69,10 +68,10 @@ from .elements import literal_column from .elements import TableValuedColumn from .elements import UnaryExpression from .visitors import InternalTraversal +from .visitors import prefix_anon_map from .. import exc from .. import util - and_ = BooleanClauseList.and_ _T = TypeVar("_T", bound=Any) @@ -855,6 +854,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.alias(name=name) +class NamedFromClause(FromClause): + named_with_column = True + + name: str + + class SelectLabelStyle(Enum): """Label style constants that may be passed to :meth:`_sql.Select.set_label_style`.""" @@ -1317,15 +1322,16 @@ class NoInit: # -> Lateral -> FromClause, but we accept SelectBase # w/ non-deprecated coercion # -> TableSample -> only for FromClause -class AliasedReturnsRows(NoInit, FromClause): +class AliasedReturnsRows(NoInit, NamedFromClause): """Base class of aliases against tables, subqueries, and other selectables.""" _is_from_container = True - named_with_column = True _supports_derived_columns = False + element: ClauseElement + _traverse_internals = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), @@ -1423,6 +1429,8 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows): inherit_cache = True + element: FromClause + @classmethod def _factory(cls, selectable, name=None, flat=False): return coercions.expect( @@ -1689,6 +1697,8 @@ class CTE( + HasSuffixes._has_suffixes_traverse_internals ) + element: HasCTE + @classmethod def _factory(cls, selectable, name=None, recursive=False): r"""Return a new :class:`_expression.CTE`, @@ -1819,7 +1829,7 @@ class _CTEOpts(NamedTuple): nesting: bool -class HasCTE(roles.HasCTERole): +class HasCTE(roles.HasCTERole, ClauseElement): """Mixin that declares a class to include CTE support. .. versionadded:: 1.1 @@ -2247,6 +2257,8 @@ class Subquery(AliasedReturnsRows): inherit_cache = True + element: Select + @classmethod def _factory(cls, selectable, name=None): """Return a :class:`.Subquery` object.""" @@ -2331,7 +2343,7 @@ class FromGrouping(GroupedElement, FromClause): self.element = state["element"] -class TableClause(roles.DMLTableRole, Immutable, FromClause): +class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """Represents a minimal "table" construct. This is a lightweight table object that has only a name, a @@ -2371,8 +2383,6 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause): ("name", InternalTraversal.dp_string), ] - named_with_column = True - _is_table = True implicit_returning = False @@ -2542,7 +2552,7 @@ class ForUpdateArg(ClauseElement): SelfValues = typing.TypeVar("SelfValues", bound="Values") -class Values(Generative, FromClause): +class Values(Generative, NamedFromClause): """Represent a ``VALUES`` construct that can be used as a FROM element in a statement. @@ -2553,7 +2563,6 @@ class Values(Generative, FromClause): """ - named_with_column = True __visit_name__ = "values" _data = () diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 7d21f1262..b2b1d9bc2 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -35,13 +35,13 @@ from .elements import _NONE_NAME from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa -from .traversals import InternalTraversal from .type_api import Emulated from .type_api import NativeForEmulated # noqa from .type_api import to_instance from .type_api import TypeDecorator from .type_api import TypeEngine from .type_api import Variant # noqa +from .visitors import InternalTraversal from .. import event from .. import exc from .. import inspection diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 4fa23d370..cf9487f93 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -15,7 +15,10 @@ import operator import typing from typing import Any from typing import Callable +from typing import Deque from typing import Dict +from typing import Set +from typing import Tuple from typing import Type from typing import TypeVar @@ -23,9 +26,9 @@ from . import operators from .cache_key import HasCacheKey from .visitors import _TraverseInternalsType from .visitors import anon_map -from .visitors import ExtendedInternalTraversal +from .visitors import ExternallyTraversible +from .visitors import HasTraversalDispatch from .visitors import HasTraverseInternals -from .visitors import InternalTraversal from .. import util from ..util import langhelpers @@ -35,6 +38,7 @@ COMPARE_SUCCEEDED = True def compare(obj1, obj2, **kw): + strategy: TraversalComparatorStrategy if kw.get("use_proxies", False): strategy = ColIdentityComparatorStrategy() else: @@ -45,16 +49,18 @@ def compare(obj1, obj2, **kw): def _preconfigure_traversals(target_hierarchy): for cls in util.walk_subclasses(target_hierarchy): - if hasattr(cls, "_traverse_internals"): - cls._generate_cache_attrs() + if hasattr(cls, "_generate_cache_attrs") and hasattr( + cls, "_traverse_internals" + ): + cls._generate_cache_attrs() # type: ignore _copy_internals.generate_dispatch( - cls, - cls._traverse_internals, + cls, # type: ignore + cls._traverse_internals, # type: ignore "_generated_copy_internals_traversal", ) _get_children.generate_dispatch( - cls, - cls._traverse_internals, + cls, # type: ignore + cls._traverse_internals, # type: ignore "_generated_get_children_traversal", ) @@ -125,54 +131,58 @@ class HasShallowCopy(HasTraverseInternals): meth_text = f"def {method_name}(self, d):\n{code}\n" return langhelpers._exec_code_in_env(meth_text, {}, method_name) - def _shallow_from_dict(self, d: Dict) -> None: + def _shallow_from_dict(self, d: Dict[str, Any]) -> None: cls = self.__class__ + shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None] try: shallow_from_dict = cls.__dict__[ "_generated_shallow_from_dict_traversal" ] except KeyError: - shallow_from_dict = ( - cls._generated_shallow_from_dict_traversal # type: ignore - ) = self._generate_shallow_from_dict( + shallow_from_dict = self._generate_shallow_from_dict( cls._traverse_internals, "_generated_shallow_from_dict_traversal", ) + cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa E501 + shallow_from_dict(self, d) def _shallow_to_dict(self) -> Dict[str, Any]: cls = self.__class__ + shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]] + try: shallow_to_dict = cls.__dict__[ "_generated_shallow_to_dict_traversal" ] except KeyError: - shallow_to_dict = ( - cls._generated_shallow_to_dict_traversal # type: ignore - ) = self._generate_shallow_to_dict( + shallow_to_dict = self._generate_shallow_to_dict( cls._traverse_internals, "_generated_shallow_to_dict_traversal" ) + cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa E501 return shallow_to_dict(self) - def _shallow_copy_to(self: SelfHasShallowCopy, other: SelfHasShallowCopy): + def _shallow_copy_to( + self: SelfHasShallowCopy, other: SelfHasShallowCopy + ) -> None: cls = self.__class__ + shallow_copy: Callable[[SelfHasShallowCopy, SelfHasShallowCopy], None] try: shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"] except KeyError: - shallow_copy = ( - cls._generated_shallow_copy_traversal # type: ignore - ) = self._generate_shallow_copy( + shallow_copy = self._generate_shallow_copy( cls._traverse_internals, "_generated_shallow_copy_traversal" ) + cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501 shallow_copy(self, other) - def _clone(self: SelfHasShallowCopy, **kw) -> SelfHasShallowCopy: + def _clone(self: SelfHasShallowCopy, **kw: Any) -> SelfHasShallowCopy: """Create a shallow copy""" c = self.__class__.__new__(self.__class__) self._shallow_copy_to(c) @@ -246,7 +256,7 @@ class HasCopyInternals(HasTraverseInternals): setattr(self, attrname, result) -class _CopyInternalsTraversal(InternalTraversal): +class _CopyInternalsTraversal(HasTraversalDispatch): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -381,7 +391,7 @@ def _flatten_clauseelement(element): return element -class _GetChildrenTraversal(InternalTraversal): +class _GetChildrenTraversal(HasTraversalDispatch): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -463,13 +473,13 @@ def _resolve_name_for_compare(element, name, anon_map, **kw): return name -class TraversalComparatorStrategy( - ExtendedInternalTraversal, util.MemoizedSlots -): +class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): __slots__ = "stack", "cache", "anon_map" def __init__(self): - self.stack = deque() + self.stack: Deque[ + Tuple[ExternallyTraversible, ExternallyTraversible] + ] = deque() self.cache = set() def _memoized_attr_anon_map(self): @@ -653,7 +663,7 @@ class TraversalComparatorStrategy( if seq1 is None: return seq2 is None - completed = set() + completed: Set[object] = set() for clause in seq1: for other_clause in set(seq2).difference(completed): if self.compare_inner(clause, other_clause, **kw): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e0248adf0..5114a2431 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -21,9 +21,9 @@ from . import coercions from . import operators from . import roles from . import visitors -from .annotation import _deep_annotate # noqa -from .annotation import _deep_deannotate # noqa -from .annotation import _shallow_annotate # noqa +from .annotation import _deep_annotate as _deep_annotate +from .annotation import _deep_deannotate as _deep_deannotate +from .annotation import _shallow_annotate as _shallow_annotate from .base import _expand_cloned from .base import _from_objects from .base import ColumnSet diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 111ecd32e..0c41e440e 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -7,43 +7,46 @@ """Visitor/traversal interface and library functions. -SQLAlchemy schema and expression constructs rely on a Python-centric -version of the classic "visitor" pattern as the primary way in which -they apply functionality. The most common use of this pattern -is statement compilation, where individual expression classes match -up to rendering methods that produce a string result. Beyond this, -the visitor system is also used to inspect expressions for various -information and patterns, as well as for the purposes of applying -transformations to expressions. - -Examples of how the visit system is used can be seen in the source code -of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler`` -modules. Some background on clause adaption is also at -https://techspot.zzzeek.org/2008/01/23/expression-transformations/ . """ from __future__ import annotations from collections import deque +from enum import Enum import itertools import operator import typing from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import Iterator from typing import List +from typing import Mapping +from typing import Optional from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union from .. import exc from .. import util from ..util import langhelpers -from ..util import symbol from ..util._has_cy import HAS_CYEXTENSION -from ..util.langhelpers import _symbol +from ..util.typing import Protocol +from ..util.typing import Self if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_util import cache_anon_map as anon_map # noqa + from ._py_util import prefix_anon_map as prefix_anon_map + from ._py_util import cache_anon_map as anon_map else: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa + from sqlalchemy.cyextension.util import prefix_anon_map as prefix_anon_map + from sqlalchemy.cyextension.util import cache_anon_map as anon_map + __all__ = [ "iterate", @@ -54,57 +57,23 @@ __all__ = [ "Visitable", "ExternalTraversal", "InternalTraversal", + "anon_map", ] -_TraverseInternalsType = List[Tuple[str, _symbol]] - - -class HasTraverseInternals: - """base for classes that have a "traverse internals" element, - which defines all kinds of ways of traversing the elements of an object. - - """ - - __slots__ = () - - _traverse_internals: _TraverseInternalsType - - @util.preload_module("sqlalchemy.sql.traversals") - def get_children(self, omit_attrs=(), **kw): - r"""Return immediate child :class:`.visitors.Visitable` - elements of this :class:`.visitors.Visitable`. - - This is used for visit traversal. - - \**kw may contain flags that change the collection that is - returned, for example to return a subset of items in order to - cut down on larger traversals, or to return child items from a - different context (such as schema-level collections instead of - clause-level). - - """ - - traversals = util.preloaded.sql_traversals - - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return [] - dispatch = traversals._get_children.run_generated_dispatch - return itertools.chain.from_iterable( - meth(obj, **kw) - for attrname, obj, meth in dispatch( - self, traverse_internals, "_generated_get_children_traversal" - ) - if attrname not in omit_attrs and obj is not None - ) +class _CompilerDispatchType(Protocol): + def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any: + ... class Visitable: """Base class for visitable objects. + :class:`.Visitable` is used to implement the SQL compiler dispatch + functions. Other forms of traversal such as for cache key generation + are implemented separately using the :class:`.HasTraverseInternals` + interface. + .. versionchanged:: 2.0 The :class:`.Visitable` class was named :class:`.Traversible` in the 1.4 series; the name is changed back to :class:`.Visitable` in 2.0 which is what it was prior to 1.4. @@ -117,32 +86,20 @@ class Visitable: __visit_name__: str + _original_compiler_dispatch: _CompilerDispatchType + + if typing.TYPE_CHECKING: + + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str: + ... + def __init_subclass__(cls) -> None: if "__visit_name__" in cls.__dict__: cls._generate_compiler_dispatch() super().__init_subclass__() @classmethod - def _generate_compiler_dispatch(cls): - """Assign dispatch attributes to various kinds of - "visitable" classes. - - Attributes include: - - * The ``_compiler_dispatch`` method, corresponding to - ``__visit_name__``. This is called "external traversal" because the - caller of each visit() method is responsible for sub-traversing the - inner elements of each object. This is appropriate for string - compilers and other traversals that need to call upon the inner - elements in a specific pattern. - - * internal traversal collections ``_children_traversal``, - ``_cache_key_traversal``, ``_copy_internals_traversal``, generated - from an optional ``_traverse_internals`` collection of symbols which - comes from the :class:`.InternalTraversal` list of symbols. This is - called "internal traversal". - - """ + def _generate_compiler_dispatch(cls) -> None: visit_name = cls.__visit_name__ if "_compiler_dispatch" in cls.__dict__: @@ -161,7 +118,9 @@ class Visitable: name = "visit_%s" % visit_name getter = operator.attrgetter(name) - def _compiler_dispatch(self, visitor, **kw): + def _compiler_dispatch( + self: Visitable, visitor: Any, **kw: Any + ) -> str: """Look for an attribute named "visit_<visit_name>" on the visitor, and call it with the same kw params. @@ -169,105 +128,20 @@ class Visitable: try: meth = getter(visitor) except AttributeError as err: - return visitor.visit_unsupported_compilation(self, err, **kw) + return visitor.visit_unsupported_compilation(self, err, **kw) # type: ignore # noqa E501 else: - return meth(self, **kw) + return meth(self, **kw) # type: ignore # noqa E501 - cls._compiler_dispatch = ( + cls._compiler_dispatch = ( # type: ignore cls._original_compiler_dispatch ) = _compiler_dispatch - def __class_getitem__(cls, key): + def __class_getitem__(cls, key: str) -> Any: # allow generic classes in py3.9+ return cls -class _HasTraversalDispatch: - r"""Define infrastructure for the :class:`.InternalTraversal` class. - - .. versionadded:: 2.0 - - """ - - __slots__ = () - - def __init_subclass__(cls) -> None: - cls._generate_traversal_dispatch() - super().__init_subclass__() - - def dispatch(self, visit_symbol): - """Given a method from :class:`._HasTraversalDispatch`, return the - corresponding method on a subclass. - - """ - name = self._dispatch_lookup[visit_symbol] - return getattr(self, name, None) - - def run_generated_dispatch( - self, target, internal_dispatch, generate_dispatcher_name - ): - try: - dispatcher = target.__class__.__dict__[generate_dispatcher_name] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = self.generate_dispatch( - target.__class__, internal_dispatch, generate_dispatcher_name - ) - return dispatcher(target, self) - - def generate_dispatch( - self, target_cls, internal_dispatch, generate_dispatcher_name - ): - dispatcher = self._generate_dispatcher( - internal_dispatch, generate_dispatcher_name - ) - # assert isinstance(target_cls, type) - setattr(target_cls, generate_dispatcher_name, dispatcher) - return dispatcher - - @classmethod - def _generate_traversal_dispatch(cls): - lookup = {} - clsdict = cls.__dict__ - for key, sym in clsdict.items(): - if key.startswith("dp_"): - visit_key = key.replace("dp_", "visit_") - sym_name = sym.name - assert sym_name not in lookup, sym_name - lookup[sym] = lookup[sym_name] = visit_key - if hasattr(cls, "_dispatch_lookup"): - lookup.update(cls._dispatch_lookup) - cls._dispatch_lookup = lookup - - def _generate_dispatcher(self, internal_dispatch, method_name): - names = [] - for attrname, visit_sym in internal_dispatch: - meth = self.dispatch(visit_sym) - if meth: - visit_name = ExtendedInternalTraversal._dispatch_lookup[ - visit_sym - ] - names.append((attrname, visit_name)) - - code = ( - (" return [\n") - + ( - ", \n".join( - " (%r, self.%s, visitor.%s)" - % (attrname, attrname, visit_name) - for attrname, visit_name in names - ) - ) - + ("\n ]\n") - ) - meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" - return langhelpers._exec_code_in_env(meth_text, {}, method_name) - - -class InternalTraversal(_HasTraversalDispatch): +class InternalTraversal(Enum): r"""Defines visitor symbols used for internal traversal. The :class:`.InternalTraversal` class is used in two ways. One is that @@ -306,18 +180,16 @@ class InternalTraversal(_HasTraversalDispatch): """ - __slots__ = () - - dp_has_cache_key = symbol("HC") + dp_has_cache_key = "HC" """Visit a :class:`.HasCacheKey` object.""" - dp_has_cache_key_list = symbol("HL") + dp_has_cache_key_list = "HL" """Visit a list of :class:`.HasCacheKey` objects.""" - dp_clauseelement = symbol("CE") + dp_clauseelement = "CE" """Visit a :class:`_expression.ClauseElement` object.""" - dp_fromclause_canonical_column_collection = symbol("FC") + dp_fromclause_canonical_column_collection = "FC" """Visit a :class:`_expression.FromClause` object in the context of the ``columns`` attribute. @@ -329,30 +201,30 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_clauseelement_tuples = symbol("CTS") + dp_clauseelement_tuples = "CTS" """Visit a list of tuples which contain :class:`_expression.ClauseElement` objects. """ - dp_clauseelement_list = symbol("CL") + dp_clauseelement_list = "CL" """Visit a list of :class:`_expression.ClauseElement` objects. """ - dp_clauseelement_tuple = symbol("CT") + dp_clauseelement_tuple = "CT" """Visit a tuple of :class:`_expression.ClauseElement` objects. """ - dp_executable_options = symbol("EO") + dp_executable_options = "EO" - dp_with_context_options = symbol("WC") + dp_with_context_options = "WC" - dp_fromclause_ordered_set = symbol("CO") + dp_fromclause_ordered_set = "CO" """Visit an ordered set of :class:`_expression.FromClause` objects. """ - dp_string = symbol("S") + dp_string = "S" """Visit a plain string value. Examples include table and column names, bound parameter keys, special @@ -363,10 +235,10 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_string_list = symbol("SL") + dp_string_list = "SL" """Visit a list of strings.""" - dp_anon_name = symbol("AN") + dp_anon_name = "AN" """Visit a potentially "anonymized" string value. The string value is considered to be significant for cache key @@ -374,7 +246,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_boolean = symbol("B") + dp_boolean = "B" """Visit a boolean value. The boolean value is considered to be significant for cache key @@ -382,7 +254,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_operator = symbol("O") + dp_operator = "O" """Visit an operator. The operator is a function from the :mod:`sqlalchemy.sql.operators` @@ -393,7 +265,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_type = symbol("T") + dp_type = "T" """Visit a :class:`.TypeEngine` object The type object is considered to be significant for cache key @@ -401,7 +273,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_plain_dict = symbol("PD") + dp_plain_dict = "PD" """Visit a dictionary with string keys. The keys of the dictionary should be strings, the values should @@ -410,22 +282,22 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_dialect_options = symbol("DO") + dp_dialect_options = "DO" """Visit a dialect options structure.""" - dp_string_clauseelement_dict = symbol("CD") + dp_string_clauseelement_dict = "CD" """Visit a dictionary of string keys to :class:`_expression.ClauseElement` objects. """ - dp_string_multi_dict = symbol("MD") + dp_string_multi_dict = "MD" """Visit a dictionary of string keys to values which may either be plain immutable/hashable or :class:`.HasCacheKey` objects. """ - dp_annotations_key = symbol("AK") + dp_annotations_key = "AK" """Visit the _annotations_cache_key element. This is a dictionary of additional information about a ClauseElement @@ -436,7 +308,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_plain_obj = symbol("PO") + dp_plain_obj = "PO" """Visit a plain python object. The value should be immutable and hashable, such as an integer. @@ -444,7 +316,7 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_named_ddl_element = symbol("DD") + dp_named_ddl_element = "DD" """Visit a simple named DDL element. The current object used by this method is the :class:`.Sequence`. @@ -454,57 +326,56 @@ class InternalTraversal(_HasTraversalDispatch): """ - dp_prefix_sequence = symbol("PS") + dp_prefix_sequence = "PS" """Visit the sequence represented by :class:`_expression.HasPrefixes` or :class:`_expression.HasSuffixes`. """ - dp_table_hint_list = symbol("TH") + dp_table_hint_list = "TH" """Visit the ``_hints`` collection of a :class:`_expression.Select` object. """ - dp_setup_join_tuple = symbol("SJ") + dp_setup_join_tuple = "SJ" - dp_memoized_select_entities = symbol("ME") + dp_memoized_select_entities = "ME" - dp_statement_hint_list = symbol("SH") + dp_statement_hint_list = "SH" """Visit the ``_statement_hints`` collection of a :class:`_expression.Select` object. """ - dp_unknown_structure = symbol("UK") + dp_unknown_structure = "UK" """Visit an unknown structure. """ - dp_dml_ordered_values = symbol("DML_OV") + dp_dml_ordered_values = "DML_OV" """Visit the values() ordered tuple list of an :class:`_expression.Update` object.""" - dp_dml_values = symbol("DML_V") + dp_dml_values = "DML_V" """Visit the values() dictionary of a :class:`.ValuesBase` (e.g. Insert or Update) object. """ - dp_dml_multi_values = symbol("DML_MV") + dp_dml_multi_values = "DML_MV" """Visit the values() multi-valued list of dictionaries of an :class:`_expression.Insert` object. """ - dp_propagate_attrs = symbol("PA") + dp_propagate_attrs = "PA" """Visit the propagate attrs dict. This hardcodes to the particular elements we care about right now.""" - -class ExtendedInternalTraversal(InternalTraversal): - """Defines additional symbols that are useful in caching applications. + """Symbols that follow are additional symbols that are useful in + caching applications. Traversals for :class:`_expression.ClauseElement` objects only need to use those symbols present in :class:`.InternalTraversal`. However, for @@ -513,9 +384,7 @@ class ExtendedInternalTraversal(InternalTraversal): """ - __slots__ = () - - dp_ignore = symbol("IG") + dp_ignore = "IG" """Specify an object that should be ignored entirely. This currently applies function call argument caching where some @@ -523,29 +392,235 @@ class ExtendedInternalTraversal(InternalTraversal): """ - dp_inspectable = symbol("IS") + dp_inspectable = "IS" """Visit an inspectable object where the return value is a :class:`.HasCacheKey` object.""" - dp_multi = symbol("M") + dp_multi = "M" """Visit an object that may be a :class:`.HasCacheKey` or may be a plain hashable object.""" - dp_multi_list = symbol("MT") + dp_multi_list = "MT" """Visit a tuple containing elements that may be :class:`.HasCacheKey` or may be a plain hashable object.""" - dp_has_cache_key_tuples = symbol("HT") + dp_has_cache_key_tuples = "HT" """Visit a list of tuples which contain :class:`.HasCacheKey` objects. """ - dp_inspectable_list = symbol("IL") + dp_inspectable_list = "IL" """Visit a list of inspectable objects which upon inspection are HasCacheKey objects.""" +_TraverseInternalsType = List[Tuple[str, InternalTraversal]] +"""a structure that defines how a HasTraverseInternals should be +traversed. + +This structure consists of a list of (attributename, internaltraversal) +tuples, where the "attributename" refers to the name of an attribute on an +instance of the HasTraverseInternals object, and "internaltraversal" refers +to an :class:`.InternalTraversal` enumeration symbol defining what kind +of data this attribute stores, which indicates to the traverser how it should +be handled. + +""" + + +class HasTraverseInternals: + """base for classes that have a "traverse internals" element, + which defines all kinds of ways of traversing the elements of an object. + + Compared to :class:`.Visitable`, which relies upon an external visitor to + define how the object is travered (i.e. the :class:`.SQLCompiler`), the + :class:`.HasTraverseInternals` interface allows classes to define their own + traversal, that is, what attributes are accessed and in what order. + + """ + + __slots__ = () + + _traverse_internals: _TraverseInternalsType + + @util.preload_module("sqlalchemy.sql.traversals") + def get_children( + self, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> Iterable[HasTraverseInternals]: + r"""Return immediate child :class:`.visitors.HasTraverseInternals` + elements of this :class:`.visitors.HasTraverseInternals`. + + This is used for visit traversal. + + \**kw may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + + traversals = util.preloaded.sql_traversals + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return [] + + dispatch = traversals._get_children.run_generated_dispatch + return itertools.chain.from_iterable( + meth(obj, **kw) + for attrname, obj, meth in dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ) + if attrname not in omit_attrs and obj is not None + ) + + +class _InternalTraversalDispatchType(Protocol): + def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any: + ... + + +class HasTraversalDispatch: + r"""Define infrastructure for classes that perform internal traversals + + .. versionadded:: 2.0 + + """ + + __slots__ = () + + _dispatch_lookup: ClassVar[Dict[Union[InternalTraversal, str], str]] = {} + + def dispatch(self, visit_symbol: InternalTraversal) -> Callable[..., Any]: + """Given a method from :class:`.HasTraversalDispatch`, return the + corresponding method on a subclass. + + """ + name = _dispatch_lookup[visit_symbol] + return getattr(self, name, None) # type: ignore + + def run_generated_dispatch( + self, + target: object, + internal_dispatch: _TraverseInternalsType, + generate_dispatcher_name: str, + ) -> Any: + dispatcher: _InternalTraversalDispatchType + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. + # this block will generate any remaining dispatchers. + dispatcher = self.generate_dispatch( + target.__class__, internal_dispatch, generate_dispatcher_name + ) + return dispatcher(target, self) + + def generate_dispatch( + self, + target_cls: Type[object], + internal_dispatch: _TraverseInternalsType, + generate_dispatcher_name: str, + ) -> _InternalTraversalDispatchType: + dispatcher = self._generate_dispatcher( + internal_dispatch, generate_dispatcher_name + ) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) + return dispatcher + + def _generate_dispatcher( + self, internal_dispatch: _TraverseInternalsType, method_name: str + ) -> _InternalTraversalDispatchType: + names = [] + for attrname, visit_sym in internal_dispatch: + meth = self.dispatch(visit_sym) + if meth: + visit_name = _dispatch_lookup[visit_sym] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) + ) + + ("\n ]\n") + ) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + return cast( + _InternalTraversalDispatchType, + langhelpers._exec_code_in_env(meth_text, {}, method_name), + ) + + +ExtendedInternalTraversal = InternalTraversal + + +def _generate_traversal_dispatch() -> None: + lookup = _dispatch_lookup + + for sym in InternalTraversal: + key = sym.name + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.value + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + + +_dispatch_lookup = HasTraversalDispatch._dispatch_lookup +_generate_traversal_dispatch() + + +class ExternallyTraversible(HasTraverseInternals, Visitable): + __slots__ = () + + _annotations: Collection[Any] = () + + if typing.TYPE_CHECKING: + + def get_children( + self, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> Iterable[ExternallyTraversible]: + ... + + def _clone(self: Self, **kw: Any) -> Self: + """clone this element""" + raise NotImplementedError() + + def _copy_internals( + self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> Self: + """Reassign internal elements to be clones of themselves. + + Called during a copy-and-traverse operation on newly + shallow-copied elements to create a deep copy. + + The given clone function should be used, which may be applying + additional transformations to the element (i.e. replacement + traversal, cloned traversal, annotations). + + """ + raise NotImplementedError() + + +_ET = TypeVar("_ET", bound=ExternallyTraversible) +_TraverseCallableType = Callable[[_ET], None] +_TraverseTransformCallableType = Callable[ + [ExternallyTraversible], Optional[ExternallyTraversible] +] + + class ExternalTraversal: """Base class for visitor objects which can traverse externally using the :func:`.visitors.traverse` function. @@ -555,7 +630,8 @@ class ExternalTraversal: """ - __traverse_options__ = {} + __traverse_options__: Dict[str, Any] = {} + _next: Optional[ExternalTraversal] def traverse_single(self, obj: Visitable, **kw: Any) -> Any: for v in self.visitor_iterator: @@ -563,20 +639,22 @@ class ExternalTraversal: if meth: return meth(obj, **kw) - def iterate(self, obj): + def iterate( + self, obj: ExternallyTraversible + ) -> Iterator[ExternallyTraversible]: """Traverse the given expression structure, returning an iterator of all elements. """ return iterate(obj, self.__traverse_options__) - def traverse(self, obj): + def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: """Traverse and visit the given expression structure.""" return traverse(obj, self.__traverse_options__, self._visitor_dict) @util.memoized_property - def _visitor_dict(self): + def _visitor_dict(self) -> Dict[str, _TraverseCallableType[Any]]: visitors = {} for name in dir(self): @@ -585,16 +663,16 @@ class ExternalTraversal: return visitors @property - def visitor_iterator(self): + def visitor_iterator(self) -> Iterator[ExternalTraversal]: """Iterate through this visitor and each 'chained' visitor.""" - v = self + v: Optional[ExternalTraversal] = self while v: yield v v = getattr(v, "_next", None) - def chain(self, visitor): - """'Chain' an additional ClauseVisitor onto this ClauseVisitor. + def chain(self, visitor: ExternalTraversal) -> ExternalTraversal: + """'Chain' an additional ExternalTraversal onto this ExternalTraversal The chained visitor will receive all visit events after this one. @@ -614,14 +692,16 @@ class CloningExternalTraversal(ExternalTraversal): """ - def copy_and_process(self, list_): + def copy_and_process( + self, list_: List[ExternallyTraversible] + ) -> List[ExternallyTraversible]: """Apply cloned traversal to the given list of elements, and return the new list. """ return [self.traverse(x) for x in list_] - def traverse(self, obj): + def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: """Traverse and visit the given expression structure.""" return cloned_traverse( @@ -638,7 +718,9 @@ class ReplacingExternalTraversal(CloningExternalTraversal): """ - def replace(self, elem): + def replace( + self, elem: ExternallyTraversible + ) -> Optional[ExternallyTraversible]: """Receive pre-copied elements during a cloning traversal. If the method returns a new element, the element is used @@ -647,15 +729,19 @@ class ReplacingExternalTraversal(CloningExternalTraversal): """ return None - def traverse(self, obj): + def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: """Traverse and visit the given expression structure.""" - def replace(elem): + def replace( + elem: ExternallyTraversible, + ) -> Optional[ExternallyTraversible]: for v in self.visitor_iterator: - e = v.replace(elem) + e = cast(ReplacingExternalTraversal, v).replace(elem) if e is not None: return e + return None + return replacement_traverse(obj, self.__traverse_options__, replace) @@ -667,7 +753,9 @@ CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal -def iterate(obj, opts=util.immutabledict()): +def iterate( + obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT +) -> Iterator[ExternallyTraversible]: r"""Traverse the given expression structure, returning an iterator. Traversal is configured to be breadth-first. @@ -702,7 +790,11 @@ def iterate(obj, opts=util.immutabledict()): stack.append(t.get_children(**opts)) -def traverse_using(iterator, obj, visitors): +def traverse_using( + iterator: Iterable[ExternallyTraversible], + obj: ExternallyTraversible, + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> ExternallyTraversible: """Visit the given expression structure using the given iterator of objects. @@ -734,7 +826,11 @@ def traverse_using(iterator, obj, visitors): return obj -def traverse(obj, opts, visitors): +def traverse( + obj: ExternallyTraversible, + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> ExternallyTraversible: """Traverse and visit the given expression structure using the default iterator. @@ -767,7 +863,11 @@ def traverse(obj, opts, visitors): return traverse_using(iterate(obj, opts), obj, visitors) -def cloned_traverse(obj, opts, visitors): +def cloned_traverse( + obj: ExternallyTraversible, + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseTransformCallableType], +) -> ExternallyTraversible: """Clone the given expression structure, allowing modifications by visitors. @@ -794,20 +894,24 @@ def cloned_traverse(obj, opts, visitors): """ - cloned = {} + cloned: Dict[int, ExternallyTraversible] = {} stop_on = set(opts.get("stop_on", [])) - def deferred_copy_internals(obj): + def deferred_copy_internals( + obj: ExternallyTraversible, + ) -> ExternallyTraversible: return cloned_traverse(obj, opts, visitors) - def clone(elem, **kw): + def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible: if elem in stop_on: return elem else: if id(elem) not in cloned: if "replace" in kw: - newelem = kw["replace"](elem) + newelem = cast( + Optional[ExternallyTraversible], kw["replace"](elem) + ) if newelem is not None: cloned[id(elem)] = newelem return newelem @@ -823,11 +927,15 @@ def cloned_traverse(obj, opts, visitors): obj = clone( obj, deferred_copy_internals=deferred_copy_internals, **opts ) - clone = None # remove gc cycles + clone = None # type: ignore[assignment] # remove gc cycles return obj -def replacement_traverse(obj, opts, replace): +def replacement_traverse( + obj: ExternallyTraversible, + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType, +) -> ExternallyTraversible: """Clone the given expression structure, allowing element replacement by a given replacement function. @@ -854,10 +962,12 @@ def replacement_traverse(obj, opts, replace): cloned = {} stop_on = {id(x) for x in opts.get("stop_on", [])} - def deferred_copy_internals(obj): + def deferred_copy_internals( + obj: ExternallyTraversible, + ) -> ExternallyTraversible: return replacement_traverse(obj, opts, replace) - def clone(elem, **kw): + def clone(elem: ExternallyTraversible, **kw: Any) -> ExternallyTraversible: if ( id(elem) in stop_on or "no_replacement_traverse" in elem._annotations @@ -888,5 +998,5 @@ def replacement_traverse(obj, opts, replace): obj = clone( obj, deferred_copy_internals=deferred_copy_internals, **opts ) - clone = None # remove gc cycles + clone = None # type: ignore[assignment] # remove gc cycles return obj |
