summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-02-25 17:48:30 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-02-25 17:48:30 +0000
commit60fca2ac8cf44bdaf68552ab5c69854a6776c73c (patch)
tree2b9bc005223c6c58009762dc120fccf309c1ba92 /lib/sqlalchemy/sql/compiler.py
parent2d97c388eae4345840f745337ec033045651b36d (diff)
parent0fe8f4a3e79c8fc805e7a84849920c7258177f41 (diff)
downloadsqlalchemy-60fca2ac8cf44bdaf68552ab5c69854a6776c73c.tar.gz
Merge "Add more nesting features to add_cte()" into main
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py148
1 files changed, 95 insertions, 53 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 131281a16..d0f114d6c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -31,6 +31,13 @@ import itertools
import operator
import re
from time import perf_counter
+import typing
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Tuple
from . import base
from . import coercions
@@ -47,6 +54,12 @@ from .elements import quoted_name
from .. import exc
from .. import util
+if typing.TYPE_CHECKING:
+ from .selectable import CTE
+ from .selectable import FromClause
+
+_FromHintsType = Dict["FromClause", str]
+
RESERVED_WORDS = set(
[
"all",
@@ -842,7 +855,7 @@ class SQLCompiler(Compiled):
return {}
@util.memoized_instancemethod
- def _init_cte_state(self):
+ def _init_cte_state(self) -> None:
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
@@ -850,19 +863,21 @@ class SQLCompiler(Compiled):
"""
# collect CTEs to tack on top of a SELECT
# To store the query to print - Dict[cte, text_query]
- self.ctes = util.OrderedDict()
+ self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
# Detect same CTE references - Dict[(level, name), cte]
# Level is required for supporting nesting
- self.ctes_by_level_name = {}
+ self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
# To retrieve key/level in ctes_by_level_name -
- # Dict[cte_reference, (level, cte_name)]
- self.level_name_by_cte = {}
+ # Dict[cte_reference, (level, cte_name, cte_opts)]
+ self.level_name_by_cte: Dict[
+ CTE, Tuple[int, str, selectable._CTEOpts]
+ ] = {}
- self.ctes_recursive = False
+ self.ctes_recursive: bool = False
if self.positional:
- self.cte_positional = {}
+ self.cte_positional: Dict[CTE, List[str]] = {}
@contextlib.contextmanager
def _nested_result(self):
@@ -1604,8 +1619,7 @@ class SQLCompiler(Compiled):
self.stack.append(new_entry)
if taf._independent_ctes:
- for cte in taf._independent_ctes:
- cte._compiler_dispatch(self, **kw)
+ self._dispatch_independent_ctes(taf, kw)
populate_result_map = (
toplevel
@@ -1879,8 +1893,7 @@ class SQLCompiler(Compiled):
)
if compound_stmt._independent_ctes:
- for cte in compound_stmt._independent_ctes:
- cte._compiler_dispatch(self, **kwargs)
+ self._dispatch_independent_ctes(compound_stmt, kwargs)
keyword = self.compound_keywords.get(cs.keyword)
@@ -2671,16 +2684,25 @@ class SQLCompiler(Compiled):
return ret
+ def _dispatch_independent_ctes(self, stmt, kw):
+ local_kw = kw.copy()
+ local_kw.pop("cte_opts", None)
+ for cte, opt in zip(
+ stmt._independent_ctes, stmt._independent_ctes_opts
+ ):
+ cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
+
def visit_cte(
self,
- cte,
- asfrom=False,
- ashint=False,
- fromhints=None,
- visiting_cte=None,
- from_linter=None,
- **kwargs,
- ):
+ cte: CTE,
+ asfrom: bool = False,
+ ashint: bool = False,
+ fromhints: Optional[_FromHintsType] = None,
+ visiting_cte: Optional[CTE] = None,
+ from_linter: Optional[FromLinter] = None,
+ cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
+ **kwargs: Any,
+ ) -> Optional[str]:
self._init_cte_state()
kwargs["visiting_cte"] = cte
@@ -2695,15 +2717,48 @@ class SQLCompiler(Compiled):
_reference_cte = cte._get_reference_cte()
+ nesting = cte.nesting or cte_opts.nesting
+
+ # check for CTE already encountered
if _reference_cte in self.level_name_by_cte:
- cte_level, _ = self.level_name_by_cte[_reference_cte]
+ cte_level, _, existing_cte_opts = self.level_name_by_cte[
+ _reference_cte
+ ]
assert _ == cte_name
- else:
- cte_level = len(self.stack) if cte.nesting else 1
- cte_level_name = (cte_level, cte_name)
- if cte_level_name in self.ctes_by_level_name:
+ cte_level_name = (cte_level, cte_name)
existing_cte = self.ctes_by_level_name[cte_level_name]
+
+ # check if we are receiving it here with a specific
+ # "nest_here" location; if so, move it to this location
+
+ if cte_opts.nesting:
+ if existing_cte_opts.nesting:
+ raise exc.CompileError(
+ "CTE is stated as 'nest_here' in "
+ "more than one location"
+ )
+
+ old_level_name = (cte_level, cte_name)
+ cte_level = len(self.stack) if nesting else 1
+ cte_level_name = new_level_name = (cte_level, cte_name)
+
+ del self.ctes_by_level_name[old_level_name]
+ self.ctes_by_level_name[new_level_name] = existing_cte
+ self.level_name_by_cte[_reference_cte] = new_level_name + (
+ cte_opts,
+ )
+
+ else:
+ cte_level = len(self.stack) if nesting else 1
+ cte_level_name = (cte_level, cte_name)
+
+ if cte_level_name in self.ctes_by_level_name:
+ existing_cte = self.ctes_by_level_name[cte_level_name]
+ else:
+ existing_cte = None
+
+ if existing_cte is not None:
embedded_in_current_named_cte = visiting_cte is existing_cte
# we've generated a same-named CTE that we are enclosed in,
@@ -2718,10 +2773,8 @@ class SQLCompiler(Compiled):
existing_cte_reference_cte = existing_cte._get_reference_cte()
- # TODO: determine if these assertions are correct. they
- # pass for current test cases
- # assert existing_cte_reference_cte is _reference_cte
- # assert existing_cte_reference_cte is existing_cte
+ assert existing_cte_reference_cte is _reference_cte
+ assert existing_cte_reference_cte is existing_cte
del self.level_name_by_cte[existing_cte_reference_cte]
else:
@@ -2746,19 +2799,9 @@ class SQLCompiler(Compiled):
if is_new_cte:
self.ctes_by_level_name[cte_level_name] = cte
- self.level_name_by_cte[_reference_cte] = cte_level_name
-
- if (
- "autocommit" in cte.element._execution_options
- and "autocommit" not in self.execution_options
- ):
- self.execution_options = self.execution_options.union(
- {
- "autocommit": cte.element._execution_options[
- "autocommit"
- ]
- }
- )
+ self.level_name_by_cte[_reference_cte] = cte_level_name + (
+ cte_opts,
+ )
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -3378,8 +3421,7 @@ class SQLCompiler(Compiled):
byfrom = None
if select_stmt._independent_ctes:
- for cte in select_stmt._independent_ctes:
- cte._compiler_dispatch(self, **kwargs)
+ self._dispatch_independent_ctes(select_stmt, kwargs)
if select_stmt._prefixes:
text += self._generate_prefixes(
@@ -3485,7 +3527,9 @@ class SQLCompiler(Compiled):
return text
- def _setup_select_hints(self, select):
+ def _setup_select_hints(
+ self, select: Select
+ ) -> Tuple[str, _FromHintsType]:
byfrom = dict(
[
(
@@ -3663,13 +3707,14 @@ class SQLCompiler(Compiled):
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
- cte_level, cte_name = self.level_name_by_cte[
+ cte_level, cte_name, cte_opts = self.level_name_by_cte[
cte._get_reference_cte()
]
+ nesting = cte.nesting or cte_opts.nesting
is_rendered_level = cte_level == nesting_level or (
include_following_stack and cte_level == nesting_level + 1
)
- if not (cte.nesting and is_rendered_level):
+ if not (nesting and is_rendered_level):
continue
ctes[cte] = self.ctes[cte]
@@ -3693,7 +3738,7 @@ class SQLCompiler(Compiled):
if nesting_level and nesting_level > 1:
for cte in list(ctes.keys()):
- cte_level, cte_name = self.level_name_by_cte[
+ cte_level, cte_name, cte_opts = self.level_name_by_cte[
cte._get_reference_cte()
]
del self.ctes[cte]
@@ -3939,8 +3984,7 @@ class SQLCompiler(Compiled):
_, table_text = self._setup_crud_hints(insert_stmt, table_text)
if insert_stmt._independent_ctes:
- for cte in insert_stmt._independent_ctes:
- cte._compiler_dispatch(self, **kw)
+ self._dispatch_independent_ctes(insert_stmt, kw)
text += table_text
@@ -4108,8 +4152,7 @@ class SQLCompiler(Compiled):
dialect_hints = None
if update_stmt._independent_ctes:
- for cte in update_stmt._independent_ctes:
- cte._compiler_dispatch(self, **kw)
+ self._dispatch_independent_ctes(update_stmt, kw)
text += table_text
@@ -4221,8 +4264,7 @@ class SQLCompiler(Compiled):
dialect_hints = None
if delete_stmt._independent_ctes:
- for cte in delete_stmt._independent_ctes:
- cte._compiler_dispatch(self, **kw)
+ self._dispatch_independent_ctes(delete_stmt, kw)
text += table_text