summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/traversals.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/traversals.py')
-rw-r--r--lib/sqlalchemy/sql/traversals.py77
1 files changed, 39 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 9ac6cda97..032488826 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import util
from ..inspection import inspect
+from ..util import collections_abc
from ..util import HasMemoized
SKIP_TRAVERSE = util.symbol("skip_traverse")
@@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal):
]
def visit_dml_values(self, parent, element, clone=_clone, **kw):
- # sequence of dictionaries
- return [
- {
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key
- ): clone(value, **kw)
- for key, value in sub_element.items()
- }
- for sub_element in element
- ]
+ return {
+ (
+ clone(key, **kw) if hasattr(key, "__clause_element__") else key
+ ): clone(value, **kw)
+ for key, value in element.items()
+ }
def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
# sequence of sequences, each sequence contains a list/dict/tuple
@@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal):
def copy(elem):
if isinstance(elem, (list, tuple)):
return [
- (
- clone(key, **kw)
- if hasattr(key, "__clause_element__")
- else key,
- clone(value, **kw)
- if hasattr(value, "__clause_element__")
- else value,
- )
- for key, value in elem
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ for value in elem
]
elif isinstance(elem, dict):
return {
@@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal):
if hasattr(value, "__clause_element__")
else value
)
- for key, value in elem
+ for key, value in elem.items()
}
else:
# TODO: use abc classes
@@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for (lk, lv), (rk, rv) in util.zip_longest(
left, right, fillvalue=(None, None)
):
- lkce = hasattr(lk, "__clause_element__")
- rkce = hasattr(rk, "__clause_element__")
- if lkce != rkce:
- return COMPARE_FAILED
- elif lkce and not self.compare_inner(lk, rk, **kw):
- return COMPARE_FAILED
- elif not lkce and lk != rk:
- return COMPARE_FAILED
- elif not self.compare_inner(lv, rv, **kw):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
return COMPARE_FAILED
+ def _compare_dml_values_or_ce(self, lv, rv, **kw):
+ lvce = hasattr(lv, "__clause_element__")
+ rvce = hasattr(rv, "__clause_element__")
+ if lvce != rvce:
+ return False
+ elif lvce and not self.compare_inner(lv, rv, **kw):
+ return False
+ elif not lvce and lv != rv:
+ return False
+ elif not self.compare_inner(lv, rv, **kw):
+ return False
+
+ return True
+
def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
if left is None or right is None or len(left) != len(right):
return COMPARE_FAILED
- for lk in left:
- lv = left[lk]
+ if isinstance(left, collections_abc.Sequence):
+ for lv, rv in zip(left, right):
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+ else:
+ for lk in left:
+ lv = left[lk]
- if lk not in right:
- return COMPARE_FAILED
- rv = right[lk]
+ if lk not in right:
+ return COMPARE_FAILED
+ rv = right[lk]
- if not self.compare_inner(lv, rv, **kw):
- return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
def visit_dml_multi_values(
self, left_parent, left, right_parent, right, **kw