diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-04-17 13:37:39 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-04-29 11:54:25 -0400 |
| commit | 099522075088a3e1a333a2285c10a8a33b203c19 (patch) | |
| tree | b1359c9ff50d19e4787d8ead0bfe5b03ad1fb69a /lib/sqlalchemy/sql | |
| parent | 2f55c844051d9fe8865576bd77107e94c6de16c1 (diff) | |
| download | sqlalchemy-099522075088a3e1a333a2285c10a8a33b203c19.tar.gz | |
Reimplement .compare() in terms of a visitor
Reworked the :meth:`.ClauseElement.compare` methods in terms of a new
visitor-based approach, and additionally added test coverage ensuring that
all :class:`.ClauseElement` subclasses can be accurately compared
against each other in terms of structure. Structural comparison
capability is used to a small degree within the ORM currently, however
it also may form the basis for new caching features.
Fixes: #4336
Change-Id: I581b667d8e1642a6c27165cc9f4aded1c66effc6
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/clause_compare.py | 331 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 123 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 34 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 9 |
6 files changed, 380 insertions, 128 deletions
diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py new file mode 100644 index 000000000..87f9fb2df --- /dev/null +++ b/lib/sqlalchemy/sql/clause_compare.py @@ -0,0 +1,331 @@ +from collections import deque + +from . import operators +from .. import util + + +SKIP_TRAVERSE = util.symbol("skip_traverse") + + +def compare(obj1, obj2, **kw): + if kw.get("use_proxies", False): + strategy = ColIdentityComparatorStrategy() + else: + strategy = StructureComparatorStrategy() + + return strategy.compare(obj1, obj2, **kw) + + +class StructureComparatorStrategy(object): + __slots__ = "compare_stack", "cache" + + def __init__(self): + self.compare_stack = deque() + self.cache = set() + + def compare(self, obj1, obj2, **kw): + stack = self.compare_stack + cache = self.cache + + stack.append((obj1, obj2)) + + while stack: + left, right = stack.popleft() + + if left is right: + continue + elif left is None or right is None: + # we know they are different so no match + return False + elif (left, right) in cache: + continue + cache.add((left, right)) + + visit_name = left.__visit_name__ + + # we're not exactly looking for identical types, because + # there are things like Column and AnnotatedColumn. So the + # visit_name has to at least match up + if visit_name != right.__visit_name__: + return False + + meth = getattr(self, "compare_%s" % visit_name, None) + + if meth: + comparison = meth(left, right, **kw) + if comparison is False: + return False + elif comparison is SKIP_TRAVERSE: + continue + + for c1, c2 in util.zip_longest( + left.get_children(column_collections=False), + right.get_children(column_collections=False), + fillvalue=None, + ): + if c1 is None or c2 is None: + # collections are different sizes, comparison fails + return False + stack.append((c1, c2)) + + return True + + def compare_inner(self, obj1, obj2, **kw): + stack = self.compare_stack + try: + self.compare_stack = deque() + return self.compare(obj1, obj2, **kw) + finally: + self.compare_stack = stack + + def _compare_unordered_sequences(self, seq1, seq2, **kw): + if seq1 is None: + return seq2 is None + + completed = set() + for clause in seq1: + for other_clause in set(seq2).difference(completed): + if self.compare_inner(clause, other_clause, **kw): + completed.add(other_clause) + break + return len(completed) == len(seq1) == len(seq2) + + def compare_bindparam(self, left, right, **kw): + # note the ".key" is often generated from id(self) so can't + # be compared, as far as determining structure. + return ( + left.type._compare_type_affinity(right.type) + and left.value == right.value + and left.callable == right.callable + and left._orig_key == right._orig_key + ) + + def compare_clauselist(self, left, right, **kw): + if left.operator is right.operator: + if operators.is_associative(left.operator): + if self._compare_unordered_sequences( + left.clauses, right.clauses + ): + return SKIP_TRAVERSE + else: + return False + else: + # normal ordered traversal + return True + else: + return False + + def compare_unary(self, left, right, **kw): + if left.operator: + disp = self._get_operator_dispatch( + left.operator, "unary", "operator" + ) + if disp is not None: + result = disp(left, right, left.operator, **kw) + if result is not True: + return result + elif left.modifier: + disp = self._get_operator_dispatch( + left.modifier, "unary", "modifier" + ) + if disp is not None: + result = disp(left, right, left.operator, **kw) + if result is not True: + return result + return ( + left.operator == right.operator and left.modifier == right.modifier + ) + + def compare_binary(self, left, right, **kw): + disp = self._get_operator_dispatch(left.operator, "binary", None) + if disp: + result = disp(left, right, left.operator, **kw) + if result is not True: + return result + + if left.operator == right.operator: + if operators.is_commutative(left.operator): + if ( + compare(left.left, right.left, **kw) + and compare(left.right, right.right, **kw) + ) or ( + compare(left.left, right.right, **kw) + and compare(left.right, right.left, **kw) + ): + return SKIP_TRAVERSE + else: + return False + else: + return True + else: + return False + + def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): + # used by compare_binary, compare_unary + attrname = "visit_%s_%s%s" % ( + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) + return getattr(self, attrname, None) + + def visit_function_as_comparison_op_binary( + self, left, right, operator, **kw + ): + return ( + left.left_index == right.left_index + and left.right_index == right.right_index + ) + + def compare_function(self, left, right, **kw): + return left.name == right.name + + def compare_column(self, left, right, **kw): + if left.table is not None: + self.compare_stack.appendleft((left.table, right.table)) + return ( + left.key == right.key + and left.name == right.name + and ( + left.type._compare_type_affinity(right.type) + if left.type is not None + else right.type is None + ) + and left.is_literal == right.is_literal + ) + + def compare_collation(self, left, right, **kw): + return left.collation == right.collation + + def compare_type_coerce(self, left, right, **kw): + return left.type._compare_type_affinity(right.type) + + @util.dependencies("sqlalchemy.sql.elements") + def compare_alias(self, elements, left, right, **kw): + return ( + left.name == right.name + if not isinstance(left.name, elements._anonymous_label) + else isinstance(right.name, elements._anonymous_label) + ) + + def compare_extract(self, left, right, **kw): + return left.field == right.field + + def compare_textual_label_reference(self, left, right, **kw): + return left.element == right.element + + def compare_slice(self, left, right, **kw): + return ( + left.start == right.start + and left.stop == right.stop + and left.step == right.step + ) + + def compare_over(self, left, right, **kw): + return left.range_ == right.range_ and left.rows == right.rows + + @util.dependencies("sqlalchemy.sql.elements") + def compare_label(self, elements, left, right, **kw): + return left._type._compare_type_affinity(right._type) and ( + left.name == right.name + if not isinstance(left, elements._anonymous_label) + else isinstance(right.name, elements._anonymous_label) + ) + + def compare_typeclause(self, left, right, **kw): + return left.type._compare_type_affinity(right.type) + + def compare_join(self, left, right, **kw): + return left.isouter == right.isouter and left.full == right.full + + def compare_table(self, left, right, **kw): + if left.name != right.name: + return False + + self.compare_stack.extendleft( + util.zip_longest(left.columns, right.columns) + ) + + def compare_compound_select(self, left, right, **kw): + + if not self._compare_unordered_sequences( + left.selects, right.selects, **kw + ): + return False + + if left.keyword != right.keyword: + return False + + if left._for_update_arg != right._for_update_arg: + return False + + if not self.compare_inner( + left._order_by_clause, right._order_by_clause, **kw + ): + return False + + if not self.compare_inner( + left._group_by_clause, right._group_by_clause, **kw + ): + return False + + return SKIP_TRAVERSE + + def compare_select(self, left, right, **kw): + if not self._compare_unordered_sequences( + left._correlate, right._correlate + ): + return False + if not self._compare_unordered_sequences( + left._correlate_except, right._correlate_except + ): + return False + + if not self._compare_unordered_sequences( + left._from_obj, right._from_obj + ): + return False + + if left._for_update_arg != right._for_update_arg: + return False + + return True + + def compare_text_as_from(self, left, right, **kw): + self.compare_stack.extendleft( + util.zip_longest(left.column_args, right.column_args) + ) + return left.positional == right.positional + + +class ColIdentityComparatorStrategy(StructureComparatorStrategy): + def compare_column_element( + self, left, right, use_proxies=True, equivalents=(), **kw + ): + """Compare ColumnElements using proxies and equivalent collections. + + This is a comparison strategy specific to the ORM. + """ + + to_compare = (right,) + if equivalents and right in equivalents: + to_compare = equivalents[right].union(to_compare) + + for oth in to_compare: + if use_proxies and left.shares_lineage(oth): + return True + elif hash(left) == hash(right): + return True + else: + return False + + def compare_column(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_label(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_table(self, left, right, **kw): + # tables compare on identity, since it's not really feasible to + # compare them column by column with the above rules + return left is right diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 6c9b8ee5b..552f61b4a 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -482,6 +482,12 @@ class _multiparam_column(elements.ColumnElement): self.default = original.default self.type = original.type + def compare(self, other, **kw): + raise NotImplementedError() + + def _copy_internals(self, other, **kw): + raise NotImplementedError() + def __eq__(self, other): return ( isinstance(other, _multiparam_column) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index b0d0feff5..38c7cf840 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -17,6 +17,7 @@ import numbers import operator import re +from . import clause_compare from . import operators from . import type_api from .annotation import Annotated @@ -341,7 +342,7 @@ class ClauseElement(Visitable): (see :class:`.ColumnElement`) """ - return self is other + return clause_compare.compare(self, other, **kw) def _copy_internals(self, clone=_clone, **kw): """Reassign internal elements to be clones of themselves. @@ -810,34 +811,6 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): selectable._columns[key] = co return co - def compare(self, other, use_proxies=False, equivalents=None, **kw): - """Compare this ColumnElement to another. - - Special arguments understood: - - :param use_proxies: when True, consider two columns that - share a common base column as equivalent (i.e. shares_lineage()) - - :param equivalents: a dictionary of columns as keys mapped to sets - of columns. If the given "other" column is present in this - dictionary, if any of the columns in the corresponding set() pass - the comparison test, the result is True. This is used to expand the - comparison to other columns that may be known to be equivalent to - this one via foreign key or other criterion. - - """ - to_compare = (other,) - if equivalents and other in equivalents: - to_compare = equivalents[other].union(to_compare) - - for oth in to_compare: - if use_proxies and self.shares_lineage(oth): - return True - elif hash(oth) == hash(self): - return True - else: - return False - def cast(self, type_): """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``. @@ -1226,17 +1199,6 @@ class BindParameter(ColumnElement): "%%(%d %s)s" % (id(self), self._orig_key or "param") ) - def compare(self, other, **kw): - """Compare this :class:`BindParameter` to the given - clause.""" - - return ( - isinstance(other, BindParameter) - and self.type._compare_type_affinity(other.type) - and self.value == other.value - and self.callable == other.callable - ) - def __getstate__(self): """execute a deferred value for serialization purposes.""" @@ -1696,9 +1658,6 @@ class TextClause(Executable, ClauseElement): def get_children(self, **kwargs): return list(self._bindparams.values()) - def compare(self, other): - return isinstance(other, TextClause) and other.text == self.text - class Null(ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -1720,9 +1679,6 @@ class Null(ColumnElement): return Null() - def compare(self, other): - return isinstance(other, Null) - class False_(ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -1779,9 +1735,6 @@ class False_(ColumnElement): return False_() - def compare(self, other): - return isinstance(other, False_) - class True_(ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. @@ -1845,9 +1798,6 @@ class True_(ColumnElement): return True_() - def compare(self, other): - return isinstance(other, True_) - class ClauseList(ClauseElement): """Describe a list of clauses, separated by an operator. @@ -1908,38 +1858,6 @@ class ClauseList(ClauseElement): else: return self - def compare(self, other, **kw): - """Compare this :class:`.ClauseList` to the given :class:`.ClauseList`, - including a comparison of all the clause items. - - """ - if not isinstance(other, ClauseList) and len(self.clauses) == 1: - return self.clauses[0].compare(other, **kw) - elif ( - isinstance(other, ClauseList) - and len(self.clauses) == len(other.clauses) - and self.operator is other.operator - ): - - if self.operator in (operators.and_, operators.or_): - completed = set() - for clause in self.clauses: - for other_clause in set(other.clauses).difference( - completed - ): - if clause.compare(other_clause, **kw): - completed.add(other_clause) - break - return len(completed) == len(other.clauses) - else: - for i in range(0, len(self.clauses)): - if not self.clauses[i].compare(other.clauses[i], **kw): - return False - else: - return True - else: - return False - class BooleanClauseList(ClauseList, ColumnElement): __visit_name__ = "clauselist" @@ -2606,6 +2524,9 @@ class _label_reference(ColumnElement): def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) + def get_children(self, **kwargs): + return [self.element] + @property def _from_objects(self): return () @@ -2885,17 +2806,6 @@ class UnaryExpression(ColumnElement): def get_children(self, **kwargs): return (self.element,) - def compare(self, other, **kw): - """Compare this :class:`UnaryExpression` against the given - :class:`.ClauseElement`.""" - - return ( - isinstance(other, UnaryExpression) - and self.operator == other.operator - and self.modifier == other.modifier - and self.element.compare(other.element, **kw) - ) - def _negate(self): if self.negate is not None: return UnaryExpression( @@ -3103,24 +3013,6 @@ class BinaryExpression(ColumnElement): def get_children(self, **kwargs): return self.left, self.right - def compare(self, other, **kw): - """Compare this :class:`BinaryExpression` against the - given :class:`BinaryExpression`.""" - - return ( - isinstance(other, BinaryExpression) - and self.operator == other.operator - and ( - self.left.compare(other.left, **kw) - and self.right.compare(other.right, **kw) - or ( - operators.is_commutative(self.operator) - and self.left.compare(other.right, **kw) - and self.right.compare(other.left, **kw) - ) - ) - ) - def self_group(self, against=None): if operators.is_precedent(self.operator, against): return Grouping(self) @@ -3213,11 +3105,6 @@ class Grouping(ColumnElement): self.element = state["element"] self.type = state["type"] - def compare(self, other, **kw): - return isinstance(other, Grouping) and self.element.compare( - other.element - ) - RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED") RANGE_CURRENT = util.symbol("RANGE_CURRENT") diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fcc843d91..f48a20ec7 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -74,7 +74,9 @@ class FunctionElement(Executable, ColumnElement, FromClause): def __init__(self, *clauses, **kwargs): """Construct a :class:`.FunctionElement`. """ - args = [_literal_as_binds(c, self.name) for c in clauses] + args = [ + _literal_as_binds(c, getattr(self, "name", None)) for c in clauses + ] self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *args @@ -376,12 +378,11 @@ class FunctionAsBinary(BinaryExpression): self.left_index = left_index self.right_index = right_index - super(FunctionAsBinary, self).__init__( - left, - right, - operators.function_as_comparison_op, - type_=sqltypes.BOOLEANTYPE, - ) + self.operator = operators.function_as_comparison_op + self.type = sqltypes.BOOLEANTYPE + self.negate = None + self._is_implicitly_boolean = True + self.modifiers = {} @property def left(self): @@ -399,10 +400,11 @@ class FunctionAsBinary(BinaryExpression): def right(self, value): self.sql_function.clauses.clauses[self.right_index - 1] = value - def _copy_internals(self, **kw): - clone = kw.pop("clone") + def _copy_internals(self, clone=_clone, **kw): self.sql_function = clone(self.sql_function, **kw) - super(FunctionAsBinary, self)._copy_internals(**kw) + + def get_children(self, **kw): + yield self.sql_function class _FunctionGenerator(object): @@ -682,6 +684,18 @@ class next_value(GenericFunction): self._bind = kw.get("bind", None) self.sequence = seq + def compare(self, other, **kw): + return ( + isinstance(other, next_value) + and self.sequence.name == other.sequence.name + ) + + def get_children(self, **kwargs): + return [] + + def _copy_internals(self, **kw): + pass + @property def _from_objects(self): return [] diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 4206de460..8479c1d59 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1414,6 +1414,11 @@ def mirror(op): _associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne]) + +def is_associative(op): + return op in _associative + + _natural_self_precedent = _associative.union( [getitem, json_getitem_op, json_path_getitem_op] ) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d4528f0c3..796e2b272 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1994,6 +1994,9 @@ class ForUpdateArg(ClauseElement): and other.of is self.of ) + def __ne__(self, other): + return not self.__eq__(other) + def __hash__(self): return id(self) @@ -3941,6 +3944,12 @@ class TextAsFrom(SelectBase): self._reset_exported() self.element = clone(self.element, **kw) + def get_children(self, column_collections=True, **kw): + if column_collections: + for c in self.column_args: + yield c + yield self.element + def _scalar_type(self): return self.column_args[0].type |
