diff options
26 files changed, 1120 insertions, 1028 deletions
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index c5b03dd72..25d369240 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -15,13 +15,16 @@ at the class level of a particular ``_Dispatch`` class as well as within instances of ``_Dispatch``. """ +from typing import ClassVar +from typing import Optional +from typing import Type import weakref from .attr import _ClsLevelDispatch from .attr import _EmptyListener from .attr import _JoinedListener from .. import util - +from ..util.typing import Protocol _registrars = util.defaultdict(list) @@ -63,8 +66,8 @@ class _Dispatch: of the :class:`._Dispatch` class is returned. A :class:`._Dispatch` class is generated for each :class:`.Events` - class defined, by the :func:`._create_dispatcher_class` function. - The original :class:`.Events` classes remain untouched. + class defined, by the :meth:`._HasEventsDispatch._create_dispatcher_class` + method. The original :class:`.Events` classes remain untouched. This decouples the construction of :class:`.Events` subclasses from the implementation used by the event internals, and allows inspecting tools like Sphinx to work in an unsurprising @@ -78,6 +81,13 @@ class _Dispatch: _empty_listener_reg = weakref.WeakKeyDictionary() + _events: Type["_HasEventsDispatch"] + """reference back to the Events class. + + Bidirectional against _HasEventsDispatch.dispatch + + """ + def __init__(self, parent, instance_cls=None): self._parent = parent self._instance_cls = instance_cls @@ -159,56 +169,6 @@ class _Dispatch: ls.for_modify(self).clear() -class _EventMeta(type): - """Intercept new Event subclasses and create - associated _Dispatch classes.""" - - def __init__(cls, classname, bases, dict_): - _create_dispatcher_class(cls, classname, bases, dict_) - type.__init__(cls, classname, bases, dict_) - - -def _create_dispatcher_class(cls, classname, bases, dict_): - """Create a :class:`._Dispatch` class corresponding to an - :class:`.Events` class.""" - - # there's all kinds of ways to do this, - # i.e. make a Dispatch class that shares the '_listen' method - # of the Event class, this is the straight monkeypatch. - if hasattr(cls, "dispatch"): - dispatch_base = cls.dispatch.__class__ - else: - dispatch_base = _Dispatch - - event_names = [k for k in dict_ if _is_event_name(k)] - dispatch_cls = type( - "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names} - ) - - dispatch_cls._event_names = event_names - - dispatch_inst = cls._set_dispatch(cls, dispatch_cls) - for k in dispatch_cls._event_names: - setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) - _registrars[k].append(cls) - - for super_ in dispatch_cls.__bases__: - if issubclass(super_, _Dispatch) and super_ is not _Dispatch: - for ls in super_._events.dispatch._event_descriptors: - setattr(dispatch_inst, ls.name, ls) - dispatch_cls._event_names.append(ls.name) - - if getattr(cls, "_dispatch_target", None): - the_cls = cls._dispatch_target - if ( - hasattr(the_cls, "__slots__") - and "_slots_dispatch" in the_cls.__slots__ - ): - cls._dispatch_target.dispatch = slots_dispatcher(cls) - else: - cls._dispatch_target.dispatch = dispatcher(cls) - - def _remove_dispatcher(cls): for k in cls.dispatch._event_names: _registrars[k].remove(cls) @@ -216,8 +176,31 @@ def _remove_dispatcher(cls): del _registrars[k] -class Events(metaclass=_EventMeta): - """Define event listening functions for a particular target type.""" +class _HasEventsDispatchProto(Protocol): + """protocol for non-event classes that will also receive the 'dispatch' + attribute in the form of a descriptor. + + """ + + dispatch: ClassVar["dispatcher"] + + +class _HasEventsDispatch: + _dispatch_target: Optional[Type[_HasEventsDispatchProto]] + """class which will receive the .dispatch collection""" + + dispatch: _Dispatch + """reference back to the _Dispatch class. + + Bidirectional against _Dispatch._events + + """ + + def __init_subclass__(cls) -> None: + """Intercept new Event subclasses and create associated _Dispatch + classes.""" + + cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__) @staticmethod def _set_dispatch(cls, dispatch_cls): @@ -231,6 +214,54 @@ class Events(metaclass=_EventMeta): return cls.dispatch @classmethod + def _create_dispatcher_class(cls, classname, bases, dict_): + """Create a :class:`._Dispatch` class corresponding to an + :class:`.Events` class.""" + + # there's all kinds of ways to do this, + # i.e. make a Dispatch class that shares the '_listen' method + # of the Event class, this is the straight monkeypatch. + if hasattr(cls, "dispatch"): + dispatch_base = cls.dispatch.__class__ + else: + dispatch_base = _Dispatch + + event_names = [k for k in dict_ if _is_event_name(k)] + dispatch_cls = type( + "%sDispatch" % classname, + (dispatch_base,), + {"__slots__": event_names}, + ) + + dispatch_cls._event_names = event_names + + dispatch_inst = cls._set_dispatch(cls, dispatch_cls) + for k in dispatch_cls._event_names: + setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) + _registrars[k].append(cls) + + for super_ in dispatch_cls.__bases__: + if issubclass(super_, _Dispatch) and super_ is not _Dispatch: + for ls in super_._events.dispatch._event_descriptors: + setattr(dispatch_inst, ls.name, ls) + dispatch_cls._event_names.append(ls.name) + + if getattr(cls, "_dispatch_target", None): + dispatch_target_cls = cls._dispatch_target + assert dispatch_target_cls is not None + if ( + hasattr(dispatch_target_cls, "__slots__") + and "_slots_dispatch" in dispatch_target_cls.__slots__ + ): + dispatch_target_cls.dispatch = slots_dispatcher(cls) + else: + dispatch_target_cls.dispatch = dispatcher(cls) + + +class Events(_HasEventsDispatch): + """Define event listening functions for a particular target type.""" + + @classmethod def _accept_with(cls, target): def dispatch_is(*types): return all(isinstance(target.dispatch, t) for t in types) diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index a2e1a3826..d75cf667b 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -269,10 +269,10 @@ and to also route attribute set events via ``__setattr__`` to the def __ne__(self, other): return not self.__eq__(other) -The :class:`.MutableComposite` class uses a Python metaclass to automatically -establish listeners for any usage of :func:`_orm.composite` that specifies our -``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class, -listeners are established which will route change events from ``Point`` +The :class:`.MutableComposite` class makes use of class mapping events to +automatically establish listeners for any usage of :func:`_orm.composite` that +specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` +class, listeners are established which will route change events from ``Point`` objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: from sqlalchemy.orm import composite, mapper diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 1094fa516..2a5b1bb2b 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -51,6 +51,10 @@ def has_inherited_table(cls): class DeclarativeMeta(type): + # DeclarativeMeta could be replaced by __subclass_init__() + # except for the class-level __setattr__() and __delattr__ hooks, + # which are still very important. + def __init__(cls, classname, bases, dict_, **kw): # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6c84f89cf..d842df221 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -37,7 +37,7 @@ from ..sql import operators from ..sql import roles from ..sql import visitors from ..sql.base import ExecutableOption -from ..sql.traversals import HasCacheKey +from ..sql.cache_key import HasCacheKey __all__ = ( diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 2e64696d9..0d87739cc 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -20,7 +20,7 @@ from .. import exc from .. import inspection from .. import util from ..sql import visitors -from ..sql.traversals import HasCacheKey +from ..sql.cache_key import HasCacheKey log = logging.getLogger(__name__) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ad31c2432..e6d16f178 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -134,7 +134,7 @@ class Query( # local Query builder state, not needed for # compilation or execution _enable_assertions = True - _last_joined_entity = None + _statement = None # mirrors that of ClauseElement, used to propagate the "orm" diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 4df275c71..c2cfbb9fc 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -29,9 +29,9 @@ from .. import exc as sa_exc from .. import inspect from .. import util from ..sql import and_ +from ..sql import cache_key from ..sql import coercions from ..sql import roles -from ..sql import traversals from ..sql import visitors from ..sql.base import _generative from ..sql.base import Generative @@ -1316,7 +1316,7 @@ class _WildcardLoad(_AbstractLoad): self.__dict__.update(state) -class _LoadElement(traversals.HasCacheKey): +class _LoadElement(cache_key.HasCacheKey): """represents strategy information to select for a LoaderStrategy and pass options to it. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 805f7b1a0..6ab9a75f6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -20,9 +20,9 @@ import typing from . import roles from . import visitors -from .traversals import HasCacheKey # noqa +from .cache_key import HasCacheKey # noqa +from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa -from .traversals import MemoizedHasCacheKey # noqa from .visitors import ClauseVisitor from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal @@ -37,7 +37,6 @@ try: except ImportError: from ._py_util import prefix_anon_map # noqa - coercions = None elements = None type_api = None @@ -610,18 +609,13 @@ class HasCompileState(Generative): class _MetaOptions(type): - """metaclass for the Options class.""" + """metaclass for the Options class. - def __init__(cls, classname, bases, dict_): - cls._cache_attrs = tuple( - sorted( - d - for d in dict_ - if not d.startswith("__") - and d not in ("_cache_key_traversal",) - ) - ) - type.__init__(cls, classname, bases, dict_) + This metaclass is actually necessary despite the availability of the + ``__init_subclass__()`` hook as this type also provides custom class-level + behavior for the ``__add__()`` method. + + """ def __add__(self, other): o1 = self() @@ -640,6 +634,18 @@ class _MetaOptions(type): class Options(metaclass=_MetaOptions): """A cacheable option dictionary with defaults.""" + def __init_subclass__(cls) -> None: + dict_ = cls.__dict__ + cls._cache_attrs = tuple( + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) + ) + super().__init_subclass__() + def __init__(self, **kw): self.__dict__.update(kw) diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py new file mode 100644 index 000000000..8dd44dbf0 --- /dev/null +++ b/lib/sqlalchemy/sql/cache_key.py @@ -0,0 +1,762 @@ +# sql/cache_key.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from collections import namedtuple +import enum +from itertools import zip_longest +from typing import Callable +from typing import Union + +from .visitors import anon_map +from .visitors import ExtendedInternalTraversal +from .visitors import InternalTraversal +from .. import util +from ..inspection import inspect +from ..util import HasMemoized +from ..util.typing import Literal + + +class CacheConst(enum.Enum): + NO_CACHE = 0 + + +NO_CACHE = CacheConst.NO_CACHE + + +class CacheTraverseTarget(enum.Enum): + CACHE_IN_PLACE = 0 + CALL_GEN_CACHE_KEY = 1 + STATIC_CACHE_KEY = 2 + PROPAGATE_ATTRS = 3 + ANON_NAME = 4 + + +( + CACHE_IN_PLACE, + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + PROPAGATE_ATTRS, + ANON_NAME, +) = tuple(CacheTraverseTarget) + + +class HasCacheKey: + """Mixin for objects which can produce a cache key. + + .. seealso:: + + :class:`.CacheKey` + + :ref:`sql_caching` + + """ + + _cache_key_traversal = NO_CACHE + + _is_has_cache_key = True + + _hierarchy_supports_caching = True + """private attribute which may be set to False to prevent the + inherit_cache warning from being emitted for a hierarchy of subclasses. + + Currently applies to the DDLElement hierarchy which does not implement + caching. + + """ + + inherit_cache = None + """Indicate if this :class:`.HasCacheKey` instance should make use of the + cache key generation scheme used by its immediate superclass. + + The attribute defaults to ``None``, which indicates that a construct has + not yet taken into account whether or not its appropriate for it to + participate in caching; this is functionally equivalent to setting the + value to ``False``, except that a warning is also emitted. + + This flag can be set to ``True`` on a particular class, if the SQL that + corresponds to the object does not change based on attributes which + are local to this class, and not its superclass. + + .. seealso:: + + :ref:`compilerext_caching` - General guideslines for setting the + :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user + defined SQL constructs. + + """ + + __slots__ = () + + @classmethod + def _generate_cache_attrs(cls): + """generate cache key dispatcher for a new class. + + This sets the _generated_cache_key_traversal attribute once called + so should only be called once per class. + + """ + inherit_cache = cls.__dict__.get("inherit_cache", None) + inherit = bool(inherit_cache) + + if inherit: + _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) + if _cache_key_traversal is None: + try: + _cache_key_traversal = cls._traverse_internals + except AttributeError: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + # TODO: wouldn't we instead get this from our superclass? + # also, our superclass may not have this yet, but in any case, + # we'd generate for the superclass that has it. this is a little + # more complicated, so for the moment this is a little less + # efficient on startup but simpler. + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + else: + _cache_key_traversal = cls.__dict__.get( + "_cache_key_traversal", None + ) + if _cache_key_traversal is None: + _cache_key_traversal = cls.__dict__.get( + "_traverse_internals", None + ) + if _cache_key_traversal is None: + cls._generated_cache_key_traversal = NO_CACHE + if ( + inherit_cache is None + and cls._hierarchy_supports_caching + ): + util.warn( + "Class %s will not make use of SQL compilation " + "caching as it does not set the 'inherit_cache' " + "attribute to ``True``. This can have " + "significant performance implications including " + "some performance degradations in comparison to " + "prior SQLAlchemy versions. Set this attribute " + "to True if this object can make use of the cache " + "key generated by the superclass. Alternatively, " + "this attribute may be set to False which will " + "disable this warning." % (cls.__name__), + code="cprf", + ) + return NO_CACHE + + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + + @util.preload_module("sqlalchemy.sql.elements") + def _gen_cache_key(self, anon_map, bindparams): + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, the NO_CACHE + symbol should be added to the anon_map and the method should + return None. + + """ + + cls = self.__class__ + + id_, found = anon_map.get_anon(self) + if found: + return (id_, cls) + + dispatcher: Union[ + Literal[CacheConst.NO_CACHE], + Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"], + ] + + try: + dispatcher = cls.__dict__["_generated_cache_key_traversal"] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = cls._generate_cache_attrs() + + if dispatcher is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result = (id_, cls) + + # inline of _cache_key_traversal_visitor.run_generated_dispatch() + + for attrname, obj, meth in dispatcher( + self, _cache_key_traversal_visitor + ): + if obj is not None: + # TODO: see if C code can help here as Python lacks an + # efficient switch construct + + if meth is STATIC_CACHE_KEY: + sck = obj._static_cache_key + if sck is NO_CACHE: + anon_map[NO_CACHE] = True + return None + result += (attrname, sck) + elif meth is ANON_NAME: + elements = util.preloaded.sql_elements + if isinstance(obj, elements._anonymous_label): + obj = obj.apply_map(anon_map) + result += (attrname, obj) + elif meth is CALL_GEN_CACHE_KEY: + result += ( + attrname, + obj._gen_cache_key(anon_map, bindparams), + ) + + # remaining cache functions are against + # Python tuples, dicts, lists, etc. so we can skip + # if they are empty + elif obj: + if meth is CACHE_IN_PLACE: + result += (attrname, obj) + elif meth is PROPAGATE_ATTRS: + result += ( + attrname, + obj["compile_state_plugin"], + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None, + ) + elif meth is InternalTraversal.dp_annotations_key: + # obj is here is the _annotations dict. however, we + # want to use the memoized cache key version of it. for + # Columns, this should be long lived. For select() + # statements, not so much, but they usually won't have + # annotations. + result += self._annotations_cache_key + elif ( + meth is InternalTraversal.dp_clauseelement_list + or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities + ): + result += ( + attrname, + tuple( + [ + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + ] + ), + ) + else: + result += meth( + attrname, obj, self, anon_map, bindparams + ) + return result + + def _generate_cache_key(self): + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`_expression.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + + bindparams = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + @classmethod + def _generate_cache_key_for_object(cls, obj): + bindparams = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key(self) + + +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): + """The key used to identify a SQL statement construct in the + SQL compilation cache. + + .. seealso:: + + :ref:`sql_caching` + + """ + + def __hash__(self): + """CacheKey itself is not hashable - hash the .key portion""" + + return None + + def to_offline_string(self, statement_cache, statement, parameters): + """Generate an "offline string" form of this :class:`.CacheKey` + + The "offline string" is basically the string SQL for the + statement plus a repr of the bound parameter values in series. + Whereas the :class:`.CacheKey` object is dependent on in-memory + identities in order to work as a cache key, the "offline" version + is suitable for a cache that will work for other processes as well. + + The given ``statement_cache`` is a dictionary-like object where the + string form of the statement itself will be cached. This dictionary + should be in a longer lived scope in order to reduce the time spent + stringifying statements. + + + """ + if self.key not in statement_cache: + statement_cache[self.key] = sql_str = str(statement) + else: + sql_str = statement_cache[self.key] + + if not self.bindparams: + param_tuple = tuple(parameters[key] for key in sorted(parameters)) + else: + param_tuple = tuple( + parameters.get(bindparam.key, bindparam.value) + for bindparam in self.bindparams + ) + + return repr((sql_str, param_tuple)) + + def __eq__(self, other): + return self.key == other.key + + @classmethod + def _diff_tuples(cls, left, right): + ck1 = CacheKey(left, []) + ck2 = CacheKey(right, []) + return ck1._diff(ck2) + + def _whats_different(self, other): + + k1 = self.key + k2 = other.key + + stack = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other): + return ", ".join(self._whats_different(other)) + + def __str__(self): + stack = [self.key] + + output = [] + sentinel = object() + indent = -1 + while stack: + elem = stack.pop(0) + if elem is sentinel: + output.append((" " * (indent * 2)) + "),") + indent -= 1 + elif isinstance(elem, tuple): + if not elem: + output.append((" " * ((indent + 1) * 2)) + "()") + else: + indent += 1 + stack = list(elem) + [sentinel] + stack + output.append((" " * (indent * 2)) + "(") + else: + if isinstance(elem, HasCacheKey): + repr_ = "<%s object at %s>" % ( + type(elem).__name__, + hex(id(elem)), + ) + else: + repr_ = repr(elem) + output.append((" " * (indent * 2)) + " " + repr_ + ", ") + + return "CacheKey(key=%s)" % ("\n".join(output),) + + def _generate_param_dict(self): + """used for testing""" + + from .compiler import prefix_anon_map + + _anon_map = prefix_anon_map() + return {b.key % _anon_map: b.effective_value for b in self.bindparams} + + def _apply_params_to_element(self, original_cache_key, target_element): + translate = { + k.key: v.value + for k, v in zip(original_cache_key.bindparams, self.bindparams) + } + + return target_element.params(translate) + + +class _CacheKeyTraversal(ExtendedInternalTraversal): + # very common elements are inlined into the main _get_cache_key() method + # to produce a dramatic savings in Python function call overhead + + visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY + visit_clauseelement_list = InternalTraversal.dp_clauseelement_list + visit_annotations_key = InternalTraversal.dp_annotations_key + visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) + + visit_string = ( + visit_boolean + ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE + visit_statement_hint_list = CACHE_IN_PLACE + visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME + + visit_propagate_attrs = PROPAGATE_ATTRS + + def visit_with_context_options( + self, attrname, obj, parent, anon_map, bindparams + ): + return tuple((fn.__code__, c_key) for fn, c_key in obj) + + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) + + def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + return tuple(obj) + + def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, + ) + + def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_executable_options( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + if elem._is_has_cache_key + ), + ) + + def visit_inspectable_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_fromclause_ordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), + ) + + def visit_clauseelement_unordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj.name) + + def visit_prefix_sequence( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + + return ( + attrname, + tuple( + [ + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ] + ), + ) + + def visit_setup_join_tuple( + self, attrname, obj, parent, anon_map, bindparams + ): + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + + def visit_table_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + + return ( + attrname, + tuple( + [ + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ] + ), + ) + + def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) + + def visit_dialect_options( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + dialect_name, + tuple( + [ + (key, obj[dialect_name][key]) + for key in sorted(obj[dialect_name]) + ] + ), + ) + for dialect_name in sorted(obj) + ), + ) + + def visit_string_clauseelement_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_fromclause_canonical_column_collection( + self, attrname, obj, parent, anon_map, bindparams + ): + # inlining into the internals of ColumnCollection + return ( + attrname, + tuple( + col._gen_cache_key(anon_map, bindparams) + for k, col in obj._collection + ), + ) + + def visit_unknown_structure( + self, attrname, obj, parent, anon_map, bindparams + ): + anon_map[NO_CACHE] = True + return () + + def visit_dml_ordered_values( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key, + value._gen_cache_key(anon_map, bindparams), + ) + for key, value in obj + ), + ) + + def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + # in py37 we can assume two dictionaries created in the same + # insert ordering will retain that sorting + return ( + attrname, + tuple( + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k, + obj[k]._gen_cache_key(anon_map, bindparams), + ) + for k in obj + ), + ) + + def visit_dml_multi_values( + self, attrname, obj, parent, anon_map, bindparams + ): + # multivalues are simply not cacheable right now + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal_visitor = _CacheKeyTraversal() diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 95697806e..fe2b498c8 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -14,7 +14,7 @@ from . import roles from . import visitors from .base import ExecutableOption from .base import Options -from .traversals import HasCacheKey +from .cache_key import HasCacheKey from .visitors import Visitable from .. import exc from .. import inspection diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 697550df4..cb10811c6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -503,7 +503,7 @@ class Compiled: return self.construct_params() -class TypeCompiler(metaclass=util.EnsureKWArgType): +class TypeCompiler(util.EnsureKWArg): """Produces DDL specification for TypeEngine objects.""" ensure_kwarg = r"visit_\w+" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 12282de05..a025cce35 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -29,10 +29,10 @@ from .base import HasMemoized from .base import Immutable from .base import NO_ARG from .base import SingletonConstant +from .cache_key import MemoizedHasCacheKey +from .cache_key import NO_CACHE from .coercions import _document_text_coercion from .traversals import HasCopyInternals -from .traversals import MemoizedHasCacheKey -from .traversals import NO_CACHE from .visitors import cloned_traverse from .visitors import InternalTraversal from .visitors import traverse diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2a3cd07d0..54f67b930 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -94,6 +94,7 @@ from .base import _from_objects from .base import _select_iterables from .base import ColumnCollection from .base import Executable +from .cache_key import CacheKey from .dml import Delete from .dml import Insert from .dml import Update @@ -173,7 +174,6 @@ from .selectable import TableValuedAlias from .selectable import TextAsFrom from .selectable import TextualSelect from .selectable import Values -from .traversals import CacheKey from .visitors import Visitable from ..util.langhelpers import public_factory diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index b7b9257b4..3b6da7175 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -37,8 +37,8 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias +from .type_api import TypeEngine from .visitors import InternalTraversal -from .visitors import TraversibleType from .. import util @@ -48,7 +48,7 @@ _registry = util.defaultdict(dict) def register_function(identifier, fn, package="_default"): """Associate a callable with a particular func. name. - This is normally called by _GenericMeta, but is also + This is normally called by GenericFunction, but is also available by itself so that a non-Function construct can be associated with the :data:`.func` accessor (i.e. CAST, EXTRACT). @@ -828,7 +828,11 @@ class Function(FunctionElement): ("type", InternalTraversal.dp_type), ] - type = sqltypes.NULLTYPE + name: str + + identifier: str + + type: TypeEngine = sqltypes.NULLTYPE """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -871,30 +875,7 @@ class Function(FunctionElement): ) -class _GenericMeta(TraversibleType): - def __init__(cls, clsname, bases, clsdict): - if annotation.Annotated not in cls.__mro__: - cls.name = name = clsdict.get("name", clsname) - cls.identifier = identifier = clsdict.get("identifier", name) - package = clsdict.pop("package", "_default") - # legacy - if "__return_type__" in clsdict: - cls.type = clsdict["__return_type__"] - - # Check _register attribute status - cls._register = getattr(cls, "_register", True) - - # Register the function if required - if cls._register: - register_function(identifier, cls, package) - else: - # Set _register to True to register child classes by default - cls._register = True - - super(_GenericMeta, cls).__init__(clsname, bases, clsdict) - - -class GenericFunction(Function, metaclass=_GenericMeta): +class GenericFunction(Function): """Define a 'generic' function. A generic function is a pre-established :class:`.Function` @@ -986,9 +967,34 @@ class GenericFunction(Function, metaclass=_GenericMeta): """ coerce_arguments = True - _register = False inherit_cache = True + name = "GenericFunction" + + def __init_subclass__(cls) -> None: + if annotation.Annotated not in cls.__mro__: + cls._register_generic_function(cls.__name__, cls.__dict__) + super().__init_subclass__() + + @classmethod + def _register_generic_function(cls, clsname, clsdict): + cls.name = name = clsdict.get("name", clsname) + cls.identifier = identifier = clsdict.get("identifier", name) + package = clsdict.get("package", "_default") + # legacy + if "__return_type__" in clsdict: + cls.type = clsdict["__return_type__"] + + # Check _register attribute status + cls._register = getattr(cls, "_register", True) + + # Register the function if required + if cls._register: + register_function(identifier, cls, package) + else: + # Set _register to True to register child classes by default + cls._register = True + def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: @@ -1006,6 +1012,7 @@ class GenericFunction(Function, metaclass=_GenericMeta): self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *parsed_args ).self_group() + self.type = sqltypes.to_instance( kwargs.pop("type_", None) or getattr(self, "type", None) ) diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 2387e551e..d71c85d60 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -11,6 +11,7 @@ import operator import types import weakref +from . import cache_key as _cache_key from . import coercions from . import elements from . import roles @@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement): else: parent_closure_cache_key = () - if parent_closure_cache_key is not traversals.NO_CACHE: + if parent_closure_cache_key is not _cache_key.NO_CACHE: anon_map = traversals.anon_map() cache_key = tuple( [ @@ -194,7 +195,7 @@ class LambdaElement(elements.ClauseElement): ] ) - if traversals.NO_CACHE not in anon_map: + if _cache_key.NO_CACHE not in anon_map: cache_key = parent_closure_cache_key + cache_key self.closure_cache_key = cache_key @@ -204,17 +205,17 @@ class LambdaElement(elements.ClauseElement): except KeyError: rec = None else: - cache_key = traversals.NO_CACHE + cache_key = _cache_key.NO_CACHE rec = None else: - cache_key = traversals.NO_CACHE + cache_key = _cache_key.NO_CACHE rec = None self.closure_cache_key = cache_key if rec is None: - if cache_key is not traversals.NO_CACHE: + if cache_key is not _cache_key.NO_CACHE: rec = AnalyzedFunction( tracker, self, apply_propagate_attrs, fn ) @@ -233,7 +234,7 @@ class LambdaElement(elements.ClauseElement): self._rec = rec - if cache_key is not traversals.NO_CACHE: + if cache_key is not _cache_key.NO_CACHE: if self.parent_lambda is not None: bindparams[:0] = self.parent_lambda._resolved_bindparams @@ -326,8 +327,8 @@ class LambdaElement(elements.ClauseElement): return expr def _gen_cache_key(self, anon_map, bindparams): - if self.closure_cache_key is traversals.NO_CACHE: - anon_map[traversals.NO_CACHE] = True + if self.closure_cache_key is _cache_key.NO_CACHE: + anon_map[_cache_key.NO_CACHE] = True return None cache_key = ( @@ -808,7 +809,7 @@ class AnalyzedCode: for tup_elem in opts.track_on[idx] ) - elif isinstance(elem, traversals.HasCacheKey): + elif isinstance(elem, _cache_key.HasCacheKey): def get(closure, opts, anon_map, bindparams): return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) @@ -834,7 +835,7 @@ class AnalyzedCode: """ - if isinstance(cell_contents, traversals.HasCacheKey): + if isinstance(cell_contents, _cache_key.HasCacheKey): def get(closure, opts, anon_map, bindparams): @@ -1166,7 +1167,7 @@ class PyWrapper(ColumnOperators): and not isinstance( # TODO: coverage where an ORM option or similar is here value, - traversals.HasCacheKey, + _cache_key.HasCacheKey, ) ): name = object.__getattribute__(self, "_name") diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 655d98b02..e674c4b74 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -18,6 +18,7 @@ import typing from typing import Type from typing import Union +from . import cache_key from . import coercions from . import operators from . import roles @@ -4300,7 +4301,7 @@ class _SelectFromElements: class _MemoizedSelectEntities( - traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible + cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible ): __visit_name__ = "memoized_select_entities" diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 05007eff1..d5506cda2 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -22,11 +22,11 @@ from . import roles from . import type_api from .base import NO_ARG from .base import SchemaEventTarget +from .cache_key import HasCacheKey from .elements import _NONE_NAME from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa -from .traversals import HasCacheKey from .traversals import InternalTraversal from .type_api import Emulated from .type_api import NativeForEmulated # noqa diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index b689fe578..2fa3a0408 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -1,32 +1,25 @@ +# sql/traversals.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + from collections import deque -from collections import namedtuple import collections.abc as collections_abc import itertools from itertools import zip_longest import operator from . import operators -from .visitors import ExtendedInternalTraversal +from .visitors import anon_map from .visitors import InternalTraversal from .. import util -from ..inspection import inspect -from ..util import HasMemoized - -try: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa -except ImportError: - from ._py_util import cache_anon_map as anon_map # noqa SKIP_TRAVERSE = util.symbol("skip_traverse") COMPARE_FAILED = False COMPARE_SUCCEEDED = True -NO_CACHE = util.symbol("no_cache") -CACHE_IN_PLACE = util.symbol("cache_in_place") -CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") -STATIC_CACHE_KEY = util.symbol("static_cache_key") -PROPAGATE_ATTRS = util.symbol("propagate_attrs") -ANON_NAME = util.symbol("anon_name") def compare(obj1, obj2, **kw): @@ -54,729 +47,10 @@ def _preconfigure_traversals(target_hierarchy): ) -class HasCacheKey: - """Mixin for objects which can produce a cache key. - - .. seealso:: - - :class:`.CacheKey` - - :ref:`sql_caching` - - """ - - _cache_key_traversal = NO_CACHE - - _is_has_cache_key = True - - _hierarchy_supports_caching = True - """private attribute which may be set to False to prevent the - inherit_cache warning from being emitted for a hierarchy of subclasses. - - Currently applies to the DDLElement hierarchy which does not implement - caching. - - """ - - inherit_cache = None - """Indicate if this :class:`.HasCacheKey` instance should make use of the - cache key generation scheme used by its immediate superclass. - - The attribute defaults to ``None``, which indicates that a construct has - not yet taken into account whether or not its appropriate for it to - participate in caching; this is functionally equivalent to setting the - value to ``False``, except that a warning is also emitted. - - This flag can be set to ``True`` on a particular class, if the SQL that - corresponds to the object does not change based on attributes which - are local to this class, and not its superclass. - - .. seealso:: - - :ref:`compilerext_caching` - General guideslines for setting the - :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user - defined SQL constructs. - - """ - - __slots__ = () - - @classmethod - def _generate_cache_attrs(cls): - """generate cache key dispatcher for a new class. - - This sets the _generated_cache_key_traversal attribute once called - so should only be called once per class. - - """ - inherit_cache = cls.__dict__.get("inherit_cache", None) - inherit = bool(inherit_cache) - - if inherit: - _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) - if _cache_key_traversal is None: - try: - _cache_key_traversal = cls._traverse_internals - except AttributeError: - cls._generated_cache_key_traversal = NO_CACHE - return NO_CACHE - - # TODO: wouldn't we instead get this from our superclass? - # also, our superclass may not have this yet, but in any case, - # we'd generate for the superclass that has it. this is a little - # more complicated, so for the moment this is a little less - # efficient on startup but simpler. - return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" - ) - else: - _cache_key_traversal = cls.__dict__.get( - "_cache_key_traversal", None - ) - if _cache_key_traversal is None: - _cache_key_traversal = cls.__dict__.get( - "_traverse_internals", None - ) - if _cache_key_traversal is None: - cls._generated_cache_key_traversal = NO_CACHE - if ( - inherit_cache is None - and cls._hierarchy_supports_caching - ): - util.warn( - "Class %s will not make use of SQL compilation " - "caching as it does not set the 'inherit_cache' " - "attribute to ``True``. This can have " - "significant performance implications including " - "some performance degradations in comparison to " - "prior SQLAlchemy versions. Set this attribute " - "to True if this object can make use of the cache " - "key generated by the superclass. Alternatively, " - "this attribute may be set to False which will " - "disable this warning." % (cls.__name__), - code="cprf", - ) - return NO_CACHE - - return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" - ) - - @util.preload_module("sqlalchemy.sql.elements") - def _gen_cache_key(self, anon_map, bindparams): - """return an optional cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the structures that would affect the SQL string or the type handlers - should result in a different cache key. - - If a structure cannot produce a useful cache key, the NO_CACHE - symbol should be added to the anon_map and the method should - return None. - - """ - - cls = self.__class__ - - id_, found = anon_map.get_anon(self) - if found: - return (id_, cls) - - try: - dispatcher = cls.__dict__["_generated_cache_key_traversal"] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = cls._generate_cache_attrs() - - if dispatcher is NO_CACHE: - anon_map[NO_CACHE] = True - return None - - result = (id_, cls) - - # inline of _cache_key_traversal_visitor.run_generated_dispatch() - - for attrname, obj, meth in dispatcher( - self, _cache_key_traversal_visitor - ): - if obj is not None: - # TODO: see if C code can help here as Python lacks an - # efficient switch construct - - if meth is STATIC_CACHE_KEY: - sck = obj._static_cache_key - if sck is NO_CACHE: - anon_map[NO_CACHE] = True - return None - result += (attrname, sck) - elif meth is ANON_NAME: - elements = util.preloaded.sql_elements - if isinstance(obj, elements._anonymous_label): - obj = obj.apply_map(anon_map) - result += (attrname, obj) - elif meth is CALL_GEN_CACHE_KEY: - result += ( - attrname, - obj._gen_cache_key(anon_map, bindparams), - ) - - # remaining cache functions are against - # Python tuples, dicts, lists, etc. so we can skip - # if they are empty - elif obj: - if meth is CACHE_IN_PLACE: - result += (attrname, obj) - elif meth is PROPAGATE_ATTRS: - result += ( - attrname, - obj["compile_state_plugin"], - obj["plugin_subject"]._gen_cache_key( - anon_map, bindparams - ) - if obj["plugin_subject"] - else None, - ) - elif meth is InternalTraversal.dp_annotations_key: - # obj is here is the _annotations dict. however, we - # want to use the memoized cache key version of it. for - # Columns, this should be long lived. For select() - # statements, not so much, but they usually won't have - # annotations. - result += self._annotations_cache_key - elif ( - meth is InternalTraversal.dp_clauseelement_list - or meth is InternalTraversal.dp_clauseelement_tuple - or meth - is InternalTraversal.dp_memoized_select_entities - ): - result += ( - attrname, - tuple( - [ - elem._gen_cache_key(anon_map, bindparams) - for elem in obj - ] - ), - ) - else: - result += meth( - attrname, obj, self, anon_map, bindparams - ) - return result - - def _generate_cache_key(self): - """return a cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the structures that would affect the SQL string or the type handlers - should result in a different cache key. - - The cache key returned by this method is an instance of - :class:`.CacheKey`, which consists of a tuple representing the - cache key, as well as a list of :class:`.BindParameter` objects - which are extracted from the expression. While two expressions - that produce identical cache key tuples will themselves generate - identical SQL strings, the list of :class:`.BindParameter` objects - indicates the bound values which may have different values in - each one; these bound parameters must be consulted in order to - execute the statement with the correct parameters. - - a :class:`_expression.ClauseElement` structure that does not implement - a :meth:`._gen_cache_key` method and does not implement a - :attr:`.traverse_internals` attribute will not be cacheable; when - such an element is embedded into a larger structure, this method - will return None, indicating no cache key is available. - - """ - - bindparams = [] - - _anon_map = anon_map() - key = self._gen_cache_key(_anon_map, bindparams) - if NO_CACHE in _anon_map: - return None - else: - return CacheKey(key, bindparams) - - @classmethod - def _generate_cache_key_for_object(cls, obj): - bindparams = [] - - _anon_map = anon_map() - key = obj._gen_cache_key(_anon_map, bindparams) - if NO_CACHE in _anon_map: - return None - else: - return CacheKey(key, bindparams) - - -class MemoizedHasCacheKey(HasCacheKey, HasMemoized): - @HasMemoized.memoized_instancemethod - def _generate_cache_key(self): - return HasCacheKey._generate_cache_key(self) - - -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): - """The key used to identify a SQL statement construct in the - SQL compilation cache. - - .. seealso:: - - :ref:`sql_caching` - - """ - - def __hash__(self): - """CacheKey itself is not hashable - hash the .key portion""" - - return None - - def to_offline_string(self, statement_cache, statement, parameters): - """Generate an "offline string" form of this :class:`.CacheKey` - - The "offline string" is basically the string SQL for the - statement plus a repr of the bound parameter values in series. - Whereas the :class:`.CacheKey` object is dependent on in-memory - identities in order to work as a cache key, the "offline" version - is suitable for a cache that will work for other processes as well. - - The given ``statement_cache`` is a dictionary-like object where the - string form of the statement itself will be cached. This dictionary - should be in a longer lived scope in order to reduce the time spent - stringifying statements. - - - """ - if self.key not in statement_cache: - statement_cache[self.key] = sql_str = str(statement) - else: - sql_str = statement_cache[self.key] - - if not self.bindparams: - param_tuple = tuple(parameters[key] for key in sorted(parameters)) - else: - param_tuple = tuple( - parameters.get(bindparam.key, bindparam.value) - for bindparam in self.bindparams - ) - - return repr((sql_str, param_tuple)) - - def __eq__(self, other): - return self.key == other.key - - @classmethod - def _diff_tuples(cls, left, right): - ck1 = CacheKey(left, []) - ck2 = CacheKey(right, []) - return ck1._diff(ck2) - - def _whats_different(self, other): - - k1 = self.key - k2 = other.key - - stack = [] - pickup_index = 0 - while True: - s1, s2 = k1, k2 - for idx in stack: - s1 = s1[idx] - s2 = s2[idx] - - for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): - if idx < pickup_index: - continue - if e1 != e2: - if isinstance(e1, tuple) and isinstance(e2, tuple): - stack.append(idx) - break - else: - yield "key%s[%d]: %s != %s" % ( - "".join("[%d]" % id_ for id_ in stack), - idx, - e1, - e2, - ) - else: - pickup_index = stack.pop(-1) - break - - def _diff(self, other): - return ", ".join(self._whats_different(other)) - - def __str__(self): - stack = [self.key] - - output = [] - sentinel = object() - indent = -1 - while stack: - elem = stack.pop(0) - if elem is sentinel: - output.append((" " * (indent * 2)) + "),") - indent -= 1 - elif isinstance(elem, tuple): - if not elem: - output.append((" " * ((indent + 1) * 2)) + "()") - else: - indent += 1 - stack = list(elem) + [sentinel] + stack - output.append((" " * (indent * 2)) + "(") - else: - if isinstance(elem, HasCacheKey): - repr_ = "<%s object at %s>" % ( - type(elem).__name__, - hex(id(elem)), - ) - else: - repr_ = repr(elem) - output.append((" " * (indent * 2)) + " " + repr_ + ", ") - - return "CacheKey(key=%s)" % ("\n".join(output),) - - def _generate_param_dict(self): - """used for testing""" - - from .compiler import prefix_anon_map - - _anon_map = prefix_anon_map() - return {b.key % _anon_map: b.effective_value for b in self.bindparams} - - def _apply_params_to_element(self, original_cache_key, target_element): - translate = { - k.key: v.value - for k, v in zip(original_cache_key.bindparams, self.bindparams) - } - - return target_element.params(translate) - - def _clone(element, **kw): return element._clone() -class _CacheKey(ExtendedInternalTraversal): - # very common elements are inlined into the main _get_cache_key() method - # to produce a dramatic savings in Python function call overhead - - visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY - visit_clauseelement_list = InternalTraversal.dp_clauseelement_list - visit_annotations_key = InternalTraversal.dp_annotations_key - visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple - visit_memoized_select_entities = ( - InternalTraversal.dp_memoized_select_entities - ) - - visit_string = ( - visit_boolean - ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE - visit_statement_hint_list = CACHE_IN_PLACE - visit_type = STATIC_CACHE_KEY - visit_anon_name = ANON_NAME - - visit_propagate_attrs = PROPAGATE_ATTRS - - def visit_with_context_options( - self, attrname, obj, parent, anon_map, bindparams - ): - return tuple((fn.__code__, c_key) for fn, c_key in obj) - - def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): - return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) - - def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): - return tuple(obj) - - def visit_multi(self, attrname, obj, parent, anon_map, bindparams): - return ( - attrname, - obj._gen_cache_key(anon_map, bindparams) - if isinstance(obj, HasCacheKey) - else obj, - ) - - def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): - return ( - attrname, - tuple( - elem._gen_cache_key(anon_map, bindparams) - if isinstance(elem, HasCacheKey) - else elem - for elem in obj - ), - ) - - def visit_has_cache_key_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple( - tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in tup_elem - ) - for tup_elem in obj - ), - ) - - def visit_has_cache_key_list( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), - ) - - def visit_executable_options( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in obj - if elem._is_has_cache_key - ), - ) - - def visit_inspectable_list( - self, attrname, obj, parent, anon_map, bindparams - ): - return self.visit_has_cache_key_list( - attrname, [inspect(o) for o in obj], parent, anon_map, bindparams - ) - - def visit_clauseelement_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): - return self.visit_has_cache_key_tuples( - attrname, obj, parent, anon_map, bindparams - ) - - def visit_fromclause_ordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), - ) - - def visit_clauseelement_unordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - cache_keys = [ - elem._gen_cache_key(anon_map, bindparams) for elem in obj - ] - return ( - attrname, - tuple( - sorted(cache_keys) - ), # cache keys all start with (id_, class) - ) - - def visit_named_ddl_element( - self, attrname, obj, parent, anon_map, bindparams - ): - return (attrname, obj.name) - - def visit_prefix_sequence( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - - return ( - attrname, - tuple( - [ - (clause._gen_cache_key(anon_map, bindparams), strval) - for clause, strval in obj - ] - ), - ) - - def visit_setup_join_tuple( - self, attrname, obj, parent, anon_map, bindparams - ): - is_legacy = "legacy" in attrname - - return tuple( - ( - target - if is_legacy and isinstance(target, str) - else target._gen_cache_key(anon_map, bindparams), - onclause - if is_legacy and isinstance(onclause, str) - else onclause._gen_cache_key(anon_map, bindparams) - if onclause is not None - else None, - from_._gen_cache_key(anon_map, bindparams) - if from_ is not None - else None, - tuple([(key, flags[key]) for key in sorted(flags)]), - ) - for (target, onclause, from_, flags) in obj - ) - - def visit_table_hint_list( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - - return ( - attrname, - tuple( - [ - ( - clause._gen_cache_key(anon_map, bindparams), - dialect_name, - text, - ) - for (clause, dialect_name), text in obj.items() - ] - ), - ) - - def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): - return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) - - def visit_dialect_options( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - dialect_name, - tuple( - [ - (key, obj[dialect_name][key]) - for key in sorted(obj[dialect_name]) - ] - ), - ) - for dialect_name in sorted(obj) - ), - ) - - def visit_string_clauseelement_dict( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - (key, obj[key]._gen_cache_key(anon_map, bindparams)) - for key in sorted(obj) - ), - ) - - def visit_string_multi_dict( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - key, - value._gen_cache_key(anon_map, bindparams) - if isinstance(value, HasCacheKey) - else value, - ) - for key, value in [(key, obj[key]) for key in sorted(obj)] - ), - ) - - def visit_fromclause_canonical_column_collection( - self, attrname, obj, parent, anon_map, bindparams - ): - # inlining into the internals of ColumnCollection - return ( - attrname, - tuple( - col._gen_cache_key(anon_map, bindparams) - for k, col in obj._collection - ), - ) - - def visit_unknown_structure( - self, attrname, obj, parent, anon_map, bindparams - ): - anon_map[NO_CACHE] = True - return () - - def visit_dml_ordered_values( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - key._gen_cache_key(anon_map, bindparams) - if hasattr(key, "__clause_element__") - else key, - value._gen_cache_key(anon_map, bindparams), - ) - for key, value in obj - ), - ) - - def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): - # in py37 we can assume two dictionaries created in the same - # insert ordering will retain that sorting - return ( - attrname, - tuple( - ( - k._gen_cache_key(anon_map, bindparams) - if hasattr(k, "__clause_element__") - else k, - obj[k]._gen_cache_key(anon_map, bindparams), - ) - for k in obj - ), - ) - - def visit_dml_multi_values( - self, attrname, obj, parent, anon_map, bindparams - ): - # multivalues are simply not cacheable right now - anon_map[NO_CACHE] = True - return () - - -_cache_key_traversal_visitor = _CacheKey() - - class HasCopyInternals: __slots__ = () @@ -813,7 +87,7 @@ class HasCopyInternals: setattr(self, attrname, result) -class _CopyInternals(InternalTraversal): +class _CopyInternalsTraversal(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -936,7 +210,7 @@ class _CopyInternals(InternalTraversal): return element -_copy_internals = _CopyInternals() +_copy_internals = _CopyInternalsTraversal() def _flatten_clauseelement(element): @@ -948,7 +222,7 @@ def _flatten_clauseelement(element): return element -class _GetChildren(InternalTraversal): +class _GetChildrenTraversal(InternalTraversal): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -1019,7 +293,7 @@ class _GetChildren(InternalTraversal): return () -_get_children = _GetChildren() +_get_children = _GetChildrenTraversal() @util.preload_module("sqlalchemy.sql.elements") diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9dd702410..55289cb85 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -13,9 +13,8 @@ import typing from . import operators from .base import SchemaEventTarget -from .traversals import NO_CACHE +from .cache_key import NO_CACHE from .visitors import Traversible -from .visitors import TraversibleType from .. import exc from .. import util @@ -858,10 +857,6 @@ class TypeEngine(Traversible): return util.generic_repr(self) -class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): - pass - - class ExternalType: """mixin that defines attributes and behaviors specific to third-party datatypes. @@ -1038,7 +1033,7 @@ class ExternalType: return NO_CACHE -class UserDefinedType(ExternalType, TypeEngine, metaclass=VisitableCheckKWArg): +class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg): """Base for user defined types. This should be the base of new types. Note that diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 69e83e46a..63067585e 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -22,6 +22,7 @@ from .annotation import _shallow_annotate # noqa from .base import _expand_cloned from .base import _from_objects from .base import ColumnSet +from .cache_key import HasCacheKey # noqa from .ddl import sort_tables # noqa from .elements import _find_columns # noqa from .elements import _label_reference @@ -41,7 +42,6 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause -from .traversals import HasCacheKey # noqa from .. import exc from .. import util diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 87fe36944..70c4dc133 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -32,6 +32,11 @@ from .. import util from ..util import langhelpers from ..util import symbol +try: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa +except ImportError: + from ._py_util import cache_anon_map as anon_map # noqa + __all__ = [ "iterate", "traverse_using", @@ -39,88 +44,77 @@ __all__ = [ "cloned_traverse", "replacement_traverse", "Traversible", - "TraversibleType", "ExternalTraversal", "InternalTraversal", ] -def _generate_compiler_dispatch(cls): - """Generate a _compiler_dispatch() external traversal on classes with a - __visit_name__ attribute. - - """ - visit_name = cls.__visit_name__ - - if "_compiler_dispatch" in cls.__dict__: - # class has a fixed _compiler_dispatch() method. - # copy it to "original" so that we can get it back if - # sqlalchemy.ext.compiles overrides it. - cls._original_compiler_dispatch = cls._compiler_dispatch - return - - if not isinstance(visit_name, str): - raise exc.InvalidRequestError( - "__visit_name__ on class %s must be a string at the class level" - % cls.__name__ - ) - - name = "visit_%s" % visit_name - getter = operator.attrgetter(name) - - def _compiler_dispatch(self, visitor, **kw): - """Look for an attribute named "visit_<visit_name>" on the - visitor, and call it with the same kw params. - - """ - try: - meth = getter(visitor) - except AttributeError as err: - return visitor.visit_unsupported_compilation(self, err, **kw) - - else: - return meth(self, **kw) - - cls._compiler_dispatch = ( - cls._original_compiler_dispatch - ) = _compiler_dispatch +class Traversible: + """Base class for visitable objects.""" + __slots__ = () -class TraversibleType(type): - """Metaclass which assigns dispatch attributes to various kinds of - "visitable" classes. + __visit_name__: str - Attributes include: + def __init_subclass__(cls) -> None: + if "__visit_name__" in cls.__dict__: + cls._generate_compiler_dispatch() + super().__init_subclass__() - * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``. - This is called "external traversal" because the caller of each visit() - method is responsible for sub-traversing the inner elements of each - object. This is appropriate for string compilers and other traversals - that need to call upon the inner elements in a specific pattern. + @classmethod + def _generate_compiler_dispatch(cls): + """Assign dispatch attributes to various kinds of + "visitable" classes. - * internal traversal collections ``_children_traversal``, - ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from - an optional ``_traverse_internals`` collection of symbols which comes - from the :class:`.InternalTraversal` list of symbols. This is called - "internal traversal" MARKMARK + Attributes include: - """ + * The ``_compiler_dispatch`` method, corresponding to + ``__visit_name__``. This is called "external traversal" because the + caller of each visit() method is responsible for sub-traversing the + inner elements of each object. This is appropriate for string + compilers and other traversals that need to call upon the inner + elements in a specific pattern. - def __init__(cls, clsname, bases, clsdict): - if clsname != "Traversible": - if "__visit_name__" in clsdict: - _generate_compiler_dispatch(cls) + * internal traversal collections ``_children_traversal``, + ``_cache_key_traversal``, ``_copy_internals_traversal``, generated + from an optional ``_traverse_internals`` collection of symbols which + comes from the :class:`.InternalTraversal` list of symbols. This is + called "internal traversal". - super(TraversibleType, cls).__init__(clsname, bases, clsdict) + """ + visit_name = cls.__visit_name__ + + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + + if not isinstance(visit_name, str): + raise exc.InvalidRequestError( + f"__visit_name__ on class {cls.__name__} must be a string " + "at the class level" + ) + name = "visit_%s" % visit_name + getter = operator.attrgetter(name) -class Traversible(metaclass=TraversibleType): - """Base class for visitable objects, applies the - :class:`.visitors.TraversibleType` metaclass. + def _compiler_dispatch(self, visitor, **kw): + """Look for an attribute named "visit_<visit_name>" on the + visitor, and call it with the same kw params. - """ + """ + try: + meth = getter(visitor) + except AttributeError as err: + return visitor.visit_unsupported_compilation(self, err, **kw) + else: + return meth(self, **kw) - __slots__ = () + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch def __class_getitem__(cls, key): # allow generic classes in py3.9+ @@ -159,48 +153,90 @@ class Traversible(metaclass=TraversibleType): ) -class _InternalTraversalType(type): - def __init__(cls, clsname, bases, clsdict): - if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"): - lookup = {} - for key, sym in clsdict.items(): - if key.startswith("dp_"): - visit_key = key.replace("dp_", "visit_") - sym_name = sym.name - assert sym_name not in lookup, sym_name - lookup[sym] = lookup[sym_name] = visit_key - if hasattr(cls, "_dispatch_lookup"): - lookup.update(cls._dispatch_lookup) - cls._dispatch_lookup = lookup +class _HasTraversalDispatch: + r"""Define infrastructure for the :class:`.InternalTraversal` class. - super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict) + .. versionadded:: 2.0 + """ -def _generate_dispatcher(visitor, internal_dispatch, method_name): - names = [] - for attrname, visit_sym in internal_dispatch: - meth = visitor.dispatch(visit_sym) - if meth: - visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym] - names.append((attrname, visit_name)) - - code = ( - (" return [\n") - + ( - ", \n".join( - " (%r, self.%s, visitor.%s)" - % (attrname, attrname, visit_name) - for attrname, visit_name in names + def __init_subclass__(cls) -> None: + cls._generate_traversal_dispatch() + super().__init_subclass__() + + def dispatch(self, visit_symbol): + """Given a method from :class:`._HasTraversalDispatch`, return the + corresponding method on a subclass. + + """ + name = self._dispatch_lookup[visit_symbol] + return getattr(self, name, None) + + def run_generated_dispatch( + self, target, internal_dispatch, generate_dispatcher_name + ): + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = self.generate_dispatch( + target.__class__, internal_dispatch, generate_dispatcher_name + ) + return dispatcher(target, self) + + def generate_dispatch( + self, target_cls, internal_dispatch, generate_dispatcher_name + ): + dispatcher = self._generate_dispatcher( + internal_dispatch, generate_dispatcher_name + ) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) + return dispatcher + + @classmethod + def _generate_traversal_dispatch(cls): + lookup = {} + clsdict = cls.__dict__ + for key, sym in clsdict.items(): + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.name + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + if hasattr(cls, "_dispatch_lookup"): + lookup.update(cls._dispatch_lookup) + cls._dispatch_lookup = lookup + + def _generate_dispatcher(self, internal_dispatch, method_name): + names = [] + for attrname, visit_sym in internal_dispatch: + meth = self.dispatch(visit_sym) + if meth: + visit_name = ExtendedInternalTraversal._dispatch_lookup[ + visit_sym + ] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) ) + + ("\n ]\n") ) - + ("\n ]\n") - ) - meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" - # print(meth_text) - return langhelpers._exec_code_in_env(meth_text, {}, method_name) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + return langhelpers._exec_code_in_env(meth_text, {}, method_name) -class InternalTraversal(metaclass=_InternalTraversalType): +class InternalTraversal(_HasTraversalDispatch): r"""Defines visitor symbols used for internal traversal. The :class:`.InternalTraversal` class is used in two ways. One is that @@ -239,39 +275,6 @@ class InternalTraversal(metaclass=_InternalTraversalType): """ - def dispatch(self, visit_symbol): - """Given a method from :class:`.InternalTraversal`, return the - corresponding method on a subclass. - - """ - name = self._dispatch_lookup[visit_symbol] - return getattr(self, name, None) - - def run_generated_dispatch( - self, target, internal_dispatch, generate_dispatcher_name - ): - try: - dispatcher = target.__class__.__dict__[generate_dispatcher_name] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = self.generate_dispatch( - target.__class__, internal_dispatch, generate_dispatcher_name - ) - return dispatcher(target, self) - - def generate_dispatch( - self, target_cls, internal_dispatch, generate_dispatcher_name - ): - dispatcher = _generate_dispatcher( - self, internal_dispatch, generate_dispatcher_name - ) - # assert isinstance(target_cls, type) - setattr(target_cls, generate_dispatcher_name, dispatcher) - return dispatcher - dp_has_cache_key = symbol("HC") """Visit a :class:`.HasCacheKey` object.""" @@ -623,7 +626,6 @@ class ReplacingExternalTraversal(CloningExternalTraversal): # backwards compatibility Visitable = Traversible -VisitableType = TraversibleType ClauseVisitor = ExternalTraversal CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 74c86e85a..ecc20f163 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -21,7 +21,6 @@ from .. import event from .. import util from ..orm import declarative_base from ..orm import registry -from ..orm.decl_api import DeclarativeMeta from ..schema import sort_tables_and_constraints @@ -647,15 +646,11 @@ class MappedTest(TablesTest, assertions.AssertsExecutionResults): """ cls_registry = cls.classes - assert cls_registry is not None - - class FindFixture(type): - def __init__(cls, classname, bases, dict_): - cls_registry[classname] = cls - type.__init__(cls, classname, bases, dict_) - - class _Base(metaclass=FindFixture): - pass + class _Base: + def __init_subclass__(cls) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__() class Basic(BasicEntity, _Base): pass @@ -699,17 +694,16 @@ class DeclarativeMappedTest(MappedTest): def _with_register_classes(cls, fn): cls_registry = cls.classes - class FindFixtureDeclarative(DeclarativeMeta): - def __init__(cls, classname, bases, dict_): - cls_registry[classname] = cls - DeclarativeMeta.__init__(cls, classname, bases, dict_) - class DeclarativeBasic: __table_cls__ = schema.Table + def __init_subclass__(cls) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__() + _DeclBase = declarative_base( metadata=cls._tables_metadata, - metaclass=FindFixtureDeclarative, cls=DeclarativeBasic, ) diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 2d2ff3565..7c03bcd4b 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -98,7 +98,7 @@ from .langhelpers import decorator from .langhelpers import dictlike_iteritems from .langhelpers import duck_type_collection from .langhelpers import ellipses_string -from .langhelpers import EnsureKWArgType +from .langhelpers import EnsureKWArg from .langhelpers import format_argspec_init from .langhelpers import format_argspec_plus from .langhelpers import generic_repr diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 1b277cdee..66c530867 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1743,14 +1743,28 @@ def attrsetter(attrname): return env["set"] -class EnsureKWArgType(type): +class EnsureKWArg: r"""Apply translation of functions to accept \**kw arguments if they don't already. + Used to ensure cross-compatibility with third party legacy code, for things + like compiler visit methods that need to accept ``**kw`` arguments, + but may have been copied from old code that didn't accept them. + + """ + + ensure_kwarg: str + """a regular expression that indicates method names for which the method + should accept ``**kw`` arguments. + + The class will scan for methods matching the name template and decorate + them if necessary to ensure ``**kw`` parameters are accepted. + """ - def __init__(cls, clsname, bases, clsdict): + def __init_subclass__(cls) -> None: fn_reg = cls.ensure_kwarg + clsdict = cls.__dict__ if fn_reg: for key in clsdict: m = re.match(fn_reg, key) @@ -1758,11 +1772,12 @@ class EnsureKWArgType(type): fn = clsdict[key] spec = compat.inspect_getfullargspec(fn) if not spec.varkw: - clsdict[key] = wrapped = cls._wrap_w_kw(fn) + wrapped = cls._wrap_w_kw(fn) setattr(cls, key, wrapped) - super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict) + super().__init_subclass__() - def _wrap_w_kw(self, fn): + @classmethod + def _wrap_w_kw(cls, fn): def wrap(*arg, **kw): return fn(*arg) diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 92ef241ed..37b4d0ae1 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -32,7 +32,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import subqueryload from sqlalchemy.orm import UserDefinedOption -from sqlalchemy.sql.traversals import NO_CACHE +from sqlalchemy.sql.cache_key import NO_CACHE from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index bbf9716f5..43aed0672 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -18,7 +18,7 @@ from sqlalchemy.sql import select from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.base import ExecutableOption -from sqlalchemy.sql.traversals import HasCacheKey +from sqlalchemy.sql.cache_key import HasCacheKey from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ |
