diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 92 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 18 |
2 files changed, 72 insertions, 38 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 333ed36f4..efcfe0e51 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -840,10 +840,17 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT + # To store the query to print - Dict[cte, text_query] self.ctes = util.OrderedDict() - # Detect same CTE references - self.ctes_by_name = {} - self.level_by_ctes = {} + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + self.ctes_by_level_name = {} + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name)] + self.level_name_by_cte = {} + self.ctes_recursive = False if self.positional: self.cte_positional = {} @@ -2515,8 +2522,6 @@ class SQLCompiler(Compiled): ): self._init_cte_state() - cte_level = len(self.stack) if cte.nesting else 1 - kwargs["visiting_cte"] = cte cte_name = cte.name @@ -2527,44 +2532,60 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - if cte in self.level_by_ctes: - cte_level = self.level_by_ctes[cte] + _reference_cte = cte._get_reference_cte() + + if _reference_cte in self.level_name_by_cte: + cte_level, _ = self.level_name_by_cte[_reference_cte] + assert _ == cte_name + else: + cte_level = len(self.stack) if cte.nesting else 1 cte_level_name = (cte_level, cte_name) - if cte_level_name in self.ctes_by_name: - existing_cte = self.ctes_by_name[cte_level_name] + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, # or this is the same CTE. just return the name. - if cte in existing_cte._restates or cte is existing_cte: + if cte is existing_cte._restates or cte is existing_cte: is_new_cte = False - elif existing_cte in cte._restates: + elif existing_cte is cte._restates: # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] - del self.level_by_ctes[existing_cte] + + existing_cte_reference_cte = existing_cte._get_reference_cte() + + # TODO: determine if these assertions are correct. they + # pass for current test cases + # assert existing_cte_reference_cte is _reference_cte + # assert existing_cte_reference_cte is existing_cte + + del self.level_name_by_cte[existing_cte_reference_cte] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " "the same name: %r" % cte_name ) - if asfrom or is_new_cte: - if cte._cte_alias is not None: - pre_alias_cte = cte._cte_alias - cte_pre_alias_name = cte._cte_alias.name - if isinstance(cte_pre_alias_name, elements._truncated_label): - cte_pre_alias_name = self._truncated_identifier( - "alias", cte_pre_alias_name - ) - else: - pre_alias_cte = cte - cte_pre_alias_name = None + if not asfrom and not is_new_cte: + return None + + if cte._cte_alias is not None: + pre_alias_cte = cte._cte_alias + cte_pre_alias_name = cte._cte_alias.name + if isinstance(cte_pre_alias_name, elements._truncated_label): + cte_pre_alias_name = self._truncated_identifier( + "alias", cte_pre_alias_name + ) + else: + pre_alias_cte = cte + cte_pre_alias_name = None if is_new_cte: - self.ctes_by_name[cte_level_name] = cte + self.ctes_by_level_name[cte_level_name] = cte + self.level_name_by_cte[_reference_cte] = cte_level_name if ( "autocommit" in cte.element._execution_options @@ -2649,7 +2670,6 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text - self.level_by_ctes[cte] = cte_level if asfrom: if from_linter: @@ -3475,7 +3495,9 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte_level = self.level_by_ctes[cte] + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) @@ -3484,14 +3506,6 @@ class SQLCompiler(Compiled): ctes[cte] = self.ctes[cte] - del self.ctes[cte] - del self.level_by_ctes[cte] - - cte_name = cte.name - if isinstance(cte_name, elements._truncated_label): - cte_name = self._truncated_identifier("alias", cte_name) - - del self.ctes_by_name[(cte_level, cte_name)] else: ctes = self.ctes @@ -3508,6 +3522,16 @@ class SQLCompiler(Compiled): cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " + + if nesting_level and nesting_level > 1: + for cte in list(ctes.keys()): + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] + del self.ctes[cte] + del self.ctes_by_level_name[(cte_level, cte_name)] + del self.level_name_by_cte[cte._get_reference_cte()] + return cte_text def get_cte_preamble(self, recursive): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8e71dfb97..616df0d05 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2049,8 +2049,9 @@ class CTE( AliasedReturnsRows._traverse_internals + [ ("_cte_alias", InternalTraversal.dp_clauseelement), - ("_restates", InternalTraversal.dp_clauseelement_list), + ("_restates", InternalTraversal.dp_clauseelement), ("recursive", InternalTraversal.dp_boolean), + ("nesting", InternalTraversal.dp_boolean), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals @@ -2075,13 +2076,14 @@ class CTE( recursive=False, nesting=False, _cte_alias=None, - _restates=(), + _restates=None, _prefixes=None, _suffixes=None, ): self.recursive = recursive self.nesting = nesting self._cte_alias = _cte_alias + # Keep recursivity reference with union/union_all self._restates = _restates if _prefixes: self._prefixes = _prefixes @@ -2125,7 +2127,7 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, - _restates=self._restates + (self,), + _restates=self, _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -2136,11 +2138,19 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, - _restates=self._restates + (self,), + _restates=self, _prefixes=self._prefixes, _suffixes=self._suffixes, ) + def _get_reference_cte(self): + """ + A recursive CTE is updated to attach the recursive part. + Updated CTEs should still refer to the original CTE. + This function returns this reference identifier. + """ + return self._restates if self._restates is not None else self + class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. |
