summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-04-17 13:37:39 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2019-04-29 11:54:25 -0400
commit099522075088a3e1a333a2285c10a8a33b203c19 (patch)
treeb1359c9ff50d19e4787d8ead0bfe5b03ad1fb69a /lib/sqlalchemy/sql
parent2f55c844051d9fe8865576bd77107e94c6de16c1 (diff)
downloadsqlalchemy-099522075088a3e1a333a2285c10a8a33b203c19.tar.gz
Reimplement .compare() in terms of a visitor
Reworked the :meth:`.ClauseElement.compare` methods in terms of a new visitor-based approach, and additionally added test coverage ensuring that all :class:`.ClauseElement` subclasses can be accurately compared against each other in terms of structure. Structural comparison capability is used to a small degree within the ORM currently, however it also may form the basis for new caching features. Fixes: #4336 Change-Id: I581b667d8e1642a6c27165cc9f4aded1c66effc6
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/clause_compare.py331
-rw-r--r--lib/sqlalchemy/sql/crud.py6
-rw-r--r--lib/sqlalchemy/sql/elements.py123
-rw-r--r--lib/sqlalchemy/sql/functions.py34
-rw-r--r--lib/sqlalchemy/sql/operators.py5
-rw-r--r--lib/sqlalchemy/sql/selectable.py9
6 files changed, 380 insertions, 128 deletions
diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py
new file mode 100644
index 000000000..87f9fb2df
--- /dev/null
+++ b/lib/sqlalchemy/sql/clause_compare.py
@@ -0,0 +1,331 @@
+from collections import deque
+
+from . import operators
+from .. import util
+
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+
+
+def compare(obj1, obj2, **kw):
+ if kw.get("use_proxies", False):
+ strategy = ColIdentityComparatorStrategy()
+ else:
+ strategy = StructureComparatorStrategy()
+
+ return strategy.compare(obj1, obj2, **kw)
+
+
+class StructureComparatorStrategy(object):
+ __slots__ = "compare_stack", "cache"
+
+ def __init__(self):
+ self.compare_stack = deque()
+ self.cache = set()
+
+ def compare(self, obj1, obj2, **kw):
+ stack = self.compare_stack
+ cache = self.cache
+
+ stack.append((obj1, obj2))
+
+ while stack:
+ left, right = stack.popleft()
+
+ if left is right:
+ continue
+ elif left is None or right is None:
+ # we know they are different so no match
+ return False
+ elif (left, right) in cache:
+ continue
+ cache.add((left, right))
+
+ visit_name = left.__visit_name__
+
+ # we're not exactly looking for identical types, because
+ # there are things like Column and AnnotatedColumn. So the
+ # visit_name has to at least match up
+ if visit_name != right.__visit_name__:
+ return False
+
+ meth = getattr(self, "compare_%s" % visit_name, None)
+
+ if meth:
+ comparison = meth(left, right, **kw)
+ if comparison is False:
+ return False
+ elif comparison is SKIP_TRAVERSE:
+ continue
+
+ for c1, c2 in util.zip_longest(
+ left.get_children(column_collections=False),
+ right.get_children(column_collections=False),
+ fillvalue=None,
+ ):
+ if c1 is None or c2 is None:
+ # collections are different sizes, comparison fails
+ return False
+ stack.append((c1, c2))
+
+ return True
+
+ def compare_inner(self, obj1, obj2, **kw):
+ stack = self.compare_stack
+ try:
+ self.compare_stack = deque()
+ return self.compare(obj1, obj2, **kw)
+ finally:
+ self.compare_stack = stack
+
+ def _compare_unordered_sequences(self, seq1, seq2, **kw):
+ if seq1 is None:
+ return seq2 is None
+
+ completed = set()
+ for clause in seq1:
+ for other_clause in set(seq2).difference(completed):
+ if self.compare_inner(clause, other_clause, **kw):
+ completed.add(other_clause)
+ break
+ return len(completed) == len(seq1) == len(seq2)
+
+ def compare_bindparam(self, left, right, **kw):
+ # note the ".key" is often generated from id(self) so can't
+ # be compared, as far as determining structure.
+ return (
+ left.type._compare_type_affinity(right.type)
+ and left.value == right.value
+ and left.callable == right.callable
+ and left._orig_key == right._orig_key
+ )
+
+ def compare_clauselist(self, left, right, **kw):
+ if left.operator is right.operator:
+ if operators.is_associative(left.operator):
+ if self._compare_unordered_sequences(
+ left.clauses, right.clauses
+ ):
+ return SKIP_TRAVERSE
+ else:
+ return False
+ else:
+ # normal ordered traversal
+ return True
+ else:
+ return False
+
+ def compare_unary(self, left, right, **kw):
+ if left.operator:
+ disp = self._get_operator_dispatch(
+ left.operator, "unary", "operator"
+ )
+ if disp is not None:
+ result = disp(left, right, left.operator, **kw)
+ if result is not True:
+ return result
+ elif left.modifier:
+ disp = self._get_operator_dispatch(
+ left.modifier, "unary", "modifier"
+ )
+ if disp is not None:
+ result = disp(left, right, left.operator, **kw)
+ if result is not True:
+ return result
+ return (
+ left.operator == right.operator and left.modifier == right.modifier
+ )
+
+ def compare_binary(self, left, right, **kw):
+ disp = self._get_operator_dispatch(left.operator, "binary", None)
+ if disp:
+ result = disp(left, right, left.operator, **kw)
+ if result is not True:
+ return result
+
+ if left.operator == right.operator:
+ if operators.is_commutative(left.operator):
+ if (
+ compare(left.left, right.left, **kw)
+ and compare(left.right, right.right, **kw)
+ ) or (
+ compare(left.left, right.right, **kw)
+ and compare(left.right, right.left, **kw)
+ ):
+ return SKIP_TRAVERSE
+ else:
+ return False
+ else:
+ return True
+ else:
+ return False
+
+ def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
+ # used by compare_binary, compare_unary
+ attrname = "visit_%s_%s%s" % (
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
+ return getattr(self, attrname, None)
+
+ def visit_function_as_comparison_op_binary(
+ self, left, right, operator, **kw
+ ):
+ return (
+ left.left_index == right.left_index
+ and left.right_index == right.right_index
+ )
+
+ def compare_function(self, left, right, **kw):
+ return left.name == right.name
+
+ def compare_column(self, left, right, **kw):
+ if left.table is not None:
+ self.compare_stack.appendleft((left.table, right.table))
+ return (
+ left.key == right.key
+ and left.name == right.name
+ and (
+ left.type._compare_type_affinity(right.type)
+ if left.type is not None
+ else right.type is None
+ )
+ and left.is_literal == right.is_literal
+ )
+
+ def compare_collation(self, left, right, **kw):
+ return left.collation == right.collation
+
+ def compare_type_coerce(self, left, right, **kw):
+ return left.type._compare_type_affinity(right.type)
+
+ @util.dependencies("sqlalchemy.sql.elements")
+ def compare_alias(self, elements, left, right, **kw):
+ return (
+ left.name == right.name
+ if not isinstance(left.name, elements._anonymous_label)
+ else isinstance(right.name, elements._anonymous_label)
+ )
+
+ def compare_extract(self, left, right, **kw):
+ return left.field == right.field
+
+ def compare_textual_label_reference(self, left, right, **kw):
+ return left.element == right.element
+
+ def compare_slice(self, left, right, **kw):
+ return (
+ left.start == right.start
+ and left.stop == right.stop
+ and left.step == right.step
+ )
+
+ def compare_over(self, left, right, **kw):
+ return left.range_ == right.range_ and left.rows == right.rows
+
+ @util.dependencies("sqlalchemy.sql.elements")
+ def compare_label(self, elements, left, right, **kw):
+ return left._type._compare_type_affinity(right._type) and (
+ left.name == right.name
+ if not isinstance(left, elements._anonymous_label)
+ else isinstance(right.name, elements._anonymous_label)
+ )
+
+ def compare_typeclause(self, left, right, **kw):
+ return left.type._compare_type_affinity(right.type)
+
+ def compare_join(self, left, right, **kw):
+ return left.isouter == right.isouter and left.full == right.full
+
+ def compare_table(self, left, right, **kw):
+ if left.name != right.name:
+ return False
+
+ self.compare_stack.extendleft(
+ util.zip_longest(left.columns, right.columns)
+ )
+
+ def compare_compound_select(self, left, right, **kw):
+
+ if not self._compare_unordered_sequences(
+ left.selects, right.selects, **kw
+ ):
+ return False
+
+ if left.keyword != right.keyword:
+ return False
+
+ if left._for_update_arg != right._for_update_arg:
+ return False
+
+ if not self.compare_inner(
+ left._order_by_clause, right._order_by_clause, **kw
+ ):
+ return False
+
+ if not self.compare_inner(
+ left._group_by_clause, right._group_by_clause, **kw
+ ):
+ return False
+
+ return SKIP_TRAVERSE
+
+ def compare_select(self, left, right, **kw):
+ if not self._compare_unordered_sequences(
+ left._correlate, right._correlate
+ ):
+ return False
+ if not self._compare_unordered_sequences(
+ left._correlate_except, right._correlate_except
+ ):
+ return False
+
+ if not self._compare_unordered_sequences(
+ left._from_obj, right._from_obj
+ ):
+ return False
+
+ if left._for_update_arg != right._for_update_arg:
+ return False
+
+ return True
+
+ def compare_text_as_from(self, left, right, **kw):
+ self.compare_stack.extendleft(
+ util.zip_longest(left.column_args, right.column_args)
+ )
+ return left.positional == right.positional
+
+
+class ColIdentityComparatorStrategy(StructureComparatorStrategy):
+ def compare_column_element(
+ self, left, right, use_proxies=True, equivalents=(), **kw
+ ):
+ """Compare ColumnElements using proxies and equivalent collections.
+
+ This is a comparison strategy specific to the ORM.
+ """
+
+ to_compare = (right,)
+ if equivalents and right in equivalents:
+ to_compare = equivalents[right].union(to_compare)
+
+ for oth in to_compare:
+ if use_proxies and left.shares_lineage(oth):
+ return True
+ elif hash(left) == hash(right):
+ return True
+ else:
+ return False
+
+ def compare_column(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_label(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_table(self, left, right, **kw):
+ # tables compare on identity, since it's not really feasible to
+ # compare them column by column with the above rules
+ return left is right
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 6c9b8ee5b..552f61b4a 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -482,6 +482,12 @@ class _multiparam_column(elements.ColumnElement):
self.default = original.default
self.type = original.type
+ def compare(self, other, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, other, **kw):
+ raise NotImplementedError()
+
def __eq__(self, other):
return (
isinstance(other, _multiparam_column)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index b0d0feff5..38c7cf840 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -17,6 +17,7 @@ import numbers
import operator
import re
+from . import clause_compare
from . import operators
from . import type_api
from .annotation import Annotated
@@ -341,7 +342,7 @@ class ClauseElement(Visitable):
(see :class:`.ColumnElement`)
"""
- return self is other
+ return clause_compare.compare(self, other, **kw)
def _copy_internals(self, clone=_clone, **kw):
"""Reassign internal elements to be clones of themselves.
@@ -810,34 +811,6 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
selectable._columns[key] = co
return co
- def compare(self, other, use_proxies=False, equivalents=None, **kw):
- """Compare this ColumnElement to another.
-
- Special arguments understood:
-
- :param use_proxies: when True, consider two columns that
- share a common base column as equivalent (i.e. shares_lineage())
-
- :param equivalents: a dictionary of columns as keys mapped to sets
- of columns. If the given "other" column is present in this
- dictionary, if any of the columns in the corresponding set() pass
- the comparison test, the result is True. This is used to expand the
- comparison to other columns that may be known to be equivalent to
- this one via foreign key or other criterion.
-
- """
- to_compare = (other,)
- if equivalents and other in equivalents:
- to_compare = equivalents[other].union(to_compare)
-
- for oth in to_compare:
- if use_proxies and self.shares_lineage(oth):
- return True
- elif hash(oth) == hash(self):
- return True
- else:
- return False
-
def cast(self, type_):
"""Produce a type cast, i.e. ``CAST(<expression> AS <type>)``.
@@ -1226,17 +1199,6 @@ class BindParameter(ColumnElement):
"%%(%d %s)s" % (id(self), self._orig_key or "param")
)
- def compare(self, other, **kw):
- """Compare this :class:`BindParameter` to the given
- clause."""
-
- return (
- isinstance(other, BindParameter)
- and self.type._compare_type_affinity(other.type)
- and self.value == other.value
- and self.callable == other.callable
- )
-
def __getstate__(self):
"""execute a deferred value for serialization purposes."""
@@ -1696,9 +1658,6 @@ class TextClause(Executable, ClauseElement):
def get_children(self, **kwargs):
return list(self._bindparams.values())
- def compare(self, other):
- return isinstance(other, TextClause) and other.text == self.text
-
class Null(ColumnElement):
"""Represent the NULL keyword in a SQL statement.
@@ -1720,9 +1679,6 @@ class Null(ColumnElement):
return Null()
- def compare(self, other):
- return isinstance(other, Null)
-
class False_(ColumnElement):
"""Represent the ``false`` keyword, or equivalent, in a SQL statement.
@@ -1779,9 +1735,6 @@ class False_(ColumnElement):
return False_()
- def compare(self, other):
- return isinstance(other, False_)
-
class True_(ColumnElement):
"""Represent the ``true`` keyword, or equivalent, in a SQL statement.
@@ -1845,9 +1798,6 @@ class True_(ColumnElement):
return True_()
- def compare(self, other):
- return isinstance(other, True_)
-
class ClauseList(ClauseElement):
"""Describe a list of clauses, separated by an operator.
@@ -1908,38 +1858,6 @@ class ClauseList(ClauseElement):
else:
return self
- def compare(self, other, **kw):
- """Compare this :class:`.ClauseList` to the given :class:`.ClauseList`,
- including a comparison of all the clause items.
-
- """
- if not isinstance(other, ClauseList) and len(self.clauses) == 1:
- return self.clauses[0].compare(other, **kw)
- elif (
- isinstance(other, ClauseList)
- and len(self.clauses) == len(other.clauses)
- and self.operator is other.operator
- ):
-
- if self.operator in (operators.and_, operators.or_):
- completed = set()
- for clause in self.clauses:
- for other_clause in set(other.clauses).difference(
- completed
- ):
- if clause.compare(other_clause, **kw):
- completed.add(other_clause)
- break
- return len(completed) == len(other.clauses)
- else:
- for i in range(0, len(self.clauses)):
- if not self.clauses[i].compare(other.clauses[i], **kw):
- return False
- else:
- return True
- else:
- return False
-
class BooleanClauseList(ClauseList, ColumnElement):
__visit_name__ = "clauselist"
@@ -2606,6 +2524,9 @@ class _label_reference(ColumnElement):
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
+ def get_children(self, **kwargs):
+ return [self.element]
+
@property
def _from_objects(self):
return ()
@@ -2885,17 +2806,6 @@ class UnaryExpression(ColumnElement):
def get_children(self, **kwargs):
return (self.element,)
- def compare(self, other, **kw):
- """Compare this :class:`UnaryExpression` against the given
- :class:`.ClauseElement`."""
-
- return (
- isinstance(other, UnaryExpression)
- and self.operator == other.operator
- and self.modifier == other.modifier
- and self.element.compare(other.element, **kw)
- )
-
def _negate(self):
if self.negate is not None:
return UnaryExpression(
@@ -3103,24 +3013,6 @@ class BinaryExpression(ColumnElement):
def get_children(self, **kwargs):
return self.left, self.right
- def compare(self, other, **kw):
- """Compare this :class:`BinaryExpression` against the
- given :class:`BinaryExpression`."""
-
- return (
- isinstance(other, BinaryExpression)
- and self.operator == other.operator
- and (
- self.left.compare(other.left, **kw)
- and self.right.compare(other.right, **kw)
- or (
- operators.is_commutative(self.operator)
- and self.left.compare(other.right, **kw)
- and self.right.compare(other.left, **kw)
- )
- )
- )
-
def self_group(self, against=None):
if operators.is_precedent(self.operator, against):
return Grouping(self)
@@ -3213,11 +3105,6 @@ class Grouping(ColumnElement):
self.element = state["element"]
self.type = state["type"]
- def compare(self, other, **kw):
- return isinstance(other, Grouping) and self.element.compare(
- other.element
- )
-
RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
RANGE_CURRENT = util.symbol("RANGE_CURRENT")
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index fcc843d91..f48a20ec7 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -74,7 +74,9 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def __init__(self, *clauses, **kwargs):
"""Construct a :class:`.FunctionElement`.
"""
- args = [_literal_as_binds(c, self.name) for c in clauses]
+ args = [
+ _literal_as_binds(c, getattr(self, "name", None)) for c in clauses
+ ]
self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
operator=operators.comma_op, group_contents=True, *args
@@ -376,12 +378,11 @@ class FunctionAsBinary(BinaryExpression):
self.left_index = left_index
self.right_index = right_index
- super(FunctionAsBinary, self).__init__(
- left,
- right,
- operators.function_as_comparison_op,
- type_=sqltypes.BOOLEANTYPE,
- )
+ self.operator = operators.function_as_comparison_op
+ self.type = sqltypes.BOOLEANTYPE
+ self.negate = None
+ self._is_implicitly_boolean = True
+ self.modifiers = {}
@property
def left(self):
@@ -399,10 +400,11 @@ class FunctionAsBinary(BinaryExpression):
def right(self, value):
self.sql_function.clauses.clauses[self.right_index - 1] = value
- def _copy_internals(self, **kw):
- clone = kw.pop("clone")
+ def _copy_internals(self, clone=_clone, **kw):
self.sql_function = clone(self.sql_function, **kw)
- super(FunctionAsBinary, self)._copy_internals(**kw)
+
+ def get_children(self, **kw):
+ yield self.sql_function
class _FunctionGenerator(object):
@@ -682,6 +684,18 @@ class next_value(GenericFunction):
self._bind = kw.get("bind", None)
self.sequence = seq
+ def compare(self, other, **kw):
+ return (
+ isinstance(other, next_value)
+ and self.sequence.name == other.sequence.name
+ )
+
+ def get_children(self, **kwargs):
+ return []
+
+ def _copy_internals(self, **kw):
+ pass
+
@property
def _from_objects(self):
return []
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 4206de460..8479c1d59 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -1414,6 +1414,11 @@ def mirror(op):
_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
+
+def is_associative(op):
+ return op in _associative
+
+
_natural_self_precedent = _associative.union(
[getitem, json_getitem_op, json_path_getitem_op]
)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index d4528f0c3..796e2b272 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -1994,6 +1994,9 @@ class ForUpdateArg(ClauseElement):
and other.of is self.of
)
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
def __hash__(self):
return id(self)
@@ -3941,6 +3944,12 @@ class TextAsFrom(SelectBase):
self._reset_exported()
self.element = clone(self.element, **kw)
+ def get_children(self, column_collections=True, **kw):
+ if column_collections:
+ for c in self.column_args:
+ yield c
+ yield self.element
+
def _scalar_type(self):
return self.column_args[0].type