diff options
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 77 |
1 files changed, 39 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 9ac6cda97..032488826 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import util from ..inspection import inspect +from ..util import collections_abc from ..util import HasMemoized SKIP_TRAVERSE = util.symbol("skip_traverse") @@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal): ] def visit_dml_values(self, parent, element, clone=_clone, **kw): - # sequence of dictionaries - return [ - { - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key - ): clone(value, **kw) - for key, value in sub_element.items() - } - for sub_element in element - ] + return { + ( + clone(key, **kw) if hasattr(key, "__clause_element__") else key + ): clone(value, **kw) + for key, value in element.items() + } def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): # sequence of sequences, each sequence contains a list/dict/tuple @@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal): def copy(elem): if isinstance(elem, (list, tuple)): return [ - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key, - clone(value, **kw) - if hasattr(value, "__clause_element__") - else value, - ) - for key, value in elem + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value + for value in elem ] elif isinstance(elem, dict): return { @@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal): if hasattr(value, "__clause_element__") else value ) - for key, value in elem + for key, value in elem.items() } else: # TODO: use abc classes @@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for (lk, lv), (rk, rv) in util.zip_longest( left, right, fillvalue=(None, None) ): - lkce = hasattr(lk, "__clause_element__") - rkce = hasattr(rk, "__clause_element__") - if lkce != rkce: - return COMPARE_FAILED - elif lkce and not self.compare_inner(lk, rk, **kw): - return COMPARE_FAILED - elif not lkce and lk != rk: - return COMPARE_FAILED - elif not self.compare_inner(lv, rv, **kw): + if not self._compare_dml_values_or_ce(lk, rk, **kw): return COMPARE_FAILED + def _compare_dml_values_or_ce(self, lv, rv, **kw): + lvce = hasattr(lv, "__clause_element__") + rvce = hasattr(rv, "__clause_element__") + if lvce != rvce: + return False + elif lvce and not self.compare_inner(lv, rv, **kw): + return False + elif not lvce and lv != rv: + return False + elif not self.compare_inner(lv, rv, **kw): + return False + + return True + def visit_dml_values(self, left_parent, left, right_parent, right, **kw): if left is None or right is None or len(left) != len(right): return COMPARE_FAILED - for lk in left: - lv = left[lk] + if isinstance(left, collections_abc.Sequence): + for lv, rv in zip(left, right): + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED + else: + for lk in left: + lv = left[lk] - if lk not in right: - return COMPARE_FAILED - rv = right[lk] + if lk not in right: + return COMPARE_FAILED + rv = right[lk] - if not self.compare_inner(lv, rv, **kw): - return COMPARE_FAILED + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED def visit_dml_multi_values( self, left_parent, left, right_parent, right, **kw |
