diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2019-11-04 19:50:17 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2019-11-04 19:50:17 +0000 |
commit | 98f7c0c080ebbc8236fc0cc98970cb1b0112c01e (patch) | |
tree | f119c0003461e0bd7426bb493391a1c563817bb4 | |
parent | bee7c3462d6235b7e108dbdbcd0727cdcbc49eb0 (diff) | |
parent | 29330ec1596f12462c501a65404ff52005b16b6c (diff) | |
download | sqlalchemy-98f7c0c080ebbc8236fc0cc98970cb1b0112c01e.tar.gz |
Merge "Add anonymizing context to cache keys, comparison; convert traversal"
37 files changed, 2336 insertions, 1290 deletions
diff --git a/doc/build/core/sqlelement.rst b/doc/build/core/sqlelement.rst index 41a27af71..f7f9cab64 100644 --- a/doc/build/core/sqlelement.rst +++ b/doc/build/core/sqlelement.rst @@ -82,6 +82,9 @@ the FROM clause of a SELECT statement. .. autoclass:: BindParameter :members: +.. autoclass:: CacheKey + :members: + .. autoclass:: Case :members: @@ -90,6 +93,7 @@ the FROM clause of a SELECT statement. .. autoclass:: ClauseElement :members: + :inherited-members: .. autoclass:: ClauseList diff --git a/doc/build/core/visitors.rst b/doc/build/core/visitors.rst index 02f6e24fc..539d66440 100644 --- a/doc/build/core/visitors.rst +++ b/doc/build/core/visitors.rst @@ -23,3 +23,4 @@ as well as when building out custom SQL expressions using the .. automodule:: sqlalchemy.sql.visitors :members: + :private-members:
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 293aa426d..b43b364fa 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -103,7 +103,6 @@ class Insert(StandardInsert): inserted_alias = getattr(self, "inserted_alias", None) self._post_values_clause = OnDuplicateClause(inserted_alias, values) - return self insert = public_factory(Insert, ".dialects.mysql.insert") diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 909d568a7..e94f9913c 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1658,23 +1658,20 @@ class PGCompiler(compiler.SQLCompiler): return "ONLY " + sqltext def get_select_precolumns(self, select, **kw): - if select._distinct is not False: - if select._distinct is True: - return "DISTINCT " - elif isinstance(select._distinct, (list, tuple)): + if select._distinct or select._distinct_on: + if select._distinct_on: return ( "DISTINCT ON (" + ", ".join( - [self.process(col, **kw) for col in select._distinct] + [ + self.process(col, **kw) + for col in select._distinct_on + ] ) + ") " ) else: - return ( - "DISTINCT ON (" - + self.process(select._distinct, **kw) - + ") " - ) + return "DISTINCT " else: return "" diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 4e77f5a4c..f4467976a 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -103,7 +103,6 @@ class Insert(StandardInsert): self._post_values_clause = OnConflictDoUpdate( constraint, index_elements, index_where, set_, where ) - return self @_generative def on_conflict_do_nothing( @@ -138,7 +137,6 @@ class Insert(StandardInsert): self._post_values_clause = OnConflictDoNothing( constraint, index_elements, index_where ) - return self insert = public_factory(Insert, ".dialects.postgresql.insert") diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index d18a35a40..8e137f141 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -198,7 +198,7 @@ class BakedQuery(object): self.spoil() else: for opt in options: - cache_key = opt._generate_cache_key(cache_path) + cache_key = opt._generate_path_cache_key(cache_path) if cache_key is False: self.spoil() elif cache_key is not None: diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 4a5a8ba9c..c2b234758 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -455,7 +455,7 @@ def deregister(class_): if hasattr(class_, "_compiler_dispatcher"): # regenerate default _compiler_dispatch - visitors._generate_dispatch(class_) + visitors._generate_compiler_dispatch(class_) # remove custom directive del class_._compiler_dispatcher diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 83069f113..aa2986205 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -47,6 +47,8 @@ from .base import state_str from .. import event from .. import inspection from .. import util +from ..sql import base as sql_base +from ..sql import visitors @inspection._self_inspects @@ -54,6 +56,7 @@ class QueryableAttribute( interfaces._MappedAttribute, interfaces.InspectionAttr, interfaces.PropComparator, + sql_base.HasCacheKey, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` @@ -102,6 +105,13 @@ class QueryableAttribute( if base[key].dispatch._active_history: self.dispatch._active_history = True + _cache_key_traversal = [ + # ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ] + @util.memoized_property def _supports_population(self): return self.impl.supports_population diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 6f8d19293..a3dea6b0e 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -216,7 +216,6 @@ def _assertions(*assertions): for assertion in assertions: assertion(self, fn.__name__) fn(self, *args[1:], **kw) - return self return generate diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index e94a81fed..704ce9df7 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -36,6 +36,8 @@ from .. import inspect from .. import inspection from .. import util from ..sql import operators +from ..sql import visitors +from ..sql.traversals import HasCacheKey __all__ = ( @@ -54,7 +56,9 @@ __all__ = ( ) -class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): +class MapperProperty( + HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots +): """Represent a particular class attribute mapped by :class:`.Mapper`. The most common occurrences of :class:`.MapperProperty` are the @@ -74,6 +78,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): "info", ) + _cache_key_traversal = [ + ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ] + cascade = frozenset() """The set of 'cascade' attribute names. @@ -647,7 +656,7 @@ class MapperOption(object): self.process_query(query) - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): """Used by the "baked lazy loader" to see if this option can be cached. The "baked lazy loader" refers to the :class:`.Query` that is diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 376ad1923..548eca58d 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -71,7 +71,7 @@ _CONFIGURE_MUTEX = util.threading.RLock() @inspection._self_inspects @log.class_logger -class Mapper(InspectionAttr): +class Mapper(sql_base.HasCacheKey, InspectionAttr): """Define the correlation of class attributes to database table columns. @@ -729,6 +729,10 @@ class Mapper(InspectionAttr): """ return self + _cache_key_traversal = [ + ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj) + ] + @property def entity(self): r"""Part of the inspection API. diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 2f680a3a1..585cb80bc 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -15,7 +15,8 @@ from .base import class_mapper from .. import exc from .. import inspection from .. import util - +from ..sql import visitors +from ..sql.traversals import HasCacheKey log = logging.getLogger(__name__) @@ -28,7 +29,7 @@ _WILDCARD_TOKEN = "*" _DEFAULT_TOKEN = "_sa_default" -class PathRegistry(object): +class PathRegistry(HasCacheKey): """Represent query load paths and registry functions. Basically represents structures like: @@ -57,6 +58,10 @@ class PathRegistry(object): is_token = False is_root = False + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) + ] + def __eq__(self, other): return other is not None and self.path == other.path @@ -78,6 +83,9 @@ class PathRegistry(object): def __len__(self): return len(self.path) + def __hash__(self): + return id(self) + @property def length(self): return len(self.path) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 26f47f616..99bbbe37c 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -26,11 +26,13 @@ from .. import inspect from .. import util from ..sql import coercions from ..sql import roles +from ..sql import visitors from ..sql.base import _generative from ..sql.base import Generative +from ..sql.traversals import HasCacheKey -class Load(Generative, MapperOption): +class Load(HasCacheKey, Generative, MapperOption): """Represents loader options which modify the state of a :class:`.Query` in order to affect how various mapped attributes are loaded. @@ -70,6 +72,17 @@ class Load(Generative, MapperOption): """ + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ( + "_context_cache_key", + visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples, + ), + ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), + ] + def __init__(self, entity): insp = inspect(entity) self.path = insp._path_registry @@ -89,7 +102,16 @@ class Load(Generative, MapperOption): load._of_type = None return load - def _generate_cache_key(self, path): + @property + def _context_cache_key(self): + serialized = [] + for (key, loader_path), obj in self.context.items(): + if key != "loader": + continue + serialized.append(loader_path + (obj,)) + return serialized + + def _generate_path_cache_key(self, path): if path.path[0].is_aliased_class: return False @@ -522,9 +544,16 @@ class _UnboundLoad(Load): self._to_bind = [] self.local_opts = {} + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_multi_list), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list), + ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), + ] + _is_chain_link = False - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): serialized = () for val in self._to_bind: for local_elem, val_elem in zip(self.path, val.path): @@ -533,7 +562,7 @@ class _UnboundLoad(Load): else: opt = val._bind_loader([path.path[0]], None, None, False) if opt: - c_key = opt._generate_cache_key(path) + c_key = opt._generate_path_cache_key(path) if c_key is False: return False elif c_key: @@ -660,7 +689,6 @@ class _UnboundLoad(Load): opt = meth(opt, all_tokens[-1], **kw) opt._is_chain_link = False - return opt def _chop_path(self, to_chop, path): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 5f0f41e8d..c86993678 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -30,10 +30,12 @@ from .. import exc as sa_exc from .. import inspection from .. import sql from .. import util +from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util +from ..sql import visitors all_cascades = frozenset( @@ -530,7 +532,7 @@ class AliasedClass(object): return str(self._aliased_insp) -class AliasedInsp(InspectionAttr): +class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -627,6 +629,12 @@ class AliasedInsp(InspectionAttr): def __clause_element__(self): return self.selectable + _cache_key_traversal = [ + ("name", visitors.ExtendedInternalTraversal.dp_string), + ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean), + ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement), + ] + @property def class_(self): """Return the mapped class ultimately represented by this diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index a0264845e..0d995ec8a 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -12,12 +12,32 @@ associations. """ from . import operators +from .base import HasCacheKey +from .visitors import InternalTraversal from .. import util -class SupportsCloneAnnotations(object): +class SupportsAnnotations(object): + @util.memoized_property + def _annotation_traversals(self): + return [ + ( + key, + InternalTraversal.dp_has_cache_key + if isinstance(value, HasCacheKey) + else InternalTraversal.dp_plain_obj, + ) + for key, value in self._annotations.items() + ] + + +class SupportsCloneAnnotations(SupportsAnnotations): _annotations = util.immutabledict() + _traverse_internals = [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] + def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = new._annotations.union(values) + new.__dict__.pop("_annotation_traversals", None) return new def _with_annotations(self, values): @@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = util.immutabledict(values) + new.__dict__.pop("_annotation_traversals", None) return new def _deannotate(self, values=None, clone=False): @@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object): # the expression for a deep deannotation new = self._clone() new._annotations = {} + new.__dict__.pop("_annotation_traversals", None) return new else: return self -class SupportsWrappingAnnotations(object): +class SupportsWrappingAnnotations(SupportsAnnotations): def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -123,6 +146,7 @@ class Annotated(object): def __init__(self, element, values): self.__dict__ = element.__dict__.copy() + self.__dict__.pop("_annotation_traversals", None) self.__element = element self._annotations = values self._hash = hash(element) @@ -135,6 +159,7 @@ class Annotated(object): def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() + clone.__dict__.pop("_annotation_traversals", None) clone._annotations = values return clone @@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None): """ - def clone(elem): + # annotated objects hack the __hash__() method so if we want to + # uniquely process them we have to use id() + + cloned_ids = {} + + def clone(elem, **kw): + id_ = id(elem) + + if id_ in cloned_ids: + return cloned_ids[id_] + if ( exclude and hasattr(elem, "proxy_set") @@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None): else: newelem = elem newelem._copy_internals(clone=clone) + cloned_ids[id_] = newelem return newelem if element is not None: @@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None): def _deep_deannotate(element, values=None): """Deep copy the given element, removing annotations.""" - cloned = util.column_dict() + cloned = {} - def clone(elem): - # if a values dict is given, - # the elem must be cloned each time it appears, - # as there may be different annotations in source - # elements that are remaining. if totally - # removing all annotations, can assume the same - # slate... - if values or elem not in cloned: + def clone(elem, **kw): + if values: + key = id(elem) + else: + key = elem + + if key not in cloned: newelem = elem._deannotate(values=values, clone=True) newelem._copy_internals(clone=clone) - if not values: - cloned[elem] = newelem + cloned[key] = newelem return newelem else: - return cloned[elem] + return cloned[key] if element is not None: element = clone(element) @@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls): "Annotated%s" % cls.__name__, (base_cls, cls), {} ) globals()["Annotated%s" % cls.__name__] = anno_cls + + if "_traverse_internals" in cls.__dict__: + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] return anno_cls diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 7e9199bfa..d11a3a313 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -14,6 +14,7 @@ import itertools import operator import re +from .traversals import HasCacheKey # noqa from .visitors import ClauseVisitor from .. import exc from .. import util @@ -38,18 +39,41 @@ class Immutable(object): def _clone(self): return self + def _copy_internals(self, **kw): + pass + + +class HasMemoized(object): + def _reset_memoizations(self): + self._memoized_property.expire_instance(self) + + def _reset_exported(self): + self._memoized_property.expire_instance(self) + + def _copy_internals(self, **kw): + super(HasMemoized, self)._copy_internals(**kw) + self._reset_memoizations() + def _from_objects(*elements): return itertools.chain(*[element._from_objects for element in elements]) def _generative(fn): + """non-caching _generative() decorator. + + This is basically the legacy decorator that copies the object and + runs a method on the new copy. + + """ + @util.decorator - def _generative(fn, *args, **kw): + def _generative(fn, self, *args, **kw): """Mark a method as generative.""" - self = args[0]._generate() - fn(self, *args[1:], **kw) + self = self._generate() + x = fn(self, *args, **kw) + assert x is None, "generative methods must have no return value" return self decorated = _generative(fn) @@ -357,10 +381,8 @@ class DialectKWArgs(object): class Generative(object): - """Allow a ClauseElement to generate itself via the - @_generative decorator. - - """ + """Provide a method-chaining pattern in conjunction with the + @_generative decorator.""" def _generate(self): s = self.__class__.__new__(self.__class__) diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py deleted file mode 100644 index 30a90348c..000000000 --- a/lib/sqlalchemy/sql/clause_compare.py +++ /dev/null @@ -1,334 +0,0 @@ -from collections import deque - -from . import operators -from .. import util - - -SKIP_TRAVERSE = util.symbol("skip_traverse") - - -def compare(obj1, obj2, **kw): - if kw.get("use_proxies", False): - strategy = ColIdentityComparatorStrategy() - else: - strategy = StructureComparatorStrategy() - - return strategy.compare(obj1, obj2, **kw) - - -class StructureComparatorStrategy(object): - __slots__ = "compare_stack", "cache" - - def __init__(self): - self.compare_stack = deque() - self.cache = set() - - def compare(self, obj1, obj2, **kw): - stack = self.compare_stack - cache = self.cache - - stack.append((obj1, obj2)) - - while stack: - left, right = stack.popleft() - - if left is right: - continue - elif left is None or right is None: - # we know they are different so no match - return False - elif (left, right) in cache: - continue - cache.add((left, right)) - - visit_name = left.__visit_name__ - - # we're not exactly looking for identical types, because - # there are things like Column and AnnotatedColumn. So the - # visit_name has to at least match up - if visit_name != right.__visit_name__: - return False - - meth = getattr(self, "compare_%s" % visit_name, None) - - if meth: - comparison = meth(left, right, **kw) - if comparison is False: - return False - elif comparison is SKIP_TRAVERSE: - continue - - for c1, c2 in util.zip_longest( - left.get_children(column_collections=False), - right.get_children(column_collections=False), - fillvalue=None, - ): - if c1 is None or c2 is None: - # collections are different sizes, comparison fails - return False - stack.append((c1, c2)) - - return True - - def compare_inner(self, obj1, obj2, **kw): - stack = self.compare_stack - try: - self.compare_stack = deque() - return self.compare(obj1, obj2, **kw) - finally: - self.compare_stack = stack - - def _compare_unordered_sequences(self, seq1, seq2, **kw): - if seq1 is None: - return seq2 is None - - completed = set() - for clause in seq1: - for other_clause in set(seq2).difference(completed): - if self.compare_inner(clause, other_clause, **kw): - completed.add(other_clause) - break - return len(completed) == len(seq1) == len(seq2) - - def compare_bindparam(self, left, right, **kw): - # note the ".key" is often generated from id(self) so can't - # be compared, as far as determining structure. - return ( - left.type._compare_type_affinity(right.type) - and left.value == right.value - and left.callable == right.callable - and left._orig_key == right._orig_key - ) - - def compare_clauselist(self, left, right, **kw): - if left.operator is right.operator: - if operators.is_associative(left.operator): - if self._compare_unordered_sequences( - left.clauses, right.clauses - ): - return SKIP_TRAVERSE - else: - return False - else: - # normal ordered traversal - return True - else: - return False - - def compare_unary(self, left, right, **kw): - if left.operator: - disp = self._get_operator_dispatch( - left.operator, "unary", "operator" - ) - if disp is not None: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - elif left.modifier: - disp = self._get_operator_dispatch( - left.modifier, "unary", "modifier" - ) - if disp is not None: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - return ( - left.operator == right.operator and left.modifier == right.modifier - ) - - def compare_binary(self, left, right, **kw): - disp = self._get_operator_dispatch(left.operator, "binary", None) - if disp: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - - if left.operator == right.operator: - if operators.is_commutative(left.operator): - if ( - compare(left.left, right.left, **kw) - and compare(left.right, right.right, **kw) - ) or ( - compare(left.left, right.right, **kw) - and compare(left.right, right.left, **kw) - ): - return SKIP_TRAVERSE - else: - return False - else: - return True - else: - return False - - def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): - # used by compare_binary, compare_unary - attrname = "visit_%s_%s%s" % ( - operator_.__name__, - qualifier1, - "_" + qualifier2 if qualifier2 else "", - ) - return getattr(self, attrname, None) - - def visit_function_as_comparison_op_binary( - self, left, right, operator, **kw - ): - return ( - left.left_index == right.left_index - and left.right_index == right.right_index - ) - - def compare_function(self, left, right, **kw): - return left.name == right.name - - def compare_column(self, left, right, **kw): - if left.table is not None: - self.compare_stack.appendleft((left.table, right.table)) - return ( - left.key == right.key - and left.name == right.name - and ( - left.type._compare_type_affinity(right.type) - if left.type is not None - else right.type is None - ) - and left.is_literal == right.is_literal - ) - - def compare_collation(self, left, right, **kw): - return left.collation == right.collation - - def compare_type_coerce(self, left, right, **kw): - return left.type._compare_type_affinity(right.type) - - @util.dependencies("sqlalchemy.sql.elements") - def compare_alias(self, elements, left, right, **kw): - return ( - left.name == right.name - if not isinstance(left.name, elements._anonymous_label) - else isinstance(right.name, elements._anonymous_label) - ) - - def compare_cte(self, elements, left, right, **kw): - raise NotImplementedError("TODO") - - def compare_extract(self, left, right, **kw): - return left.field == right.field - - def compare_textual_label_reference(self, left, right, **kw): - return left.element == right.element - - def compare_slice(self, left, right, **kw): - return ( - left.start == right.start - and left.stop == right.stop - and left.step == right.step - ) - - def compare_over(self, left, right, **kw): - return left.range_ == right.range_ and left.rows == right.rows - - @util.dependencies("sqlalchemy.sql.elements") - def compare_label(self, elements, left, right, **kw): - return left._type._compare_type_affinity(right._type) and ( - left.name == right.name - if not isinstance(left.name, elements._anonymous_label) - else isinstance(right.name, elements._anonymous_label) - ) - - def compare_typeclause(self, left, right, **kw): - return left.type._compare_type_affinity(right.type) - - def compare_join(self, left, right, **kw): - return left.isouter == right.isouter and left.full == right.full - - def compare_table(self, left, right, **kw): - if left.name != right.name: - return False - - self.compare_stack.extendleft( - util.zip_longest(left.columns, right.columns) - ) - - def compare_compound_select(self, left, right, **kw): - - if not self._compare_unordered_sequences( - left.selects, right.selects, **kw - ): - return False - - if left.keyword != right.keyword: - return False - - if left._for_update_arg != right._for_update_arg: - return False - - if not self.compare_inner( - left._order_by_clause, right._order_by_clause, **kw - ): - return False - - if not self.compare_inner( - left._group_by_clause, right._group_by_clause, **kw - ): - return False - - return SKIP_TRAVERSE - - def compare_select(self, left, right, **kw): - if not self._compare_unordered_sequences( - left._correlate, right._correlate - ): - return False - if not self._compare_unordered_sequences( - left._correlate_except, right._correlate_except - ): - return False - - if not self._compare_unordered_sequences( - left._from_obj, right._from_obj - ): - return False - - if left._for_update_arg != right._for_update_arg: - return False - - return True - - def compare_textual_select(self, left, right, **kw): - self.compare_stack.extendleft( - util.zip_longest(left.column_args, right.column_args) - ) - return left.positional == right.positional - - -class ColIdentityComparatorStrategy(StructureComparatorStrategy): - def compare_column_element( - self, left, right, use_proxies=True, equivalents=(), **kw - ): - """Compare ColumnElements using proxies and equivalent collections. - - This is a comparison strategy specific to the ORM. - """ - - to_compare = (right,) - if equivalents and right in equivalents: - to_compare = equivalents[right].union(to_compare) - - for oth in to_compare: - if use_proxies and left.shares_lineage(oth): - return True - elif hash(left) == hash(right): - return True - else: - return False - - def compare_column(self, left, right, **kw): - return self.compare_column_element(left, right, **kw) - - def compare_label(self, left, right, **kw): - return self.compare_column_element(left, right, **kw) - - def compare_table(self, left, right, **kw): - # tables compare on identity, since it's not really feasible to - # compare them column by column with the above rules - return left is right diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5ecec7d6c..546fffc6c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -434,6 +434,27 @@ class _CompileLabel(elements.ColumnElement): return self +class prefix_anon_map(dict): + """A map that creates new keys for missing key access. + + Considers keys of the form "<ident> <name>" to produce + new symbols "<name>_<index>", where "index" is an incrementing integer + corresponding to <name>. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key): + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 + value = derived + "_" + str(anonymous_counter) + self[key] = value + return value + + class SQLCompiler(Compiled): """Default implementation of :class:`.Compiled`. @@ -574,7 +595,7 @@ class SQLCompiler(Compiled): # a map which tracks "anonymous" identifiers that are created on # the fly here - self.anon_map = util.PopulateDict(self._process_anon) + self.anon_map = prefix_anon_map() # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length @@ -1712,12 +1733,6 @@ class SQLCompiler(Compiled): def _anonymize(self, name): return name % self.anon_map - def _process_anon(self, key): - (ident, derived) = key.split(" ", 1) - anonymous_counter = self.anon_map.get(derived, 1) - self.anon_map[derived] = anonymous_counter + 1 - return derived + "_" + str(anonymous_counter) - def bindparam_string( self, name, diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 918f7524e..c0baa8555 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -178,6 +178,9 @@ def _unsupported_impl(expr, op, *arg, **kw): def _inv_impl(expr, op, **kw): """See :meth:`.ColumnOperators.__inv__`.""" + + # undocumented element currently used by the ORM for + # relationship.contains() if hasattr(expr, "negation_clause"): return expr.negation_clause else: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e6f57b8d1..ba615bc3f 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -16,23 +16,29 @@ import itertools import operator import re -from . import clause_compare from . import coercions from . import operators from . import roles +from . import traversals from . import type_api from .annotation import Annotated from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative from .base import Executable +from .base import HasCacheKey +from .base import HasMemoized from .base import Immutable from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .coercions import _document_text_coercion +from .traversals import _copy_internals +from .traversals import _get_children +from .traversals import NO_CACHE from .visitors import cloned_traverse +from .visitors import InternalTraversal from .visitors import traverse -from .visitors import Visitable +from .visitors import Traversible from .. import exc from .. import inspection from .. import util @@ -162,7 +168,9 @@ def not_(clause): @inspection._self_inspects -class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): +class ClauseElement( + roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible +): """Base class for elements of a programmatically constructed SQL expression. @@ -190,6 +198,13 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): _order_by_label_element = None + @property + def _cache_key_traversal(self): + try: + return self._traverse_internals + except AttributeError: + return NO_CACHE + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -221,28 +236,6 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): """ return self - def _cache_key(self, **kw): - """return an optional cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the strucures that would affect the SQL string or the type handlers - should result in a different cache key. - - If a structure cannot produce a useful cache key, it should raise - NotImplementedError, which will result in the entire structure - for which it's part of not being useful as a cache key. - - - """ - raise NotImplementedError() - @property def _constructor(self): """return the 'constructor' for this ClauseElement. @@ -336,9 +329,9 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): (see :class:`.ColumnElement`) """ - return clause_compare.compare(self, other, **kw) + return traversals.compare(self, other, **kw) - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals(self, **kw): """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly @@ -349,21 +342,46 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): traversal, cloned traversal, annotations). """ - pass - def get_children(self, **kwargs): - r"""Return immediate child elements of this :class:`.ClauseElement`. + try: + traverse_internals = self._traverse_internals + except AttributeError: + return + + for attrname, obj, meth in _copy_internals.run_generated_dispatch( + self, traverse_internals, "_generated_copy_internals_traversal" + ): + if obj is not None: + result = meth(self, obj, **kw) + if result is not None: + setattr(self, attrname, result) + + def get_children(self, omit_attrs=None, **kw): + r"""Return immediate child :class:`.Traversible` elements of this + :class:`.Traversible`. This is used for visit traversal. - \**kwargs may contain flags that change the collection that is + \**kw may contain flags that change the collection that is returned, for example to return a subset of items in order to cut down on larger traversals, or to return child items from a different context (such as schema-level collections instead of clause-level). """ - return [] + result = [] + try: + traverse_internals = self._traverse_internals + except AttributeError: + return result + + for attrname, obj, meth in _get_children.run_generated_dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ): + if obj is None or omit_attrs and attrname in omit_attrs: + continue + result.extend(meth(obj, **kw)) + return result def self_group(self, against=None): # type: (Optional[Any]) -> ClauseElement @@ -501,6 +519,8 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): return or_(self, other) def __invert__(self): + # undocumented element currently used by the ORM for + # relationship.contains() if hasattr(self, "negation_clause"): return self.negation_clause else: @@ -508,9 +528,7 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): def _negate(self): return UnaryExpression( - self.self_group(against=operators.inv), - operator=operators.inv, - negate=None, + self.self_group(against=operators.inv), operator=operators.inv ) def __bool__(self): @@ -731,9 +749,6 @@ class ColumnElement( else: return comparator_factory(self) - def _cache_key(self, **kw): - raise NotImplementedError(self.__class__) - def __getattr__(self, key): try: return getattr(self.comparator, key) @@ -969,6 +984,13 @@ class BindParameter(roles.InElementRole, ColumnElement): __visit_name__ = "bindparam" + _traverse_internals = [ + ("key", InternalTraversal.dp_anon_name), + ("type", InternalTraversal.dp_type), + ("callable", InternalTraversal.dp_plain_dict), + ("value", InternalTraversal.dp_plain_obj), + ] + _is_crud = False _expanding_in_types = () @@ -1321,26 +1343,19 @@ class BindParameter(roles.InElementRole, ColumnElement): ) return c - def _cache_key(self, bindparams=None, **kw): - if bindparams is None: - # even though _cache_key is a private method, we would like to - # be super paranoid about this point. You can't include the - # "value" or "callable" in the cache key, because the value is - # not part of the structure of a statement and is likely to - # change every time. However you cannot *throw it away* either, - # because you can't invoke the statement without the parameter - # values that were explicitly placed. So require that they - # are collected here to make sure this happens. - if self._value_required_for_cache: - raise NotImplementedError( - "bindparams collection argument required for _cache_key " - "implementation. Bound parameter cache keys are not safe " - "to use without accommodating for the value or callable " - "within the parameter itself." - ) - else: - bindparams.append(self) - return (BindParameter, self.type._cache_key, self._orig_key) + def _gen_cache_key(self, anon_map, bindparams): + if self in anon_map: + return (anon_map[self], self.__class__) + + id_ = anon_map[self] + bindparams.append(self) + + return ( + id_, + self.__class__, + self.type._gen_cache_key, + traversals._resolve_name_for_compare(self, self.key, anon_map), + ) def _convert_to_unique(self): if not self.unique: @@ -1377,12 +1392,11 @@ class TypeClause(ClauseElement): __visit_name__ = "typeclause" + _traverse_internals = [("type", InternalTraversal.dp_type)] + def __init__(self, type_): self.type = type_ - def _cache_key(self, **kw): - return (TypeClause, self.type._cache_key) - class TextClause( roles.DDLConstraintColumnRole, @@ -1419,6 +1433,11 @@ class TextClause( __visit_name__ = "textclause" + _traverse_internals = [ + ("_bindparams", InternalTraversal.dp_string_clauseelement_dict), + ("text", InternalTraversal.dp_string), + ] + _is_text_clause = True _is_textual = True @@ -1861,19 +1880,6 @@ class TextClause( else: return self - def _copy_internals(self, clone=_clone, **kw): - self._bindparams = dict( - (b.key, clone(b, **kw)) for b in self._bindparams.values() - ) - - def get_children(self, **kwargs): - return list(self._bindparams.values()) - - def _cache_key(self, **kw): - return (self.text,) + tuple( - bind._cache_key for bind in self._bindparams.values() - ) - class Null(roles.ConstExprRole, ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -1885,6 +1891,8 @@ class Null(roles.ConstExprRole, ColumnElement): __visit_name__ = "null" + _traverse_internals = [] + @util.memoized_property def type(self): return type_api.NULLTYPE @@ -1895,9 +1903,6 @@ class Null(roles.ConstExprRole, ColumnElement): return Null() - def _cache_key(self, **kw): - return (Null,) - class False_(roles.ConstExprRole, ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -1908,6 +1913,7 @@ class False_(roles.ConstExprRole, ColumnElement): """ __visit_name__ = "false" + _traverse_internals = [] @util.memoized_property def type(self): @@ -1954,9 +1960,6 @@ class False_(roles.ConstExprRole, ColumnElement): return False_() - def _cache_key(self, **kw): - return (False_,) - class True_(roles.ConstExprRole, ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. @@ -1968,6 +1971,8 @@ class True_(roles.ConstExprRole, ColumnElement): __visit_name__ = "true" + _traverse_internals = [] + @util.memoized_property def type(self): return type_api.BOOLEANTYPE @@ -2020,9 +2025,6 @@ class True_(roles.ConstExprRole, ColumnElement): return True_() - def _cache_key(self, **kw): - return (True_,) - class ClauseList( roles.InElementRole, @@ -2038,6 +2040,11 @@ class ClauseList( __visit_name__ = "clauselist" + _traverse_internals = [ + ("clauses", InternalTraversal.dp_clauseelement_list), + ("operator", InternalTraversal.dp_operator), + ] + def __init__(self, *clauses, **kwargs): self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) @@ -2082,17 +2089,6 @@ class ClauseList( coercions.expect(self._text_converter_role, clause) ) - def _copy_internals(self, clone=_clone, **kw): - self.clauses = [clone(clause, **kw) for clause in self.clauses] - - def get_children(self, **kwargs): - return self.clauses - - def _cache_key(self, **kw): - return (ClauseList, self.operator) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - @property def _from_objects(self): return list(itertools.chain(*[c._from_objects for c in self.clauses])) @@ -2115,11 +2111,6 @@ class BooleanClauseList(ClauseList, ColumnElement): "BooleanClauseList has a private constructor" ) - def _cache_key(self, **kw): - return (BooleanClauseList, self.operator) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - @classmethod def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): convert_clauses = [] @@ -2250,6 +2241,8 @@ or_ = BooleanClauseList.or_ class Tuple(ClauseList, ColumnElement): """Represent a SQL tuple.""" + _traverse_internals = ClauseList._traverse_internals + [] + def __init__(self, *clauses, **kw): """Return a :class:`.Tuple`. @@ -2289,11 +2282,6 @@ class Tuple(ClauseList, ColumnElement): def _select_iterable(self): return (self,) - def _cache_key(self, **kw): - return (Tuple,) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - def _bind_param(self, operator, obj, type_=None): return Tuple( *[ @@ -2339,6 +2327,12 @@ class Case(ColumnElement): __visit_name__ = "case" + _traverse_internals = [ + ("value", InternalTraversal.dp_clauseelement), + ("whens", InternalTraversal.dp_clauseelement_tuples), + ("else_", InternalTraversal.dp_clauseelement), + ] + def __init__(self, whens, value=None, else_=None): r"""Produce a ``CASE`` expression. @@ -2501,40 +2495,6 @@ class Case(ColumnElement): else: self.else_ = None - def _copy_internals(self, clone=_clone, **kw): - if self.value is not None: - self.value = clone(self.value, **kw) - self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens] - if self.else_ is not None: - self.else_ = clone(self.else_, **kw) - - def get_children(self, **kwargs): - if self.value is not None: - yield self.value - for x, y in self.whens: - yield x - yield y - if self.else_ is not None: - yield self.else_ - - def _cache_key(self, **kw): - return ( - ( - Case, - self.value._cache_key(**kw) - if self.value is not None - else None, - ) - + tuple( - (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens - ) - + ( - self.else_._cache_key(**kw) - if self.else_ is not None - else None, - ) - ) - @property def _from_objects(self): return list( @@ -2603,6 +2563,11 @@ class Cast(WrapsColumnExpression, ColumnElement): __visit_name__ = "cast" + _traverse_internals = [ + ("clause", InternalTraversal.dp_clauseelement), + ("typeclause", InternalTraversal.dp_clauseelement), + ] + def __init__(self, expression, type_): r"""Produce a ``CAST`` expression. @@ -2662,20 +2627,6 @@ class Cast(WrapsColumnExpression, ColumnElement): ) self.typeclause = TypeClause(self.type) - def _copy_internals(self, clone=_clone, **kw): - self.clause = clone(self.clause, **kw) - self.typeclause = clone(self.typeclause, **kw) - - def get_children(self, **kwargs): - return self.clause, self.typeclause - - def _cache_key(self, **kw): - return ( - Cast, - self.clause._cache_key(**kw), - self.typeclause._cache_key(**kw), - ) - @property def _from_objects(self): return self.clause._from_objects @@ -2685,7 +2636,7 @@ class Cast(WrapsColumnExpression, ColumnElement): return self.clause -class TypeCoerce(WrapsColumnExpression, ColumnElement): +class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement): """Represent a Python-side type-coercion wrapper. :class:`.TypeCoerce` supplies the :func:`.expression.type_coerce` @@ -2705,6 +2656,13 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): __visit_name__ = "type_coerce" + _traverse_internals = [ + ("clause", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + + _memoized_property = util.group_expirable_memoized_property() + def __init__(self, expression, type_): r"""Associate a SQL expression with a particular type, without rendering ``CAST``. @@ -2773,21 +2731,11 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): roles.ExpressionElementRole, expression, type_=self.type ) - def _copy_internals(self, clone=_clone, **kw): - self.clause = clone(self.clause, **kw) - self.__dict__.pop("typed_expression", None) - - def get_children(self, **kwargs): - return (self.clause,) - - def _cache_key(self, **kw): - return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw)) - @property def _from_objects(self): return self.clause._from_objects - @util.memoized_property + @_memoized_property def typed_expression(self): if isinstance(self.clause, BindParameter): bp = self.clause._clone() @@ -2806,6 +2754,11 @@ class Extract(ColumnElement): __visit_name__ = "extract" + _traverse_internals = [ + ("expr", InternalTraversal.dp_clauseelement), + ("field", InternalTraversal.dp_string), + ] + def __init__(self, field, expr, **kwargs): """Return a :class:`.Extract` construct. @@ -2818,15 +2771,6 @@ class Extract(ColumnElement): self.field = field self.expr = coercions.expect(roles.ExpressionElementRole, expr) - def _copy_internals(self, clone=_clone, **kw): - self.expr = clone(self.expr, **kw) - - def get_children(self, **kwargs): - return (self.expr,) - - def _cache_key(self, **kw): - return (Extract, self.field, self.expr._cache_key(**kw)) - @property def _from_objects(self): return self.expr._from_objects @@ -2847,18 +2791,11 @@ class _label_reference(ColumnElement): __visit_name__ = "label_reference" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + def __init__(self, element): self.element = element - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (_label_reference, self.element._cache_key(**kw)) - - def get_children(self, **kwargs): - return [self.element] - @property def _from_objects(self): return () @@ -2867,6 +2804,8 @@ class _label_reference(ColumnElement): class _textual_label_reference(ColumnElement): __visit_name__ = "textual_label_reference" + _traverse_internals = [("element", InternalTraversal.dp_string)] + def __init__(self, element): self.element = element @@ -2874,9 +2813,6 @@ class _textual_label_reference(ColumnElement): def _text_clause(self): return TextClause._create_text(self.element) - def _cache_key(self, **kw): - return (_textual_label_reference, self.element) - class UnaryExpression(ColumnElement): """Define a 'unary' expression. @@ -2894,13 +2830,18 @@ class UnaryExpression(ColumnElement): __visit_name__ = "unary" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("operator", InternalTraversal.dp_operator), + ("modifier", InternalTraversal.dp_operator), + ] + def __init__( self, element, operator=None, modifier=None, type_=None, - negate=None, wraps_column_expression=False, ): self.operator = operator @@ -2909,7 +2850,6 @@ class UnaryExpression(ColumnElement): against=self.operator or self.modifier ) self.type = type_api.to_instance(type_) - self.negate = negate self.wraps_column_expression = wraps_column_expression @classmethod @@ -3135,37 +3075,13 @@ class UnaryExpression(ColumnElement): def _from_objects(self): return self.element._from_objects - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return ( - UnaryExpression, - self.element._cache_key(**kw), - self.operator, - self.modifier, - ) - - def get_children(self, **kwargs): - return (self.element,) - def _negate(self): - if self.negate is not None: - return UnaryExpression( - self.element, - operator=self.negate, - negate=self.operator, - modifier=self.modifier, - type_=self.type, - wraps_column_expression=self.wraps_column_expression, - ) - elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: + if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv, type_=type_api.BOOLEANTYPE, wraps_column_expression=self.wraps_column_expression, - negate=None, ) else: return ClauseElement._negate(self) @@ -3286,15 +3202,6 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression): # type: (Optional[Any]) -> ClauseElement return self - def _cache_key(self, **kw): - return ( - self.element._cache_key(**kw), - self.type._cache_key, - self.operator, - self.negate, - self.modifier, - ) - def _negate(self): if isinstance(self.element, (True_, False_)): return self.element._negate() @@ -3318,6 +3225,14 @@ class BinaryExpression(ColumnElement): __visit_name__ = "binary" + _traverse_internals = [ + ("left", InternalTraversal.dp_clauseelement), + ("right", InternalTraversal.dp_clauseelement), + ("operator", InternalTraversal.dp_operator), + ("negate", InternalTraversal.dp_operator), + ("modifiers", InternalTraversal.dp_plain_dict), + ] + _is_implicitly_boolean = True """Indicates that any database will know this is a boolean expression even if the database does not have an explicit boolean datatype. @@ -3360,20 +3275,6 @@ class BinaryExpression(ColumnElement): def _from_objects(self): return self.left._from_objects + self.right._from_objects - def _copy_internals(self, clone=_clone, **kw): - self.left = clone(self.left, **kw) - self.right = clone(self.right, **kw) - - def get_children(self, **kwargs): - return self.left, self.right - - def _cache_key(self, **kw): - return ( - BinaryExpression, - self.left._cache_key(**kw), - self.right._cache_key(**kw), - ) - def self_group(self, against=None): # type: (Optional[Any]) -> ClauseElement @@ -3406,6 +3307,12 @@ class Slice(ColumnElement): __visit_name__ = "slice" + _traverse_internals = [ + ("start", InternalTraversal.dp_plain_obj), + ("stop", InternalTraversal.dp_plain_obj), + ("step", InternalTraversal.dp_plain_obj), + ] + def __init__(self, start, stop, step): self.start = start self.stop = stop @@ -3417,9 +3324,6 @@ class Slice(ColumnElement): assert against is operator.getitem return self - def _cache_key(self, **kw): - return (Slice, self.start, self.stop, self.step) - class IndexExpression(BinaryExpression): """Represent the class of expressions that are like an "index" operation. @@ -3444,6 +3348,11 @@ class GroupedElement(ClauseElement): class Grouping(GroupedElement, ColumnElement): """Represent a grouping within a column expression""" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + def __init__(self, element): self.element = element self.type = getattr(element, "type", type_api.NULLTYPE) @@ -3460,15 +3369,6 @@ class Grouping(GroupedElement, ColumnElement): def _label(self): return getattr(self.element, "_label", None) or self.anon_label - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def get_children(self, **kwargs): - return (self.element,) - - def _cache_key(self, **kw): - return (Grouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -3501,6 +3401,14 @@ class Over(ColumnElement): __visit_name__ = "over" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("order_by", InternalTraversal.dp_clauseelement), + ("partition_by", InternalTraversal.dp_clauseelement), + ("range_", InternalTraversal.dp_plain_obj), + ("rows", InternalTraversal.dp_plain_obj), + ] + order_by = None partition_by = None @@ -3667,30 +3575,6 @@ class Over(ColumnElement): def type(self): return self.element.type - def get_children(self, **kwargs): - return [ - c - for c in (self.element, self.partition_by, self.order_by) - if c is not None - ] - - def _cache_key(self, **kw): - return ( - (Over,) - + tuple( - e._cache_key(**kw) if e is not None else None - for e in (self.element, self.partition_by, self.order_by) - ) - + (self.range_, self.rows) - ) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - if self.partition_by is not None: - self.partition_by = clone(self.partition_by, **kw) - if self.order_by is not None: - self.order_by = clone(self.order_by, **kw) - @property def _from_objects(self): return list( @@ -3723,6 +3607,11 @@ class WithinGroup(ColumnElement): __visit_name__ = "withingroup" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("order_by", InternalTraversal.dp_clauseelement), + ] + order_by = None def __init__(self, element, *order_by): @@ -3791,25 +3680,6 @@ class WithinGroup(ColumnElement): else: return self.element.type - def get_children(self, **kwargs): - return [c for c in (self.element, self.order_by) if c is not None] - - def _cache_key(self, **kw): - return ( - WithinGroup, - self.element._cache_key(**kw) - if self.element is not None - else None, - self.order_by._cache_key(**kw) - if self.order_by is not None - else None, - ) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - if self.order_by is not None: - self.order_by = clone(self.order_by, **kw) - @property def _from_objects(self): return list( @@ -3845,6 +3715,11 @@ class FunctionFilter(ColumnElement): __visit_name__ = "funcfilter" + _traverse_internals = [ + ("func", InternalTraversal.dp_clauseelement), + ("criterion", InternalTraversal.dp_clauseelement), + ] + criterion = None def __init__(self, func, *criterion): @@ -3932,23 +3807,6 @@ class FunctionFilter(ColumnElement): def type(self): return self.func.type - def get_children(self, **kwargs): - return [c for c in (self.func, self.criterion) if c is not None] - - def _copy_internals(self, clone=_clone, **kw): - self.func = clone(self.func, **kw) - if self.criterion is not None: - self.criterion = clone(self.criterion, **kw) - - def _cache_key(self, **kw): - return ( - FunctionFilter, - self.func._cache_key(**kw), - self.criterion._cache_key(**kw) - if self.criterion is not None - else None, - ) - @property def _from_objects(self): return list( @@ -3962,7 +3820,7 @@ class FunctionFilter(ColumnElement): ) -class Label(roles.LabeledColumnExprRole, ColumnElement): +class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3972,6 +3830,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): __visit_name__ = "label" + _traverse_internals = [ + ("name", InternalTraversal.dp_anon_name), + ("_type", InternalTraversal.dp_type), + ("_element", InternalTraversal.dp_clauseelement), + ] + + _memoized_property = util.group_expirable_memoized_property() + def __init__(self, name, element, type_=None): """Return a :class:`Label` object for the given :class:`.ColumnElement`. @@ -4010,14 +3876,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): def __reduce__(self): return self.__class__, (self.name, self._element, self._type) - def _cache_key(self, **kw): - return (Label, self.element._cache_key(**kw), self._resolve_label) - @util.memoized_property def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean - @util.memoized_property + @_memoized_property def _allow_label_resolve(self): return self.element._allow_label_resolve @@ -4031,7 +3894,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): self._type or getattr(self._element, "type", None) ) - @util.memoized_property + @_memoized_property def element(self): return self._element.self_group(against=operators.as_) @@ -4057,13 +3920,9 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): def foreign_keys(self): return self.element.foreign_keys - def get_children(self, **kwargs): - return (self.element,) - def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): + self._reset_memoizations() self._element = clone(self._element, **kw) - self.__dict__.pop("element", None) - self.__dict__.pop("_allow_label_resolve", None) if anonymize_labels: self.name = self._resolve_label = _anonymous_label( "%%(%d %s)s" @@ -4124,6 +3983,13 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): __visit_name__ = "column" + _traverse_internals = [ + ("name", InternalTraversal.dp_string), + ("type", InternalTraversal.dp_type), + ("table", InternalTraversal.dp_clauseelement), + ("is_literal", InternalTraversal.dp_boolean), + ] + onupdate = default = server_default = server_onupdate = None _is_multiparam_column = False @@ -4254,14 +4120,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): table = property(_get_table, _set_table) - def _cache_key(self, **kw): - return ( - self.name, - self.table.name if self.table is not None else None, - self.is_literal, - self.type._cache_key, - ) - @_memoized_property def _from_objects(self): t = self.table @@ -4395,12 +4253,11 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): class CollationClause(ColumnElement): __visit_name__ = "collation" + _traverse_internals = [("collation", InternalTraversal.dp_string)] + def __init__(self, collation): self.collation = collation - def _cache_key(self, **kw): - return (CollationClause, self.collation) - class _IdentifiedClause(Executable, ClauseElement): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 7ce822669..08e69f075 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -86,7 +86,6 @@ __all__ = [ from .base import _from_objects # noqa from .base import ColumnCollection # noqa from .base import Executable # noqa -from .base import Generative # noqa from .base import PARSE_AUTOCOMMIT # noqa from .dml import Delete # noqa from .dml import Insert # noqa @@ -242,7 +241,6 @@ _UnaryExpression = UnaryExpression _Case = Case _Tuple = Tuple _Over = Over -_Generative = Generative _TypeClause = TypeClause _Extract = Extract _Exists = Exists diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cbc8e539f..96e64dc28 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -17,7 +17,6 @@ from . import sqltypes from . import util as sqlutil from .base import ColumnCollection from .base import Executable -from .elements import _clone from .elements import _type_from_args from .elements import BinaryExpression from .elements import BindParameter @@ -33,7 +32,8 @@ from .elements import WithinGroup from .selectable import Alias from .selectable import FromClause from .selectable import Select -from .visitors import VisitableType +from .visitors import InternalTraversal +from .visitors import TraversibleType from .. import util @@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ + _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)] + packagenames = () _has_args = False + _memoized_property = FromClause._memoized_property + def __init__(self, *clauses, **kwargs): r"""Construct a :class:`.FunctionElement`. @@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): col = self.label(None) return ColumnCollection(columns=[(col.key, col)]) - @util.memoized_property + @_memoized_property def clauses(self): """Return the underlying :class:`.ClauseList` which contains the arguments for this :class:`.FunctionElement`. @@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause): def _from_objects(self): return self.clauses._from_objects - def get_children(self, **kwargs): - return (self.clause_expr,) - - def _cache_key(self, **kw): - return (FunctionElement, self.clause_expr._cache_key(**kw)) - - def _copy_internals(self, clone=_clone, **kw): - self.clause_expr = clone(self.clause_expr, **kw) - self._reset_exported() - FunctionElement.clauses._reset(self) - def within_group_type(self, within_group): """For types that define their return type as based on the criteria within a WITHIN GROUP (ORDER BY) expression, called by the @@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause): class FunctionAsBinary(BinaryExpression): + _traverse_internals = [ + ("sql_function", InternalTraversal.dp_clauseelement), + ("left_index", InternalTraversal.dp_plain_obj), + ("right_index", InternalTraversal.dp_plain_obj), + ("modifiers", InternalTraversal.dp_plain_dict), + ] + def __init__(self, fn, left_index, right_index): self.sql_function = fn self.left_index = left_index @@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression): def right(self, value): self.sql_function.clauses.clauses[self.right_index - 1] = value - def _copy_internals(self, clone=_clone, **kw): - self.sql_function = clone(self.sql_function, **kw) - - def get_children(self, **kw): - yield self.sql_function - - def _cache_key(self, **kw): - return ( - FunctionAsBinary, - self.sql_function._cache_key(**kw), - self.left_index, - self.right_index, - ) - class _FunctionGenerator(object): """Generate SQL function expressions. @@ -606,6 +592,12 @@ class Function(FunctionElement): __visit_name__ = "function" + _traverse_internals = FunctionElement._traverse_internals + [ + ("packagenames", InternalTraversal.dp_plain_obj), + ("name", InternalTraversal.dp_string), + ("type", InternalTraversal.dp_type), + ] + def __init__(self, name, *clauses, **kw): """Construct a :class:`.Function`. @@ -630,15 +622,8 @@ class Function(FunctionElement): unique=True, ) - def _cache_key(self, **kw): - return ( - (Function,) + tuple(self.packagenames) - if self.packagenames - else () + (self.name, self.clause_expr._cache_key(**kw)) - ) - -class _GenericMeta(VisitableType): +class _GenericMeta(TraversibleType): def __init__(cls, clsname, bases, clsdict): if annotation.Annotated not in cls.__mro__: cls.name = name = clsdict.get("name", clsname) @@ -764,6 +749,10 @@ class next_value(GenericFunction): type = sqltypes.Integer() name = "next_value" + _traverse_internals = [ + ("sequence", InternalTraversal.dp_named_ddl_element) + ] + def __init__(self, seq, **kw): assert isinstance( seq, schema.Sequence @@ -771,21 +760,12 @@ class next_value(GenericFunction): self._bind = kw.get("bind", None) self.sequence = seq - def _cache_key(self, **kw): - return (next_value, self.sequence.name) - def compare(self, other, **kw): return ( isinstance(other, next_value) and self.sequence.name == other.sequence.name ) - def get_children(self, **kwargs): - return [] - - def _copy_internals(self, **kw): - pass - @property def _from_objects(self): return [] diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4e8f4a397..ee7dc61ce 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -50,6 +50,7 @@ from .elements import ColumnElement from .elements import quoted_name from .elements import TextClause from .selectable import TableClause +from .visitors import InternalTraversal from .. import event from .. import exc from .. import inspection @@ -425,6 +426,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): __visit_name__ = "table" + _traverse_internals = TableClause._traverse_internals + [ + ("schema", InternalTraversal.dp_string) + ] + + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + @util.deprecated_params( + useexisting=( + "0.7", + "The :paramref:`.Table.useexisting` parameter is deprecated and " + "will be removed in a future release. Please use " + ":paramref:`.Table.extend_existing`.", + ) + ) def __new__(cls, *args, **kw): if not args: # python3k pickle seems to call this @@ -763,6 +779,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def get_children( self, column_collections=True, schema_visitor=False, **kw ): + # TODO: consider that we probably don't need column_collections=True + # at all, it does not seem to impact anything if not schema_visitor: return TableClause.get_children( self, column_collections=column_collections, **kw diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6a7413fc0..4b3844eec 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -31,6 +31,7 @@ from .base import ColumnSet from .base import DedupeColumnCollection from .base import Executable from .base import Generative +from .base import HasMemoized from .base import Immutable from .coercions import _document_text_coercion from .elements import _anonymous_label @@ -39,11 +40,13 @@ from .elements import and_ from .elements import BindParameter from .elements import ClauseElement from .elements import ClauseList +from .elements import ColumnClause from .elements import GroupedElement from .elements import Grouping from .elements import literal_column from .elements import True_ from .elements import UnaryExpression +from .visitors import InternalTraversal from .. import exc from .. import util @@ -201,6 +204,8 @@ class Selectable(ReturnsRows): class HasPrefixes(object): _prefixes = () + _traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)] + @_generative @_document_text_coercion( "expr", @@ -252,6 +257,8 @@ class HasPrefixes(object): class HasSuffixes(object): _suffixes = () + _traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)] + @_generative @_document_text_coercion( "expr", @@ -295,7 +302,7 @@ class HasSuffixes(object): ) -class FromClause(roles.AnonymizedFromClauseRole, Selectable): +class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -529,11 +536,6 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return getattr(self, "name", self.__class__.__name__ + " object") - def _reset_exported(self): - """delete memoized collections when a FromClause is cloned.""" - - self._memoized_property.expire_instance(self) - def _generate_fromclause_column_proxies(self, fromclause): fromclause._columns._populate_separate_keys( col._make_proxy(fromclause) for col in self.c @@ -668,6 +670,14 @@ class Join(FromClause): __visit_name__ = "join" + _traverse_internals = [ + ("left", InternalTraversal.dp_clauseelement), + ("right", InternalTraversal.dp_clauseelement), + ("onclause", InternalTraversal.dp_clauseelement), + ("isouter", InternalTraversal.dp_boolean), + ("full", InternalTraversal.dp_boolean), + ] + _is_join = True def __init__(self, left, right, onclause=None, isouter=False, full=False): @@ -805,25 +815,6 @@ class Join(FromClause): self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) - def _copy_internals(self, clone=_clone, **kw): - self._reset_exported() - self.left = clone(self.left, **kw) - self.right = clone(self.right, **kw) - self.onclause = clone(self.onclause, **kw) - - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def _cache_key(self, **kw): - return ( - Join, - self.isouter, - self.full, - self.left._cache_key(**kw), - self.right._cache_key(**kw), - self.onclause._cache_key(**kw), - ) - def _match_primaries(self, left, right): if isinstance(left, Join): left_right = left.right @@ -1175,6 +1166,11 @@ class AliasedReturnsRows(FromClause): _is_from_container = True named_with_column = True + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("name", InternalTraversal.dp_anon_name), + ] + def __init__(self, *arg, **kw): raise NotImplementedError( "The %s class is not intended to be constructed " @@ -1243,18 +1239,13 @@ class AliasedReturnsRows(FromClause): def _copy_internals(self, clone=_clone, **kw): element = clone(self.element, **kw) + + # the element clone is usually against a Table that returns the + # same object. don't reset exported .c. collections and other + # memoized details if nothing changed if element is not self.element: self._reset_exported() - self.element = element - - def get_children(self, column_collections=True, **kw): - if column_collections: - for c in self.c: - yield c - yield self.element - - def _cache_key(self, **kw): - return (self.__class__, self.element._cache_key(**kw), self._orig_name) + self.element = element @property def _from_objects(self): @@ -1396,6 +1387,11 @@ class TableSample(AliasedReturnsRows): __visit_name__ = "tablesample" + _traverse_internals = AliasedReturnsRows._traverse_internals + [ + ("sampling", InternalTraversal.dp_clauseelement), + ("seed", InternalTraversal.dp_clauseelement), + ] + @classmethod def _factory(cls, selectable, sampling, name=None, seed=None): """Return a :class:`.TableSample` object. @@ -1466,6 +1462,16 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows): __visit_name__ = "cte" + _traverse_internals = ( + AliasedReturnsRows._traverse_internals + + [ + ("_cte_alias", InternalTraversal.dp_clauseelement), + ("_restates", InternalTraversal.dp_clauseelement_unordered_set), + ("recursive", InternalTraversal.dp_boolean), + ] + + HasSuffixes._traverse_internals + ) + @classmethod def _factory(cls, selectable, name=None, recursive=False): r"""Return a new :class:`.CTE`, or Common Table Expression instance. @@ -1495,15 +1501,13 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows): def _copy_internals(self, clone=_clone, **kw): super(CTE, self)._copy_internals(clone, **kw) + # TODO: I don't like that we can't use the traversal data here if self._cte_alias is not None: self._cte_alias = clone(self._cte_alias, **kw) self._restates = frozenset( [clone(elem, **kw) for elem in self._restates] ) - def _cache_key(self, *arg, **kw): - raise NotImplementedError("TODO") - def alias(self, name=None, flat=False): """Return an :class:`.Alias` of this :class:`.CTE`. @@ -1764,6 +1768,8 @@ class Subquery(AliasedReturnsRows): class FromGrouping(GroupedElement, FromClause): """Represent a grouping of a FROM clause""" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + def __init__(self, element): self.element = coercions.expect(roles.FromClauseRole, element) @@ -1792,15 +1798,6 @@ class FromGrouping(GroupedElement, FromClause): def _hide_froms(self): return self.element._hide_froms - def get_children(self, **kwargs): - return (self.element,) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (FromGrouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -1843,6 +1840,14 @@ class TableClause(Immutable, FromClause): __visit_name__ = "table" + _traverse_internals = [ + ( + "columns", + InternalTraversal.dp_fromclause_canonical_column_collection, + ), + ("name", InternalTraversal.dp_string), + ] + named_with_column = True implicit_returning = False @@ -1895,17 +1900,6 @@ class TableClause(Immutable, FromClause): self._columns.add(c) c.table = self - def get_children(self, column_collections=True, **kwargs): - if column_collections: - return [c for c in self.c] - else: - return [] - - def _cache_key(self, **kw): - return (TableClause, self.name) + tuple( - col._cache_key(**kw) for col in self._columns - ) - @util.dependencies("sqlalchemy.sql.dml") def insert(self, dml, values=None, inline=False, **kwargs): """Generate an :func:`.insert` construct against this @@ -1965,6 +1959,13 @@ class TableClause(Immutable, FromClause): class ForUpdateArg(ClauseElement): + _traverse_internals = [ + ("of", InternalTraversal.dp_clauseelement_list), + ("nowait", InternalTraversal.dp_boolean), + ("read", InternalTraversal.dp_boolean), + ("skip_locked", InternalTraversal.dp_boolean), + ] + @classmethod def parse_legacy_select(self, arg): """Parse the for_update argument of :func:`.select`. @@ -2029,19 +2030,6 @@ class ForUpdateArg(ClauseElement): def __hash__(self): return id(self) - def _copy_internals(self, clone=_clone, **kw): - if self.of is not None: - self.of = [clone(col, **kw) for col in self.of] - - def _cache_key(self, **kw): - return ( - ForUpdateArg, - self.nowait, - self.read, - self.skip_locked, - self.of._cache_key(**kw) if self.of is not None else None, - ) - def __init__( self, nowait=False, @@ -2074,6 +2062,7 @@ class SelectBase( roles.DMLSelectRole, roles.CompoundElementRole, roles.InElementRole, + HasMemoized, HasCTE, Executable, SupportsCloneAnnotations, @@ -2092,9 +2081,6 @@ class SelectBase( _memoized_property = util.group_expirable_memoized_property() - def _reset_memoizations(self): - self._memoized_property.expire_instance(self) - def _generate_fromclause_column_proxies(self, fromclause): # type: (FromClause) raise NotImplementedError() @@ -2339,6 +2325,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ __visit_name__ = "grouping" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] _is_select_container = True @@ -2350,9 +2337,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def select_statement(self): return self.element - def get_children(self, **kwargs): - return (self.element,) - def self_group(self, against=None): # type: (Optional[Any]) -> FromClause return self @@ -2377,12 +2361,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ return self.element.selected_columns - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (SelectStatementGrouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -2758,9 +2736,6 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): def _label_resolve_dict(self): raise NotImplementedError() - def _copy_internals(self, clone=_clone, **kw): - raise NotImplementedError() - class CompoundSelect(GenerativeSelect): """Forms the basis of ``UNION``, ``UNION ALL``, and other @@ -2785,6 +2760,16 @@ class CompoundSelect(GenerativeSelect): __visit_name__ = "compound_select" + _traverse_internals = [ + ("selects", InternalTraversal.dp_clauseelement_list), + ("_limit_clause", InternalTraversal.dp_clauseelement), + ("_offset_clause", InternalTraversal.dp_clauseelement), + ("_order_by_clause", InternalTraversal.dp_clauseelement), + ("_group_by_clause", InternalTraversal.dp_clauseelement), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("keyword", InternalTraversal.dp_string), + ] + SupportsCloneAnnotations._traverse_internals + UNION = util.symbol("UNION") UNION_ALL = util.symbol("UNION ALL") EXCEPT = util.symbol("EXCEPT") @@ -3004,47 +2989,6 @@ class CompoundSelect(GenerativeSelect): """ return self.selects[0].selected_columns - def _copy_internals(self, clone=_clone, **kw): - self._reset_memoizations() - self.selects = [clone(s, **kw) for s in self.selects] - if hasattr(self, "_col_map"): - del self._col_map - for attr in ( - "_limit_clause", - "_offset_clause", - "_order_by_clause", - "_group_by_clause", - "_for_update_arg", - ): - if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr), **kw)) - - def get_children(self, **kwargs): - return [self._order_by_clause, self._group_by_clause] + list( - self.selects - ) - - def _cache_key(self, **kw): - return ( - (CompoundSelect, self.keyword) - + tuple(stmt._cache_key(**kw) for stmt in self.selects) - + ( - self._order_by_clause._cache_key(**kw) - if self._order_by_clause is not None - else None, - ) - + ( - self._group_by_clause._cache_key(**kw) - if self._group_by_clause is not None - else None, - ) - + ( - self._for_update_arg._cache_key(**kw) - if self._for_update_arg is not None - else None, - ) - ) - def bind(self): if self._bind: return self._bind @@ -3193,11 +3137,35 @@ class Select( _hints = util.immutabledict() _statement_hints = () _distinct = False - _from_cloned = None + _distinct_on = () _correlate = () _correlate_except = None _memoized_property = SelectBase._memoized_property + _traverse_internals = ( + [ + ("_from_obj", InternalTraversal.dp_fromclause_ordered_set), + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_whereclause", InternalTraversal.dp_clauseelement), + ("_having", InternalTraversal.dp_clauseelement), + ("_order_by_clause", InternalTraversal.dp_clauseelement_list), + ("_group_by_clause", InternalTraversal.dp_clauseelement_list), + ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), + ( + "_correlate_except", + InternalTraversal.dp_clauseelement_unordered_set, + ), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("_statement_hints", InternalTraversal.dp_statement_hint_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_distinct", InternalTraversal.dp_boolean), + ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ] + + HasPrefixes._traverse_internals + + HasSuffixes._traverse_internals + + SupportsCloneAnnotations._traverse_internals + ) + @util.deprecated_params( autocommit=( "0.6", @@ -3416,13 +3384,14 @@ class Select( """ self._auto_correlate = correlate if distinct is not False: - if distinct is True: - self._distinct = True - else: - self._distinct = [ - coercions.expect(roles.WhereHavingRole, e) - for e in util.to_list(distinct) - ] + self._distinct = True + if not isinstance(distinct, bool): + self._distinct_on = tuple( + [ + coercions.expect(roles.WhereHavingRole, e) + for e in util.to_list(distinct) + ] + ) if from_obj is not None: self._from_obj = util.OrderedSet( @@ -3472,15 +3441,17 @@ class Select( GenerativeSelect.__init__(self, **kwargs) + # @_memoized_property @property def _froms(self): - # would love to cache this, - # but there's just enough edge cases, particularly now that - # declarative encourages construction of SQL expressions - # without tables present, to just regen this each time. + # current roadblock to caching is two tests that test that the + # SELECT can be compiled to a string, then a Table is created against + # columns, then it can be compiled again and works. this is somewhat + # valid as people make select() against declarative class where + # columns don't have their Table yet and perhaps some operations + # call upon _froms and cache it too soon. froms = [] seen = set() - translate = self._from_cloned for item in itertools.chain( _from_objects(*self._raw_columns), @@ -3493,8 +3464,6 @@ class Select( raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" ) - if translate and item in translate: - item = translate[item] if not seen.intersection(item._cloned_set): froms.append(item) seen.update(item._cloned_set) @@ -3518,15 +3487,6 @@ class Select( itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms]) ) if toremove: - # if we're maintaining clones of froms, - # add the copies out to the toremove list. only include - # clones that are lexical equivalents. - if self._from_cloned: - toremove.update( - self._from_cloned[f] - for f in toremove.intersection(self._from_cloned) - if self._from_cloned[f]._is_lexical_equivalent(f) - ) # filter out to FROM clauses not in the list, # using a list to maintain ordering froms = [f for f in froms if f not in toremove] @@ -3707,7 +3667,6 @@ class Select( return False def _copy_internals(self, clone=_clone, **kw): - # Select() object has been cloned and probably adapted by the # given clone function. Apply the cloning function to internal # objects @@ -3719,37 +3678,42 @@ class Select( # as of 0.7.4 we also put the current version of _froms, which # gets cleared on each generation. previously we were "baking" # _froms into self._from_obj. - self._from_cloned = from_cloned = dict( - (f, clone(f, **kw)) for f in self._from_obj.union(self._froms) - ) - # 3. update persistent _from_obj with the cloned versions. - self._from_obj = util.OrderedSet( - from_cloned[f] for f in self._from_obj + all_the_froms = list( + itertools.chain( + _from_objects(*self._raw_columns), + _from_objects(self._whereclause) + if self._whereclause is not None + else (), + ) ) + new_froms = {f: clone(f, **kw) for f in all_the_froms} + # copy FROM collections - # the _correlate collection is done separately, what can happen - # here is the same item is _correlate as in _from_obj but the - # _correlate version has an annotation on it - (specifically - # RelationshipProperty.Comparator._criterion_exists() does - # this). Also keep _correlate liberally open with its previous - # contents, as this set is used for matching, not rendering. - self._correlate = set(clone(f) for f in self._correlate).union( - self._correlate - ) + self._from_obj = util.OrderedSet( + clone(f, **kw) for f in self._from_obj + ).union(f for f in new_froms.values() if isinstance(f, Join)) - # do something similar for _correlate_except - this is a more - # unusual case but same idea applies + self._correlate = set(clone(f) for f in self._correlate) if self._correlate_except: self._correlate_except = set( clone(f) for f in self._correlate_except - ).union(self._correlate_except) + ) # 4. clone other things. The difficulty here is that Column - # objects are not actually cloned, and refer to their original - # .table, resulting in the wrong "from" parent after a clone - # operation. Hence _from_cloned and _from_obj supersede what is - # present here. + # objects are usually not altered by a straight clone because they + # are dependent on the FROM cloning we just did above in order to + # be targeted correctly, or a new FROM we have might be a JOIN + # object which doesn't have its own columns. so give the cloner a + # hint. + def replace(obj, **kw): + if isinstance(obj, ColumnClause) and obj.table in new_froms: + newelem = new_froms[obj.table].corresponding_column(obj) + return newelem + + kw["replace"] = replace + + # TODO: I'd still like to try to leverage the traversal data self._raw_columns = [clone(c, **kw) for c in self._raw_columns] for attr in ( "_limit_clause", @@ -3763,67 +3727,12 @@ class Select( if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) - # erase _froms collection, - # etc. self._reset_memoizations() def get_children(self, **kwargs): - """return child elements as per the ClauseElement specification.""" - - return ( - self._raw_columns - + list(self._froms) - + [ - x - for x in ( - self._whereclause, - self._having, - self._order_by_clause, - self._group_by_clause, - ) - if x is not None - ] - ) - - def _cache_key(self, **kw): - return ( - (Select,) - + ("raw_columns",) - + tuple(elem._cache_key(**kw) for elem in self._raw_columns) - + ("elements",) - + tuple( - elem._cache_key(**kw) if elem is not None else None - for elem in ( - self._whereclause, - self._having, - self._order_by_clause, - self._group_by_clause, - ) - ) - + ("from_obj",) - + tuple(elem._cache_key(**kw) for elem in self._from_obj) - + ("correlate",) - + tuple( - elem._cache_key(**kw) - for elem in ( - self._correlate if self._correlate is not None else () - ) - ) - + ("correlate_except",) - + tuple( - elem._cache_key(**kw) - for elem in ( - self._correlate_except - if self._correlate_except is not None - else () - ) - ) - + ("for_update",), - ( - self._for_update_arg._cache_key(**kw) - if self._for_update_arg is not None - else None, - ), + # TODO: define "get_children" traversal items separately? + return self._froms + super(Select, self).get_children( + omit_attrs=["_from_obj", "_correlate", "_correlate_except"] ) @_generative @@ -3987,10 +3896,8 @@ class Select( """ if expr: expr = [coercions.expect(roles.ByOfRole, e) for e in expr] - if isinstance(self._distinct, list): - self._distinct = self._distinct + expr - else: - self._distinct = expr + self._distinct = True + self._distinct_on = self._distinct_on + tuple(expr) else: self._distinct = True @@ -4489,6 +4396,11 @@ class TextualSelect(SelectBase): __visit_name__ = "textual_select" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("column_args", InternalTraversal.dp_clauseelement_list), + ] + SupportsCloneAnnotations._traverse_internals + _is_textual = True def __init__(self, text, columns, positional=False): @@ -4534,18 +4446,6 @@ class TextualSelect(SelectBase): c._make_proxy(fromclause) for c in self.column_args ) - def _copy_internals(self, clone=_clone, **kw): - self._reset_memoizations() - self.element = clone(self.element, **kw) - - def get_children(self, **kw): - return [self.element] - - def _cache_key(self, **kw): - return (TextualSelect, self.element._cache_key(**kw)) + tuple( - col._cache_key(**kw) for col in self.column_args - ) - def _scalar_type(self): return self.column_args[0].type diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py new file mode 100644 index 000000000..c0782ce48 --- /dev/null +++ b/lib/sqlalchemy/sql/traversals.py @@ -0,0 +1,768 @@ +from collections import deque +from collections import namedtuple + +from . import operators +from .visitors import ExtendedInternalTraversal +from .visitors import InternalTraversal +from .. import inspect +from .. import util + +SKIP_TRAVERSE = util.symbol("skip_traverse") +COMPARE_FAILED = False +COMPARE_SUCCEEDED = True +NO_CACHE = util.symbol("no_cache") + + +def compare(obj1, obj2, **kw): + if kw.get("use_proxies", False): + strategy = ColIdentityComparatorStrategy() + else: + strategy = TraversalComparatorStrategy() + + return strategy.compare(obj1, obj2, **kw) + + +class HasCacheKey(object): + _cache_key_traversal = NO_CACHE + + def _gen_cache_key(self, anon_map, bindparams): + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the strucures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, it should raise + NotImplementedError, which will result in the entire structure + for which it's part of not being useful as a cache key. + + + """ + + if self in anon_map: + return (anon_map[self], self.__class__) + + id_ = anon_map[self] + + if self._cache_key_traversal is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result = (id_, self.__class__) + + for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch( + self, self._cache_key_traversal, "_generated_cache_key_traversal" + ): + if obj is not None: + result += meth(attrname, obj, self, anon_map, bindparams) + return result + + def _generate_cache_key(self): + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the strucures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + bindparams = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.key == other.key + + +def _clone(element, **kw): + return element._clone() + + +class _CacheKey(ExtendedInternalTraversal): + def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key(anon_map, bindparams)) + + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + return self.visit_has_cache_key( + attrname, inspect(obj), parent, anon_map, bindparams + ) + + def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key(anon_map, bindparams)) + + def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, + ) + + def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_inspectable_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_clauseelement_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams): + from . import elements + + name = obj + if isinstance(name, elements._anonymous_label): + name = name.apply_map(anon_map) + + return (attrname, name) + + def visit_fromclause_ordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_clauseelement_unordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj.name) + + def visit_prefix_sequence( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ), + ) + + def visit_statement_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj) + + def visit_table_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ), + ) + + def visit_type(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key) + + def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, tuple((key, obj[key]) for key in sorted(obj))) + + def visit_string_clauseelement_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_string(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_boolean(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_operator(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_fromclause_canonical_column_collection( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(col._gen_cache_key(anon_map, bindparams) for col in obj), + ) + + def visit_annotations_state( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + self.dispatch(sym)( + key, obj[key], obj, anon_map, bindparams + ), + ) + for key, sym in parent._annotation_traversals + ), + ) + + def visit_unknown_structure( + self, attrname, obj, parent, anon_map, bindparams + ): + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal = _CacheKey() + + +class _CopyInternals(InternalTraversal): + """Generate a _copy_internals internal traversal dispatch for classes + with a _traverse_internals collection.""" + + def visit_clauseelement(self, parent, element, clone=_clone, **kw): + return clone(element, **kw) + + def visit_clauseelement_list(self, parent, element, clone=_clone, **kw): + return [clone(clause, **kw) for clause in element] + + def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw): + return [ + tuple(clone(tup_elem, **kw) for tup_elem in elem) + for elem in element + ] + + def visit_string_clauseelement_dict( + self, parent, element, clone=_clone, **kw + ): + return dict( + (key, clone(value, **kw)) for key, value in element.items() + ) + + +_copy_internals = _CopyInternals() + + +class _GetChildren(InternalTraversal): + """Generate a _children_traversal internal traversal dispatch for classes + with a _traverse_internals collection.""" + + def visit_has_cache_key(self, element, **kw): + return (element,) + + def visit_clauseelement(self, element, **kw): + return (element,) + + def visit_clauseelement_list(self, element, **kw): + return tuple(element) + + def visit_clauseelement_tuples(self, element, **kw): + tup = () + for elem in element: + tup += elem + return tup + + def visit_fromclause_canonical_column_collection(self, element, **kw): + if kw.get("column_collections", False): + return tuple(element) + else: + return () + + def visit_string_clauseelement_dict(self, element, **kw): + return tuple(element.values()) + + def visit_fromclause_ordered_set(self, element, **kw): + return tuple(element) + + def visit_clauseelement_unordered_set(self, element, **kw): + return tuple(element) + + +_get_children = _GetChildren() + + +@util.dependencies("sqlalchemy.sql.elements") +def _resolve_name_for_compare(elements, element, name, anon_map, **kw): + if isinstance(name, elements._anonymous_label): + name = name.apply_map(anon_map) + + return name + + +class anon_map(dict): + """A map that creates new keys for missing key access. + + Produces an incrementing sequence given a series of unique keys. + + This is similar to the compiler prefix_anon_map class although simpler. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __init__(self): + self.index = 0 + + def __missing__(self, key): + self[key] = val = str(self.index) + self.index += 1 + return val + + +class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): + __slots__ = "stack", "cache", "anon_map" + + def __init__(self): + self.stack = deque() + self.cache = set() + + def _memoized_attr_anon_map(self): + return (anon_map(), anon_map()) + + def compare(self, obj1, obj2, **kw): + stack = self.stack + cache = self.cache + + compare_annotations = kw.get("compare_annotations", False) + + stack.append((obj1, obj2)) + + while stack: + left, right = stack.popleft() + + if left is right: + continue + elif left is None or right is None: + # we know they are different so no match + return False + elif (left, right) in cache: + continue + cache.add((left, right)) + + visit_name = left.__visit_name__ + if visit_name != right.__visit_name__: + return False + + meth = getattr(self, "compare_%s" % visit_name, None) + + if meth: + attributes_compared = meth(left, right, **kw) + if attributes_compared is COMPARE_FAILED: + return False + elif attributes_compared is SKIP_TRAVERSE: + continue + + # attributes_compared is returned as a list of attribute + # names that were "handled" by the comparison method above. + # remaining attribute names in the _traverse_internals + # will be compared. + else: + attributes_compared = () + + for ( + (left_attrname, left_visit_sym), + (right_attrname, right_visit_sym), + ) in util.zip_longest( + left._traverse_internals, + right._traverse_internals, + fillvalue=(None, None), + ): + if ( + left_attrname != right_attrname + or left_visit_sym is not right_visit_sym + ): + if not compare_annotations and ( + ( + left_visit_sym + is InternalTraversal.dp_annotations_state, + ) + or ( + right_visit_sym + is InternalTraversal.dp_annotations_state, + ) + ): + continue + + return False + elif left_attrname in attributes_compared: + continue + + dispatch = self.dispatch(left_visit_sym) + left_child = getattr(left, left_attrname) + right_child = getattr(right, right_attrname) + if left_child is None: + if right_child is not None: + return False + else: + continue + + comparison = dispatch( + left, left_child, right, right_child, **kw + ) + if comparison is COMPARE_FAILED: + return False + + return True + + def compare_inner(self, obj1, obj2, **kw): + comparator = self.__class__() + return comparator.compare(obj1, obj2, **kw) + + def visit_has_cache_key( + self, left_parent, left, right_parent, right, **kw + ): + if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key( + self.anon_map[1], [] + ): + return COMPARE_FAILED + + def visit_clauseelement( + self, left_parent, left, right_parent, right, **kw + ): + self.stack.append((left, right)) + + def visit_fromclause_canonical_column_collection( + self, left_parent, left, right_parent, right, **kw + ): + for lcol, rcol in util.zip_longest(left, right, fillvalue=None): + self.stack.append((lcol, rcol)) + + def visit_fromclause_derived_column_collection( + self, left_parent, left, right_parent, right, **kw + ): + pass + + def visit_string_clauseelement_dict( + self, left_parent, left, right_parent, right, **kw + ): + for lstr, rstr in util.zip_longest( + sorted(left), sorted(right), fillvalue=None + ): + if lstr != rstr: + return COMPARE_FAILED + self.stack.append((left[lstr], right[rstr])) + + def visit_annotations_state( + self, left_parent, left, right_parent, right, **kw + ): + if not kw.get("compare_annotations", False): + return + + for (lstr, lmeth), (rstr, rmeth) in util.zip_longest( + left_parent._annotation_traversals, + right_parent._annotation_traversals, + fillvalue=(None, None), + ): + if lstr != rstr or (lmeth is not rmeth): + return COMPARE_FAILED + + dispatch = self.dispatch(lmeth) + left_child = left[lstr] + right_child = right[rstr] + if left_child is None: + if right_child is not None: + return False + else: + continue + + comparison = dispatch(None, left_child, None, right_child, **kw) + if comparison is COMPARE_FAILED: + return comparison + + def visit_clauseelement_tuples( + self, left_parent, left, right_parent, right, **kw + ): + for ltup, rtup in util.zip_longest(left, right, fillvalue=None): + if ltup is None or rtup is None: + return COMPARE_FAILED + + for l, r in util.zip_longest(ltup, rtup, fillvalue=None): + self.stack.append((l, r)) + + def visit_clauseelement_list( + self, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + + def _compare_unordered_sequences(self, seq1, seq2, **kw): + if seq1 is None: + return seq2 is None + + completed = set() + for clause in seq1: + for other_clause in set(seq2).difference(completed): + if self.compare_inner(clause, other_clause, **kw): + completed.add(other_clause) + break + return len(completed) == len(seq1) == len(seq2) + + def visit_clauseelement_unordered_set( + self, left_parent, left, right_parent, right, **kw + ): + return self._compare_unordered_sequences(left, right, **kw) + + def visit_fromclause_ordered_set( + self, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + + def visit_string(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_anon_name(self, left_parent, left, right_parent, right, **kw): + return _resolve_name_for_compare( + left_parent, left, self.anon_map[0], **kw + ) == _resolve_name_for_compare( + right_parent, right, self.anon_map[1], **kw + ) + + def visit_boolean(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_operator(self, left_parent, left, right_parent, right, **kw): + return left is right + + def visit_type(self, left_parent, left, right_parent, right, **kw): + return left._compare_type_affinity(right) + + def visit_plain_dict(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_plain_obj(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_named_ddl_element( + self, left_parent, left, right_parent, right, **kw + ): + if left is None: + if right is not None: + return COMPARE_FAILED + + return left.name == right.name + + def visit_prefix_sequence( + self, left_parent, left, right_parent, right, **kw + ): + for (l_clause, l_str), (r_clause, r_str) in util.zip_longest( + left, right, fillvalue=(None, None) + ): + if l_str != r_str: + return COMPARE_FAILED + else: + self.stack.append((l_clause, r_clause)) + + def visit_table_hint_list( + self, left_parent, left, right_parent, right, **kw + ): + left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1])) + right_keys = sorted( + right, key=lambda elem: (elem[0].fullname, elem[1]) + ) + for (ltable, ldialect), (rtable, rdialect) in util.zip_longest( + left_keys, right_keys, fillvalue=(None, None) + ): + if ldialect != rdialect: + return COMPARE_FAILED + elif left[(ltable, ldialect)] != right[(rtable, rdialect)]: + return COMPARE_FAILED + else: + self.stack.append((ltable, rtable)) + + def visit_statement_hint_list( + self, left_parent, left, right_parent, right, **kw + ): + return left == right + + def visit_unknown_structure( + self, left_parent, left, right_parent, right, **kw + ): + raise NotImplementedError() + + def compare_clauselist(self, left, right, **kw): + if left.operator is right.operator: + if operators.is_associative(left.operator): + if self._compare_unordered_sequences( + left.clauses, right.clauses, **kw + ): + return ["operator", "clauses"] + else: + return COMPARE_FAILED + else: + return ["operator"] + else: + return COMPARE_FAILED + + def compare_binary(self, left, right, **kw): + if left.operator == right.operator: + if operators.is_commutative(left.operator): + if ( + compare(left.left, right.left, **kw) + and compare(left.right, right.right, **kw) + ) or ( + compare(left.left, right.right, **kw) + and compare(left.right, right.left, **kw) + ): + return ["operator", "negate", "left", "right"] + else: + return COMPARE_FAILED + else: + return ["operator", "negate"] + else: + return COMPARE_FAILED + + +class ColIdentityComparatorStrategy(TraversalComparatorStrategy): + def compare_column_element( + self, left, right, use_proxies=True, equivalents=(), **kw + ): + """Compare ColumnElements using proxies and equivalent collections. + + This is a comparison strategy specific to the ORM. + """ + + to_compare = (right,) + if equivalents and right in equivalents: + to_compare = equivalents[right].union(to_compare) + + for oth in to_compare: + if use_proxies and left.shares_lineage(oth): + return SKIP_TRAVERSE + elif hash(left) == hash(right): + return SKIP_TRAVERSE + else: + return COMPARE_FAILED + + def compare_column(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_label(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_table(self, left, right, **kw): + # tables compare on identity, since it's not really feasible to + # compare them column by column with the above rules + return SKIP_TRAVERSE if left is right else COMPARE_FAILED diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9c5f5dd47..d09bb28bb 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -12,8 +12,8 @@ from . import operators from .base import SchemaEventTarget -from .visitors import Visitable -from .visitors import VisitableType +from .visitors import Traversible +from .visitors import TraversibleType from .. import exc from .. import util @@ -28,7 +28,7 @@ INDEXABLE = None _resolve_value_to_type = None -class TypeEngine(Visitable): +class TypeEngine(Traversible): """The ultimate base class for all SQL datatypes. Common subclasses of :class:`.TypeEngine` include @@ -535,8 +535,13 @@ class TypeEngine(Visitable): return dialect.type_descriptor(self) @util.memoized_property - def _cache_key(self): - return util.constructor_key(self, self.__class__) + def _gen_cache_key(self): + names = util.get_cls_kwargs(self.__class__) + return (self.__class__,) + tuple( + (k, self.__dict__[k]) + for k in names + if k in self.__dict__ and not k.startswith("_") + ) def adapt(self, cls, **kw): """Produce an "adapted" form of this type, given an "impl" class @@ -617,7 +622,7 @@ class TypeEngine(Visitable): return util.generic_repr(self) -class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType): +class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): pass diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e109852a2..8539f4845 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -734,7 +734,7 @@ def criterion_as_pairs( return pairs -class ClauseAdapter(visitors.ReplacingCloningVisitor): +class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. E.g.:: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7b2ac285a..8c06eb8af 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,14 +28,10 @@ import operator from .. import exc from .. import util - +from ..util import langhelpers +from ..util import symbol __all__ = [ - "VisitableType", - "Visitable", - "ClauseVisitor", - "CloningVisitor", - "ReplacingCloningVisitor", "iterate", "iterate_depthfirst", "traverse_using", @@ -43,85 +39,382 @@ __all__ = [ "traverse_depthfirst", "cloned_traverse", "replacement_traverse", + "Traversible", + "TraversibleType", + "ExternalTraversal", + "InternalTraversal", ] -class VisitableType(type): - """Metaclass which assigns a ``_compiler_dispatch`` method to classes - having a ``__visit_name__`` attribute. +def _generate_compiler_dispatch(cls): + """Generate a _compiler_dispatch() external traversal on classes with a + __visit_name__ attribute. + + """ + visit_name = cls.__visit_name__ + + if isinstance(visit_name, util.compat.string_types): + # There is an optimization opportunity here because the + # the string name of the class's __visit_name__ is known at + # this early stage (import time) so it can be pre-constructed. + getter = operator.attrgetter("visit_%s" % visit_name) + + def _compiler_dispatch(self, visitor, **kw): + try: + meth = getter(visitor) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + else: + # The optimization opportunity is lost for this case because the + # __visit_name__ is not yet a string. As a result, the visit + # string has to be recalculated with each compilation. + def _compiler_dispatch(self, visitor, **kw): + visit_attr = "visit_%s" % self.__visit_name__ + try: + meth = getattr(visitor, visit_attr) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + + self.__visit_name__ on the visitor, and call it with the same + kw params. + """ + cls._compiler_dispatch = _compiler_dispatch + + +class TraversibleType(type): + """Metaclass which assigns dispatch attributes to various kinds of + "visitable" classes. - The ``_compiler_dispatch`` attribute becomes an instance method which - looks approximately like the following:: + Attributes include: - def _compiler_dispatch (self, visitor, **kw): - '''Look for an attribute named "visit_" + self.__visit_name__ - on the visitor, and call it with the same kw params.''' - visit_attr = 'visit_%s' % self.__visit_name__ - return getattr(visitor, visit_attr)(self, **kw) + * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``. + This is called "external traversal" because the caller of each visit() + method is responsible for sub-traversing the inner elements of each + object. This is appropriate for string compilers and other traversals + that need to call upon the inner elements in a specific pattern. - Classes having no ``__visit_name__`` attribute will remain unaffected. + * internal traversal collections ``_children_traversal``, + ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from + an optional ``_traverse_internals`` collection of symbols which comes + from the :class:`.InternalTraversal` list of symbols. This is called + "internal traversal" MARKMARK """ def __init__(cls, clsname, bases, clsdict): - if clsname != "Visitable" and hasattr(cls, "__visit_name__"): - _generate_dispatch(cls) + if clsname != "Traversible": + if "__visit_name__" in clsdict: + _generate_compiler_dispatch(cls) + + super(TraversibleType, cls).__init__(clsname, bases, clsdict) - super(VisitableType, cls).__init__(clsname, bases, clsdict) +class Traversible(util.with_metaclass(TraversibleType)): + """Base class for visitable objects, applies the + :class:`.visitors.TraversibleType` metaclass. -def _generate_dispatch(cls): - """Return an optimized visit dispatch function for the cls - for use by the compiler. """ - if "__visit_name__" in cls.__dict__: - visit_name = cls.__visit_name__ - if isinstance(visit_name, util.compat.string_types): - # There is an optimization opportunity here because the - # the string name of the class's __visit_name__ is known at - # this early stage (import time) so it can be pre-constructed. - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - try: - meth = getter(visitor) - except AttributeError: - raise exc.UnsupportedCompilationError(visitor, cls) - else: - return meth(self, **kw) +class _InternalTraversalType(type): + def __init__(cls, clsname, bases, clsdict): + if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"): + lookup = {} + for key, sym in clsdict.items(): + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.name + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + if hasattr(cls, "_dispatch_lookup"): + lookup.update(cls._dispatch_lookup) + cls._dispatch_lookup = lookup + + super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict) + + +def _generate_dispatcher(visitor, internal_dispatch, method_name): + names = [] + for attrname, visit_sym in internal_dispatch: + meth = visitor.dispatch(visit_sym) + if meth: + visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) + ) + + ("\n ]\n") + ) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + # print(meth_text) + return langhelpers._exec_code_in_env(meth_text, {}, method_name) - else: - # The optimization opportunity is lost for this case because the - # __visit_name__ is not yet a string. As a result, the visit - # string has to be recalculated with each compilation. - def _compiler_dispatch(self, visitor, **kw): - visit_attr = "visit_%s" % self.__visit_name__ - try: - meth = getattr(visitor, visit_attr) - except AttributeError: - raise exc.UnsupportedCompilationError(visitor, cls) - else: - return meth(self, **kw) - - _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ - on the visitor, and call it with the same kw params. - """ - cls._compiler_dispatch = _compiler_dispatch - - -class Visitable(util.with_metaclass(VisitableType, object)): - """Base class for visitable objects, applies the - :class:`.visitors.VisitableType` metaclass. - The :class:`.Visitable` class is essentially at the base of the - :class:`.ClauseElement` hierarchy. +class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): + r"""Defines visitor symbols used for internal traversal. + + The :class:`.InternalTraversal` class is used in two ways. One is that + it can serve as the superclass for an object that implements the + various visit methods of the class. The other is that the symbols + themselves of :class:`.InternalTraversal` are used within + the ``_traverse_internals`` collection. Such as, the :class:`.Case` + object defines ``_travserse_internals`` as :: + + _traverse_internals = [ + ("value", InternalTraversal.dp_clauseelement), + ("whens", InternalTraversal.dp_clauseelement_tuples), + ("else_", InternalTraversal.dp_clauseelement), + ] + + Above, the :class:`.Case` class indicates its internal state as the + attribtues named ``value``, ``whens``, and ``else\_``. They each + link to an :class:`.InternalTraversal` method which indicates the type + of datastructure referred towards. + + Using the ``_traverse_internals`` structure, objects of type + :class:`.InternalTraversible` will have the following methods automatically + implemented: + + * :meth:`.Traversible.get_children` + + * :meth:`.Traversible._copy_internals` + + * :meth:`.Traversible._gen_cache_key` + + Subclasses can also implement these methods directly, particularly for the + :meth:`.Traversible._copy_internals` method, when special steps + are needed. + + .. versionadded:: 1.4 """ + def dispatch(self, visit_symbol): + """Given a method from :class:`.InternalTraversal`, return the + corresponding method on a subclass. -class ClauseVisitor(object): - """Base class for visitor objects which can traverse using + """ + name = self._dispatch_lookup[visit_symbol] + return getattr(self, name, None) + + def run_generated_dispatch( + self, target, internal_dispatch, generate_dispatcher_name + ): + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + dispatcher = _generate_dispatcher( + self, internal_dispatch, generate_dispatcher_name + ) + setattr(target.__class__, generate_dispatcher_name, dispatcher) + return dispatcher(target, self) + + dp_has_cache_key = symbol("HC") + """Visit a :class:`.HasCacheKey` object.""" + + dp_clauseelement = symbol("CE") + """Visit a :class:`.ClauseElement` object.""" + + dp_fromclause_canonical_column_collection = symbol("FC") + """Visit a :class:`.FromClause` object in the context of the + ``columns`` attribute. + + The column collection is "canonical", meaning it is the originally + defined location of the :class:`.ColumnClause` objects. Right now + this means that the object being visited is a :class:`.TableClause` + or :class:`.Table` object only. + + """ + + dp_clauseelement_tuples = symbol("CT") + """Visit a list of tuples which contain :class:`.ClauseElement` + objects. + + """ + + dp_clauseelement_list = symbol("CL") + """Visit a list of :class:`.ClauseElement` objects. + + """ + + dp_clauseelement_unordered_set = symbol("CU") + """Visit an unordered set of :class:`.ClauseElement` objects. """ + + dp_fromclause_ordered_set = symbol("CO") + """Visit an ordered set of :class:`.FromClause` objects. """ + + dp_string = symbol("S") + """Visit a plain string value. + + Examples include table and column names, bound parameter keys, special + keywords such as "UNION", "UNION ALL". + + The string value is considered to be significant for cache key + generation. + + """ + + dp_anon_name = symbol("AN") + """Visit a potentially "anonymized" string value. + + The string value is considered to be significant for cache key + generation. + + """ + + dp_boolean = symbol("B") + """Visit a boolean value. + + The boolean value is considered to be significant for cache key + generation. + + """ + + dp_operator = symbol("O") + """Visit an operator. + + The operator is a function from the :mod:`sqlalchemy.sql.operators` + module. + + The operator value is considered to be significant for cache key + generation. + + """ + + dp_type = symbol("T") + """Visit a :class:`.TypeEngine` object + + The type object is considered to be significant for cache key + generation. + + """ + + dp_plain_dict = symbol("PD") + """Visit a dictionary with string keys. + + The keys of the dictionary should be strings, the values should + be immutable and hashable. The dictionary is considered to be + significant for cache key generation. + + """ + + dp_string_clauseelement_dict = symbol("CD") + """Visit a dictionary of string keys to :class:`.ClauseElement` + objects. + + """ + + dp_string_multi_dict = symbol("MD") + """Visit a dictionary of string keys to values which may either be + plain immutable/hashable or :class:`.HasCacheKey` objects. + + """ + + dp_plain_obj = symbol("PO") + """Visit a plain python object. + + The value should be immutable and hashable, such as an integer. + The value is considered to be significant for cache key generation. + + """ + + dp_annotations_state = symbol("A") + """Visit the state of the :class:`.Annotatated` version of an object. + + """ + + dp_named_ddl_element = symbol("DD") + """Visit a simple named DDL element. + + The current object used by this method is the :class:`.Sequence`. + + The object is only considered to be important for cache key generation + as far as its name, but not any other aspects of it. + + """ + + dp_prefix_sequence = symbol("PS") + """Visit the sequence represented by :class:`.HasPrefixes` + or :class:`.HasSuffixes`. + + """ + + dp_table_hint_list = symbol("TH") + """Visit the ``_hints`` collection of a :class:`.Select` object. + + """ + + dp_statement_hint_list = symbol("SH") + """Visit the ``_statement_hints`` collection of a :class:`.Select` + object. + + """ + + dp_unknown_structure = symbol("UK") + """Visit an unknown structure. + + """ + + +class ExtendedInternalTraversal(InternalTraversal): + """defines additional symbols that are useful in caching applications. + + Traversals for :class:`.ClauseElement` objects only need to use + those symbols present in :class:`.InternalTraversal`. However, for + additional caching use cases within the ORM, symbols dealing with the + :class:`.HasCacheKey` class are added here. + + """ + + dp_ignore = symbol("IG") + """Specify an object that should be ignored entirely. + + This currently applies function call argument caching where some + arguments should not be considered to be part of a cache key. + + """ + + dp_inspectable = symbol("IS") + """Visit an inspectable object where the return value is a HasCacheKey` + object.""" + + dp_multi = symbol("M") + """Visit an object that may be a :class:`.HasCacheKey` or may be a + plain hashable object.""" + + dp_multi_list = symbol("MT") + """Visit a tuple containing elements that may be :class:`.HasCacheKey` or + may be a plain hashable object.""" + + dp_has_cache_key_tuples = symbol("HT") + """Visit a list of tuples which contain :class:`.HasCacheKey` + objects. + + """ + + dp_has_cache_key_list = symbol("HL") + """Visit a list of :class:`.HasCacheKey` objects.""" + + dp_inspectable_list = symbol("IL") + """Visit a list of inspectable objects which upon inspection are + HasCacheKey objects.""" + + +class ExternalTraversal(object): + """Base class for visitor objects which can traverse externally using the :func:`.visitors.traverse` function. Direct usage of the :func:`.visitors.traverse` function is usually @@ -178,7 +471,7 @@ class ClauseVisitor(object): return self -class CloningVisitor(ClauseVisitor): +class CloningExternalTraversal(ExternalTraversal): """Base class for visitor objects which can traverse using the :func:`.visitors.cloned_traverse` function. @@ -203,7 +496,7 @@ class CloningVisitor(ClauseVisitor): ) -class ReplacingCloningVisitor(CloningVisitor): +class ReplacingExternalTraversal(CloningExternalTraversal): """Base class for visitor objects which can traverse using the :func:`.visitors.replacement_traverse` function. @@ -233,6 +526,14 @@ class ReplacingCloningVisitor(CloningVisitor): return replacement_traverse(obj, self.__traverse_options__, replace) +# backwards compatibility +Visitable = Traversible +VisitableType = TraversibleType +ClauseVisitor = ExternalTraversal +CloningVisitor = CloningExternalTraversal +ReplacingCloningVisitor = ReplacingExternalTraversal + + def iterate(obj, opts): r"""traverse the given expression structure, returning an iterator. @@ -405,11 +706,18 @@ def cloned_traverse(obj, opts, visitors): cloned = {} stop_on = set(opts.get("stop_on", [])) - def clone(elem): + def clone(elem, **kw): if elem in stop_on: return elem else: if id(elem) not in cloned: + + if "replace" in kw: + newelem = kw["replace"](elem) + if newelem is not None: + cloned[id(elem)] = newelem + return newelem + cloned[id(elem)] = newelem = elem._clone() newelem._copy_internals(clone=clone) meth = visitors.get(newelem.__visit_name__, None) @@ -461,7 +769,14 @@ def replacement_traverse(obj, opts, replace): stop_on.add(id(newelem)) return newelem else: + if elem not in cloned: + if "replace" in kw: + newelem = kw["replace"](elem) + if newelem is not None: + cloned[elem] = newelem + return newelem + cloned[elem] = newelem = elem._clone() newelem._copy_internals(clone=clone, **kw) return cloned[elem] diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 632f55937..209bc02e3 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -934,7 +934,7 @@ class BranchedOptionTest(fixtures.MappedTest): configure_mappers() - def test_generate_cache_key_unbound_branching(self): + def test_generate_path_cache_key_unbound_branching(self): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = joinedload(A.bs) @@ -950,11 +950,11 @@ class BranchedOptionTest(fixtures.MappedTest): @profiling.function_call_count() def go(): for opt in opts: - opt._generate_cache_key(cache_path) + opt._generate_path_cache_key(cache_path) go() - def test_generate_cache_key_bound_branching(self): + def test_generate_path_cache_key_bound_branching(self): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") base = Load(A).joinedload(A.bs) @@ -970,7 +970,7 @@ class BranchedOptionTest(fixtures.MappedTest): @profiling.function_call_count() def go(): for opt in opts: - opt._generate_cache_key(cache_path) + opt._generate_path_cache_key(cache_path) go() diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index acefe625a..01f0e267f 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -1533,7 +1533,7 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): if query._current_path: query._cache_key = "user7_addresses" - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): return None return RelationshipCache() diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py new file mode 100644 index 000000000..79a94848e --- /dev/null +++ b/test/orm/test_cache_key.py @@ -0,0 +1,120 @@ +from sqlalchemy import inspect +from sqlalchemy.orm import aliased +from sqlalchemy.orm import defaultload +from sqlalchemy.orm import defer +from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Load +from sqlalchemy.orm import subqueryload +from sqlalchemy.testing import eq_ +from test.orm import _fixtures +from ..sql.test_compare import CacheKeyFixture + + +class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): + run_setup_mappers = "once" + run_inserts = None + run_deletes = None + + @classmethod + def setup_mappers(cls): + cls._setup_stock_mapping() + + def test_mapper_and_aliased(self): + User, Address, Keyword = self.classes("User", "Address", "Keyword") + + self._run_cache_key_fixture( + lambda: (inspect(User), inspect(Address), inspect(aliased(User))) + ) + + def test_attributes(self): + User, Address, Keyword = self.classes("User", "Address", "Keyword") + + self._run_cache_key_fixture( + lambda: ( + User.id, + Address.id, + aliased(User).id, + aliased(User, name="foo").id, + aliased(User, name="bar").id, + User.name, + User.addresses, + Address.email_address, + aliased(User).addresses, + ) + ) + + def test_unbound_options(self): + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + self._run_cache_key_fixture( + lambda: ( + joinedload(User.addresses), + joinedload("addresses"), + joinedload(User.orders).selectinload("items"), + joinedload(User.orders).selectinload(Order.items), + defer(User.id), + defer("id"), + defer(Address.id), + joinedload(User.addresses).defer(Address.id), + joinedload(aliased(User).addresses).defer(Address.id), + joinedload(User.addresses).defer("id"), + joinedload(User.orders).joinedload(Order.items), + joinedload(User.orders).subqueryload(Order.items), + subqueryload(User.orders).subqueryload(Order.items), + subqueryload(User.orders) + .subqueryload(Order.items) + .defer(Item.description), + defaultload(User.orders).defaultload(Order.items), + defaultload(User.orders), + ) + ) + + def test_bound_options(self): + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + self._run_cache_key_fixture( + lambda: ( + Load(User).joinedload(User.addresses), + Load(User).joinedload(User.orders), + Load(User).defer(User.id), + Load(User).subqueryload("addresses"), + Load(Address).defer("id"), + Load(aliased(Address)).defer("id"), + Load(User).joinedload(User.addresses).defer(Address.id), + Load(User).joinedload(User.orders).joinedload(Order.items), + Load(User).joinedload(User.orders).subqueryload(Order.items), + Load(User).subqueryload(User.orders).subqueryload(Order.items), + Load(User) + .subqueryload(User.orders) + .subqueryload(Order.items) + .defer(Item.description), + Load(User).defaultload(User.orders).defaultload(Order.items), + Load(User).defaultload(User.orders), + ) + ) + + def test_bound_options_equiv_on_strname(self): + """Bound loader options resolve on string name so test that the cache + key for the string version matches the resolved version. + + """ + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + for left, right in [ + (Load(User).defer(User.id), Load(User).defer("id")), + ( + Load(User).joinedload(User.addresses), + Load(User).joinedload("addresses"), + ), + ( + Load(User).joinedload(User.orders).joinedload(Order.items), + Load(User).joinedload("orders").joinedload("items"), + ), + ]: + eq_(left._generate_cache_key(), right._generate_cache_key()) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index bf099e7e6..e84d5950c 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -1790,7 +1790,7 @@ class SubOptionsTest(PathTest, QueryTest): ) -class CacheKeyTest(PathTest, QueryTest): +class PathedCacheKeyTest(PathTest, QueryTest): run_create_tables = False run_inserts = None @@ -1805,7 +1805,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload(Order.items) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) @@ -1821,12 +1821,12 @@ class CacheKeyTest(PathTest, QueryTest): opt2 = base.joinedload(Order.address) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), (((Order, "address", Address, ("lazy", "joined")),)), ) @@ -1842,12 +1842,12 @@ class CacheKeyTest(PathTest, QueryTest): opt2 = base.joinedload(Order.address) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), (((Order, "address", Address, ("lazy", "joined")),)), ) @@ -1860,7 +1860,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = Load(User).joinedload(User.orders).joinedload(Order.items) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), (((Order, "items", Item, ("lazy", "joined")),)), ) @@ -1872,7 +1872,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) opt = joinedload(User.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_excluded_on_other(self): User, Address, Order, Item, SubItem = self.classes( @@ -1882,7 +1882,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) opt = Load(User).joinedload(User.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_excluded_on_aliased(self): User, Address, Order, Item, SubItem = self.classes( @@ -1901,7 +1901,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "orders"]) opt = joinedload(aliased(User).orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_wildcard_one(self): # do not change this test, it is testing @@ -1911,7 +1911,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "addresses"]) opt = Load(User).lazyload("*") - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_wildcard_one(self): User, Address = self.classes("User", "Address") @@ -1920,7 +1920,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = lazyload("*") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), (("relationship:_sa_default", ("lazy", "select")),), ) @@ -1933,7 +1933,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = Load(User).lazyload("orders").lazyload("*") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ("orders", Order, ("lazy", "select")), ("orders", Order, "relationship:*", ("lazy", "select")), @@ -1949,7 +1949,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = lazyload("orders").lazyload("*") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ("orders", Order, ("lazy", "select")), ("orders", Order, "relationship:*", ("lazy", "select")), @@ -1968,7 +1968,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -1987,7 +1987,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -2008,7 +2008,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -2029,7 +2029,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (SubItem, ("lazy", "subquery")), ("extra_keywords", Keyword, ("lazy", "subquery")), @@ -2056,7 +2056,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = subqueryload(User.orders).subqueryload( Order.items.of_type(SubItem) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_excluded_of_type_unsafe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2078,7 +2078,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = subqueryload(User.orders).subqueryload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_excluded_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2102,7 +2102,7 @@ class CacheKeyTest(PathTest, QueryTest): .subqueryload(User.orders) .subqueryload(Order.items.of_type(SubItem)) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_excluded_of_type_unsafe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2126,7 +2126,7 @@ class CacheKeyTest(PathTest, QueryTest): .subqueryload(User.orders) .subqueryload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_unbound_cache_key_included_of_type_safe(self): User, Address, Order, Item, SubItem = self.classes( @@ -2137,7 +2137,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload(Order.items.of_type(SubItem)) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Order, "items", SubItem, ("lazy", "joined")),), ) @@ -2155,7 +2155,7 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Order, "items", SubItem, ("lazy", "joined")),), ) @@ -2169,7 +2169,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_option_two(self): User, Address, Order, Item, SubItem = self.classes( @@ -2181,7 +2181,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_option_three(self): User, Address, Order, Item, SubItem = self.classes( @@ -2193,7 +2193,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = joinedload(User.orders).joinedload( Order.items.of_type(aliased(SubItem)) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_unsafe_query(self): User, Address, Order, Item, SubItem = self.classes( @@ -2204,7 +2204,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([inspect(au), "orders"]) opt = joinedload(au.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_unbound_cache_key_included_safe_w_deferred(self): User, Address, Order, Item, SubItem = self.classes( @@ -2219,7 +2219,7 @@ class CacheKeyTest(PathTest, QueryTest): .defer(Address.user_id) ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Address, @@ -2247,12 +2247,12 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), ((Order, "items", Item, ("lazy", "joined")),), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), ( (Order, "address", Address, ("lazy", "joined")), ( @@ -2288,7 +2288,7 @@ class CacheKeyTest(PathTest, QueryTest): .defer(Address.user_id) ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Address, @@ -2316,12 +2316,12 @@ class CacheKeyTest(PathTest, QueryTest): ) eq_( - opt1._generate_cache_key(query_path), + opt1._generate_path_cache_key(query_path), ((Order, "items", Item, ("lazy", "joined")),), ) eq_( - opt2._generate_cache_key(query_path), + opt2._generate_path_cache_key(query_path), ( (Order, "address", Address, ("lazy", "joined")), ( @@ -2356,7 +2356,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "orders"]) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Order, @@ -2385,7 +2385,7 @@ class CacheKeyTest(PathTest, QueryTest): au = aliased(User) opt = Load(au).joinedload(au.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), None) + eq_(opt._generate_path_cache_key(query_path), None) def test_bound_cache_key_included_unsafe_option_one(self): User, Address, Order, Item, SubItem = self.classes( @@ -2399,7 +2399,7 @@ class CacheKeyTest(PathTest, QueryTest): .joinedload(User.orders) .joinedload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_option_two(self): User, Address, Order, Item, SubItem = self.classes( @@ -2413,7 +2413,7 @@ class CacheKeyTest(PathTest, QueryTest): .joinedload(User.orders) .joinedload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_option_three(self): User, Address, Order, Item, SubItem = self.classes( @@ -2427,7 +2427,7 @@ class CacheKeyTest(PathTest, QueryTest): .joinedload(User.orders) .joinedload(Order.items.of_type(aliased(SubItem))) ) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_unsafe_query(self): User, Address, Order, Item, SubItem = self.classes( @@ -2438,7 +2438,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([inspect(au), "orders"]) opt = Load(au).joinedload(au.orders).joinedload(Order.items) - eq_(opt._generate_cache_key(query_path), False) + eq_(opt._generate_path_cache_key(query_path), False) def test_bound_cache_key_included_safe_w_option(self): User, Address, Order, Item, SubItem = self.classes( @@ -2454,7 +2454,7 @@ class CacheKeyTest(PathTest, QueryTest): query_path = self._make_path_registry([User, "orders"]) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( ( Order, @@ -2483,7 +2483,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = defaultload(User.addresses).load_only("id", "email_address") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (Address, "id", ("deferred", False), ("instrument", True)), ( @@ -2513,7 +2513,7 @@ class CacheKeyTest(PathTest, QueryTest): Address.id, Address.email_address ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (Address, "id", ("deferred", False), ("instrument", True)), ( @@ -2545,7 +2545,7 @@ class CacheKeyTest(PathTest, QueryTest): .load_only("id", "email_address") ) eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ( (Address, "id", ("deferred", False), ("instrument", True)), ( @@ -2572,7 +2572,7 @@ class CacheKeyTest(PathTest, QueryTest): opt = defaultload(User.addresses).undefer_group("xyz") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Address, "column:*", ("undefer_group_xyz", True)),), ) @@ -2584,6 +2584,6 @@ class CacheKeyTest(PathTest, QueryTest): opt = Load(User).defaultload(User.addresses).undefer_group("xyz") eq_( - opt._generate_cache_key(query_path), + opt._generate_path_cache_key(query_path), ((Address, "column:*", ("undefer_group_xyz", True)),), ) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index d48a8ed33..5d21960b7 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -32,6 +32,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import True_ from sqlalchemy.sql import type_coerce from sqlalchemy.sql import visitors +from sqlalchemy.sql.base import HasCacheKey from sqlalchemy.sql.elements import _label_reference from sqlalchemy.sql.elements import _textual_label_reference from sqlalchemy.sql.elements import Annotated @@ -46,13 +47,13 @@ from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql.functions import GenericFunction from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.sql.selectable import _OffsetLimitParam +from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping from sqlalchemy.sql.selectable import Selectable from sqlalchemy.sql.selectable import SelectStatementGrouping -from sqlalchemy.testing import assert_raises_message +from sqlalchemy.sql.visitors import InternalTraversal from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ @@ -63,8 +64,17 @@ meta = MetaData() meta2 = MetaData() table_a = Table("a", meta, Column("a", Integer), Column("b", String)) +table_b_like_a = Table("b2", meta, Column("a", Integer), Column("b", String)) + table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String)) +table_a_2_fs = Table( + "a", meta2, Column("a", Integer), Column("b", String), schema="fs" +) +table_a_2_bs = Table( + "a", meta2, Column("a", Integer), Column("b", String), schema="bs" +) + table_b = Table("b", meta, Column("a", Integer), Column("b", Integer)) table_c = Table("c", meta, Column("x", Integer), Column("y", Integer)) @@ -72,8 +82,18 @@ table_c = Table("c", meta, Column("x", Integer), Column("y", Integer)) table_d = Table("d", meta, Column("y", Integer), Column("z", Integer)) -class CompareAndCopyTest(fixtures.TestBase): +class MyEntity(HasCacheKey): + def __init__(self, name, element): + self.name = name + self.element = element + + _cache_key_traversal = [ + ("name", InternalTraversal.dp_string), + ("element", InternalTraversal.dp_clauseelement), + ] + +class CoreFixtures(object): # lambdas which return a tuple of ColumnElement objects. # must return at least two objects that should compare differently. # to test more varieties of "difference" additional objects can be added. @@ -100,11 +120,47 @@ class CompareAndCopyTest(fixtures.TestBase): text("select a, b, c from table").columns( a=Integer, b=String, c=Integer ), + text("select a, b, c from table where foo=:bar").bindparams( + bindparam("bar", Integer) + ), + text("select a, b, c from table where foo=:foo").bindparams( + bindparam("foo", Integer) + ), + text("select a, b, c from table where foo=:bar").bindparams( + bindparam("bar", String) + ), ), lambda: ( column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), + column("z") + column("x"), + column("z") - column("x"), + column("x") - column("z"), + column("z") > column("x"), + # note these two are mathematically equivalent but for now they + # are considered to be different + column("z") >= column("x"), + column("x") <= column("z"), + column("q").between(5, 6), + column("q").between(5, 6, symmetric=True), + column("q").like("somstr"), + column("q").like("somstr", escape="\\"), + column("q").like("somstr", escape="X"), + ), + lambda: ( + table_a.c.a, + table_a.c.a._annotate({"orm": True}), + table_a.c.a._annotate({"orm": True})._annotate({"bar": False}), + table_a.c.a._annotate( + {"orm": True, "parententity": MyEntity("a", table_a)} + ), + table_a.c.a._annotate( + {"orm": True, "parententity": MyEntity("b", table_a)} + ), + table_a.c.a._annotate( + {"orm": True, "parententity": MyEntity("b", select([table_a]))} + ), ), lambda: ( cast(column("q"), Integer), @@ -226,6 +282,58 @@ class CompareAndCopyTest(fixtures.TestBase): .correlate_except(table_b), ), lambda: ( + select([table_a.c.a]).cte(), + select([table_a.c.a]).cte(recursive=True), + select([table_a.c.a]).cte(name="some_cte", recursive=True), + select([table_a.c.a]).cte(name="some_cte"), + select([table_a.c.a]).cte(name="some_cte").alias("other_cte"), + select([table_a.c.a]) + .cte(name="some_cte") + .union_all(select([table_a.c.a])), + select([table_a.c.a]) + .cte(name="some_cte") + .union_all(select([table_a.c.b])), + select([table_a.c.a]).lateral(), + select([table_a.c.a]).lateral(name="bar"), + table_a.tablesample(func.bernoulli(1)), + table_a.tablesample(func.bernoulli(1), seed=func.random()), + table_a.tablesample(func.bernoulli(1), seed=func.other_random()), + table_a.tablesample(func.hoho(1)), + table_a.tablesample(func.bernoulli(1), name="bar"), + table_a.tablesample( + func.bernoulli(1), name="bar", seed=func.random() + ), + ), + lambda: ( + select([table_a.c.a]), + select([table_a.c.a]).prefix_with("foo"), + select([table_a.c.a]).prefix_with("foo", dialect="mysql"), + select([table_a.c.a]).prefix_with("foo", dialect="postgresql"), + select([table_a.c.a]).prefix_with("bar"), + select([table_a.c.a]).suffix_with("bar"), + ), + lambda: ( + select([table_a_2.c.a]), + select([table_a_2_fs.c.a]), + select([table_a_2_bs.c.a]), + ), + lambda: ( + select([table_a.c.a]), + select([table_a.c.a]).with_hint(None, "some hint"), + select([table_a.c.a]).with_hint(None, "some other hint"), + select([table_a.c.a]).with_hint(table_a, "some hint"), + select([table_a.c.a]) + .with_hint(table_a, "some hint") + .with_hint(None, "some other hint"), + select([table_a.c.a]).with_hint(table_a, "some other hint"), + select([table_a.c.a]).with_hint( + table_a, "some hint", dialect_name="mysql" + ), + select([table_a.c.a]).with_hint( + table_a, "some hint", dialect_name="postgresql" + ), + ), + lambda: ( table_a.join(table_b, table_a.c.a == table_b.c.a), table_a.join( table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1) @@ -273,12 +381,202 @@ class CompareAndCopyTest(fixtures.TestBase): table("a", column("x"), column("y", Integer)), table("a", column("q"), column("y", Integer)), ), - lambda: ( - Table("a", MetaData(), Column("q", Integer), Column("b", String)), - Table("b", MetaData(), Column("q", Integer), Column("b", String)), - ), + lambda: (table_a, table_b), ] + def _complex_fixtures(): + def one(): + a1 = table_a.alias() + a2 = table_b_like_a.alias() + + stmt = ( + select([table_a.c.a, a1.c.b, a2.c.b]) + .where(table_a.c.b == a1.c.b) + .where(a1.c.b == a2.c.b) + .where(a1.c.a == 5) + ) + + return stmt + + def one_diff(): + a1 = table_b_like_a.alias() + a2 = table_a.alias() + + stmt = ( + select([table_a.c.a, a1.c.b, a2.c.b]) + .where(table_a.c.b == a1.c.b) + .where(a1.c.b == a2.c.b) + .where(a1.c.a == 5) + ) + + return stmt + + def two(): + inner = one().subquery() + + stmt = select([table_b.c.a, inner.c.a, inner.c.b]).select_from( + table_b.join(inner, table_b.c.b == inner.c.b) + ) + + return stmt + + def three(): + + a1 = table_a.alias() + a2 = table_a.alias() + ex = exists().where(table_b.c.b == a1.c.a) + + stmt = ( + select([a1.c.a, a2.c.a]) + .select_from(a1.join(a2, a1.c.b == a2.c.b)) + .where(ex) + ) + return stmt + + return [one(), one_diff(), two(), three()] + + fixtures.append(_complex_fixtures) + + +class CacheKeyFixture(object): + def _run_cache_key_fixture(self, fixture): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + eq_(a_key.key, b_key.key) + + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + assert a_param.compare(b_param, compare_values=False) + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare(b_param, compare_values=True): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + visitors.traverse_depthfirst( + case_a[a], {}, {"bindparam": assert_a_params.append} + ) + visitors.traverse_depthfirst( + case_b[b], {}, {"bindparam": assert_b_params.append} + ) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be deterministic + # and matches what a traversal would provide. + # regular traverse_depthfirst does produce dupes in cases like + # select([some_alias]). + # select_from(join(some_alias, other_table)) + # where a bound parameter is inside of some_alias. the + # cache key case is more minimalistic + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + +class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): + def test_cache_key(self): + for fixture in self.fixtures: + self._run_cache_key_fixture(fixture) + + def test_cache_key_unknown_traverse(self): + class Foobar1(ClauseElement): + _traverse_internals = [ + ("key", InternalTraversal.dp_anon_name), + ("type_", InternalTraversal.dp_unknown_structure), + ] + + def __init__(self, key, type_): + self.key = key + self.type_ = type_ + + f1 = Foobar1("foo", String()) + eq_(f1._generate_cache_key(), None) + + def test_cache_key_no_method(self): + class Foobar1(ClauseElement): + pass + + class Foobar2(ColumnElement): + pass + + # the None for cache key will prevent objects + # which contain these elements from being cached. + f1 = Foobar1() + eq_(f1._generate_cache_key(), None) + + f2 = Foobar2() + eq_(f2._generate_cache_key(), None) + + s1 = select([column("q"), Foobar2()]) + + eq_(s1._generate_cache_key(), None) + + def test_get_children_no_method(self): + class Foobar1(ClauseElement): + pass + + class Foobar2(ColumnElement): + pass + + f1 = Foobar1() + eq_(f1.get_children(), []) + + f2 = Foobar2() + eq_(f2.get_children(), []) + + def test_copy_internals_no_method(self): + class Foobar1(ClauseElement): + pass + + class Foobar2(ColumnElement): + pass + + f1 = Foobar1() + f2 = Foobar2() + + f1._copy_internals() + f2._copy_internals() + + +class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): @classmethod def setup_class(cls): # TODO: we need to get dialects here somehow, perhaps in test_suite? @@ -293,7 +591,10 @@ class CompareAndCopyTest(fixtures.TestBase): cls for cls in class_hierarchy(ClauseElement) if issubclass(cls, (ColumnElement, Selectable)) - and "__init__" in cls.__dict__ + and ( + "__init__" in cls.__dict__ + or issubclass(cls, AliasedReturnsRows) + ) and not issubclass(cls, (Annotated)) and "orm" not in cls.__module__ and "compiler" not in cls.__module__ @@ -318,123 +619,16 @@ class CompareAndCopyTest(fixtures.TestBase): ): if a == b: is_true( - case_a[a].compare( - case_b[b], arbitrary_expression=True - ), + case_a[a].compare(case_b[b], compare_annotations=True), "%r != %r" % (case_a[a], case_b[b]), ) else: is_false( - case_a[a].compare( - case_b[b], arbitrary_expression=True - ), + case_a[a].compare(case_b[b], compare_annotations=True), "%r == %r" % (case_a[a], case_b[b]), ) - def test_cache_key(self): - def assert_params_append(assert_params): - def append(param): - if param._value_required_for_cache: - assert_params.append(param) - else: - is_(param.value, None) - - return append - - for fixture in self.fixtures: - case_a = fixture() - case_b = fixture() - - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - - assert_a_params = [] - assert_b_params = [] - - visitors.traverse_depthfirst( - case_a[a], - {}, - {"bindparam": assert_params_append(assert_a_params)}, - ) - visitors.traverse_depthfirst( - case_b[b], - {}, - {"bindparam": assert_params_append(assert_b_params)}, - ) - if assert_a_params: - assert_raises_message( - NotImplementedError, - "bindparams collection argument required ", - case_a[a]._cache_key, - ) - if assert_b_params: - assert_raises_message( - NotImplementedError, - "bindparams collection argument required ", - case_b[b]._cache_key, - ) - - if not assert_a_params and not assert_b_params: - if a == b: - eq_(case_a[a]._cache_key(), case_b[b]._cache_key()) - else: - ne_(case_a[a]._cache_key(), case_b[b]._cache_key()) - - def test_cache_key_gather_bindparams(self): - for fixture in self.fixtures: - case_a = fixture() - case_b = fixture() - - # in the "bindparams" case, the cache keys for bound parameters - # with only different values will be the same, but the params - # themselves are gathered into a collection. - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - a_params = {"bindparams": []} - b_params = {"bindparams": []} - if a == b: - a_key = case_a[a]._cache_key(**a_params) - b_key = case_b[b]._cache_key(**b_params) - eq_(a_key, b_key) - - if a_params["bindparams"]: - for a_param, b_param in zip( - a_params["bindparams"], b_params["bindparams"] - ): - assert a_param.compare(b_param) - else: - a_key = case_a[a]._cache_key(**a_params) - b_key = case_b[b]._cache_key(**b_params) - - if a_key == b_key: - for a_param, b_param in zip( - a_params["bindparams"], b_params["bindparams"] - ): - if not a_param.compare(b_param): - break - else: - assert False, "Bound parameters are all the same" - else: - ne_(a_key, b_key) - - assert_a_params = [] - assert_b_params = [] - visitors.traverse_depthfirst( - case_a[a], {}, {"bindparam": assert_a_params.append} - ) - visitors.traverse_depthfirst( - case_b[b], {}, {"bindparam": assert_b_params.append} - ) - - # note we're asserting the order of the params as well as - # if there are dupes or not. ordering has to be deterministic - # and matches what a traversal would provide. - eq_(a_params["bindparams"], assert_a_params) - eq_(b_params["bindparams"], assert_b_params) - def test_compare_col_identity(self): stmt1 = ( select([table_a.c.a, table_b.c.b]) @@ -473,8 +667,9 @@ class CompareAndCopyTest(fixtures.TestBase): assert case_a[0].compare(case_b[0]) - clone = case_a[0]._clone() - clone._copy_internals() + clone = visitors.replacement_traverse( + case_a[0], {}, lambda elem: None + ) assert clone.compare(case_b[0]) @@ -511,6 +706,37 @@ class CompareAndCopyTest(fixtures.TestBase): class CompareClausesTest(fixtures.TestBase): + def test_compare_metadata_tables(self): + # metadata Table objects cache on their own identity, not their + # structure. This is mainly to reduce the size of cache keys + # as well as reduce computational overhead, as Table objects have + # very large internal state and they are also generally global + # objects. + + t1 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer)) + t2 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer)) + + ne_(t1._generate_cache_key(), t2._generate_cache_key()) + + eq_(t1._generate_cache_key().key, (t1,)) + + def test_compare_adhoc_tables(self): + # non-metadata tables compare on their structure. these objects are + # not commonly used. + + # note this test is a bit redundant as we have a similar test + # via the fixtures also + t1 = table("a", Column("q", Integer), Column("p", Integer)) + t2 = table("a", Column("q", Integer), Column("p", Integer)) + t3 = table("b", Column("q", Integer), Column("p", Integer)) + t4 = table("a", Column("q", Integer), Column("x", Integer)) + + eq_(t1._generate_cache_key(), t2._generate_cache_key()) + + ne_(t1._generate_cache_key(), t3._generate_cache_key()) + ne_(t1._generate_cache_key(), t4._generate_cache_key()) + ne_(t3._generate_cache_key(), t4._generate_cache_key()) + def test_compare_comparison_associative(self): l1 = table_c.c.x == table_d.c.y @@ -521,6 +747,15 @@ class CompareClausesTest(fixtures.TestBase): is_true(l1.compare(l2)) is_false(l1.compare(l3)) + def test_compare_comparison_non_commutative_inverses(self): + l1 = table_c.c.x >= table_d.c.y + l2 = table_d.c.y < table_c.c.x + l3 = table_d.c.y <= table_c.c.x + + # we're not doing this kind of commutativity right now. + is_false(l1.compare(l2)) + is_false(l1.compare(l3)) + def test_compare_clauselist_associative(self): l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z) @@ -624,3 +859,45 @@ class CompareClausesTest(fixtures.TestBase): use_proxies=True, ) ) + + def test_compare_annotated_clears_mapping(self): + t = table("t", column("x"), column("y")) + x_a = t.c.x._annotate({"foo": True}) + x_b = t.c.x._annotate({"foo": True}) + + is_true(x_a.compare(x_b, compare_annotations=True)) + is_false( + x_a.compare(x_b._annotate({"bar": True}), compare_annotations=True) + ) + + s1 = select([t.c.x])._annotate({"foo": True}) + s2 = select([t.c.x])._annotate({"foo": True}) + + is_true(s1.compare(s2, compare_annotations=True)) + + is_false( + s1.compare(s2._annotate({"bar": True}), compare_annotations=True) + ) + + def test_compare_annotated_wo_annotations(self): + t = table("t", column("x"), column("y")) + x_a = t.c.x._annotate({}) + x_b = t.c.x._annotate({"foo": True}) + + is_true(t.c.x.compare(x_a)) + is_true(x_b.compare(x_a)) + + is_true(x_a.compare(t.c.x)) + is_false(x_a.compare(t.c.y)) + is_false(t.c.y.compare(x_a)) + is_true((t.c.x == 5).compare(x_a == 5)) + is_false((t.c.y == 5).compare(x_a == 5)) + + s = select([t]).subquery() + x_p = s.c.x + is_false(x_a.compare(x_p)) + is_false(t.c.x.compare(x_p)) + x_p_a = x_p._annotate({}) + is_true(x_p_a.compare(x_p)) + is_true(x_p.compare(x_p_a)) + is_false(x_p_a.compare(x_a)) diff --git a/test/sql/test_generative.py b/test/sql/test_external_traversal.py index 8d347a522..8bfe5cf6f 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_external_traversal.py @@ -55,6 +55,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): # identity semantics. class A(ClauseElement): __visit_name__ = "a" + _traverse_internals = [] def __init__(self, expr): self.expr = expr diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 637f1f8a5..06cfdc4b5 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -118,11 +118,14 @@ class DefaultColumnComparatorTest(fixtures.TestBase): ) ) + modifiers = operator(left, right).modifiers + assert operator(left, right).compare( BinaryExpression( coercions.expect(roles.WhereHavingRole, left), coercions.expect(roles.WhereHavingRole, right), operator, + modifiers=modifiers, ) ) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 2bc7ccc93..184e4a99c 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1070,7 +1070,7 @@ class SelectableTest( s4 = s3.with_only_columns([table2.c.b]) self.assert_compile(s4, "SELECT t2.b FROM t2") - def test_from_list_warning_against_existing(self): + def test_from_list_against_existing_one(self): c1 = Column("c1", Integer) s = select([c1]) @@ -1081,7 +1081,7 @@ class SelectableTest( self.assert_compile(s, "SELECT t.c1 FROM t") - def test_from_list_recovers_after_warning(self): + def test_from_list_against_existing_two(self): c1 = Column("c1", Integer) c2 = Column("c2", Integer) @@ -1090,18 +1090,11 @@ class SelectableTest( # force a compile. eq_(str(s), "SELECT c1") - @testing.emits_warning() - def go(): - return Table("t", MetaData(), c1, c2) - - t = go() + t = Table("t", MetaData(), c1, c2) eq_(c1._from_objects, [t]) eq_(c2._from_objects, [t]) - # 's' has been baked. Can't afford - # not caching select._froms. - # hopefully the warning will clue the user self.assert_compile(s, "SELECT t.c1 FROM t") self.assert_compile(select([c1]), "SELECT t.c1 FROM t") self.assert_compile(select([c2]), "SELECT t.c2 FROM t") @@ -1124,6 +1117,26 @@ class SelectableTest( "foo", ) + def test_whereclause_adapted(self): + table1 = table("t1", column("a")) + + s1 = select([table1]).subquery() + + s2 = select([s1]).where(s1.c.a == 5) + + assert s2._whereclause.left.table is s1 + + ta = select([table1]).subquery() + + s3 = sql_util.ClauseAdapter(ta).traverse(s2) + + assert s1 not in s3._froms + + # these are new assumptions with the newer approach that + # actively swaps out whereclause and others + assert s3._whereclause.left.table is not s1 + assert s3._whereclause.left.table in s3._froms + class RefreshForNewColTest(fixtures.TestBase): def test_join_uninit(self): @@ -2241,25 +2254,6 @@ class AnnotationsTest(fixtures.TestBase): annot = obj._annotate({}) ne_(set([obj]), set([annot])) - def test_compare(self): - t = table("t", column("x"), column("y")) - x_a = t.c.x._annotate({}) - assert t.c.x.compare(x_a) - assert x_a.compare(t.c.x) - assert not x_a.compare(t.c.y) - assert not t.c.y.compare(x_a) - assert (t.c.x == 5).compare(x_a == 5) - assert not (t.c.y == 5).compare(x_a == 5) - - s = select([t]).subquery() - x_p = s.c.x - assert not x_a.compare(x_p) - assert not t.c.x.compare(x_p) - x_p_a = x_p._annotate({}) - assert x_p_a.compare(x_p) - assert x_p.compare(x_p_a) - assert not x_p_a.compare(x_a) - def test_proxy_set_iteration_includes_annotated(self): from sqlalchemy.schema import Column @@ -2542,13 +2536,13 @@ class AnnotationsTest(fixtures.TestBase): ): # the columns clause isn't changed at all assert sel._raw_columns[0].table is a1 - assert sel._froms[0] is sel._froms[1].left + assert sel._froms[0].element is sel._froms[1].left.element eq_(str(s), str(sel)) # when we are modifying annotations sets only - # partially, each element is copied unconditionally - # when encountered. + # partially, elements are copied uniquely based on id(). + # this is new as of 1.4, previously they'd be copied every time for sel in ( sql_util._deep_deannotate(s, {"foo": "bar"}), sql_util._deep_annotate(s, {"foo": "bar"}), diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 988d5331e..48d6de6db 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -7,6 +7,6 @@ from sqlalchemy.testing import fixtures class MiscTest(fixtures.TestBase): def test_column_element_no_visit(self): class MyElement(ColumnElement): - pass + _traverse_internals = [] eq_(sql_util.find_tables(MyElement(), check_columns=True), []) |