diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-01-11 19:02:22 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-01-11 19:02:22 +0000 |
| commit | a869dc8fe3cd579ed9bab665d215a6c3e3d8a4ca (patch) | |
| tree | fd4e7eb4dac959d80599179d1ac9d07f5ff3f957 /lib/sqlalchemy/sql | |
| parent | 71b6425db3517fc6194a349e6cc5abea851c7f35 (diff) | |
| parent | 3a23e8ed29180e914883a263ec83373ecbd02efa (diff) | |
| download | sqlalchemy-a869dc8fe3cd579ed9bab665d215a6c3e3d8a4ca.tar.gz | |
Merge "remove internal use of metaclasses" into main
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 34 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/cache_key.py | 762 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 63 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/lambdas.py | 23 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 750 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 270 |
14 files changed, 988 insertions, 940 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 805f7b1a0..6ab9a75f6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -20,9 +20,9 @@ import typing from . import roles from . import visitors -from .traversals import HasCacheKey # noqa +from .cache_key import HasCacheKey # noqa +from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa -from .traversals import MemoizedHasCacheKey # noqa from .visitors import ClauseVisitor from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal @@ -37,7 +37,6 @@ try: except ImportError: from ._py_util import prefix_anon_map # noqa - coercions = None elements = None type_api = None @@ -610,18 +609,13 @@ class HasCompileState(Generative): class _MetaOptions(type): - """metaclass for the Options class.""" + """metaclass for the Options class. - def __init__(cls, classname, bases, dict_): - cls._cache_attrs = tuple( - sorted( - d - for d in dict_ - if not d.startswith("__") - and d not in ("_cache_key_traversal",) - ) - ) - type.__init__(cls, classname, bases, dict_) + This metaclass is actually necessary despite the availability of the + ``__init_subclass__()`` hook as this type also provides custom class-level + behavior for the ``__add__()`` method. + + """ def __add__(self, other): o1 = self() @@ -640,6 +634,18 @@ class _MetaOptions(type): class Options(metaclass=_MetaOptions): """A cacheable option dictionary with defaults.""" + def __init_subclass__(cls) -> None: + dict_ = cls.__dict__ + cls._cache_attrs = tuple( + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) + ) + super().__init_subclass__() + def __init__(self, **kw): self.__dict__.update(kw) diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py new file mode 100644 index 000000000..8dd44dbf0 --- /dev/null +++ b/lib/sqlalchemy/sql/cache_key.py @@ -0,0 +1,762 @@ +# sql/cache_key.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from collections import namedtuple +import enum +from itertools import zip_longest +from typing import Callable +from typing import Union + +from .visitors import anon_map +from .visitors import ExtendedInternalTraversal +from .visitors import InternalTraversal +from .. import util +from ..inspection import inspect +from ..util import HasMemoized +from ..util.typing import Literal + + +class CacheConst(enum.Enum): + NO_CACHE = 0 + + +NO_CACHE = CacheConst.NO_CACHE + + +class CacheTraverseTarget(enum.Enum): + CACHE_IN_PLACE = 0 + CALL_GEN_CACHE_KEY = 1 + STATIC_CACHE_KEY = 2 + PROPAGATE_ATTRS = 3 + ANON_NAME = 4 + + +( + CACHE_IN_PLACE, + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + PROPAGATE_ATTRS, + ANON_NAME, +) = tuple(CacheTraverseTarget) + + +class HasCacheKey: + """Mixin for objects which can produce a cache key. + + .. seealso:: + + :class:`.CacheKey` + + :ref:`sql_caching` + + """ + + _cache_key_traversal = NO_CACHE + + _is_has_cache_key = True + + _hierarchy_supports_caching = True + """private attribute which may be set to False to prevent the + inherit_cache warning from being emitted for a hierarchy of subclasses. + + Currently applies to the DDLElement hierarchy which does not implement + caching. + + """ + + inherit_cache = None + """Indicate if this :class:`.HasCacheKey` instance should make use of the + cache key generation scheme used by its immediate superclass. + + The attribute defaults to ``None``, which indicates that a construct has + not yet taken into account whether or not its appropriate for it to + participate in caching; this is functionally equivalent to setting the + value to ``False``, except that a warning is also emitted. + + This flag can be set to ``True`` on a particular class, if the SQL that + corresponds to the object does not change based on attributes which + are local to this class, and not its superclass. + + .. seealso:: + + :ref:`compilerext_caching` - General guideslines for setting the + :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user + defined SQL constructs. + + """ + + __slots__ = () + + @classmethod + def _generate_cache_attrs(cls): + """generate cache key dispatcher for a new class. + + This sets the _generated_cache_key_traversal attribute once called + so should only be called once per class. + + """ + inherit_cache = cls.__dict__.get("inherit_cache", None) + inherit = bool(inherit_cache) + + if inherit: + _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) + if _cache_key_traversal is None: + try: + _cache_key_traversal = cls._traverse_internals + except AttributeError: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + # TODO: wouldn't we instead get this from our superclass? + # also, our superclass may not have this yet, but in any case, + # we'd generate for the superclass that has it. this is a little + # more complicated, so for the moment this is a little less + # efficient on startup but simpler. + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + else: + _cache_key_traversal = cls.__dict__.get( + "_cache_key_traversal", None + ) + if _cache_key_traversal is None: + _cache_key_traversal = cls.__dict__.get( + "_traverse_internals", None + ) + if _cache_key_traversal is None: + cls._generated_cache_key_traversal = NO_CACHE + if ( + inherit_cache is None + and cls._hierarchy_supports_caching + ): + util.warn( + "Class %s will not make use of SQL compilation " + "caching as it does not set the 'inherit_cache' " + "attribute to ``True``. This can have " + "significant performance implications including " + "some performance degradations in comparison to " + "prior SQLAlchemy versions. Set this attribute " + "to True if this object can make use of the cache " + "key generated by the superclass. Alternatively, " + "this attribute may be set to False which will " + "disable this warning." % (cls.__name__), + code="cprf", + ) + return NO_CACHE + + return _cache_key_traversal_visitor.generate_dispatch( + cls, _cache_key_traversal, "_generated_cache_key_traversal" + ) + + @util.preload_module("sqlalchemy.sql.elements") + def _gen_cache_key(self, anon_map, bindparams): + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, the NO_CACHE + symbol should be added to the anon_map and the method should + return None. + + """ + + cls = self.__class__ + + id_, found = anon_map.get_anon(self) + if found: + return (id_, cls) + + dispatcher: Union[ + Literal[CacheConst.NO_CACHE], + Callable[[HasCacheKey, "_CacheKeyTraversal"], "CacheKey"], + ] + + try: + dispatcher = cls.__dict__["_generated_cache_key_traversal"] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = cls._generate_cache_attrs() + + if dispatcher is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result = (id_, cls) + + # inline of _cache_key_traversal_visitor.run_generated_dispatch() + + for attrname, obj, meth in dispatcher( + self, _cache_key_traversal_visitor + ): + if obj is not None: + # TODO: see if C code can help here as Python lacks an + # efficient switch construct + + if meth is STATIC_CACHE_KEY: + sck = obj._static_cache_key + if sck is NO_CACHE: + anon_map[NO_CACHE] = True + return None + result += (attrname, sck) + elif meth is ANON_NAME: + elements = util.preloaded.sql_elements + if isinstance(obj, elements._anonymous_label): + obj = obj.apply_map(anon_map) + result += (attrname, obj) + elif meth is CALL_GEN_CACHE_KEY: + result += ( + attrname, + obj._gen_cache_key(anon_map, bindparams), + ) + + # remaining cache functions are against + # Python tuples, dicts, lists, etc. so we can skip + # if they are empty + elif obj: + if meth is CACHE_IN_PLACE: + result += (attrname, obj) + elif meth is PROPAGATE_ATTRS: + result += ( + attrname, + obj["compile_state_plugin"], + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None, + ) + elif meth is InternalTraversal.dp_annotations_key: + # obj is here is the _annotations dict. however, we + # want to use the memoized cache key version of it. for + # Columns, this should be long lived. For select() + # statements, not so much, but they usually won't have + # annotations. + result += self._annotations_cache_key + elif ( + meth is InternalTraversal.dp_clauseelement_list + or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities + ): + result += ( + attrname, + tuple( + [ + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + ] + ), + ) + else: + result += meth( + attrname, obj, self, anon_map, bindparams + ) + return result + + def _generate_cache_key(self): + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`_expression.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + + bindparams = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + @classmethod + def _generate_cache_key_for_object(cls, obj): + bindparams = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key(self) + + +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): + """The key used to identify a SQL statement construct in the + SQL compilation cache. + + .. seealso:: + + :ref:`sql_caching` + + """ + + def __hash__(self): + """CacheKey itself is not hashable - hash the .key portion""" + + return None + + def to_offline_string(self, statement_cache, statement, parameters): + """Generate an "offline string" form of this :class:`.CacheKey` + + The "offline string" is basically the string SQL for the + statement plus a repr of the bound parameter values in series. + Whereas the :class:`.CacheKey` object is dependent on in-memory + identities in order to work as a cache key, the "offline" version + is suitable for a cache that will work for other processes as well. + + The given ``statement_cache`` is a dictionary-like object where the + string form of the statement itself will be cached. This dictionary + should be in a longer lived scope in order to reduce the time spent + stringifying statements. + + + """ + if self.key not in statement_cache: + statement_cache[self.key] = sql_str = str(statement) + else: + sql_str = statement_cache[self.key] + + if not self.bindparams: + param_tuple = tuple(parameters[key] for key in sorted(parameters)) + else: + param_tuple = tuple( + parameters.get(bindparam.key, bindparam.value) + for bindparam in self.bindparams + ) + + return repr((sql_str, param_tuple)) + + def __eq__(self, other): + return self.key == other.key + + @classmethod + def _diff_tuples(cls, left, right): + ck1 = CacheKey(left, []) + ck2 = CacheKey(right, []) + return ck1._diff(ck2) + + def _whats_different(self, other): + + k1 = self.key + k2 = other.key + + stack = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other): + return ", ".join(self._whats_different(other)) + + def __str__(self): + stack = [self.key] + + output = [] + sentinel = object() + indent = -1 + while stack: + elem = stack.pop(0) + if elem is sentinel: + output.append((" " * (indent * 2)) + "),") + indent -= 1 + elif isinstance(elem, tuple): + if not elem: + output.append((" " * ((indent + 1) * 2)) + "()") + else: + indent += 1 + stack = list(elem) + [sentinel] + stack + output.append((" " * (indent * 2)) + "(") + else: + if isinstance(elem, HasCacheKey): + repr_ = "<%s object at %s>" % ( + type(elem).__name__, + hex(id(elem)), + ) + else: + repr_ = repr(elem) + output.append((" " * (indent * 2)) + " " + repr_ + ", ") + + return "CacheKey(key=%s)" % ("\n".join(output),) + + def _generate_param_dict(self): + """used for testing""" + + from .compiler import prefix_anon_map + + _anon_map = prefix_anon_map() + return {b.key % _anon_map: b.effective_value for b in self.bindparams} + + def _apply_params_to_element(self, original_cache_key, target_element): + translate = { + k.key: v.value + for k, v in zip(original_cache_key.bindparams, self.bindparams) + } + + return target_element.params(translate) + + +class _CacheKeyTraversal(ExtendedInternalTraversal): + # very common elements are inlined into the main _get_cache_key() method + # to produce a dramatic savings in Python function call overhead + + visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY + visit_clauseelement_list = InternalTraversal.dp_clauseelement_list + visit_annotations_key = InternalTraversal.dp_annotations_key + visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) + + visit_string = ( + visit_boolean + ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE + visit_statement_hint_list = CACHE_IN_PLACE + visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME + + visit_propagate_attrs = PROPAGATE_ATTRS + + def visit_with_context_options( + self, attrname, obj, parent, anon_map, bindparams + ): + return tuple((fn.__code__, c_key) for fn, c_key in obj) + + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) + + def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): + return tuple(obj) + + def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, + ) + + def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_executable_options( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + if elem._is_has_cache_key + ), + ) + + def visit_inspectable_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_fromclause_ordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + return ( + attrname, + tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), + ) + + def visit_clauseelement_unordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj.name) + + def visit_prefix_sequence( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + + return ( + attrname, + tuple( + [ + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ] + ), + ) + + def visit_setup_join_tuple( + self, attrname, obj, parent, anon_map, bindparams + ): + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + + def visit_table_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + if not obj: + return () + + return ( + attrname, + tuple( + [ + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ] + ), + ) + + def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) + + def visit_dialect_options( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + dialect_name, + tuple( + [ + (key, obj[dialect_name][key]) + for key in sorted(obj[dialect_name]) + ] + ), + ) + for dialect_name in sorted(obj) + ), + ) + + def visit_string_clauseelement_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_fromclause_canonical_column_collection( + self, attrname, obj, parent, anon_map, bindparams + ): + # inlining into the internals of ColumnCollection + return ( + attrname, + tuple( + col._gen_cache_key(anon_map, bindparams) + for k, col in obj._collection + ), + ) + + def visit_unknown_structure( + self, attrname, obj, parent, anon_map, bindparams + ): + anon_map[NO_CACHE] = True + return () + + def visit_dml_ordered_values( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key, + value._gen_cache_key(anon_map, bindparams), + ) + for key, value in obj + ), + ) + + def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + # in py37 we can assume two dictionaries created in the same + # insert ordering will retain that sorting + return ( + attrname, + tuple( + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k, + obj[k]._gen_cache_key(anon_map, bindparams), + ) + for k in obj + ), + ) + + def visit_dml_multi_values( + self, attrname, obj, parent, anon_map, bindparams + ): + # multivalues are simply not cacheable right now + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal_visitor = _CacheKeyTraversal() diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 95697806e..fe2b498c8 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -14,7 +14,7 @@ from . import roles from . import visitors from .base import ExecutableOption from .base import Options -from .traversals import HasCacheKey +from .cache_key import HasCacheKey from .visitors import Visitable from .. import exc from .. import inspection diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 697550df4..cb10811c6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -503,7 +503,7 @@ class Compiled: return self.construct_params() -class TypeCompiler(metaclass=util.EnsureKWArgType): +class TypeCompiler(util.EnsureKWArg): """Produces DDL specification for TypeEngine objects.""" ensure_kwarg = r"visit_\w+" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 12282de05..a025cce35 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -29,10 +29,10 @@ from .base import HasMemoized from .base import Immutable from .base import NO_ARG from .base import SingletonConstant +from .cache_key import MemoizedHasCacheKey +from .cache_key import NO_CACHE from .coercions import _document_text_coercion from .traversals import HasCopyInternals -from .traversals import MemoizedHasCacheKey -from .traversals import NO_CACHE from .visitors import cloned_traverse from .visitors import InternalTraversal from .visitors import traverse diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2a3cd07d0..54f67b930 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -94,6 +94,7 @@ from .base import _from_objects from .base import _select_iterables from .base import ColumnCollection from .base import Executable +from .cache_key import CacheKey from .dml import Delete from .dml import Insert from .dml import Update @@ -173,7 +174,6 @@ from .selectable import TableValuedAlias from .selectable import TextAsFrom from .selectable import TextualSelect from .selectable import Values -from .traversals import CacheKey from .visitors import Visitable from ..util.langhelpers import public_factory diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index b7b9257b4..3b6da7175 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -37,8 +37,8 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias +from .type_api import TypeEngine from .visitors import InternalTraversal -from .visitors import TraversibleType from .. import util @@ -48,7 +48,7 @@ _registry = util.defaultdict(dict) def register_function(identifier, fn, package="_default"): """Associate a callable with a particular func. name. - This is normally called by _GenericMeta, but is also + This is normally called by GenericFunction, but is also available by itself so that a non-Function construct can be associated with the :data:`.func` accessor (i.e. CAST, EXTRACT). @@ -828,7 +828,11 @@ class Function(FunctionElement): ("type", InternalTraversal.dp_type), ] - type = sqltypes.NULLTYPE + name: str + + identifier: str + + type: TypeEngine = sqltypes.NULLTYPE """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -871,30 +875,7 @@ class Function(FunctionElement): ) -class _GenericMeta(TraversibleType): - def __init__(cls, clsname, bases, clsdict): - if annotation.Annotated not in cls.__mro__: - cls.name = name = clsdict.get("name", clsname) - cls.identifier = identifier = clsdict.get("identifier", name) - package = clsdict.pop("package", "_default") - # legacy - if "__return_type__" in clsdict: - cls.type = clsdict["__return_type__"] - - # Check _register attribute status - cls._register = getattr(cls, "_register", True) - - # Register the function if required - if cls._register: - register_function(identifier, cls, package) - else: - # Set _register to True to register child classes by default - cls._register = True - - super(_GenericMeta, cls).__init__(clsname, bases, clsdict) - - -class GenericFunction(Function, metaclass=_GenericMeta): +class GenericFunction(Function): """Define a 'generic' function. A generic function is a pre-established :class:`.Function` @@ -986,9 +967,34 @@ class GenericFunction(Function, metaclass=_GenericMeta): """ coerce_arguments = True - _register = False inherit_cache = True + name = "GenericFunction" + + def __init_subclass__(cls) -> None: + if annotation.Annotated not in cls.__mro__: + cls._register_generic_function(cls.__name__, cls.__dict__) + super().__init_subclass__() + + @classmethod + def _register_generic_function(cls, clsname, clsdict): + cls.name = name = clsdict.get("name", clsname) + cls.identifier = identifier = clsdict.get("identifier", name) + package = clsdict.get("package", "_default") + # legacy + if "__return_type__" in clsdict: + cls.type = clsdict["__return_type__"] + + # Check _register attribute status + cls._register = getattr(cls, "_register", True) + + # Register the function if required + if cls._register: + register_function(identifier, cls, package) + else: + # Set _register to True to register child classes by default + cls._register = True + def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: @@ -1006,6 +1012,7 @@ class GenericFunction(Function, metaclass=_GenericMeta): self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *parsed_args ).self_group() + self.type = sqltypes.to_instance( kwargs.pop("type_", None) or getattr(self, "type", None) ) diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 2387e551e..d71c85d60 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -11,6 +11,7 @@ import operator import types import weakref +from . import cache_key as _cache_key from . import coercions from . import elements from . import roles @@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement): else: parent_closure_cache_key = () - if parent_closure_cache_key is not traversals.NO_CACHE: + if parent_closure_cache_key is not _cache_key.NO_CACHE: anon_map = traversals.anon_map() cache_key = tuple( [ @@ -194,7 +195,7 @@ class LambdaElement(elements.ClauseElement): ] ) - if traversals.NO_CACHE not in anon_map: + if _cache_key.NO_CACHE not in anon_map: cache_key = parent_closure_cache_key + cache_key self.closure_cache_key = cache_key @@ -204,17 +205,17 @@ class LambdaElement(elements.ClauseElement): except KeyError: rec = None else: - cache_key = traversals.NO_CACHE + cache_key = _cache_key.NO_CACHE rec = None else: - cache_key = traversals.NO_CACHE + cache_key = _cache_key.NO_CACHE rec = None self.closure_cache_key = cache_key if rec is None: - if cache_key is not traversals.NO_CACHE: + if cache_key is not _cache_key.NO_CACHE: rec = AnalyzedFunction( tracker, self, apply_propagate_attrs, fn ) @@ -233,7 +234,7 @@ class LambdaElement(elements.ClauseElement): self._rec = rec - if cache_key is not traversals.NO_CACHE: + if cache_key is not _cache_key.NO_CACHE: if self.parent_lambda is not None: bindparams[:0] = self.parent_lambda._resolved_bindparams @@ -326,8 +327,8 @@ class LambdaElement(elements.ClauseElement): return expr def _gen_cache_key(self, anon_map, bindparams): - if self.closure_cache_key is traversals.NO_CACHE: - anon_map[traversals.NO_CACHE] = True + if self.closure_cache_key is _cache_key.NO_CACHE: + anon_map[_cache_key.NO_CACHE] = True return None cache_key = ( @@ -808,7 +809,7 @@ class AnalyzedCode: for tup_elem in opts.track_on[idx] ) - elif isinstance(elem, traversals.HasCacheKey): + elif isinstance(elem, _cache_key.HasCacheKey): def get(closure, opts, anon_map, bindparams): return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) @@ -834,7 +835,7 @@ class AnalyzedCode: """ - if isinstance(cell_contents, traversals.HasCacheKey): + if isinstance(cell_contents, _cache_key.HasCacheKey): def get(closure, opts, anon_map, bindparams): @@ -1166,7 +1167,7 @@ class PyWrapper(ColumnOperators): and not isinstance( # TODO: coverage where an ORM option or similar is here value, - traversals.HasCacheKey, + _cache_key.HasCacheKey, ) ): name = object.__getattribute__(self, "_name") diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 655d98b02..e674c4b74 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -18,6 +18,7 @@ import typing from typing import Type from typing import Union +from . import cache_key from . import coercions from . import operators from . import roles @@ -4300,7 +4301,7 @@ class _SelectFromElements: class _MemoizedSelectEntities( - traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible + cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible ): __visit_name__ = "memoized_select_entities" diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 744ddd025..93ff53663 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -22,11 +22,11 @@ from . import roles from . import type_api from .base import NO_ARG from .base import SchemaEventTarget +from .cache_key import HasCacheKey from .elements import _NONE_NAME from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa -from .traversals import HasCacheKey from .traversals import InternalTraversal from .type_api import Emulated from .type_api import NativeForEmulated # noqa diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index b689fe578..2fa3a0408 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -1,32 +1,25 @@ +# sql/traversals.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + from collections import deque -from collections import namedtuple import collections.abc as collections_abc import itertools from itertools import zip_longest import operator from . import operators -from .visitors import ExtendedInternalTraversal +from .visitors import anon_map from .visitors import InternalTraversal from .. import util -from ..inspection import inspect -from ..util import HasMemoized - -try: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa -except ImportError: - from ._py_util import cache_anon_map as anon_map # noqa SKIP_TRAVERSE = util.symbol("skip_traverse") COMPARE_FAILED = False COMPARE_SUCCEEDED = True -NO_CACHE = util.symbol("no_cache") -CACHE_IN_PLACE = util.symbol("cache_in_place") -CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") -STATIC_CACHE_KEY = util.symbol("static_cache_key") -PROPAGATE_ATTRS = util.symbol("propagate_attrs") -ANON_NAME = util.symbol("anon_name") def compare(obj1, obj2, **kw): @@ -54,729 +47,10 @@ def _preconfigure_traversals(target_hierarchy): ) -class HasCacheKey: - """Mixin for objects which can produce a cache key. - - .. seealso:: - - :class:`.CacheKey` - - :ref:`sql_caching` - - """ - - _cache_key_traversal = NO_CACHE - - _is_has_cache_key = True - - _hierarchy_supports_caching = True - """private attribute which may be set to False to prevent the - inherit_cache warning from being emitted for a hierarchy of subclasses. - - Currently applies to the DDLElement hierarchy which does not implement - caching. - - """ - - inherit_cache = None - """Indicate if this :class:`.HasCacheKey` instance should make use of the - cache key generation scheme used by its immediate superclass. - - The attribute defaults to ``None``, which indicates that a construct has - not yet taken into account whether or not its appropriate for it to - participate in caching; this is functionally equivalent to setting the - value to ``False``, except that a warning is also emitted. - - This flag can be set to ``True`` on a particular class, if the SQL that - corresponds to the object does not change based on attributes which - are local to this class, and not its superclass. - - .. seealso:: - - :ref:`compilerext_caching` - General guideslines for setting the - :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user - defined SQL constructs. - - """ - - __slots__ = () - - @classmethod - def _generate_cache_attrs(cls): - """generate cache key dispatcher for a new class. - - This sets the _generated_cache_key_traversal attribute once called - so should only be called once per class. - - """ - inherit_cache = cls.__dict__.get("inherit_cache", None) - inherit = bool(inherit_cache) - - if inherit: - _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) - if _cache_key_traversal is None: - try: - _cache_key_traversal = cls._traverse_internals - except AttributeError: - cls._generated_cache_key_traversal = NO_CACHE - return NO_CACHE - - # TODO: wouldn't we instead get this from our superclass? - # also, our superclass may not have this yet, but in any case, - # we'd generate for the superclass that has it. this is a little - # more complicated, so for the moment this is a little less - # efficient on startup but simpler. - return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" - ) - else: - _cache_key_traversal = cls.__dict__.get( - "_cache_key_traversal", None - ) - if _cache_key_traversal is None: - _cache_key_traversal = cls.__dict__.get( - "_traverse_internals", None - ) - if _cache_key_traversal is None: - cls._generated_cache_key_traversal = NO_CACHE - if ( - inherit_cache is None - and cls._hierarchy_supports_caching - ): - util.warn( - "Class %s will not make use of SQL compilation " - "caching as it does not set the 'inherit_cache' " - "attribute to ``True``. This can have " - "significant performance implications including " - "some performance degradations in comparison to " - "prior SQLAlchemy versions. Set this attribute " - "to True if this object can make use of the cache " - "key generated by the superclass. Alternatively, " - "this attribute may be set to False which will " - "disable this warning." % (cls.__name__), - code="cprf", - ) - return NO_CACHE - - return _cache_key_traversal_visitor.generate_dispatch( - cls, _cache_key_traversal, "_generated_cache_key_traversal" - ) - - @util.preload_module("sqlalchemy.sql.elements") - def _gen_cache_key(self, anon_map, bindparams): - """return an optional cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the structures that would affect the SQL string or the type handlers - should result in a different cache key. - - If a structure cannot produce a useful cache key, the NO_CACHE - symbol should be added to the anon_map and the method should - return None. - - """ - - cls = self.__class__ - - id_, found = anon_map.get_anon(self) - if found: - return (id_, cls) - - try: - dispatcher = cls.__dict__["_generated_cache_key_traversal"] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = cls._generate_cache_attrs() - - if dispatcher is NO_CACHE: - anon_map[NO_CACHE] = True - return None - - result = (id_, cls) - - # inline of _cache_key_traversal_visitor.run_generated_dispatch() - - for attrname, obj, meth in dispatcher( - self, _cache_key_traversal_visitor - ): - if obj is not None: - # TODO: see if C code can help here as Python lacks an - # efficient switch construct - - if meth is STATIC_CACHE_KEY: - sck = obj._static_cache_key - if sck is NO_CACHE: - anon_map[NO_CACHE] = True - return None - result += (attrname, sck) - elif meth is ANON_NAME: - elements = util.preloaded.sql_elements - if isinstance(obj, elements._anonymous_label): - obj = obj.apply_map(anon_map) - result += (attrname, obj) - elif meth is CALL_GEN_CACHE_KEY: - result += ( - attrname, - obj._gen_cache_key(anon_map, bindparams), - ) - - # remaining cache functions are against - # Python tuples, dicts, lists, etc. so we can skip - # if they are empty - elif obj: - if meth is CACHE_IN_PLACE: - result += (attrname, obj) - elif meth is PROPAGATE_ATTRS: - result += ( - attrname, - obj["compile_state_plugin"], - obj["plugin_subject"]._gen_cache_key( - anon_map, bindparams - ) - if obj["plugin_subject"] - else None, - ) - elif meth is InternalTraversal.dp_annotations_key: - # obj is here is the _annotations dict. however, we - # want to use the memoized cache key version of it. for - # Columns, this should be long lived. For select() - # statements, not so much, but they usually won't have - # annotations. - result += self._annotations_cache_key - elif ( - meth is InternalTraversal.dp_clauseelement_list - or meth is InternalTraversal.dp_clauseelement_tuple - or meth - is InternalTraversal.dp_memoized_select_entities - ): - result += ( - attrname, - tuple( - [ - elem._gen_cache_key(anon_map, bindparams) - for elem in obj - ] - ), - ) - else: - result += meth( - attrname, obj, self, anon_map, bindparams - ) - return result - - def _generate_cache_key(self): - """return a cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the structures that would affect the SQL string or the type handlers - should result in a different cache key. - - The cache key returned by this method is an instance of - :class:`.CacheKey`, which consists of a tuple representing the - cache key, as well as a list of :class:`.BindParameter` objects - which are extracted from the expression. While two expressions - that produce identical cache key tuples will themselves generate - identical SQL strings, the list of :class:`.BindParameter` objects - indicates the bound values which may have different values in - each one; these bound parameters must be consulted in order to - execute the statement with the correct parameters. - - a :class:`_expression.ClauseElement` structure that does not implement - a :meth:`._gen_cache_key` method and does not implement a - :attr:`.traverse_internals` attribute will not be cacheable; when - such an element is embedded into a larger structure, this method - will return None, indicating no cache key is available. - - """ - - bindparams = [] - - _anon_map = anon_map() - key = self._gen_cache_key(_anon_map, bindparams) - if NO_CACHE in _anon_map: - return None - else: - return CacheKey(key, bindparams) - - @classmethod - def _generate_cache_key_for_object(cls, obj): - bindparams = [] - - _anon_map = anon_map() - key = obj._gen_cache_key(_anon_map, bindparams) - if NO_CACHE in _anon_map: - return None - else: - return CacheKey(key, bindparams) - - -class MemoizedHasCacheKey(HasCacheKey, HasMemoized): - @HasMemoized.memoized_instancemethod - def _generate_cache_key(self): - return HasCacheKey._generate_cache_key(self) - - -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): - """The key used to identify a SQL statement construct in the - SQL compilation cache. - - .. seealso:: - - :ref:`sql_caching` - - """ - - def __hash__(self): - """CacheKey itself is not hashable - hash the .key portion""" - - return None - - def to_offline_string(self, statement_cache, statement, parameters): - """Generate an "offline string" form of this :class:`.CacheKey` - - The "offline string" is basically the string SQL for the - statement plus a repr of the bound parameter values in series. - Whereas the :class:`.CacheKey` object is dependent on in-memory - identities in order to work as a cache key, the "offline" version - is suitable for a cache that will work for other processes as well. - - The given ``statement_cache`` is a dictionary-like object where the - string form of the statement itself will be cached. This dictionary - should be in a longer lived scope in order to reduce the time spent - stringifying statements. - - - """ - if self.key not in statement_cache: - statement_cache[self.key] = sql_str = str(statement) - else: - sql_str = statement_cache[self.key] - - if not self.bindparams: - param_tuple = tuple(parameters[key] for key in sorted(parameters)) - else: - param_tuple = tuple( - parameters.get(bindparam.key, bindparam.value) - for bindparam in self.bindparams - ) - - return repr((sql_str, param_tuple)) - - def __eq__(self, other): - return self.key == other.key - - @classmethod - def _diff_tuples(cls, left, right): - ck1 = CacheKey(left, []) - ck2 = CacheKey(right, []) - return ck1._diff(ck2) - - def _whats_different(self, other): - - k1 = self.key - k2 = other.key - - stack = [] - pickup_index = 0 - while True: - s1, s2 = k1, k2 - for idx in stack: - s1 = s1[idx] - s2 = s2[idx] - - for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): - if idx < pickup_index: - continue - if e1 != e2: - if isinstance(e1, tuple) and isinstance(e2, tuple): - stack.append(idx) - break - else: - yield "key%s[%d]: %s != %s" % ( - "".join("[%d]" % id_ for id_ in stack), - idx, - e1, - e2, - ) - else: - pickup_index = stack.pop(-1) - break - - def _diff(self, other): - return ", ".join(self._whats_different(other)) - - def __str__(self): - stack = [self.key] - - output = [] - sentinel = object() - indent = -1 - while stack: - elem = stack.pop(0) - if elem is sentinel: - output.append((" " * (indent * 2)) + "),") - indent -= 1 - elif isinstance(elem, tuple): - if not elem: - output.append((" " * ((indent + 1) * 2)) + "()") - else: - indent += 1 - stack = list(elem) + [sentinel] + stack - output.append((" " * (indent * 2)) + "(") - else: - if isinstance(elem, HasCacheKey): - repr_ = "<%s object at %s>" % ( - type(elem).__name__, - hex(id(elem)), - ) - else: - repr_ = repr(elem) - output.append((" " * (indent * 2)) + " " + repr_ + ", ") - - return "CacheKey(key=%s)" % ("\n".join(output),) - - def _generate_param_dict(self): - """used for testing""" - - from .compiler import prefix_anon_map - - _anon_map = prefix_anon_map() - return {b.key % _anon_map: b.effective_value for b in self.bindparams} - - def _apply_params_to_element(self, original_cache_key, target_element): - translate = { - k.key: v.value - for k, v in zip(original_cache_key.bindparams, self.bindparams) - } - - return target_element.params(translate) - - def _clone(element, **kw): return element._clone() -class _CacheKey(ExtendedInternalTraversal): - # very common elements are inlined into the main _get_cache_key() method - # to produce a dramatic savings in Python function call overhead - - visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY - visit_clauseelement_list = InternalTraversal.dp_clauseelement_list - visit_annotations_key = InternalTraversal.dp_annotations_key - visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple - visit_memoized_select_entities = ( - InternalTraversal.dp_memoized_select_entities - ) - - visit_string = ( - visit_boolean - ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE - visit_statement_hint_list = CACHE_IN_PLACE - visit_type = STATIC_CACHE_KEY - visit_anon_name = ANON_NAME - - visit_propagate_attrs = PROPAGATE_ATTRS - - def visit_with_context_options( - self, attrname, obj, parent, anon_map, bindparams - ): - return tuple((fn.__code__, c_key) for fn, c_key in obj) - - def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): - return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) - - def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): - return tuple(obj) - - def visit_multi(self, attrname, obj, parent, anon_map, bindparams): - return ( - attrname, - obj._gen_cache_key(anon_map, bindparams) - if isinstance(obj, HasCacheKey) - else obj, - ) - - def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): - return ( - attrname, - tuple( - elem._gen_cache_key(anon_map, bindparams) - if isinstance(elem, HasCacheKey) - else elem - for elem in obj - ), - ) - - def visit_has_cache_key_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple( - tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in tup_elem - ) - for tup_elem in obj - ), - ) - - def visit_has_cache_key_list( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), - ) - - def visit_executable_options( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in obj - if elem._is_has_cache_key - ), - ) - - def visit_inspectable_list( - self, attrname, obj, parent, anon_map, bindparams - ): - return self.visit_has_cache_key_list( - attrname, [inspect(o) for o in obj], parent, anon_map, bindparams - ) - - def visit_clauseelement_tuples( - self, attrname, obj, parent, anon_map, bindparams - ): - return self.visit_has_cache_key_tuples( - attrname, obj, parent, anon_map, bindparams - ) - - def visit_fromclause_ordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - return ( - attrname, - tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), - ) - - def visit_clauseelement_unordered_set( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - cache_keys = [ - elem._gen_cache_key(anon_map, bindparams) for elem in obj - ] - return ( - attrname, - tuple( - sorted(cache_keys) - ), # cache keys all start with (id_, class) - ) - - def visit_named_ddl_element( - self, attrname, obj, parent, anon_map, bindparams - ): - return (attrname, obj.name) - - def visit_prefix_sequence( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - - return ( - attrname, - tuple( - [ - (clause._gen_cache_key(anon_map, bindparams), strval) - for clause, strval in obj - ] - ), - ) - - def visit_setup_join_tuple( - self, attrname, obj, parent, anon_map, bindparams - ): - is_legacy = "legacy" in attrname - - return tuple( - ( - target - if is_legacy and isinstance(target, str) - else target._gen_cache_key(anon_map, bindparams), - onclause - if is_legacy and isinstance(onclause, str) - else onclause._gen_cache_key(anon_map, bindparams) - if onclause is not None - else None, - from_._gen_cache_key(anon_map, bindparams) - if from_ is not None - else None, - tuple([(key, flags[key]) for key in sorted(flags)]), - ) - for (target, onclause, from_, flags) in obj - ) - - def visit_table_hint_list( - self, attrname, obj, parent, anon_map, bindparams - ): - if not obj: - return () - - return ( - attrname, - tuple( - [ - ( - clause._gen_cache_key(anon_map, bindparams), - dialect_name, - text, - ) - for (clause, dialect_name), text in obj.items() - ] - ), - ) - - def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): - return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) - - def visit_dialect_options( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - dialect_name, - tuple( - [ - (key, obj[dialect_name][key]) - for key in sorted(obj[dialect_name]) - ] - ), - ) - for dialect_name in sorted(obj) - ), - ) - - def visit_string_clauseelement_dict( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - (key, obj[key]._gen_cache_key(anon_map, bindparams)) - for key in sorted(obj) - ), - ) - - def visit_string_multi_dict( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - key, - value._gen_cache_key(anon_map, bindparams) - if isinstance(value, HasCacheKey) - else value, - ) - for key, value in [(key, obj[key]) for key in sorted(obj)] - ), - ) - - def visit_fromclause_canonical_column_collection( - self, attrname, obj, parent, anon_map, bindparams - ): - # inlining into the internals of ColumnCollection - return ( - attrname, - tuple( - col._gen_cache_key(anon_map, bindparams) - for k, col in obj._collection - ), - ) - - def visit_unknown_structure( - self, attrname, obj, parent, anon_map, bindparams - ): - anon_map[NO_CACHE] = True - return () - - def visit_dml_ordered_values( - self, attrname, obj, parent, anon_map, bindparams - ): - return ( - attrname, - tuple( - ( - key._gen_cache_key(anon_map, bindparams) - if hasattr(key, "__clause_element__") - else key, - value._gen_cache_key(anon_map, bindparams), - ) - for key, value in obj - ), - ) - - def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): - # in py37 we can assume two dictionaries created in the same - # insert ordering will retain that sorting - return ( - attrname, - tuple( - ( - k._gen_cache_key(anon_map, bindparams) - if hasattr(k, "__clause_element__") - else k, - obj[k]._gen_cache_key(anon_map, bindparams), - ) - for k in obj - ), - ) - - def visit_dml_multi_values( - self, attrname, obj, parent, anon_map, bindparams - ): - # multivalues are simply not cacheable right now - anon_map[NO_CACHE] = True - return () - - -_cache_key_traversal_visitor = _CacheKey() - - class HasCopyInternals: __slots__ = () @@ -813,7 +87,7 @@ class HasCopyInternals: setattr(self, attrname, result) -class _CopyInternals(InternalTraversal): +class _CopyInternalsTraversal(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -936,7 +210,7 @@ class _CopyInternals(InternalTraversal): return element -_copy_internals = _CopyInternals() +_copy_internals = _CopyInternalsTraversal() def _flatten_clauseelement(element): @@ -948,7 +222,7 @@ def _flatten_clauseelement(element): return element -class _GetChildren(InternalTraversal): +class _GetChildrenTraversal(InternalTraversal): """Generate a _children_traversal internal traversal dispatch for classes with a _traverse_internals collection.""" @@ -1019,7 +293,7 @@ class _GetChildren(InternalTraversal): return () -_get_children = _GetChildren() +_get_children = _GetChildrenTraversal() @util.preload_module("sqlalchemy.sql.elements") diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index ea00a32ad..7981100a4 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -13,9 +13,8 @@ import typing from . import operators from .base import SchemaEventTarget -from .traversals import NO_CACHE +from .cache_key import NO_CACHE from .visitors import Traversible -from .visitors import TraversibleType from .. import exc from .. import util @@ -869,10 +868,6 @@ class TypeEngine(Traversible): return util.generic_repr(self) -class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): - pass - - class ExternalType: """mixin that defines attributes and behaviors specific to third-party datatypes. @@ -1049,7 +1044,7 @@ class ExternalType: return NO_CACHE -class UserDefinedType(ExternalType, TypeEngine, metaclass=VisitableCheckKWArg): +class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg): """Base for user defined types. This should be the base of new types. Note that diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 69e83e46a..63067585e 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -22,6 +22,7 @@ from .annotation import _shallow_annotate # noqa from .base import _expand_cloned from .base import _from_objects from .base import ColumnSet +from .cache_key import HasCacheKey # noqa from .ddl import sort_tables # noqa from .elements import _find_columns # noqa from .elements import _label_reference @@ -41,7 +42,6 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause -from .traversals import HasCacheKey # noqa from .. import exc from .. import util diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 87fe36944..70c4dc133 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -32,6 +32,11 @@ from .. import util from ..util import langhelpers from ..util import symbol +try: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa +except ImportError: + from ._py_util import cache_anon_map as anon_map # noqa + __all__ = [ "iterate", "traverse_using", @@ -39,88 +44,77 @@ __all__ = [ "cloned_traverse", "replacement_traverse", "Traversible", - "TraversibleType", "ExternalTraversal", "InternalTraversal", ] -def _generate_compiler_dispatch(cls): - """Generate a _compiler_dispatch() external traversal on classes with a - __visit_name__ attribute. - - """ - visit_name = cls.__visit_name__ - - if "_compiler_dispatch" in cls.__dict__: - # class has a fixed _compiler_dispatch() method. - # copy it to "original" so that we can get it back if - # sqlalchemy.ext.compiles overrides it. - cls._original_compiler_dispatch = cls._compiler_dispatch - return - - if not isinstance(visit_name, str): - raise exc.InvalidRequestError( - "__visit_name__ on class %s must be a string at the class level" - % cls.__name__ - ) - - name = "visit_%s" % visit_name - getter = operator.attrgetter(name) - - def _compiler_dispatch(self, visitor, **kw): - """Look for an attribute named "visit_<visit_name>" on the - visitor, and call it with the same kw params. - - """ - try: - meth = getter(visitor) - except AttributeError as err: - return visitor.visit_unsupported_compilation(self, err, **kw) - - else: - return meth(self, **kw) - - cls._compiler_dispatch = ( - cls._original_compiler_dispatch - ) = _compiler_dispatch +class Traversible: + """Base class for visitable objects.""" + __slots__ = () -class TraversibleType(type): - """Metaclass which assigns dispatch attributes to various kinds of - "visitable" classes. + __visit_name__: str - Attributes include: + def __init_subclass__(cls) -> None: + if "__visit_name__" in cls.__dict__: + cls._generate_compiler_dispatch() + super().__init_subclass__() - * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``. - This is called "external traversal" because the caller of each visit() - method is responsible for sub-traversing the inner elements of each - object. This is appropriate for string compilers and other traversals - that need to call upon the inner elements in a specific pattern. + @classmethod + def _generate_compiler_dispatch(cls): + """Assign dispatch attributes to various kinds of + "visitable" classes. - * internal traversal collections ``_children_traversal``, - ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from - an optional ``_traverse_internals`` collection of symbols which comes - from the :class:`.InternalTraversal` list of symbols. This is called - "internal traversal" MARKMARK + Attributes include: - """ + * The ``_compiler_dispatch`` method, corresponding to + ``__visit_name__``. This is called "external traversal" because the + caller of each visit() method is responsible for sub-traversing the + inner elements of each object. This is appropriate for string + compilers and other traversals that need to call upon the inner + elements in a specific pattern. - def __init__(cls, clsname, bases, clsdict): - if clsname != "Traversible": - if "__visit_name__" in clsdict: - _generate_compiler_dispatch(cls) + * internal traversal collections ``_children_traversal``, + ``_cache_key_traversal``, ``_copy_internals_traversal``, generated + from an optional ``_traverse_internals`` collection of symbols which + comes from the :class:`.InternalTraversal` list of symbols. This is + called "internal traversal". - super(TraversibleType, cls).__init__(clsname, bases, clsdict) + """ + visit_name = cls.__visit_name__ + + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + + if not isinstance(visit_name, str): + raise exc.InvalidRequestError( + f"__visit_name__ on class {cls.__name__} must be a string " + "at the class level" + ) + name = "visit_%s" % visit_name + getter = operator.attrgetter(name) -class Traversible(metaclass=TraversibleType): - """Base class for visitable objects, applies the - :class:`.visitors.TraversibleType` metaclass. + def _compiler_dispatch(self, visitor, **kw): + """Look for an attribute named "visit_<visit_name>" on the + visitor, and call it with the same kw params. - """ + """ + try: + meth = getter(visitor) + except AttributeError as err: + return visitor.visit_unsupported_compilation(self, err, **kw) + else: + return meth(self, **kw) - __slots__ = () + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch def __class_getitem__(cls, key): # allow generic classes in py3.9+ @@ -159,48 +153,90 @@ class Traversible(metaclass=TraversibleType): ) -class _InternalTraversalType(type): - def __init__(cls, clsname, bases, clsdict): - if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"): - lookup = {} - for key, sym in clsdict.items(): - if key.startswith("dp_"): - visit_key = key.replace("dp_", "visit_") - sym_name = sym.name - assert sym_name not in lookup, sym_name - lookup[sym] = lookup[sym_name] = visit_key - if hasattr(cls, "_dispatch_lookup"): - lookup.update(cls._dispatch_lookup) - cls._dispatch_lookup = lookup +class _HasTraversalDispatch: + r"""Define infrastructure for the :class:`.InternalTraversal` class. - super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict) + .. versionadded:: 2.0 + """ -def _generate_dispatcher(visitor, internal_dispatch, method_name): - names = [] - for attrname, visit_sym in internal_dispatch: - meth = visitor.dispatch(visit_sym) - if meth: - visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym] - names.append((attrname, visit_name)) - - code = ( - (" return [\n") - + ( - ", \n".join( - " (%r, self.%s, visitor.%s)" - % (attrname, attrname, visit_name) - for attrname, visit_name in names + def __init_subclass__(cls) -> None: + cls._generate_traversal_dispatch() + super().__init_subclass__() + + def dispatch(self, visit_symbol): + """Given a method from :class:`._HasTraversalDispatch`, return the + corresponding method on a subclass. + + """ + name = self._dispatch_lookup[visit_symbol] + return getattr(self, name, None) + + def run_generated_dispatch( + self, target, internal_dispatch, generate_dispatcher_name + ): + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + # most of the dispatchers are generated up front + # in sqlalchemy/sql/__init__.py -> + # traversals.py-> _preconfigure_traversals(). + # this block will generate any remaining dispatchers. + dispatcher = self.generate_dispatch( + target.__class__, internal_dispatch, generate_dispatcher_name + ) + return dispatcher(target, self) + + def generate_dispatch( + self, target_cls, internal_dispatch, generate_dispatcher_name + ): + dispatcher = self._generate_dispatcher( + internal_dispatch, generate_dispatcher_name + ) + # assert isinstance(target_cls, type) + setattr(target_cls, generate_dispatcher_name, dispatcher) + return dispatcher + + @classmethod + def _generate_traversal_dispatch(cls): + lookup = {} + clsdict = cls.__dict__ + for key, sym in clsdict.items(): + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.name + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + if hasattr(cls, "_dispatch_lookup"): + lookup.update(cls._dispatch_lookup) + cls._dispatch_lookup = lookup + + def _generate_dispatcher(self, internal_dispatch, method_name): + names = [] + for attrname, visit_sym in internal_dispatch: + meth = self.dispatch(visit_sym) + if meth: + visit_name = ExtendedInternalTraversal._dispatch_lookup[ + visit_sym + ] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) ) + + ("\n ]\n") ) - + ("\n ]\n") - ) - meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" - # print(meth_text) - return langhelpers._exec_code_in_env(meth_text, {}, method_name) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + return langhelpers._exec_code_in_env(meth_text, {}, method_name) -class InternalTraversal(metaclass=_InternalTraversalType): +class InternalTraversal(_HasTraversalDispatch): r"""Defines visitor symbols used for internal traversal. The :class:`.InternalTraversal` class is used in two ways. One is that @@ -239,39 +275,6 @@ class InternalTraversal(metaclass=_InternalTraversalType): """ - def dispatch(self, visit_symbol): - """Given a method from :class:`.InternalTraversal`, return the - corresponding method on a subclass. - - """ - name = self._dispatch_lookup[visit_symbol] - return getattr(self, name, None) - - def run_generated_dispatch( - self, target, internal_dispatch, generate_dispatcher_name - ): - try: - dispatcher = target.__class__.__dict__[generate_dispatcher_name] - except KeyError: - # most of the dispatchers are generated up front - # in sqlalchemy/sql/__init__.py -> - # traversals.py-> _preconfigure_traversals(). - # this block will generate any remaining dispatchers. - dispatcher = self.generate_dispatch( - target.__class__, internal_dispatch, generate_dispatcher_name - ) - return dispatcher(target, self) - - def generate_dispatch( - self, target_cls, internal_dispatch, generate_dispatcher_name - ): - dispatcher = _generate_dispatcher( - self, internal_dispatch, generate_dispatcher_name - ) - # assert isinstance(target_cls, type) - setattr(target_cls, generate_dispatcher_name, dispatcher) - return dispatcher - dp_has_cache_key = symbol("HC") """Visit a :class:`.HasCacheKey` object.""" @@ -623,7 +626,6 @@ class ReplacingExternalTraversal(CloningExternalTraversal): # backwards compatibility Visitable = Traversible -VisitableType = TraversibleType ClauseVisitor = ExternalTraversal CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal |
