diff options
Diffstat (limited to 'lib')
26 files changed, 1733 insertions, 1087 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 293aa426d..b43b364fa 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -103,7 +103,6 @@ class Insert(StandardInsert): inserted_alias = getattr(self, "inserted_alias", None) self._post_values_clause = OnDuplicateClause(inserted_alias, values) - return self insert = public_factory(Insert, ".dialects.mysql.insert") diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 909d568a7..e94f9913c 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1658,23 +1658,20 @@ class PGCompiler(compiler.SQLCompiler): return "ONLY " + sqltext def get_select_precolumns(self, select, **kw): - if select._distinct is not False: - if select._distinct is True: - return "DISTINCT " - elif isinstance(select._distinct, (list, tuple)): + if select._distinct or select._distinct_on: + if select._distinct_on: return ( "DISTINCT ON (" + ", ".join( - [self.process(col, **kw) for col in select._distinct] + [ + self.process(col, **kw) + for col in select._distinct_on + ] ) + ") " ) else: - return ( - "DISTINCT ON (" - + self.process(select._distinct, **kw) - + ") " - ) + return "DISTINCT " else: return "" diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 4e77f5a4c..f4467976a 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -103,7 +103,6 @@ class Insert(StandardInsert): self._post_values_clause = OnConflictDoUpdate( constraint, index_elements, index_where, set_, where ) - return self @_generative def on_conflict_do_nothing( @@ -138,7 +137,6 @@ class Insert(StandardInsert): self._post_values_clause = OnConflictDoNothing( constraint, index_elements, index_where ) - return self insert = public_factory(Insert, ".dialects.postgresql.insert") diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index d18a35a40..8e137f141 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -198,7 +198,7 @@ class BakedQuery(object): self.spoil() else: for opt in options: - cache_key = opt._generate_cache_key(cache_path) + cache_key = opt._generate_path_cache_key(cache_path) if cache_key is False: self.spoil() elif cache_key is not None: diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 4a5a8ba9c..c2b234758 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -455,7 +455,7 @@ def deregister(class_): if hasattr(class_, "_compiler_dispatcher"): # regenerate default _compiler_dispatch - visitors._generate_dispatch(class_) + visitors._generate_compiler_dispatch(class_) # remove custom directive del class_._compiler_dispatcher diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 83069f113..aa2986205 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -47,6 +47,8 @@ from .base import state_str from .. import event from .. import inspection from .. import util +from ..sql import base as sql_base +from ..sql import visitors @inspection._self_inspects @@ -54,6 +56,7 @@ class QueryableAttribute( interfaces._MappedAttribute, interfaces.InspectionAttr, interfaces.PropComparator, + sql_base.HasCacheKey, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` @@ -102,6 +105,13 @@ class QueryableAttribute( if base[key].dispatch._active_history: self.dispatch._active_history = True + _cache_key_traversal = [ + # ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ] + @util.memoized_property def _supports_population(self): return self.impl.supports_population diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 6f8d19293..a3dea6b0e 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -216,7 +216,6 @@ def _assertions(*assertions): for assertion in assertions: assertion(self, fn.__name__) fn(self, *args[1:], **kw) - return self return generate diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index e94a81fed..704ce9df7 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -36,6 +36,8 @@ from .. import inspect from .. import inspection from .. import util from ..sql import operators +from ..sql import visitors +from ..sql.traversals import HasCacheKey __all__ = ( @@ -54,7 +56,9 @@ __all__ = ( ) -class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): +class MapperProperty( + HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots +): """Represent a particular class attribute mapped by :class:`.Mapper`. The most common occurrences of :class:`.MapperProperty` are the @@ -74,6 +78,11 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): "info", ) + _cache_key_traversal = [ + ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ] + cascade = frozenset() """The set of 'cascade' attribute names. @@ -647,7 +656,7 @@ class MapperOption(object): self.process_query(query) - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): """Used by the "baked lazy loader" to see if this option can be cached. The "baked lazy loader" refers to the :class:`.Query` that is diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 376ad1923..548eca58d 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -71,7 +71,7 @@ _CONFIGURE_MUTEX = util.threading.RLock() @inspection._self_inspects @log.class_logger -class Mapper(InspectionAttr): +class Mapper(sql_base.HasCacheKey, InspectionAttr): """Define the correlation of class attributes to database table columns. @@ -729,6 +729,10 @@ class Mapper(InspectionAttr): """ return self + _cache_key_traversal = [ + ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj) + ] + @property def entity(self): r"""Part of the inspection API. diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 2f680a3a1..585cb80bc 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -15,7 +15,8 @@ from .base import class_mapper from .. import exc from .. import inspection from .. import util - +from ..sql import visitors +from ..sql.traversals import HasCacheKey log = logging.getLogger(__name__) @@ -28,7 +29,7 @@ _WILDCARD_TOKEN = "*" _DEFAULT_TOKEN = "_sa_default" -class PathRegistry(object): +class PathRegistry(HasCacheKey): """Represent query load paths and registry functions. Basically represents structures like: @@ -57,6 +58,10 @@ class PathRegistry(object): is_token = False is_root = False + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) + ] + def __eq__(self, other): return other is not None and self.path == other.path @@ -78,6 +83,9 @@ class PathRegistry(object): def __len__(self): return len(self.path) + def __hash__(self): + return id(self) + @property def length(self): return len(self.path) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 26f47f616..99bbbe37c 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -26,11 +26,13 @@ from .. import inspect from .. import util from ..sql import coercions from ..sql import roles +from ..sql import visitors from ..sql.base import _generative from ..sql.base import Generative +from ..sql.traversals import HasCacheKey -class Load(Generative, MapperOption): +class Load(HasCacheKey, Generative, MapperOption): """Represents loader options which modify the state of a :class:`.Query` in order to affect how various mapped attributes are loaded. @@ -70,6 +72,17 @@ class Load(Generative, MapperOption): """ + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ( + "_context_cache_key", + visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples, + ), + ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), + ] + def __init__(self, entity): insp = inspect(entity) self.path = insp._path_registry @@ -89,7 +102,16 @@ class Load(Generative, MapperOption): load._of_type = None return load - def _generate_cache_key(self, path): + @property + def _context_cache_key(self): + serialized = [] + for (key, loader_path), obj in self.context.items(): + if key != "loader": + continue + serialized.append(loader_path + (obj,)) + return serialized + + def _generate_path_cache_key(self, path): if path.path[0].is_aliased_class: return False @@ -522,9 +544,16 @@ class _UnboundLoad(Load): self._to_bind = [] self.local_opts = {} + _cache_key_traversal = [ + ("path", visitors.ExtendedInternalTraversal.dp_multi_list), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list), + ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), + ] + _is_chain_link = False - def _generate_cache_key(self, path): + def _generate_path_cache_key(self, path): serialized = () for val in self._to_bind: for local_elem, val_elem in zip(self.path, val.path): @@ -533,7 +562,7 @@ class _UnboundLoad(Load): else: opt = val._bind_loader([path.path[0]], None, None, False) if opt: - c_key = opt._generate_cache_key(path) + c_key = opt._generate_path_cache_key(path) if c_key is False: return False elif c_key: @@ -660,7 +689,6 @@ class _UnboundLoad(Load): opt = meth(opt, all_tokens[-1], **kw) opt._is_chain_link = False - return opt def _chop_path(self, to_chop, path): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 5f0f41e8d..c86993678 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -30,10 +30,12 @@ from .. import exc as sa_exc from .. import inspection from .. import sql from .. import util +from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util +from ..sql import visitors all_cascades = frozenset( @@ -530,7 +532,7 @@ class AliasedClass(object): return str(self._aliased_insp) -class AliasedInsp(InspectionAttr): +class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -627,6 +629,12 @@ class AliasedInsp(InspectionAttr): def __clause_element__(self): return self.selectable + _cache_key_traversal = [ + ("name", visitors.ExtendedInternalTraversal.dp_string), + ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean), + ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement), + ] + @property def class_(self): """Return the mapped class ultimately represented by this diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index a0264845e..0d995ec8a 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -12,12 +12,32 @@ associations. """ from . import operators +from .base import HasCacheKey +from .visitors import InternalTraversal from .. import util -class SupportsCloneAnnotations(object): +class SupportsAnnotations(object): + @util.memoized_property + def _annotation_traversals(self): + return [ + ( + key, + InternalTraversal.dp_has_cache_key + if isinstance(value, HasCacheKey) + else InternalTraversal.dp_plain_obj, + ) + for key, value in self._annotations.items() + ] + + +class SupportsCloneAnnotations(SupportsAnnotations): _annotations = util.immutabledict() + _traverse_internals = [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] + def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -25,6 +45,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = new._annotations.union(values) + new.__dict__.pop("_annotation_traversals", None) return new def _with_annotations(self, values): @@ -34,6 +55,7 @@ class SupportsCloneAnnotations(object): """ new = self._clone() new._annotations = util.immutabledict(values) + new.__dict__.pop("_annotation_traversals", None) return new def _deannotate(self, values=None, clone=False): @@ -49,12 +71,13 @@ class SupportsCloneAnnotations(object): # the expression for a deep deannotation new = self._clone() new._annotations = {} + new.__dict__.pop("_annotation_traversals", None) return new else: return self -class SupportsWrappingAnnotations(object): +class SupportsWrappingAnnotations(SupportsAnnotations): def _annotate(self, values): """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -123,6 +146,7 @@ class Annotated(object): def __init__(self, element, values): self.__dict__ = element.__dict__.copy() + self.__dict__.pop("_annotation_traversals", None) self.__element = element self._annotations = values self._hash = hash(element) @@ -135,6 +159,7 @@ class Annotated(object): def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() + clone.__dict__.pop("_annotation_traversals", None) clone._annotations = values return clone @@ -192,7 +217,17 @@ def _deep_annotate(element, annotations, exclude=None): """ - def clone(elem): + # annotated objects hack the __hash__() method so if we want to + # uniquely process them we have to use id() + + cloned_ids = {} + + def clone(elem, **kw): + id_ = id(elem) + + if id_ in cloned_ids: + return cloned_ids[id_] + if ( exclude and hasattr(elem, "proxy_set") @@ -204,6 +239,7 @@ def _deep_annotate(element, annotations, exclude=None): else: newelem = elem newelem._copy_internals(clone=clone) + cloned_ids[id_] = newelem return newelem if element is not None: @@ -214,23 +250,21 @@ def _deep_annotate(element, annotations, exclude=None): def _deep_deannotate(element, values=None): """Deep copy the given element, removing annotations.""" - cloned = util.column_dict() + cloned = {} - def clone(elem): - # if a values dict is given, - # the elem must be cloned each time it appears, - # as there may be different annotations in source - # elements that are remaining. if totally - # removing all annotations, can assume the same - # slate... - if values or elem not in cloned: + def clone(elem, **kw): + if values: + key = id(elem) + else: + key = elem + + if key not in cloned: newelem = elem._deannotate(values=values, clone=True) newelem._copy_internals(clone=clone) - if not values: - cloned[elem] = newelem + cloned[key] = newelem return newelem else: - return cloned[elem] + return cloned[key] if element is not None: element = clone(element) @@ -268,6 +302,11 @@ def _new_annotation_type(cls, base_cls): "Annotated%s" % cls.__name__, (base_cls, cls), {} ) globals()["Annotated%s" % cls.__name__] = anno_cls + + if "_traverse_internals" in cls.__dict__: + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_state) + ] return anno_cls diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 7e9199bfa..d11a3a313 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -14,6 +14,7 @@ import itertools import operator import re +from .traversals import HasCacheKey # noqa from .visitors import ClauseVisitor from .. import exc from .. import util @@ -38,18 +39,41 @@ class Immutable(object): def _clone(self): return self + def _copy_internals(self, **kw): + pass + + +class HasMemoized(object): + def _reset_memoizations(self): + self._memoized_property.expire_instance(self) + + def _reset_exported(self): + self._memoized_property.expire_instance(self) + + def _copy_internals(self, **kw): + super(HasMemoized, self)._copy_internals(**kw) + self._reset_memoizations() + def _from_objects(*elements): return itertools.chain(*[element._from_objects for element in elements]) def _generative(fn): + """non-caching _generative() decorator. + + This is basically the legacy decorator that copies the object and + runs a method on the new copy. + + """ + @util.decorator - def _generative(fn, *args, **kw): + def _generative(fn, self, *args, **kw): """Mark a method as generative.""" - self = args[0]._generate() - fn(self, *args[1:], **kw) + self = self._generate() + x = fn(self, *args, **kw) + assert x is None, "generative methods must have no return value" return self decorated = _generative(fn) @@ -357,10 +381,8 @@ class DialectKWArgs(object): class Generative(object): - """Allow a ClauseElement to generate itself via the - @_generative decorator. - - """ + """Provide a method-chaining pattern in conjunction with the + @_generative decorator.""" def _generate(self): s = self.__class__.__new__(self.__class__) diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py deleted file mode 100644 index 30a90348c..000000000 --- a/lib/sqlalchemy/sql/clause_compare.py +++ /dev/null @@ -1,334 +0,0 @@ -from collections import deque - -from . import operators -from .. import util - - -SKIP_TRAVERSE = util.symbol("skip_traverse") - - -def compare(obj1, obj2, **kw): - if kw.get("use_proxies", False): - strategy = ColIdentityComparatorStrategy() - else: - strategy = StructureComparatorStrategy() - - return strategy.compare(obj1, obj2, **kw) - - -class StructureComparatorStrategy(object): - __slots__ = "compare_stack", "cache" - - def __init__(self): - self.compare_stack = deque() - self.cache = set() - - def compare(self, obj1, obj2, **kw): - stack = self.compare_stack - cache = self.cache - - stack.append((obj1, obj2)) - - while stack: - left, right = stack.popleft() - - if left is right: - continue - elif left is None or right is None: - # we know they are different so no match - return False - elif (left, right) in cache: - continue - cache.add((left, right)) - - visit_name = left.__visit_name__ - - # we're not exactly looking for identical types, because - # there are things like Column and AnnotatedColumn. So the - # visit_name has to at least match up - if visit_name != right.__visit_name__: - return False - - meth = getattr(self, "compare_%s" % visit_name, None) - - if meth: - comparison = meth(left, right, **kw) - if comparison is False: - return False - elif comparison is SKIP_TRAVERSE: - continue - - for c1, c2 in util.zip_longest( - left.get_children(column_collections=False), - right.get_children(column_collections=False), - fillvalue=None, - ): - if c1 is None or c2 is None: - # collections are different sizes, comparison fails - return False - stack.append((c1, c2)) - - return True - - def compare_inner(self, obj1, obj2, **kw): - stack = self.compare_stack - try: - self.compare_stack = deque() - return self.compare(obj1, obj2, **kw) - finally: - self.compare_stack = stack - - def _compare_unordered_sequences(self, seq1, seq2, **kw): - if seq1 is None: - return seq2 is None - - completed = set() - for clause in seq1: - for other_clause in set(seq2).difference(completed): - if self.compare_inner(clause, other_clause, **kw): - completed.add(other_clause) - break - return len(completed) == len(seq1) == len(seq2) - - def compare_bindparam(self, left, right, **kw): - # note the ".key" is often generated from id(self) so can't - # be compared, as far as determining structure. - return ( - left.type._compare_type_affinity(right.type) - and left.value == right.value - and left.callable == right.callable - and left._orig_key == right._orig_key - ) - - def compare_clauselist(self, left, right, **kw): - if left.operator is right.operator: - if operators.is_associative(left.operator): - if self._compare_unordered_sequences( - left.clauses, right.clauses - ): - return SKIP_TRAVERSE - else: - return False - else: - # normal ordered traversal - return True - else: - return False - - def compare_unary(self, left, right, **kw): - if left.operator: - disp = self._get_operator_dispatch( - left.operator, "unary", "operator" - ) - if disp is not None: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - elif left.modifier: - disp = self._get_operator_dispatch( - left.modifier, "unary", "modifier" - ) - if disp is not None: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - return ( - left.operator == right.operator and left.modifier == right.modifier - ) - - def compare_binary(self, left, right, **kw): - disp = self._get_operator_dispatch(left.operator, "binary", None) - if disp: - result = disp(left, right, left.operator, **kw) - if result is not True: - return result - - if left.operator == right.operator: - if operators.is_commutative(left.operator): - if ( - compare(left.left, right.left, **kw) - and compare(left.right, right.right, **kw) - ) or ( - compare(left.left, right.right, **kw) - and compare(left.right, right.left, **kw) - ): - return SKIP_TRAVERSE - else: - return False - else: - return True - else: - return False - - def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): - # used by compare_binary, compare_unary - attrname = "visit_%s_%s%s" % ( - operator_.__name__, - qualifier1, - "_" + qualifier2 if qualifier2 else "", - ) - return getattr(self, attrname, None) - - def visit_function_as_comparison_op_binary( - self, left, right, operator, **kw - ): - return ( - left.left_index == right.left_index - and left.right_index == right.right_index - ) - - def compare_function(self, left, right, **kw): - return left.name == right.name - - def compare_column(self, left, right, **kw): - if left.table is not None: - self.compare_stack.appendleft((left.table, right.table)) - return ( - left.key == right.key - and left.name == right.name - and ( - left.type._compare_type_affinity(right.type) - if left.type is not None - else right.type is None - ) - and left.is_literal == right.is_literal - ) - - def compare_collation(self, left, right, **kw): - return left.collation == right.collation - - def compare_type_coerce(self, left, right, **kw): - return left.type._compare_type_affinity(right.type) - - @util.dependencies("sqlalchemy.sql.elements") - def compare_alias(self, elements, left, right, **kw): - return ( - left.name == right.name - if not isinstance(left.name, elements._anonymous_label) - else isinstance(right.name, elements._anonymous_label) - ) - - def compare_cte(self, elements, left, right, **kw): - raise NotImplementedError("TODO") - - def compare_extract(self, left, right, **kw): - return left.field == right.field - - def compare_textual_label_reference(self, left, right, **kw): - return left.element == right.element - - def compare_slice(self, left, right, **kw): - return ( - left.start == right.start - and left.stop == right.stop - and left.step == right.step - ) - - def compare_over(self, left, right, **kw): - return left.range_ == right.range_ and left.rows == right.rows - - @util.dependencies("sqlalchemy.sql.elements") - def compare_label(self, elements, left, right, **kw): - return left._type._compare_type_affinity(right._type) and ( - left.name == right.name - if not isinstance(left.name, elements._anonymous_label) - else isinstance(right.name, elements._anonymous_label) - ) - - def compare_typeclause(self, left, right, **kw): - return left.type._compare_type_affinity(right.type) - - def compare_join(self, left, right, **kw): - return left.isouter == right.isouter and left.full == right.full - - def compare_table(self, left, right, **kw): - if left.name != right.name: - return False - - self.compare_stack.extendleft( - util.zip_longest(left.columns, right.columns) - ) - - def compare_compound_select(self, left, right, **kw): - - if not self._compare_unordered_sequences( - left.selects, right.selects, **kw - ): - return False - - if left.keyword != right.keyword: - return False - - if left._for_update_arg != right._for_update_arg: - return False - - if not self.compare_inner( - left._order_by_clause, right._order_by_clause, **kw - ): - return False - - if not self.compare_inner( - left._group_by_clause, right._group_by_clause, **kw - ): - return False - - return SKIP_TRAVERSE - - def compare_select(self, left, right, **kw): - if not self._compare_unordered_sequences( - left._correlate, right._correlate - ): - return False - if not self._compare_unordered_sequences( - left._correlate_except, right._correlate_except - ): - return False - - if not self._compare_unordered_sequences( - left._from_obj, right._from_obj - ): - return False - - if left._for_update_arg != right._for_update_arg: - return False - - return True - - def compare_textual_select(self, left, right, **kw): - self.compare_stack.extendleft( - util.zip_longest(left.column_args, right.column_args) - ) - return left.positional == right.positional - - -class ColIdentityComparatorStrategy(StructureComparatorStrategy): - def compare_column_element( - self, left, right, use_proxies=True, equivalents=(), **kw - ): - """Compare ColumnElements using proxies and equivalent collections. - - This is a comparison strategy specific to the ORM. - """ - - to_compare = (right,) - if equivalents and right in equivalents: - to_compare = equivalents[right].union(to_compare) - - for oth in to_compare: - if use_proxies and left.shares_lineage(oth): - return True - elif hash(left) == hash(right): - return True - else: - return False - - def compare_column(self, left, right, **kw): - return self.compare_column_element(left, right, **kw) - - def compare_label(self, left, right, **kw): - return self.compare_column_element(left, right, **kw) - - def compare_table(self, left, right, **kw): - # tables compare on identity, since it's not really feasible to - # compare them column by column with the above rules - return left is right diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5ecec7d6c..546fffc6c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -434,6 +434,27 @@ class _CompileLabel(elements.ColumnElement): return self +class prefix_anon_map(dict): + """A map that creates new keys for missing key access. + + Considers keys of the form "<ident> <name>" to produce + new symbols "<name>_<index>", where "index" is an incrementing integer + corresponding to <name>. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key): + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 + value = derived + "_" + str(anonymous_counter) + self[key] = value + return value + + class SQLCompiler(Compiled): """Default implementation of :class:`.Compiled`. @@ -574,7 +595,7 @@ class SQLCompiler(Compiled): # a map which tracks "anonymous" identifiers that are created on # the fly here - self.anon_map = util.PopulateDict(self._process_anon) + self.anon_map = prefix_anon_map() # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length @@ -1712,12 +1733,6 @@ class SQLCompiler(Compiled): def _anonymize(self, name): return name % self.anon_map - def _process_anon(self, key): - (ident, derived) = key.split(" ", 1) - anonymous_counter = self.anon_map.get(derived, 1) - self.anon_map[derived] = anonymous_counter + 1 - return derived + "_" + str(anonymous_counter) - def bindparam_string( self, name, diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 918f7524e..c0baa8555 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -178,6 +178,9 @@ def _unsupported_impl(expr, op, *arg, **kw): def _inv_impl(expr, op, **kw): """See :meth:`.ColumnOperators.__inv__`.""" + + # undocumented element currently used by the ORM for + # relationship.contains() if hasattr(expr, "negation_clause"): return expr.negation_clause else: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e6f57b8d1..ba615bc3f 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -16,23 +16,29 @@ import itertools import operator import re -from . import clause_compare from . import coercions from . import operators from . import roles +from . import traversals from . import type_api from .annotation import Annotated from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative from .base import Executable +from .base import HasCacheKey +from .base import HasMemoized from .base import Immutable from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .coercions import _document_text_coercion +from .traversals import _copy_internals +from .traversals import _get_children +from .traversals import NO_CACHE from .visitors import cloned_traverse +from .visitors import InternalTraversal from .visitors import traverse -from .visitors import Visitable +from .visitors import Traversible from .. import exc from .. import inspection from .. import util @@ -162,7 +168,9 @@ def not_(clause): @inspection._self_inspects -class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): +class ClauseElement( + roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible +): """Base class for elements of a programmatically constructed SQL expression. @@ -190,6 +198,13 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): _order_by_label_element = None + @property + def _cache_key_traversal(self): + try: + return self._traverse_internals + except AttributeError: + return NO_CACHE + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -221,28 +236,6 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): """ return self - def _cache_key(self, **kw): - """return an optional cache key. - - The cache key is a tuple which can contain any series of - objects that are hashable and also identifies - this object uniquely within the presence of a larger SQL expression - or statement, for the purposes of caching the resulting query. - - The cache key should be based on the SQL compiled structure that would - ultimately be produced. That is, two structures that are composed in - exactly the same way should produce the same cache key; any difference - in the strucures that would affect the SQL string or the type handlers - should result in a different cache key. - - If a structure cannot produce a useful cache key, it should raise - NotImplementedError, which will result in the entire structure - for which it's part of not being useful as a cache key. - - - """ - raise NotImplementedError() - @property def _constructor(self): """return the 'constructor' for this ClauseElement. @@ -336,9 +329,9 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): (see :class:`.ColumnElement`) """ - return clause_compare.compare(self, other, **kw) + return traversals.compare(self, other, **kw) - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals(self, **kw): """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly @@ -349,21 +342,46 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): traversal, cloned traversal, annotations). """ - pass - def get_children(self, **kwargs): - r"""Return immediate child elements of this :class:`.ClauseElement`. + try: + traverse_internals = self._traverse_internals + except AttributeError: + return + + for attrname, obj, meth in _copy_internals.run_generated_dispatch( + self, traverse_internals, "_generated_copy_internals_traversal" + ): + if obj is not None: + result = meth(self, obj, **kw) + if result is not None: + setattr(self, attrname, result) + + def get_children(self, omit_attrs=None, **kw): + r"""Return immediate child :class:`.Traversible` elements of this + :class:`.Traversible`. This is used for visit traversal. - \**kwargs may contain flags that change the collection that is + \**kw may contain flags that change the collection that is returned, for example to return a subset of items in order to cut down on larger traversals, or to return child items from a different context (such as schema-level collections instead of clause-level). """ - return [] + result = [] + try: + traverse_internals = self._traverse_internals + except AttributeError: + return result + + for attrname, obj, meth in _get_children.run_generated_dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ): + if obj is None or omit_attrs and attrname in omit_attrs: + continue + result.extend(meth(obj, **kw)) + return result def self_group(self, against=None): # type: (Optional[Any]) -> ClauseElement @@ -501,6 +519,8 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): return or_(self, other) def __invert__(self): + # undocumented element currently used by the ORM for + # relationship.contains() if hasattr(self, "negation_clause"): return self.negation_clause else: @@ -508,9 +528,7 @@ class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): def _negate(self): return UnaryExpression( - self.self_group(against=operators.inv), - operator=operators.inv, - negate=None, + self.self_group(against=operators.inv), operator=operators.inv ) def __bool__(self): @@ -731,9 +749,6 @@ class ColumnElement( else: return comparator_factory(self) - def _cache_key(self, **kw): - raise NotImplementedError(self.__class__) - def __getattr__(self, key): try: return getattr(self.comparator, key) @@ -969,6 +984,13 @@ class BindParameter(roles.InElementRole, ColumnElement): __visit_name__ = "bindparam" + _traverse_internals = [ + ("key", InternalTraversal.dp_anon_name), + ("type", InternalTraversal.dp_type), + ("callable", InternalTraversal.dp_plain_dict), + ("value", InternalTraversal.dp_plain_obj), + ] + _is_crud = False _expanding_in_types = () @@ -1321,26 +1343,19 @@ class BindParameter(roles.InElementRole, ColumnElement): ) return c - def _cache_key(self, bindparams=None, **kw): - if bindparams is None: - # even though _cache_key is a private method, we would like to - # be super paranoid about this point. You can't include the - # "value" or "callable" in the cache key, because the value is - # not part of the structure of a statement and is likely to - # change every time. However you cannot *throw it away* either, - # because you can't invoke the statement without the parameter - # values that were explicitly placed. So require that they - # are collected here to make sure this happens. - if self._value_required_for_cache: - raise NotImplementedError( - "bindparams collection argument required for _cache_key " - "implementation. Bound parameter cache keys are not safe " - "to use without accommodating for the value or callable " - "within the parameter itself." - ) - else: - bindparams.append(self) - return (BindParameter, self.type._cache_key, self._orig_key) + def _gen_cache_key(self, anon_map, bindparams): + if self in anon_map: + return (anon_map[self], self.__class__) + + id_ = anon_map[self] + bindparams.append(self) + + return ( + id_, + self.__class__, + self.type._gen_cache_key, + traversals._resolve_name_for_compare(self, self.key, anon_map), + ) def _convert_to_unique(self): if not self.unique: @@ -1377,12 +1392,11 @@ class TypeClause(ClauseElement): __visit_name__ = "typeclause" + _traverse_internals = [("type", InternalTraversal.dp_type)] + def __init__(self, type_): self.type = type_ - def _cache_key(self, **kw): - return (TypeClause, self.type._cache_key) - class TextClause( roles.DDLConstraintColumnRole, @@ -1419,6 +1433,11 @@ class TextClause( __visit_name__ = "textclause" + _traverse_internals = [ + ("_bindparams", InternalTraversal.dp_string_clauseelement_dict), + ("text", InternalTraversal.dp_string), + ] + _is_text_clause = True _is_textual = True @@ -1861,19 +1880,6 @@ class TextClause( else: return self - def _copy_internals(self, clone=_clone, **kw): - self._bindparams = dict( - (b.key, clone(b, **kw)) for b in self._bindparams.values() - ) - - def get_children(self, **kwargs): - return list(self._bindparams.values()) - - def _cache_key(self, **kw): - return (self.text,) + tuple( - bind._cache_key for bind in self._bindparams.values() - ) - class Null(roles.ConstExprRole, ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -1885,6 +1891,8 @@ class Null(roles.ConstExprRole, ColumnElement): __visit_name__ = "null" + _traverse_internals = [] + @util.memoized_property def type(self): return type_api.NULLTYPE @@ -1895,9 +1903,6 @@ class Null(roles.ConstExprRole, ColumnElement): return Null() - def _cache_key(self, **kw): - return (Null,) - class False_(roles.ConstExprRole, ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -1908,6 +1913,7 @@ class False_(roles.ConstExprRole, ColumnElement): """ __visit_name__ = "false" + _traverse_internals = [] @util.memoized_property def type(self): @@ -1954,9 +1960,6 @@ class False_(roles.ConstExprRole, ColumnElement): return False_() - def _cache_key(self, **kw): - return (False_,) - class True_(roles.ConstExprRole, ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. @@ -1968,6 +1971,8 @@ class True_(roles.ConstExprRole, ColumnElement): __visit_name__ = "true" + _traverse_internals = [] + @util.memoized_property def type(self): return type_api.BOOLEANTYPE @@ -2020,9 +2025,6 @@ class True_(roles.ConstExprRole, ColumnElement): return True_() - def _cache_key(self, **kw): - return (True_,) - class ClauseList( roles.InElementRole, @@ -2038,6 +2040,11 @@ class ClauseList( __visit_name__ = "clauselist" + _traverse_internals = [ + ("clauses", InternalTraversal.dp_clauseelement_list), + ("operator", InternalTraversal.dp_operator), + ] + def __init__(self, *clauses, **kwargs): self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) @@ -2082,17 +2089,6 @@ class ClauseList( coercions.expect(self._text_converter_role, clause) ) - def _copy_internals(self, clone=_clone, **kw): - self.clauses = [clone(clause, **kw) for clause in self.clauses] - - def get_children(self, **kwargs): - return self.clauses - - def _cache_key(self, **kw): - return (ClauseList, self.operator) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - @property def _from_objects(self): return list(itertools.chain(*[c._from_objects for c in self.clauses])) @@ -2115,11 +2111,6 @@ class BooleanClauseList(ClauseList, ColumnElement): "BooleanClauseList has a private constructor" ) - def _cache_key(self, **kw): - return (BooleanClauseList, self.operator) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - @classmethod def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): convert_clauses = [] @@ -2250,6 +2241,8 @@ or_ = BooleanClauseList.or_ class Tuple(ClauseList, ColumnElement): """Represent a SQL tuple.""" + _traverse_internals = ClauseList._traverse_internals + [] + def __init__(self, *clauses, **kw): """Return a :class:`.Tuple`. @@ -2289,11 +2282,6 @@ class Tuple(ClauseList, ColumnElement): def _select_iterable(self): return (self,) - def _cache_key(self, **kw): - return (Tuple,) + tuple( - clause._cache_key(**kw) for clause in self.clauses - ) - def _bind_param(self, operator, obj, type_=None): return Tuple( *[ @@ -2339,6 +2327,12 @@ class Case(ColumnElement): __visit_name__ = "case" + _traverse_internals = [ + ("value", InternalTraversal.dp_clauseelement), + ("whens", InternalTraversal.dp_clauseelement_tuples), + ("else_", InternalTraversal.dp_clauseelement), + ] + def __init__(self, whens, value=None, else_=None): r"""Produce a ``CASE`` expression. @@ -2501,40 +2495,6 @@ class Case(ColumnElement): else: self.else_ = None - def _copy_internals(self, clone=_clone, **kw): - if self.value is not None: - self.value = clone(self.value, **kw) - self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens] - if self.else_ is not None: - self.else_ = clone(self.else_, **kw) - - def get_children(self, **kwargs): - if self.value is not None: - yield self.value - for x, y in self.whens: - yield x - yield y - if self.else_ is not None: - yield self.else_ - - def _cache_key(self, **kw): - return ( - ( - Case, - self.value._cache_key(**kw) - if self.value is not None - else None, - ) - + tuple( - (x._cache_key(**kw), y._cache_key(**kw)) for x, y in self.whens - ) - + ( - self.else_._cache_key(**kw) - if self.else_ is not None - else None, - ) - ) - @property def _from_objects(self): return list( @@ -2603,6 +2563,11 @@ class Cast(WrapsColumnExpression, ColumnElement): __visit_name__ = "cast" + _traverse_internals = [ + ("clause", InternalTraversal.dp_clauseelement), + ("typeclause", InternalTraversal.dp_clauseelement), + ] + def __init__(self, expression, type_): r"""Produce a ``CAST`` expression. @@ -2662,20 +2627,6 @@ class Cast(WrapsColumnExpression, ColumnElement): ) self.typeclause = TypeClause(self.type) - def _copy_internals(self, clone=_clone, **kw): - self.clause = clone(self.clause, **kw) - self.typeclause = clone(self.typeclause, **kw) - - def get_children(self, **kwargs): - return self.clause, self.typeclause - - def _cache_key(self, **kw): - return ( - Cast, - self.clause._cache_key(**kw), - self.typeclause._cache_key(**kw), - ) - @property def _from_objects(self): return self.clause._from_objects @@ -2685,7 +2636,7 @@ class Cast(WrapsColumnExpression, ColumnElement): return self.clause -class TypeCoerce(WrapsColumnExpression, ColumnElement): +class TypeCoerce(HasMemoized, WrapsColumnExpression, ColumnElement): """Represent a Python-side type-coercion wrapper. :class:`.TypeCoerce` supplies the :func:`.expression.type_coerce` @@ -2705,6 +2656,13 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): __visit_name__ = "type_coerce" + _traverse_internals = [ + ("clause", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + + _memoized_property = util.group_expirable_memoized_property() + def __init__(self, expression, type_): r"""Associate a SQL expression with a particular type, without rendering ``CAST``. @@ -2773,21 +2731,11 @@ class TypeCoerce(WrapsColumnExpression, ColumnElement): roles.ExpressionElementRole, expression, type_=self.type ) - def _copy_internals(self, clone=_clone, **kw): - self.clause = clone(self.clause, **kw) - self.__dict__.pop("typed_expression", None) - - def get_children(self, **kwargs): - return (self.clause,) - - def _cache_key(self, **kw): - return (TypeCoerce, self.type._cache_key, self.clause._cache_key(**kw)) - @property def _from_objects(self): return self.clause._from_objects - @util.memoized_property + @_memoized_property def typed_expression(self): if isinstance(self.clause, BindParameter): bp = self.clause._clone() @@ -2806,6 +2754,11 @@ class Extract(ColumnElement): __visit_name__ = "extract" + _traverse_internals = [ + ("expr", InternalTraversal.dp_clauseelement), + ("field", InternalTraversal.dp_string), + ] + def __init__(self, field, expr, **kwargs): """Return a :class:`.Extract` construct. @@ -2818,15 +2771,6 @@ class Extract(ColumnElement): self.field = field self.expr = coercions.expect(roles.ExpressionElementRole, expr) - def _copy_internals(self, clone=_clone, **kw): - self.expr = clone(self.expr, **kw) - - def get_children(self, **kwargs): - return (self.expr,) - - def _cache_key(self, **kw): - return (Extract, self.field, self.expr._cache_key(**kw)) - @property def _from_objects(self): return self.expr._from_objects @@ -2847,18 +2791,11 @@ class _label_reference(ColumnElement): __visit_name__ = "label_reference" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + def __init__(self, element): self.element = element - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (_label_reference, self.element._cache_key(**kw)) - - def get_children(self, **kwargs): - return [self.element] - @property def _from_objects(self): return () @@ -2867,6 +2804,8 @@ class _label_reference(ColumnElement): class _textual_label_reference(ColumnElement): __visit_name__ = "textual_label_reference" + _traverse_internals = [("element", InternalTraversal.dp_string)] + def __init__(self, element): self.element = element @@ -2874,9 +2813,6 @@ class _textual_label_reference(ColumnElement): def _text_clause(self): return TextClause._create_text(self.element) - def _cache_key(self, **kw): - return (_textual_label_reference, self.element) - class UnaryExpression(ColumnElement): """Define a 'unary' expression. @@ -2894,13 +2830,18 @@ class UnaryExpression(ColumnElement): __visit_name__ = "unary" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("operator", InternalTraversal.dp_operator), + ("modifier", InternalTraversal.dp_operator), + ] + def __init__( self, element, operator=None, modifier=None, type_=None, - negate=None, wraps_column_expression=False, ): self.operator = operator @@ -2909,7 +2850,6 @@ class UnaryExpression(ColumnElement): against=self.operator or self.modifier ) self.type = type_api.to_instance(type_) - self.negate = negate self.wraps_column_expression = wraps_column_expression @classmethod @@ -3135,37 +3075,13 @@ class UnaryExpression(ColumnElement): def _from_objects(self): return self.element._from_objects - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return ( - UnaryExpression, - self.element._cache_key(**kw), - self.operator, - self.modifier, - ) - - def get_children(self, **kwargs): - return (self.element,) - def _negate(self): - if self.negate is not None: - return UnaryExpression( - self.element, - operator=self.negate, - negate=self.operator, - modifier=self.modifier, - type_=self.type, - wraps_column_expression=self.wraps_column_expression, - ) - elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: + if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv, type_=type_api.BOOLEANTYPE, wraps_column_expression=self.wraps_column_expression, - negate=None, ) else: return ClauseElement._negate(self) @@ -3286,15 +3202,6 @@ class AsBoolean(WrapsColumnExpression, UnaryExpression): # type: (Optional[Any]) -> ClauseElement return self - def _cache_key(self, **kw): - return ( - self.element._cache_key(**kw), - self.type._cache_key, - self.operator, - self.negate, - self.modifier, - ) - def _negate(self): if isinstance(self.element, (True_, False_)): return self.element._negate() @@ -3318,6 +3225,14 @@ class BinaryExpression(ColumnElement): __visit_name__ = "binary" + _traverse_internals = [ + ("left", InternalTraversal.dp_clauseelement), + ("right", InternalTraversal.dp_clauseelement), + ("operator", InternalTraversal.dp_operator), + ("negate", InternalTraversal.dp_operator), + ("modifiers", InternalTraversal.dp_plain_dict), + ] + _is_implicitly_boolean = True """Indicates that any database will know this is a boolean expression even if the database does not have an explicit boolean datatype. @@ -3360,20 +3275,6 @@ class BinaryExpression(ColumnElement): def _from_objects(self): return self.left._from_objects + self.right._from_objects - def _copy_internals(self, clone=_clone, **kw): - self.left = clone(self.left, **kw) - self.right = clone(self.right, **kw) - - def get_children(self, **kwargs): - return self.left, self.right - - def _cache_key(self, **kw): - return ( - BinaryExpression, - self.left._cache_key(**kw), - self.right._cache_key(**kw), - ) - def self_group(self, against=None): # type: (Optional[Any]) -> ClauseElement @@ -3406,6 +3307,12 @@ class Slice(ColumnElement): __visit_name__ = "slice" + _traverse_internals = [ + ("start", InternalTraversal.dp_plain_obj), + ("stop", InternalTraversal.dp_plain_obj), + ("step", InternalTraversal.dp_plain_obj), + ] + def __init__(self, start, stop, step): self.start = start self.stop = stop @@ -3417,9 +3324,6 @@ class Slice(ColumnElement): assert against is operator.getitem return self - def _cache_key(self, **kw): - return (Slice, self.start, self.stop, self.step) - class IndexExpression(BinaryExpression): """Represent the class of expressions that are like an "index" operation. @@ -3444,6 +3348,11 @@ class GroupedElement(ClauseElement): class Grouping(GroupedElement, ColumnElement): """Represent a grouping within a column expression""" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + def __init__(self, element): self.element = element self.type = getattr(element, "type", type_api.NULLTYPE) @@ -3460,15 +3369,6 @@ class Grouping(GroupedElement, ColumnElement): def _label(self): return getattr(self.element, "_label", None) or self.anon_label - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def get_children(self, **kwargs): - return (self.element,) - - def _cache_key(self, **kw): - return (Grouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -3501,6 +3401,14 @@ class Over(ColumnElement): __visit_name__ = "over" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("order_by", InternalTraversal.dp_clauseelement), + ("partition_by", InternalTraversal.dp_clauseelement), + ("range_", InternalTraversal.dp_plain_obj), + ("rows", InternalTraversal.dp_plain_obj), + ] + order_by = None partition_by = None @@ -3667,30 +3575,6 @@ class Over(ColumnElement): def type(self): return self.element.type - def get_children(self, **kwargs): - return [ - c - for c in (self.element, self.partition_by, self.order_by) - if c is not None - ] - - def _cache_key(self, **kw): - return ( - (Over,) - + tuple( - e._cache_key(**kw) if e is not None else None - for e in (self.element, self.partition_by, self.order_by) - ) - + (self.range_, self.rows) - ) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - if self.partition_by is not None: - self.partition_by = clone(self.partition_by, **kw) - if self.order_by is not None: - self.order_by = clone(self.order_by, **kw) - @property def _from_objects(self): return list( @@ -3723,6 +3607,11 @@ class WithinGroup(ColumnElement): __visit_name__ = "withingroup" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("order_by", InternalTraversal.dp_clauseelement), + ] + order_by = None def __init__(self, element, *order_by): @@ -3791,25 +3680,6 @@ class WithinGroup(ColumnElement): else: return self.element.type - def get_children(self, **kwargs): - return [c for c in (self.element, self.order_by) if c is not None] - - def _cache_key(self, **kw): - return ( - WithinGroup, - self.element._cache_key(**kw) - if self.element is not None - else None, - self.order_by._cache_key(**kw) - if self.order_by is not None - else None, - ) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - if self.order_by is not None: - self.order_by = clone(self.order_by, **kw) - @property def _from_objects(self): return list( @@ -3845,6 +3715,11 @@ class FunctionFilter(ColumnElement): __visit_name__ = "funcfilter" + _traverse_internals = [ + ("func", InternalTraversal.dp_clauseelement), + ("criterion", InternalTraversal.dp_clauseelement), + ] + criterion = None def __init__(self, func, *criterion): @@ -3932,23 +3807,6 @@ class FunctionFilter(ColumnElement): def type(self): return self.func.type - def get_children(self, **kwargs): - return [c for c in (self.func, self.criterion) if c is not None] - - def _copy_internals(self, clone=_clone, **kw): - self.func = clone(self.func, **kw) - if self.criterion is not None: - self.criterion = clone(self.criterion, **kw) - - def _cache_key(self, **kw): - return ( - FunctionFilter, - self.func._cache_key(**kw), - self.criterion._cache_key(**kw) - if self.criterion is not None - else None, - ) - @property def _from_objects(self): return list( @@ -3962,7 +3820,7 @@ class FunctionFilter(ColumnElement): ) -class Label(roles.LabeledColumnExprRole, ColumnElement): +class Label(HasMemoized, roles.LabeledColumnExprRole, ColumnElement): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3972,6 +3830,14 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): __visit_name__ = "label" + _traverse_internals = [ + ("name", InternalTraversal.dp_anon_name), + ("_type", InternalTraversal.dp_type), + ("_element", InternalTraversal.dp_clauseelement), + ] + + _memoized_property = util.group_expirable_memoized_property() + def __init__(self, name, element, type_=None): """Return a :class:`Label` object for the given :class:`.ColumnElement`. @@ -4010,14 +3876,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): def __reduce__(self): return self.__class__, (self.name, self._element, self._type) - def _cache_key(self, **kw): - return (Label, self.element._cache_key(**kw), self._resolve_label) - @util.memoized_property def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean - @util.memoized_property + @_memoized_property def _allow_label_resolve(self): return self.element._allow_label_resolve @@ -4031,7 +3894,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): self._type or getattr(self._element, "type", None) ) - @util.memoized_property + @_memoized_property def element(self): return self._element.self_group(against=operators.as_) @@ -4057,13 +3920,9 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): def foreign_keys(self): return self.element.foreign_keys - def get_children(self, **kwargs): - return (self.element,) - def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): + self._reset_memoizations() self._element = clone(self._element, **kw) - self.__dict__.pop("element", None) - self.__dict__.pop("_allow_label_resolve", None) if anonymize_labels: self.name = self._resolve_label = _anonymous_label( "%%(%d %s)s" @@ -4124,6 +3983,13 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): __visit_name__ = "column" + _traverse_internals = [ + ("name", InternalTraversal.dp_string), + ("type", InternalTraversal.dp_type), + ("table", InternalTraversal.dp_clauseelement), + ("is_literal", InternalTraversal.dp_boolean), + ] + onupdate = default = server_default = server_onupdate = None _is_multiparam_column = False @@ -4254,14 +4120,6 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): table = property(_get_table, _set_table) - def _cache_key(self, **kw): - return ( - self.name, - self.table.name if self.table is not None else None, - self.is_literal, - self.type._cache_key, - ) - @_memoized_property def _from_objects(self): t = self.table @@ -4395,12 +4253,11 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): class CollationClause(ColumnElement): __visit_name__ = "collation" + _traverse_internals = [("collation", InternalTraversal.dp_string)] + def __init__(self, collation): self.collation = collation - def _cache_key(self, **kw): - return (CollationClause, self.collation) - class _IdentifiedClause(Executable, ClauseElement): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 7ce822669..08e69f075 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -86,7 +86,6 @@ __all__ = [ from .base import _from_objects # noqa from .base import ColumnCollection # noqa from .base import Executable # noqa -from .base import Generative # noqa from .base import PARSE_AUTOCOMMIT # noqa from .dml import Delete # noqa from .dml import Insert # noqa @@ -242,7 +241,6 @@ _UnaryExpression = UnaryExpression _Case = Case _Tuple = Tuple _Over = Over -_Generative = Generative _TypeClause = TypeClause _Extract = Extract _Exists = Exists diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cbc8e539f..96e64dc28 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -17,7 +17,6 @@ from . import sqltypes from . import util as sqlutil from .base import ColumnCollection from .base import Executable -from .elements import _clone from .elements import _type_from_args from .elements import BinaryExpression from .elements import BindParameter @@ -33,7 +32,8 @@ from .elements import WithinGroup from .selectable import Alias from .selectable import FromClause from .selectable import Select -from .visitors import VisitableType +from .visitors import InternalTraversal +from .visitors import TraversibleType from .. import util @@ -78,10 +78,14 @@ class FunctionElement(Executable, ColumnElement, FromClause): """ + _traverse_internals = [("clause_expr", InternalTraversal.dp_clauseelement)] + packagenames = () _has_args = False + _memoized_property = FromClause._memoized_property + def __init__(self, *clauses, **kwargs): r"""Construct a :class:`.FunctionElement`. @@ -136,7 +140,7 @@ class FunctionElement(Executable, ColumnElement, FromClause): col = self.label(None) return ColumnCollection(columns=[(col.key, col)]) - @util.memoized_property + @_memoized_property def clauses(self): """Return the underlying :class:`.ClauseList` which contains the arguments for this :class:`.FunctionElement`. @@ -283,17 +287,6 @@ class FunctionElement(Executable, ColumnElement, FromClause): def _from_objects(self): return self.clauses._from_objects - def get_children(self, **kwargs): - return (self.clause_expr,) - - def _cache_key(self, **kw): - return (FunctionElement, self.clause_expr._cache_key(**kw)) - - def _copy_internals(self, clone=_clone, **kw): - self.clause_expr = clone(self.clause_expr, **kw) - self._reset_exported() - FunctionElement.clauses._reset(self) - def within_group_type(self, within_group): """For types that define their return type as based on the criteria within a WITHIN GROUP (ORDER BY) expression, called by the @@ -404,6 +397,13 @@ class FunctionElement(Executable, ColumnElement, FromClause): class FunctionAsBinary(BinaryExpression): + _traverse_internals = [ + ("sql_function", InternalTraversal.dp_clauseelement), + ("left_index", InternalTraversal.dp_plain_obj), + ("right_index", InternalTraversal.dp_plain_obj), + ("modifiers", InternalTraversal.dp_plain_dict), + ] + def __init__(self, fn, left_index, right_index): self.sql_function = fn self.left_index = left_index @@ -431,20 +431,6 @@ class FunctionAsBinary(BinaryExpression): def right(self, value): self.sql_function.clauses.clauses[self.right_index - 1] = value - def _copy_internals(self, clone=_clone, **kw): - self.sql_function = clone(self.sql_function, **kw) - - def get_children(self, **kw): - yield self.sql_function - - def _cache_key(self, **kw): - return ( - FunctionAsBinary, - self.sql_function._cache_key(**kw), - self.left_index, - self.right_index, - ) - class _FunctionGenerator(object): """Generate SQL function expressions. @@ -606,6 +592,12 @@ class Function(FunctionElement): __visit_name__ = "function" + _traverse_internals = FunctionElement._traverse_internals + [ + ("packagenames", InternalTraversal.dp_plain_obj), + ("name", InternalTraversal.dp_string), + ("type", InternalTraversal.dp_type), + ] + def __init__(self, name, *clauses, **kw): """Construct a :class:`.Function`. @@ -630,15 +622,8 @@ class Function(FunctionElement): unique=True, ) - def _cache_key(self, **kw): - return ( - (Function,) + tuple(self.packagenames) - if self.packagenames - else () + (self.name, self.clause_expr._cache_key(**kw)) - ) - -class _GenericMeta(VisitableType): +class _GenericMeta(TraversibleType): def __init__(cls, clsname, bases, clsdict): if annotation.Annotated not in cls.__mro__: cls.name = name = clsdict.get("name", clsname) @@ -764,6 +749,10 @@ class next_value(GenericFunction): type = sqltypes.Integer() name = "next_value" + _traverse_internals = [ + ("sequence", InternalTraversal.dp_named_ddl_element) + ] + def __init__(self, seq, **kw): assert isinstance( seq, schema.Sequence @@ -771,21 +760,12 @@ class next_value(GenericFunction): self._bind = kw.get("bind", None) self.sequence = seq - def _cache_key(self, **kw): - return (next_value, self.sequence.name) - def compare(self, other, **kw): return ( isinstance(other, next_value) and self.sequence.name == other.sequence.name ) - def get_children(self, **kwargs): - return [] - - def _copy_internals(self, **kw): - pass - @property def _from_objects(self): return [] diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4e8f4a397..ee7dc61ce 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -50,6 +50,7 @@ from .elements import ColumnElement from .elements import quoted_name from .elements import TextClause from .selectable import TableClause +from .visitors import InternalTraversal from .. import event from .. import exc from .. import inspection @@ -425,6 +426,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): __visit_name__ = "table" + _traverse_internals = TableClause._traverse_internals + [ + ("schema", InternalTraversal.dp_string) + ] + + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + @util.deprecated_params( + useexisting=( + "0.7", + "The :paramref:`.Table.useexisting` parameter is deprecated and " + "will be removed in a future release. Please use " + ":paramref:`.Table.extend_existing`.", + ) + ) def __new__(cls, *args, **kw): if not args: # python3k pickle seems to call this @@ -763,6 +779,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause): def get_children( self, column_collections=True, schema_visitor=False, **kw ): + # TODO: consider that we probably don't need column_collections=True + # at all, it does not seem to impact anything if not schema_visitor: return TableClause.get_children( self, column_collections=column_collections, **kw diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6a7413fc0..4b3844eec 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -31,6 +31,7 @@ from .base import ColumnSet from .base import DedupeColumnCollection from .base import Executable from .base import Generative +from .base import HasMemoized from .base import Immutable from .coercions import _document_text_coercion from .elements import _anonymous_label @@ -39,11 +40,13 @@ from .elements import and_ from .elements import BindParameter from .elements import ClauseElement from .elements import ClauseList +from .elements import ColumnClause from .elements import GroupedElement from .elements import Grouping from .elements import literal_column from .elements import True_ from .elements import UnaryExpression +from .visitors import InternalTraversal from .. import exc from .. import util @@ -201,6 +204,8 @@ class Selectable(ReturnsRows): class HasPrefixes(object): _prefixes = () + _traverse_internals = [("_prefixes", InternalTraversal.dp_prefix_sequence)] + @_generative @_document_text_coercion( "expr", @@ -252,6 +257,8 @@ class HasPrefixes(object): class HasSuffixes(object): _suffixes = () + _traverse_internals = [("_suffixes", InternalTraversal.dp_prefix_sequence)] + @_generative @_document_text_coercion( "expr", @@ -295,7 +302,7 @@ class HasSuffixes(object): ) -class FromClause(roles.AnonymizedFromClauseRole, Selectable): +class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -529,11 +536,6 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return getattr(self, "name", self.__class__.__name__ + " object") - def _reset_exported(self): - """delete memoized collections when a FromClause is cloned.""" - - self._memoized_property.expire_instance(self) - def _generate_fromclause_column_proxies(self, fromclause): fromclause._columns._populate_separate_keys( col._make_proxy(fromclause) for col in self.c @@ -668,6 +670,14 @@ class Join(FromClause): __visit_name__ = "join" + _traverse_internals = [ + ("left", InternalTraversal.dp_clauseelement), + ("right", InternalTraversal.dp_clauseelement), + ("onclause", InternalTraversal.dp_clauseelement), + ("isouter", InternalTraversal.dp_boolean), + ("full", InternalTraversal.dp_boolean), + ] + _is_join = True def __init__(self, left, right, onclause=None, isouter=False, full=False): @@ -805,25 +815,6 @@ class Join(FromClause): self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) - def _copy_internals(self, clone=_clone, **kw): - self._reset_exported() - self.left = clone(self.left, **kw) - self.right = clone(self.right, **kw) - self.onclause = clone(self.onclause, **kw) - - def get_children(self, **kwargs): - return self.left, self.right, self.onclause - - def _cache_key(self, **kw): - return ( - Join, - self.isouter, - self.full, - self.left._cache_key(**kw), - self.right._cache_key(**kw), - self.onclause._cache_key(**kw), - ) - def _match_primaries(self, left, right): if isinstance(left, Join): left_right = left.right @@ -1175,6 +1166,11 @@ class AliasedReturnsRows(FromClause): _is_from_container = True named_with_column = True + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("name", InternalTraversal.dp_anon_name), + ] + def __init__(self, *arg, **kw): raise NotImplementedError( "The %s class is not intended to be constructed " @@ -1243,18 +1239,13 @@ class AliasedReturnsRows(FromClause): def _copy_internals(self, clone=_clone, **kw): element = clone(self.element, **kw) + + # the element clone is usually against a Table that returns the + # same object. don't reset exported .c. collections and other + # memoized details if nothing changed if element is not self.element: self._reset_exported() - self.element = element - - def get_children(self, column_collections=True, **kw): - if column_collections: - for c in self.c: - yield c - yield self.element - - def _cache_key(self, **kw): - return (self.__class__, self.element._cache_key(**kw), self._orig_name) + self.element = element @property def _from_objects(self): @@ -1396,6 +1387,11 @@ class TableSample(AliasedReturnsRows): __visit_name__ = "tablesample" + _traverse_internals = AliasedReturnsRows._traverse_internals + [ + ("sampling", InternalTraversal.dp_clauseelement), + ("seed", InternalTraversal.dp_clauseelement), + ] + @classmethod def _factory(cls, selectable, sampling, name=None, seed=None): """Return a :class:`.TableSample` object. @@ -1466,6 +1462,16 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows): __visit_name__ = "cte" + _traverse_internals = ( + AliasedReturnsRows._traverse_internals + + [ + ("_cte_alias", InternalTraversal.dp_clauseelement), + ("_restates", InternalTraversal.dp_clauseelement_unordered_set), + ("recursive", InternalTraversal.dp_boolean), + ] + + HasSuffixes._traverse_internals + ) + @classmethod def _factory(cls, selectable, name=None, recursive=False): r"""Return a new :class:`.CTE`, or Common Table Expression instance. @@ -1495,15 +1501,13 @@ class CTE(Generative, HasSuffixes, AliasedReturnsRows): def _copy_internals(self, clone=_clone, **kw): super(CTE, self)._copy_internals(clone, **kw) + # TODO: I don't like that we can't use the traversal data here if self._cte_alias is not None: self._cte_alias = clone(self._cte_alias, **kw) self._restates = frozenset( [clone(elem, **kw) for elem in self._restates] ) - def _cache_key(self, *arg, **kw): - raise NotImplementedError("TODO") - def alias(self, name=None, flat=False): """Return an :class:`.Alias` of this :class:`.CTE`. @@ -1764,6 +1768,8 @@ class Subquery(AliasedReturnsRows): class FromGrouping(GroupedElement, FromClause): """Represent a grouping of a FROM clause""" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + def __init__(self, element): self.element = coercions.expect(roles.FromClauseRole, element) @@ -1792,15 +1798,6 @@ class FromGrouping(GroupedElement, FromClause): def _hide_froms(self): return self.element._hide_froms - def get_children(self, **kwargs): - return (self.element,) - - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (FromGrouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -1843,6 +1840,14 @@ class TableClause(Immutable, FromClause): __visit_name__ = "table" + _traverse_internals = [ + ( + "columns", + InternalTraversal.dp_fromclause_canonical_column_collection, + ), + ("name", InternalTraversal.dp_string), + ] + named_with_column = True implicit_returning = False @@ -1895,17 +1900,6 @@ class TableClause(Immutable, FromClause): self._columns.add(c) c.table = self - def get_children(self, column_collections=True, **kwargs): - if column_collections: - return [c for c in self.c] - else: - return [] - - def _cache_key(self, **kw): - return (TableClause, self.name) + tuple( - col._cache_key(**kw) for col in self._columns - ) - @util.dependencies("sqlalchemy.sql.dml") def insert(self, dml, values=None, inline=False, **kwargs): """Generate an :func:`.insert` construct against this @@ -1965,6 +1959,13 @@ class TableClause(Immutable, FromClause): class ForUpdateArg(ClauseElement): + _traverse_internals = [ + ("of", InternalTraversal.dp_clauseelement_list), + ("nowait", InternalTraversal.dp_boolean), + ("read", InternalTraversal.dp_boolean), + ("skip_locked", InternalTraversal.dp_boolean), + ] + @classmethod def parse_legacy_select(self, arg): """Parse the for_update argument of :func:`.select`. @@ -2029,19 +2030,6 @@ class ForUpdateArg(ClauseElement): def __hash__(self): return id(self) - def _copy_internals(self, clone=_clone, **kw): - if self.of is not None: - self.of = [clone(col, **kw) for col in self.of] - - def _cache_key(self, **kw): - return ( - ForUpdateArg, - self.nowait, - self.read, - self.skip_locked, - self.of._cache_key(**kw) if self.of is not None else None, - ) - def __init__( self, nowait=False, @@ -2074,6 +2062,7 @@ class SelectBase( roles.DMLSelectRole, roles.CompoundElementRole, roles.InElementRole, + HasMemoized, HasCTE, Executable, SupportsCloneAnnotations, @@ -2092,9 +2081,6 @@ class SelectBase( _memoized_property = util.group_expirable_memoized_property() - def _reset_memoizations(self): - self._memoized_property.expire_instance(self) - def _generate_fromclause_column_proxies(self, fromclause): # type: (FromClause) raise NotImplementedError() @@ -2339,6 +2325,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ __visit_name__ = "grouping" + _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] _is_select_container = True @@ -2350,9 +2337,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def select_statement(self): return self.element - def get_children(self, **kwargs): - return (self.element,) - def self_group(self, against=None): # type: (Optional[Any]) -> FromClause return self @@ -2377,12 +2361,6 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ return self.element.selected_columns - def _copy_internals(self, clone=_clone, **kw): - self.element = clone(self.element, **kw) - - def _cache_key(self, **kw): - return (SelectStatementGrouping, self.element._cache_key(**kw)) - @property def _from_objects(self): return self.element._from_objects @@ -2758,9 +2736,6 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): def _label_resolve_dict(self): raise NotImplementedError() - def _copy_internals(self, clone=_clone, **kw): - raise NotImplementedError() - class CompoundSelect(GenerativeSelect): """Forms the basis of ``UNION``, ``UNION ALL``, and other @@ -2785,6 +2760,16 @@ class CompoundSelect(GenerativeSelect): __visit_name__ = "compound_select" + _traverse_internals = [ + ("selects", InternalTraversal.dp_clauseelement_list), + ("_limit_clause", InternalTraversal.dp_clauseelement), + ("_offset_clause", InternalTraversal.dp_clauseelement), + ("_order_by_clause", InternalTraversal.dp_clauseelement), + ("_group_by_clause", InternalTraversal.dp_clauseelement), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("keyword", InternalTraversal.dp_string), + ] + SupportsCloneAnnotations._traverse_internals + UNION = util.symbol("UNION") UNION_ALL = util.symbol("UNION ALL") EXCEPT = util.symbol("EXCEPT") @@ -3004,47 +2989,6 @@ class CompoundSelect(GenerativeSelect): """ return self.selects[0].selected_columns - def _copy_internals(self, clone=_clone, **kw): - self._reset_memoizations() - self.selects = [clone(s, **kw) for s in self.selects] - if hasattr(self, "_col_map"): - del self._col_map - for attr in ( - "_limit_clause", - "_offset_clause", - "_order_by_clause", - "_group_by_clause", - "_for_update_arg", - ): - if getattr(self, attr) is not None: - setattr(self, attr, clone(getattr(self, attr), **kw)) - - def get_children(self, **kwargs): - return [self._order_by_clause, self._group_by_clause] + list( - self.selects - ) - - def _cache_key(self, **kw): - return ( - (CompoundSelect, self.keyword) - + tuple(stmt._cache_key(**kw) for stmt in self.selects) - + ( - self._order_by_clause._cache_key(**kw) - if self._order_by_clause is not None - else None, - ) - + ( - self._group_by_clause._cache_key(**kw) - if self._group_by_clause is not None - else None, - ) - + ( - self._for_update_arg._cache_key(**kw) - if self._for_update_arg is not None - else None, - ) - ) - def bind(self): if self._bind: return self._bind @@ -3193,11 +3137,35 @@ class Select( _hints = util.immutabledict() _statement_hints = () _distinct = False - _from_cloned = None + _distinct_on = () _correlate = () _correlate_except = None _memoized_property = SelectBase._memoized_property + _traverse_internals = ( + [ + ("_from_obj", InternalTraversal.dp_fromclause_ordered_set), + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_whereclause", InternalTraversal.dp_clauseelement), + ("_having", InternalTraversal.dp_clauseelement), + ("_order_by_clause", InternalTraversal.dp_clauseelement_list), + ("_group_by_clause", InternalTraversal.dp_clauseelement_list), + ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), + ( + "_correlate_except", + InternalTraversal.dp_clauseelement_unordered_set, + ), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("_statement_hints", InternalTraversal.dp_statement_hint_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_distinct", InternalTraversal.dp_boolean), + ("_distinct_on", InternalTraversal.dp_clauseelement_list), + ] + + HasPrefixes._traverse_internals + + HasSuffixes._traverse_internals + + SupportsCloneAnnotations._traverse_internals + ) + @util.deprecated_params( autocommit=( "0.6", @@ -3416,13 +3384,14 @@ class Select( """ self._auto_correlate = correlate if distinct is not False: - if distinct is True: - self._distinct = True - else: - self._distinct = [ - coercions.expect(roles.WhereHavingRole, e) - for e in util.to_list(distinct) - ] + self._distinct = True + if not isinstance(distinct, bool): + self._distinct_on = tuple( + [ + coercions.expect(roles.WhereHavingRole, e) + for e in util.to_list(distinct) + ] + ) if from_obj is not None: self._from_obj = util.OrderedSet( @@ -3472,15 +3441,17 @@ class Select( GenerativeSelect.__init__(self, **kwargs) + # @_memoized_property @property def _froms(self): - # would love to cache this, - # but there's just enough edge cases, particularly now that - # declarative encourages construction of SQL expressions - # without tables present, to just regen this each time. + # current roadblock to caching is two tests that test that the + # SELECT can be compiled to a string, then a Table is created against + # columns, then it can be compiled again and works. this is somewhat + # valid as people make select() against declarative class where + # columns don't have their Table yet and perhaps some operations + # call upon _froms and cache it too soon. froms = [] seen = set() - translate = self._from_cloned for item in itertools.chain( _from_objects(*self._raw_columns), @@ -3493,8 +3464,6 @@ class Select( raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" ) - if translate and item in translate: - item = translate[item] if not seen.intersection(item._cloned_set): froms.append(item) seen.update(item._cloned_set) @@ -3518,15 +3487,6 @@ class Select( itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms]) ) if toremove: - # if we're maintaining clones of froms, - # add the copies out to the toremove list. only include - # clones that are lexical equivalents. - if self._from_cloned: - toremove.update( - self._from_cloned[f] - for f in toremove.intersection(self._from_cloned) - if self._from_cloned[f]._is_lexical_equivalent(f) - ) # filter out to FROM clauses not in the list, # using a list to maintain ordering froms = [f for f in froms if f not in toremove] @@ -3707,7 +3667,6 @@ class Select( return False def _copy_internals(self, clone=_clone, **kw): - # Select() object has been cloned and probably adapted by the # given clone function. Apply the cloning function to internal # objects @@ -3719,37 +3678,42 @@ class Select( # as of 0.7.4 we also put the current version of _froms, which # gets cleared on each generation. previously we were "baking" # _froms into self._from_obj. - self._from_cloned = from_cloned = dict( - (f, clone(f, **kw)) for f in self._from_obj.union(self._froms) - ) - # 3. update persistent _from_obj with the cloned versions. - self._from_obj = util.OrderedSet( - from_cloned[f] for f in self._from_obj + all_the_froms = list( + itertools.chain( + _from_objects(*self._raw_columns), + _from_objects(self._whereclause) + if self._whereclause is not None + else (), + ) ) + new_froms = {f: clone(f, **kw) for f in all_the_froms} + # copy FROM collections - # the _correlate collection is done separately, what can happen - # here is the same item is _correlate as in _from_obj but the - # _correlate version has an annotation on it - (specifically - # RelationshipProperty.Comparator._criterion_exists() does - # this). Also keep _correlate liberally open with its previous - # contents, as this set is used for matching, not rendering. - self._correlate = set(clone(f) for f in self._correlate).union( - self._correlate - ) + self._from_obj = util.OrderedSet( + clone(f, **kw) for f in self._from_obj + ).union(f for f in new_froms.values() if isinstance(f, Join)) - # do something similar for _correlate_except - this is a more - # unusual case but same idea applies + self._correlate = set(clone(f) for f in self._correlate) if self._correlate_except: self._correlate_except = set( clone(f) for f in self._correlate_except - ).union(self._correlate_except) + ) # 4. clone other things. The difficulty here is that Column - # objects are not actually cloned, and refer to their original - # .table, resulting in the wrong "from" parent after a clone - # operation. Hence _from_cloned and _from_obj supersede what is - # present here. + # objects are usually not altered by a straight clone because they + # are dependent on the FROM cloning we just did above in order to + # be targeted correctly, or a new FROM we have might be a JOIN + # object which doesn't have its own columns. so give the cloner a + # hint. + def replace(obj, **kw): + if isinstance(obj, ColumnClause) and obj.table in new_froms: + newelem = new_froms[obj.table].corresponding_column(obj) + return newelem + + kw["replace"] = replace + + # TODO: I'd still like to try to leverage the traversal data self._raw_columns = [clone(c, **kw) for c in self._raw_columns] for attr in ( "_limit_clause", @@ -3763,67 +3727,12 @@ class Select( if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr), **kw)) - # erase _froms collection, - # etc. self._reset_memoizations() def get_children(self, **kwargs): - """return child elements as per the ClauseElement specification.""" - - return ( - self._raw_columns - + list(self._froms) - + [ - x - for x in ( - self._whereclause, - self._having, - self._order_by_clause, - self._group_by_clause, - ) - if x is not None - ] - ) - - def _cache_key(self, **kw): - return ( - (Select,) - + ("raw_columns",) - + tuple(elem._cache_key(**kw) for elem in self._raw_columns) - + ("elements",) - + tuple( - elem._cache_key(**kw) if elem is not None else None - for elem in ( - self._whereclause, - self._having, - self._order_by_clause, - self._group_by_clause, - ) - ) - + ("from_obj",) - + tuple(elem._cache_key(**kw) for elem in self._from_obj) - + ("correlate",) - + tuple( - elem._cache_key(**kw) - for elem in ( - self._correlate if self._correlate is not None else () - ) - ) - + ("correlate_except",) - + tuple( - elem._cache_key(**kw) - for elem in ( - self._correlate_except - if self._correlate_except is not None - else () - ) - ) - + ("for_update",), - ( - self._for_update_arg._cache_key(**kw) - if self._for_update_arg is not None - else None, - ), + # TODO: define "get_children" traversal items separately? + return self._froms + super(Select, self).get_children( + omit_attrs=["_from_obj", "_correlate", "_correlate_except"] ) @_generative @@ -3987,10 +3896,8 @@ class Select( """ if expr: expr = [coercions.expect(roles.ByOfRole, e) for e in expr] - if isinstance(self._distinct, list): - self._distinct = self._distinct + expr - else: - self._distinct = expr + self._distinct = True + self._distinct_on = self._distinct_on + tuple(expr) else: self._distinct = True @@ -4489,6 +4396,11 @@ class TextualSelect(SelectBase): __visit_name__ = "textual_select" + _traverse_internals = [ + ("element", InternalTraversal.dp_clauseelement), + ("column_args", InternalTraversal.dp_clauseelement_list), + ] + SupportsCloneAnnotations._traverse_internals + _is_textual = True def __init__(self, text, columns, positional=False): @@ -4534,18 +4446,6 @@ class TextualSelect(SelectBase): c._make_proxy(fromclause) for c in self.column_args ) - def _copy_internals(self, clone=_clone, **kw): - self._reset_memoizations() - self.element = clone(self.element, **kw) - - def get_children(self, **kw): - return [self.element] - - def _cache_key(self, **kw): - return (TextualSelect, self.element._cache_key(**kw)) + tuple( - col._cache_key(**kw) for col in self.column_args - ) - def _scalar_type(self): return self.column_args[0].type diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py new file mode 100644 index 000000000..c0782ce48 --- /dev/null +++ b/lib/sqlalchemy/sql/traversals.py @@ -0,0 +1,768 @@ +from collections import deque +from collections import namedtuple + +from . import operators +from .visitors import ExtendedInternalTraversal +from .visitors import InternalTraversal +from .. import inspect +from .. import util + +SKIP_TRAVERSE = util.symbol("skip_traverse") +COMPARE_FAILED = False +COMPARE_SUCCEEDED = True +NO_CACHE = util.symbol("no_cache") + + +def compare(obj1, obj2, **kw): + if kw.get("use_proxies", False): + strategy = ColIdentityComparatorStrategy() + else: + strategy = TraversalComparatorStrategy() + + return strategy.compare(obj1, obj2, **kw) + + +class HasCacheKey(object): + _cache_key_traversal = NO_CACHE + + def _gen_cache_key(self, anon_map, bindparams): + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the strucures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, it should raise + NotImplementedError, which will result in the entire structure + for which it's part of not being useful as a cache key. + + + """ + + if self in anon_map: + return (anon_map[self], self.__class__) + + id_ = anon_map[self] + + if self._cache_key_traversal is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result = (id_, self.__class__) + + for attrname, obj, meth in _cache_key_traversal.run_generated_dispatch( + self, self._cache_key_traversal, "_generated_cache_key_traversal" + ): + if obj is not None: + result += meth(attrname, obj, self, anon_map, bindparams) + return result + + def _generate_cache_key(self): + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the strucures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + bindparams = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.key == other.key + + +def _clone(element, **kw): + return element._clone() + + +class _CacheKey(ExtendedInternalTraversal): + def visit_has_cache_key(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key(anon_map, bindparams)) + + def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): + return self.visit_has_cache_key( + attrname, inspect(obj), parent, anon_map, bindparams + ) + + def visit_clauseelement(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key(anon_map, bindparams)) + + def visit_multi(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, + ) + + def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams): + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_inspectable_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_clauseelement_tuples( + self, attrname, obj, parent, anon_map, bindparams + ): + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams): + from . import elements + + name = obj + if isinstance(name, elements._anonymous_label): + name = name.apply_map(anon_map) + + return (attrname, name) + + def visit_fromclause_ordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_clauseelement_unordered_set( + self, attrname, obj, parent, anon_map, bindparams + ): + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj.name) + + def visit_prefix_sequence( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ), + ) + + def visit_statement_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return (attrname, obj) + + def visit_table_hint_list( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ), + ) + + def visit_type(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj._gen_cache_key) + + def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, tuple((key, obj[key]) for key in sorted(obj))) + + def visit_string_clauseelement_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_string(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_boolean(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_operator(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_plain_obj(self, attrname, obj, parent, anon_map, bindparams): + return (attrname, obj) + + def visit_fromclause_canonical_column_collection( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple(col._gen_cache_key(anon_map, bindparams) for col in obj), + ) + + def visit_annotations_state( + self, attrname, obj, parent, anon_map, bindparams + ): + return ( + attrname, + tuple( + ( + key, + self.dispatch(sym)( + key, obj[key], obj, anon_map, bindparams + ), + ) + for key, sym in parent._annotation_traversals + ), + ) + + def visit_unknown_structure( + self, attrname, obj, parent, anon_map, bindparams + ): + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal = _CacheKey() + + +class _CopyInternals(InternalTraversal): + """Generate a _copy_internals internal traversal dispatch for classes + with a _traverse_internals collection.""" + + def visit_clauseelement(self, parent, element, clone=_clone, **kw): + return clone(element, **kw) + + def visit_clauseelement_list(self, parent, element, clone=_clone, **kw): + return [clone(clause, **kw) for clause in element] + + def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw): + return [ + tuple(clone(tup_elem, **kw) for tup_elem in elem) + for elem in element + ] + + def visit_string_clauseelement_dict( + self, parent, element, clone=_clone, **kw + ): + return dict( + (key, clone(value, **kw)) for key, value in element.items() + ) + + +_copy_internals = _CopyInternals() + + +class _GetChildren(InternalTraversal): + """Generate a _children_traversal internal traversal dispatch for classes + with a _traverse_internals collection.""" + + def visit_has_cache_key(self, element, **kw): + return (element,) + + def visit_clauseelement(self, element, **kw): + return (element,) + + def visit_clauseelement_list(self, element, **kw): + return tuple(element) + + def visit_clauseelement_tuples(self, element, **kw): + tup = () + for elem in element: + tup += elem + return tup + + def visit_fromclause_canonical_column_collection(self, element, **kw): + if kw.get("column_collections", False): + return tuple(element) + else: + return () + + def visit_string_clauseelement_dict(self, element, **kw): + return tuple(element.values()) + + def visit_fromclause_ordered_set(self, element, **kw): + return tuple(element) + + def visit_clauseelement_unordered_set(self, element, **kw): + return tuple(element) + + +_get_children = _GetChildren() + + +@util.dependencies("sqlalchemy.sql.elements") +def _resolve_name_for_compare(elements, element, name, anon_map, **kw): + if isinstance(name, elements._anonymous_label): + name = name.apply_map(anon_map) + + return name + + +class anon_map(dict): + """A map that creates new keys for missing key access. + + Produces an incrementing sequence given a series of unique keys. + + This is similar to the compiler prefix_anon_map class although simpler. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __init__(self): + self.index = 0 + + def __missing__(self, key): + self[key] = val = str(self.index) + self.index += 1 + return val + + +class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): + __slots__ = "stack", "cache", "anon_map" + + def __init__(self): + self.stack = deque() + self.cache = set() + + def _memoized_attr_anon_map(self): + return (anon_map(), anon_map()) + + def compare(self, obj1, obj2, **kw): + stack = self.stack + cache = self.cache + + compare_annotations = kw.get("compare_annotations", False) + + stack.append((obj1, obj2)) + + while stack: + left, right = stack.popleft() + + if left is right: + continue + elif left is None or right is None: + # we know they are different so no match + return False + elif (left, right) in cache: + continue + cache.add((left, right)) + + visit_name = left.__visit_name__ + if visit_name != right.__visit_name__: + return False + + meth = getattr(self, "compare_%s" % visit_name, None) + + if meth: + attributes_compared = meth(left, right, **kw) + if attributes_compared is COMPARE_FAILED: + return False + elif attributes_compared is SKIP_TRAVERSE: + continue + + # attributes_compared is returned as a list of attribute + # names that were "handled" by the comparison method above. + # remaining attribute names in the _traverse_internals + # will be compared. + else: + attributes_compared = () + + for ( + (left_attrname, left_visit_sym), + (right_attrname, right_visit_sym), + ) in util.zip_longest( + left._traverse_internals, + right._traverse_internals, + fillvalue=(None, None), + ): + if ( + left_attrname != right_attrname + or left_visit_sym is not right_visit_sym + ): + if not compare_annotations and ( + ( + left_visit_sym + is InternalTraversal.dp_annotations_state, + ) + or ( + right_visit_sym + is InternalTraversal.dp_annotations_state, + ) + ): + continue + + return False + elif left_attrname in attributes_compared: + continue + + dispatch = self.dispatch(left_visit_sym) + left_child = getattr(left, left_attrname) + right_child = getattr(right, right_attrname) + if left_child is None: + if right_child is not None: + return False + else: + continue + + comparison = dispatch( + left, left_child, right, right_child, **kw + ) + if comparison is COMPARE_FAILED: + return False + + return True + + def compare_inner(self, obj1, obj2, **kw): + comparator = self.__class__() + return comparator.compare(obj1, obj2, **kw) + + def visit_has_cache_key( + self, left_parent, left, right_parent, right, **kw + ): + if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key( + self.anon_map[1], [] + ): + return COMPARE_FAILED + + def visit_clauseelement( + self, left_parent, left, right_parent, right, **kw + ): + self.stack.append((left, right)) + + def visit_fromclause_canonical_column_collection( + self, left_parent, left, right_parent, right, **kw + ): + for lcol, rcol in util.zip_longest(left, right, fillvalue=None): + self.stack.append((lcol, rcol)) + + def visit_fromclause_derived_column_collection( + self, left_parent, left, right_parent, right, **kw + ): + pass + + def visit_string_clauseelement_dict( + self, left_parent, left, right_parent, right, **kw + ): + for lstr, rstr in util.zip_longest( + sorted(left), sorted(right), fillvalue=None + ): + if lstr != rstr: + return COMPARE_FAILED + self.stack.append((left[lstr], right[rstr])) + + def visit_annotations_state( + self, left_parent, left, right_parent, right, **kw + ): + if not kw.get("compare_annotations", False): + return + + for (lstr, lmeth), (rstr, rmeth) in util.zip_longest( + left_parent._annotation_traversals, + right_parent._annotation_traversals, + fillvalue=(None, None), + ): + if lstr != rstr or (lmeth is not rmeth): + return COMPARE_FAILED + + dispatch = self.dispatch(lmeth) + left_child = left[lstr] + right_child = right[rstr] + if left_child is None: + if right_child is not None: + return False + else: + continue + + comparison = dispatch(None, left_child, None, right_child, **kw) + if comparison is COMPARE_FAILED: + return comparison + + def visit_clauseelement_tuples( + self, left_parent, left, right_parent, right, **kw + ): + for ltup, rtup in util.zip_longest(left, right, fillvalue=None): + if ltup is None or rtup is None: + return COMPARE_FAILED + + for l, r in util.zip_longest(ltup, rtup, fillvalue=None): + self.stack.append((l, r)) + + def visit_clauseelement_list( + self, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + + def _compare_unordered_sequences(self, seq1, seq2, **kw): + if seq1 is None: + return seq2 is None + + completed = set() + for clause in seq1: + for other_clause in set(seq2).difference(completed): + if self.compare_inner(clause, other_clause, **kw): + completed.add(other_clause) + break + return len(completed) == len(seq1) == len(seq2) + + def visit_clauseelement_unordered_set( + self, left_parent, left, right_parent, right, **kw + ): + return self._compare_unordered_sequences(left, right, **kw) + + def visit_fromclause_ordered_set( + self, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + self.stack.append((l, r)) + + def visit_string(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_anon_name(self, left_parent, left, right_parent, right, **kw): + return _resolve_name_for_compare( + left_parent, left, self.anon_map[0], **kw + ) == _resolve_name_for_compare( + right_parent, right, self.anon_map[1], **kw + ) + + def visit_boolean(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_operator(self, left_parent, left, right_parent, right, **kw): + return left is right + + def visit_type(self, left_parent, left, right_parent, right, **kw): + return left._compare_type_affinity(right) + + def visit_plain_dict(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_plain_obj(self, left_parent, left, right_parent, right, **kw): + return left == right + + def visit_named_ddl_element( + self, left_parent, left, right_parent, right, **kw + ): + if left is None: + if right is not None: + return COMPARE_FAILED + + return left.name == right.name + + def visit_prefix_sequence( + self, left_parent, left, right_parent, right, **kw + ): + for (l_clause, l_str), (r_clause, r_str) in util.zip_longest( + left, right, fillvalue=(None, None) + ): + if l_str != r_str: + return COMPARE_FAILED + else: + self.stack.append((l_clause, r_clause)) + + def visit_table_hint_list( + self, left_parent, left, right_parent, right, **kw + ): + left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1])) + right_keys = sorted( + right, key=lambda elem: (elem[0].fullname, elem[1]) + ) + for (ltable, ldialect), (rtable, rdialect) in util.zip_longest( + left_keys, right_keys, fillvalue=(None, None) + ): + if ldialect != rdialect: + return COMPARE_FAILED + elif left[(ltable, ldialect)] != right[(rtable, rdialect)]: + return COMPARE_FAILED + else: + self.stack.append((ltable, rtable)) + + def visit_statement_hint_list( + self, left_parent, left, right_parent, right, **kw + ): + return left == right + + def visit_unknown_structure( + self, left_parent, left, right_parent, right, **kw + ): + raise NotImplementedError() + + def compare_clauselist(self, left, right, **kw): + if left.operator is right.operator: + if operators.is_associative(left.operator): + if self._compare_unordered_sequences( + left.clauses, right.clauses, **kw + ): + return ["operator", "clauses"] + else: + return COMPARE_FAILED + else: + return ["operator"] + else: + return COMPARE_FAILED + + def compare_binary(self, left, right, **kw): + if left.operator == right.operator: + if operators.is_commutative(left.operator): + if ( + compare(left.left, right.left, **kw) + and compare(left.right, right.right, **kw) + ) or ( + compare(left.left, right.right, **kw) + and compare(left.right, right.left, **kw) + ): + return ["operator", "negate", "left", "right"] + else: + return COMPARE_FAILED + else: + return ["operator", "negate"] + else: + return COMPARE_FAILED + + +class ColIdentityComparatorStrategy(TraversalComparatorStrategy): + def compare_column_element( + self, left, right, use_proxies=True, equivalents=(), **kw + ): + """Compare ColumnElements using proxies and equivalent collections. + + This is a comparison strategy specific to the ORM. + """ + + to_compare = (right,) + if equivalents and right in equivalents: + to_compare = equivalents[right].union(to_compare) + + for oth in to_compare: + if use_proxies and left.shares_lineage(oth): + return SKIP_TRAVERSE + elif hash(left) == hash(right): + return SKIP_TRAVERSE + else: + return COMPARE_FAILED + + def compare_column(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_label(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_table(self, left, right, **kw): + # tables compare on identity, since it's not really feasible to + # compare them column by column with the above rules + return SKIP_TRAVERSE if left is right else COMPARE_FAILED diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9c5f5dd47..d09bb28bb 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -12,8 +12,8 @@ from . import operators from .base import SchemaEventTarget -from .visitors import Visitable -from .visitors import VisitableType +from .visitors import Traversible +from .visitors import TraversibleType from .. import exc from .. import util @@ -28,7 +28,7 @@ INDEXABLE = None _resolve_value_to_type = None -class TypeEngine(Visitable): +class TypeEngine(Traversible): """The ultimate base class for all SQL datatypes. Common subclasses of :class:`.TypeEngine` include @@ -535,8 +535,13 @@ class TypeEngine(Visitable): return dialect.type_descriptor(self) @util.memoized_property - def _cache_key(self): - return util.constructor_key(self, self.__class__) + def _gen_cache_key(self): + names = util.get_cls_kwargs(self.__class__) + return (self.__class__,) + tuple( + (k, self.__dict__[k]) + for k in names + if k in self.__dict__ and not k.startswith("_") + ) def adapt(self, cls, **kw): """Produce an "adapted" form of this type, given an "impl" class @@ -617,7 +622,7 @@ class TypeEngine(Visitable): return util.generic_repr(self) -class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType): +class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType): pass diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e109852a2..8539f4845 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -734,7 +734,7 @@ def criterion_as_pairs( return pairs -class ClauseAdapter(visitors.ReplacingCloningVisitor): +class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. E.g.:: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 7b2ac285a..8c06eb8af 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,14 +28,10 @@ import operator from .. import exc from .. import util - +from ..util import langhelpers +from ..util import symbol __all__ = [ - "VisitableType", - "Visitable", - "ClauseVisitor", - "CloningVisitor", - "ReplacingCloningVisitor", "iterate", "iterate_depthfirst", "traverse_using", @@ -43,85 +39,382 @@ __all__ = [ "traverse_depthfirst", "cloned_traverse", "replacement_traverse", + "Traversible", + "TraversibleType", + "ExternalTraversal", + "InternalTraversal", ] -class VisitableType(type): - """Metaclass which assigns a ``_compiler_dispatch`` method to classes - having a ``__visit_name__`` attribute. +def _generate_compiler_dispatch(cls): + """Generate a _compiler_dispatch() external traversal on classes with a + __visit_name__ attribute. + + """ + visit_name = cls.__visit_name__ + + if isinstance(visit_name, util.compat.string_types): + # There is an optimization opportunity here because the + # the string name of the class's __visit_name__ is known at + # this early stage (import time) so it can be pre-constructed. + getter = operator.attrgetter("visit_%s" % visit_name) + + def _compiler_dispatch(self, visitor, **kw): + try: + meth = getter(visitor) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + else: + # The optimization opportunity is lost for this case because the + # __visit_name__ is not yet a string. As a result, the visit + # string has to be recalculated with each compilation. + def _compiler_dispatch(self, visitor, **kw): + visit_attr = "visit_%s" % self.__visit_name__ + try: + meth = getattr(visitor, visit_attr) + except AttributeError: + raise exc.UnsupportedCompilationError(visitor, cls) + else: + return meth(self, **kw) + + _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + + self.__visit_name__ on the visitor, and call it with the same + kw params. + """ + cls._compiler_dispatch = _compiler_dispatch + + +class TraversibleType(type): + """Metaclass which assigns dispatch attributes to various kinds of + "visitable" classes. - The ``_compiler_dispatch`` attribute becomes an instance method which - looks approximately like the following:: + Attributes include: - def _compiler_dispatch (self, visitor, **kw): - '''Look for an attribute named "visit_" + self.__visit_name__ - on the visitor, and call it with the same kw params.''' - visit_attr = 'visit_%s' % self.__visit_name__ - return getattr(visitor, visit_attr)(self, **kw) + * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``. + This is called "external traversal" because the caller of each visit() + method is responsible for sub-traversing the inner elements of each + object. This is appropriate for string compilers and other traversals + that need to call upon the inner elements in a specific pattern. - Classes having no ``__visit_name__`` attribute will remain unaffected. + * internal traversal collections ``_children_traversal``, + ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from + an optional ``_traverse_internals`` collection of symbols which comes + from the :class:`.InternalTraversal` list of symbols. This is called + "internal traversal" MARKMARK """ def __init__(cls, clsname, bases, clsdict): - if clsname != "Visitable" and hasattr(cls, "__visit_name__"): - _generate_dispatch(cls) + if clsname != "Traversible": + if "__visit_name__" in clsdict: + _generate_compiler_dispatch(cls) + + super(TraversibleType, cls).__init__(clsname, bases, clsdict) - super(VisitableType, cls).__init__(clsname, bases, clsdict) +class Traversible(util.with_metaclass(TraversibleType)): + """Base class for visitable objects, applies the + :class:`.visitors.TraversibleType` metaclass. -def _generate_dispatch(cls): - """Return an optimized visit dispatch function for the cls - for use by the compiler. """ - if "__visit_name__" in cls.__dict__: - visit_name = cls.__visit_name__ - if isinstance(visit_name, util.compat.string_types): - # There is an optimization opportunity here because the - # the string name of the class's __visit_name__ is known at - # this early stage (import time) so it can be pre-constructed. - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - try: - meth = getter(visitor) - except AttributeError: - raise exc.UnsupportedCompilationError(visitor, cls) - else: - return meth(self, **kw) +class _InternalTraversalType(type): + def __init__(cls, clsname, bases, clsdict): + if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"): + lookup = {} + for key, sym in clsdict.items(): + if key.startswith("dp_"): + visit_key = key.replace("dp_", "visit_") + sym_name = sym.name + assert sym_name not in lookup, sym_name + lookup[sym] = lookup[sym_name] = visit_key + if hasattr(cls, "_dispatch_lookup"): + lookup.update(cls._dispatch_lookup) + cls._dispatch_lookup = lookup + + super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict) + + +def _generate_dispatcher(visitor, internal_dispatch, method_name): + names = [] + for attrname, visit_sym in internal_dispatch: + meth = visitor.dispatch(visit_sym) + if meth: + visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym] + names.append((attrname, visit_name)) + + code = ( + (" return [\n") + + ( + ", \n".join( + " (%r, self.%s, visitor.%s)" + % (attrname, attrname, visit_name) + for attrname, visit_name in names + ) + ) + + ("\n ]\n") + ) + meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n" + # print(meth_text) + return langhelpers._exec_code_in_env(meth_text, {}, method_name) - else: - # The optimization opportunity is lost for this case because the - # __visit_name__ is not yet a string. As a result, the visit - # string has to be recalculated with each compilation. - def _compiler_dispatch(self, visitor, **kw): - visit_attr = "visit_%s" % self.__visit_name__ - try: - meth = getattr(visitor, visit_attr) - except AttributeError: - raise exc.UnsupportedCompilationError(visitor, cls) - else: - return meth(self, **kw) - - _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ - on the visitor, and call it with the same kw params. - """ - cls._compiler_dispatch = _compiler_dispatch - - -class Visitable(util.with_metaclass(VisitableType, object)): - """Base class for visitable objects, applies the - :class:`.visitors.VisitableType` metaclass. - The :class:`.Visitable` class is essentially at the base of the - :class:`.ClauseElement` hierarchy. +class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): + r"""Defines visitor symbols used for internal traversal. + + The :class:`.InternalTraversal` class is used in two ways. One is that + it can serve as the superclass for an object that implements the + various visit methods of the class. The other is that the symbols + themselves of :class:`.InternalTraversal` are used within + the ``_traverse_internals`` collection. Such as, the :class:`.Case` + object defines ``_travserse_internals`` as :: + + _traverse_internals = [ + ("value", InternalTraversal.dp_clauseelement), + ("whens", InternalTraversal.dp_clauseelement_tuples), + ("else_", InternalTraversal.dp_clauseelement), + ] + + Above, the :class:`.Case` class indicates its internal state as the + attribtues named ``value``, ``whens``, and ``else\_``. They each + link to an :class:`.InternalTraversal` method which indicates the type + of datastructure referred towards. + + Using the ``_traverse_internals`` structure, objects of type + :class:`.InternalTraversible` will have the following methods automatically + implemented: + + * :meth:`.Traversible.get_children` + + * :meth:`.Traversible._copy_internals` + + * :meth:`.Traversible._gen_cache_key` + + Subclasses can also implement these methods directly, particularly for the + :meth:`.Traversible._copy_internals` method, when special steps + are needed. + + .. versionadded:: 1.4 """ + def dispatch(self, visit_symbol): + """Given a method from :class:`.InternalTraversal`, return the + corresponding method on a subclass. -class ClauseVisitor(object): - """Base class for visitor objects which can traverse using + """ + name = self._dispatch_lookup[visit_symbol] + return getattr(self, name, None) + + def run_generated_dispatch( + self, target, internal_dispatch, generate_dispatcher_name + ): + try: + dispatcher = target.__class__.__dict__[generate_dispatcher_name] + except KeyError: + dispatcher = _generate_dispatcher( + self, internal_dispatch, generate_dispatcher_name + ) + setattr(target.__class__, generate_dispatcher_name, dispatcher) + return dispatcher(target, self) + + dp_has_cache_key = symbol("HC") + """Visit a :class:`.HasCacheKey` object.""" + + dp_clauseelement = symbol("CE") + """Visit a :class:`.ClauseElement` object.""" + + dp_fromclause_canonical_column_collection = symbol("FC") + """Visit a :class:`.FromClause` object in the context of the + ``columns`` attribute. + + The column collection is "canonical", meaning it is the originally + defined location of the :class:`.ColumnClause` objects. Right now + this means that the object being visited is a :class:`.TableClause` + or :class:`.Table` object only. + + """ + + dp_clauseelement_tuples = symbol("CT") + """Visit a list of tuples which contain :class:`.ClauseElement` + objects. + + """ + + dp_clauseelement_list = symbol("CL") + """Visit a list of :class:`.ClauseElement` objects. + + """ + + dp_clauseelement_unordered_set = symbol("CU") + """Visit an unordered set of :class:`.ClauseElement` objects. """ + + dp_fromclause_ordered_set = symbol("CO") + """Visit an ordered set of :class:`.FromClause` objects. """ + + dp_string = symbol("S") + """Visit a plain string value. + + Examples include table and column names, bound parameter keys, special + keywords such as "UNION", "UNION ALL". + + The string value is considered to be significant for cache key + generation. + + """ + + dp_anon_name = symbol("AN") + """Visit a potentially "anonymized" string value. + + The string value is considered to be significant for cache key + generation. + + """ + + dp_boolean = symbol("B") + """Visit a boolean value. + + The boolean value is considered to be significant for cache key + generation. + + """ + + dp_operator = symbol("O") + """Visit an operator. + + The operator is a function from the :mod:`sqlalchemy.sql.operators` + module. + + The operator value is considered to be significant for cache key + generation. + + """ + + dp_type = symbol("T") + """Visit a :class:`.TypeEngine` object + + The type object is considered to be significant for cache key + generation. + + """ + + dp_plain_dict = symbol("PD") + """Visit a dictionary with string keys. + + The keys of the dictionary should be strings, the values should + be immutable and hashable. The dictionary is considered to be + significant for cache key generation. + + """ + + dp_string_clauseelement_dict = symbol("CD") + """Visit a dictionary of string keys to :class:`.ClauseElement` + objects. + + """ + + dp_string_multi_dict = symbol("MD") + """Visit a dictionary of string keys to values which may either be + plain immutable/hashable or :class:`.HasCacheKey` objects. + + """ + + dp_plain_obj = symbol("PO") + """Visit a plain python object. + + The value should be immutable and hashable, such as an integer. + The value is considered to be significant for cache key generation. + + """ + + dp_annotations_state = symbol("A") + """Visit the state of the :class:`.Annotatated` version of an object. + + """ + + dp_named_ddl_element = symbol("DD") + """Visit a simple named DDL element. + + The current object used by this method is the :class:`.Sequence`. + + The object is only considered to be important for cache key generation + as far as its name, but not any other aspects of it. + + """ + + dp_prefix_sequence = symbol("PS") + """Visit the sequence represented by :class:`.HasPrefixes` + or :class:`.HasSuffixes`. + + """ + + dp_table_hint_list = symbol("TH") + """Visit the ``_hints`` collection of a :class:`.Select` object. + + """ + + dp_statement_hint_list = symbol("SH") + """Visit the ``_statement_hints`` collection of a :class:`.Select` + object. + + """ + + dp_unknown_structure = symbol("UK") + """Visit an unknown structure. + + """ + + +class ExtendedInternalTraversal(InternalTraversal): + """defines additional symbols that are useful in caching applications. + + Traversals for :class:`.ClauseElement` objects only need to use + those symbols present in :class:`.InternalTraversal`. However, for + additional caching use cases within the ORM, symbols dealing with the + :class:`.HasCacheKey` class are added here. + + """ + + dp_ignore = symbol("IG") + """Specify an object that should be ignored entirely. + + This currently applies function call argument caching where some + arguments should not be considered to be part of a cache key. + + """ + + dp_inspectable = symbol("IS") + """Visit an inspectable object where the return value is a HasCacheKey` + object.""" + + dp_multi = symbol("M") + """Visit an object that may be a :class:`.HasCacheKey` or may be a + plain hashable object.""" + + dp_multi_list = symbol("MT") + """Visit a tuple containing elements that may be :class:`.HasCacheKey` or + may be a plain hashable object.""" + + dp_has_cache_key_tuples = symbol("HT") + """Visit a list of tuples which contain :class:`.HasCacheKey` + objects. + + """ + + dp_has_cache_key_list = symbol("HL") + """Visit a list of :class:`.HasCacheKey` objects.""" + + dp_inspectable_list = symbol("IL") + """Visit a list of inspectable objects which upon inspection are + HasCacheKey objects.""" + + +class ExternalTraversal(object): + """Base class for visitor objects which can traverse externally using the :func:`.visitors.traverse` function. Direct usage of the :func:`.visitors.traverse` function is usually @@ -178,7 +471,7 @@ class ClauseVisitor(object): return self -class CloningVisitor(ClauseVisitor): +class CloningExternalTraversal(ExternalTraversal): """Base class for visitor objects which can traverse using the :func:`.visitors.cloned_traverse` function. @@ -203,7 +496,7 @@ class CloningVisitor(ClauseVisitor): ) -class ReplacingCloningVisitor(CloningVisitor): +class ReplacingExternalTraversal(CloningExternalTraversal): """Base class for visitor objects which can traverse using the :func:`.visitors.replacement_traverse` function. @@ -233,6 +526,14 @@ class ReplacingCloningVisitor(CloningVisitor): return replacement_traverse(obj, self.__traverse_options__, replace) +# backwards compatibility +Visitable = Traversible +VisitableType = TraversibleType +ClauseVisitor = ExternalTraversal +CloningVisitor = CloningExternalTraversal +ReplacingCloningVisitor = ReplacingExternalTraversal + + def iterate(obj, opts): r"""traverse the given expression structure, returning an iterator. @@ -405,11 +706,18 @@ def cloned_traverse(obj, opts, visitors): cloned = {} stop_on = set(opts.get("stop_on", [])) - def clone(elem): + def clone(elem, **kw): if elem in stop_on: return elem else: if id(elem) not in cloned: + + if "replace" in kw: + newelem = kw["replace"](elem) + if newelem is not None: + cloned[id(elem)] = newelem + return newelem + cloned[id(elem)] = newelem = elem._clone() newelem._copy_internals(clone=clone) meth = visitors.get(newelem.__visit_name__, None) @@ -461,7 +769,14 @@ def replacement_traverse(obj, opts, replace): stop_on.add(id(newelem)) return newelem else: + if elem not in cloned: + if "replace" in kw: + newelem = kw["replace"](elem) + if newelem is not None: + cloned[elem] = newelem + return newelem + cloned[elem] = newelem = elem._clone() newelem._copy_internals(clone=clone, **kw) return cloned[elem] |
