diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-04-01 18:31:16 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-04-01 19:25:23 -0400 |
| commit | 49b6c50016c8a038a6df7104560bb3945debe064 (patch) | |
| tree | 9b5b6b9ad6a6aba5374768afd52783fd8c2170f3 /lib/sqlalchemy/sql | |
| parent | a9b62055bfa61c11e9fe0b2984437e2c3e32bf0e (diff) | |
| download | sqlalchemy-49b6c50016c8a038a6df7104560bb3945debe064.tar.gz | |
Repair caching / traversals for values
The test suite wasn't running the copy_internals most fixtures,
enable that and try to get all cases working.
Set up selectable.values to do tuple conversion within compilation
step. at the same time, disable caching for selectable.values
for the moment and make it equivalent to dml_multi_values.
fix cache / compare / copy cases for dml_values and dml_multi_values
which weren't fully tested or covered.
Change-Id: I484ca6e9cb2b66c2e6a321698f2abc0838db1460
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 77 |
3 files changed, 54 insertions, 56 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 799fca2f5..b93ed8890 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2324,9 +2324,14 @@ class SQLCompiler(Compiled): return text def visit_values(self, element, asfrom=False, from_linter=None, **kw): + v = "VALUES %s" % ", ".join( - self.process(elem, literal_binds=element.literal_binds) - for elem in element._data + self.process( + elements.Tuple(*elem).self_group(), + literal_binds=element.literal_binds, + ) + for chunk in element._data + for elem in chunk ) if isinstance(element.name, elements._truncated_label): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e39d61fdb..a0df45b52 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -47,7 +47,6 @@ from .elements import ColumnClause from .elements import GroupedElement from .elements import Grouping from .elements import literal_column -from .elements import Tuple from .elements import UnaryExpression from .visitors import InternalTraversal from .. import exc @@ -1264,14 +1263,16 @@ class AliasedReturnsRows(NoInit, FromClause): self.element._generate_fromclause_column_proxies(self) def _copy_internals(self, clone=_clone, **kw): - element = clone(self.element, **kw) + existing_element = self.element + + super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) # the element clone is usually against a Table that returns the # same object. don't reset exported .c. collections and other - # memoized details if nothing changed - if element is not self.element: + # memoized details if it was not changed. this saves a lot on + # performance. + if existing_element is not self.element: self._reset_column_collection() - self.element = element @property def _from_objects(self): @@ -1528,15 +1529,6 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): self._suffixes = _suffixes super(CTE, self)._init(selectable, name=name) - def _copy_internals(self, clone=_clone, **kw): - super(CTE, self)._copy_internals(clone, **kw) - # TODO: I don't like that we can't use the traversal data here - if self._cte_alias is not None: - self._cte_alias = clone(self._cte_alias, **kw) - self._restates = frozenset( - [clone(elem, **kw) for elem in self._restates] - ) - def alias(self, name=None, flat=False): """Return an :class:`.Alias` of this :class:`.CTE`. @@ -2064,7 +2056,7 @@ class Values(Generative, FromClause): _traverse_internals = [ ("_column_args", InternalTraversal.dp_clauseelement_list,), - ("_data", InternalTraversal.dp_clauseelement_list), + ("_data", InternalTraversal.dp_dml_multi_values), ("name", InternalTraversal.dp_string), ("literal_binds", InternalTraversal.dp_boolean), ] @@ -2155,7 +2147,7 @@ class Values(Generative, FromClause): """ - self._data += tuple(Tuple(*row).self_group() for row in values) + self._data += (values,) def _populate_column_collection(self): for c in self._column_args: diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 9ac6cda97..032488826 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import util from ..inspection import inspect +from ..util import collections_abc from ..util import HasMemoized SKIP_TRAVERSE = util.symbol("skip_traverse") @@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal): ] def visit_dml_values(self, parent, element, clone=_clone, **kw): - # sequence of dictionaries - return [ - { - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key - ): clone(value, **kw) - for key, value in sub_element.items() - } - for sub_element in element - ] + return { + ( + clone(key, **kw) if hasattr(key, "__clause_element__") else key + ): clone(value, **kw) + for key, value in element.items() + } def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): # sequence of sequences, each sequence contains a list/dict/tuple @@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal): def copy(elem): if isinstance(elem, (list, tuple)): return [ - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key, - clone(value, **kw) - if hasattr(value, "__clause_element__") - else value, - ) - for key, value in elem + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value + for value in elem ] elif isinstance(elem, dict): return { @@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal): if hasattr(value, "__clause_element__") else value ) - for key, value in elem + for key, value in elem.items() } else: # TODO: use abc classes @@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for (lk, lv), (rk, rv) in util.zip_longest( left, right, fillvalue=(None, None) ): - lkce = hasattr(lk, "__clause_element__") - rkce = hasattr(rk, "__clause_element__") - if lkce != rkce: - return COMPARE_FAILED - elif lkce and not self.compare_inner(lk, rk, **kw): - return COMPARE_FAILED - elif not lkce and lk != rk: - return COMPARE_FAILED - elif not self.compare_inner(lv, rv, **kw): + if not self._compare_dml_values_or_ce(lk, rk, **kw): return COMPARE_FAILED + def _compare_dml_values_or_ce(self, lv, rv, **kw): + lvce = hasattr(lv, "__clause_element__") + rvce = hasattr(rv, "__clause_element__") + if lvce != rvce: + return False + elif lvce and not self.compare_inner(lv, rv, **kw): + return False + elif not lvce and lv != rv: + return False + elif not self.compare_inner(lv, rv, **kw): + return False + + return True + def visit_dml_values(self, left_parent, left, right_parent, right, **kw): if left is None or right is None or len(left) != len(right): return COMPARE_FAILED - for lk in left: - lv = left[lk] + if isinstance(left, collections_abc.Sequence): + for lv, rv in zip(left, right): + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED + else: + for lk in left: + lv = left[lk] - if lk not in right: - return COMPARE_FAILED - rv = right[lk] + if lk not in right: + return COMPARE_FAILED + rv = right[lk] - if not self.compare_inner(lv, rv, **kw): - return COMPARE_FAILED + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED def visit_dml_multi_values( self, left_parent, left, right_parent, right, **kw |
