diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-02-25 17:48:30 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-02-25 17:48:30 +0000 |
commit | 60fca2ac8cf44bdaf68552ab5c69854a6776c73c (patch) | |
tree | 2b9bc005223c6c58009762dc120fccf309c1ba92 /lib/sqlalchemy/sql/compiler.py | |
parent | 2d97c388eae4345840f745337ec033045651b36d (diff) | |
parent | 0fe8f4a3e79c8fc805e7a84849920c7258177f41 (diff) | |
download | sqlalchemy-60fca2ac8cf44bdaf68552ab5c69854a6776c73c.tar.gz |
Merge "Add more nesting features to add_cte()" into main
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 148 |
1 files changed, 95 insertions, 53 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 131281a16..d0f114d6c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -31,6 +31,13 @@ import itertools import operator import re from time import perf_counter +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import MutableMapping +from typing import Optional +from typing import Tuple from . import base from . import coercions @@ -47,6 +54,12 @@ from .elements import quoted_name from .. import exc from .. import util +if typing.TYPE_CHECKING: + from .selectable import CTE + from .selectable import FromClause + +_FromHintsType = Dict["FromClause", str] + RESERVED_WORDS = set( [ "all", @@ -842,7 +855,7 @@ class SQLCompiler(Compiled): return {} @util.memoized_instancemethod - def _init_cte_state(self): + def _init_cte_state(self) -> None: """Initialize collections related to CTEs only if a CTE is located, to save on the overhead of these collections otherwise. @@ -850,19 +863,21 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # To store the query to print - Dict[cte, text_query] - self.ctes = util.OrderedDict() + self.ctes: MutableMapping[CTE, str] = util.OrderedDict() # Detect same CTE references - Dict[(level, name), cte] # Level is required for supporting nesting - self.ctes_by_level_name = {} + self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {} # To retrieve key/level in ctes_by_level_name - - # Dict[cte_reference, (level, cte_name)] - self.level_name_by_cte = {} + # Dict[cte_reference, (level, cte_name, cte_opts)] + self.level_name_by_cte: Dict[ + CTE, Tuple[int, str, selectable._CTEOpts] + ] = {} - self.ctes_recursive = False + self.ctes_recursive: bool = False if self.positional: - self.cte_positional = {} + self.cte_positional: Dict[CTE, List[str]] = {} @contextlib.contextmanager def _nested_result(self): @@ -1604,8 +1619,7 @@ class SQLCompiler(Compiled): self.stack.append(new_entry) if taf._independent_ctes: - for cte in taf._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(taf, kw) populate_result_map = ( toplevel @@ -1879,8 +1893,7 @@ class SQLCompiler(Compiled): ) if compound_stmt._independent_ctes: - for cte in compound_stmt._independent_ctes: - cte._compiler_dispatch(self, **kwargs) + self._dispatch_independent_ctes(compound_stmt, kwargs) keyword = self.compound_keywords.get(cs.keyword) @@ -2671,16 +2684,25 @@ class SQLCompiler(Compiled): return ret + def _dispatch_independent_ctes(self, stmt, kw): + local_kw = kw.copy() + local_kw.pop("cte_opts", None) + for cte, opt in zip( + stmt._independent_ctes, stmt._independent_ctes_opts + ): + cte._compiler_dispatch(self, cte_opts=opt, **local_kw) + def visit_cte( self, - cte, - asfrom=False, - ashint=False, - fromhints=None, - visiting_cte=None, - from_linter=None, - **kwargs, - ): + cte: CTE, + asfrom: bool = False, + ashint: bool = False, + fromhints: Optional[_FromHintsType] = None, + visiting_cte: Optional[CTE] = None, + from_linter: Optional[FromLinter] = None, + cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), + **kwargs: Any, + ) -> Optional[str]: self._init_cte_state() kwargs["visiting_cte"] = cte @@ -2695,15 +2717,48 @@ class SQLCompiler(Compiled): _reference_cte = cte._get_reference_cte() + nesting = cte.nesting or cte_opts.nesting + + # check for CTE already encountered if _reference_cte in self.level_name_by_cte: - cte_level, _ = self.level_name_by_cte[_reference_cte] + cte_level, _, existing_cte_opts = self.level_name_by_cte[ + _reference_cte + ] assert _ == cte_name - else: - cte_level = len(self.stack) if cte.nesting else 1 - cte_level_name = (cte_level, cte_name) - if cte_level_name in self.ctes_by_level_name: + cte_level_name = (cte_level, cte_name) existing_cte = self.ctes_by_level_name[cte_level_name] + + # check if we are receiving it here with a specific + # "nest_here" location; if so, move it to this location + + if cte_opts.nesting: + if existing_cte_opts.nesting: + raise exc.CompileError( + "CTE is stated as 'nest_here' in " + "more than one location" + ) + + old_level_name = (cte_level, cte_name) + cte_level = len(self.stack) if nesting else 1 + cte_level_name = new_level_name = (cte_level, cte_name) + + del self.ctes_by_level_name[old_level_name] + self.ctes_by_level_name[new_level_name] = existing_cte + self.level_name_by_cte[_reference_cte] = new_level_name + ( + cte_opts, + ) + + else: + cte_level = len(self.stack) if nesting else 1 + cte_level_name = (cte_level, cte_name) + + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] + else: + existing_cte = None + + if existing_cte is not None: embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, @@ -2718,10 +2773,8 @@ class SQLCompiler(Compiled): existing_cte_reference_cte = existing_cte._get_reference_cte() - # TODO: determine if these assertions are correct. they - # pass for current test cases - # assert existing_cte_reference_cte is _reference_cte - # assert existing_cte_reference_cte is existing_cte + assert existing_cte_reference_cte is _reference_cte + assert existing_cte_reference_cte is existing_cte del self.level_name_by_cte[existing_cte_reference_cte] else: @@ -2746,19 +2799,9 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_level_name[cte_level_name] = cte - self.level_name_by_cte[_reference_cte] = cte_level_name - - if ( - "autocommit" in cte.element._execution_options - and "autocommit" not in self.execution_options - ): - self.execution_options = self.execution_options.union( - { - "autocommit": cte.element._execution_options[ - "autocommit" - ] - } - ) + self.level_name_by_cte[_reference_cte] = cte_level_name + ( + cte_opts, + ) if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -3378,8 +3421,7 @@ class SQLCompiler(Compiled): byfrom = None if select_stmt._independent_ctes: - for cte in select_stmt._independent_ctes: - cte._compiler_dispatch(self, **kwargs) + self._dispatch_independent_ctes(select_stmt, kwargs) if select_stmt._prefixes: text += self._generate_prefixes( @@ -3485,7 +3527,9 @@ class SQLCompiler(Compiled): return text - def _setup_select_hints(self, select): + def _setup_select_hints( + self, select: Select + ) -> Tuple[str, _FromHintsType]: byfrom = dict( [ ( @@ -3663,13 +3707,14 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte_level, cte_name = self.level_name_by_cte[ + cte_level, cte_name, cte_opts = self.level_name_by_cte[ cte._get_reference_cte() ] + nesting = cte.nesting or cte_opts.nesting is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) - if not (cte.nesting and is_rendered_level): + if not (nesting and is_rendered_level): continue ctes[cte] = self.ctes[cte] @@ -3693,7 +3738,7 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: for cte in list(ctes.keys()): - cte_level, cte_name = self.level_name_by_cte[ + cte_level, cte_name, cte_opts = self.level_name_by_cte[ cte._get_reference_cte() ] del self.ctes[cte] @@ -3939,8 +3984,7 @@ class SQLCompiler(Compiled): _, table_text = self._setup_crud_hints(insert_stmt, table_text) if insert_stmt._independent_ctes: - for cte in insert_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(insert_stmt, kw) text += table_text @@ -4108,8 +4152,7 @@ class SQLCompiler(Compiled): dialect_hints = None if update_stmt._independent_ctes: - for cte in update_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(update_stmt, kw) text += table_text @@ -4221,8 +4264,7 @@ class SQLCompiler(Compiled): dialect_hints = None if delete_stmt._independent_ctes: - for cte in delete_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(delete_stmt, kw) text += table_text |