summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-03-11 18:08:03 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-03-11 18:08:03 +0000
commit2aa9a8043b4982d4d7b53e8b11371ea27fccd09c (patch)
tree0fc1f0ddd3a6defdda5888ee48bd4e69bd162c4c /lib/sqlalchemy/sql/compiler.py
parent59ca6e5fcc6974ea1fac82d05157aa58e550b332 (diff)
parent693938dd6fb2f3ee3e031aed4c62355ac97f3ceb (diff)
downloadsqlalchemy-2aa9a8043b4982d4d7b53e8b11371ea27fccd09c.tar.gz
Merge "Rework select(), CompoundSelect() in terms of CompileState"
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py183
1 files changed, 131 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b37c46216..1f183b5c1 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -28,6 +28,7 @@ import contextlib
import itertools
import re
+from . import base
from . import coercions
from . import crud
from . import elements
@@ -1046,9 +1047,13 @@ class SQLCompiler(Compiled):
self, element, within_columns_clause=False, **kwargs
):
if self.stack and self.dialect.supports_simple_order_by_label:
- selectable = self.stack[-1]["selectable"]
+ compile_state = self.stack[-1]["compile_state"]
- with_cols, only_froms, only_cols = selectable._label_resolve_dict
+ (
+ with_cols,
+ only_froms,
+ only_cols,
+ ) = compile_state._label_resolve_dict
if within_columns_clause:
resolve_dict = only_froms
else:
@@ -1083,8 +1088,8 @@ class SQLCompiler(Compiled):
# compiling the element outside of the context of a SELECT
return self.process(element._text_clause)
- selectable = self.stack[-1]["selectable"]
- with_cols, only_froms, only_cols = selectable._label_resolve_dict
+ compile_state = self.stack[-1]["compile_state"]
+ with_cols, only_froms, only_cols = compile_state._label_resolve_dict
try:
if within_columns_clause:
col = only_froms[element.element]
@@ -1314,6 +1319,24 @@ class SQLCompiler(Compiled):
if s
)
+ def _generate_delimited_and_list(self, clauses, **kw):
+
+ lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
+ operators.and_,
+ elements.True_._singleton,
+ elements.False_._singleton,
+ clauses,
+ )
+ if lcc == 1:
+ return clauses[0]._compiler_dispatch(self, **kw)
+ else:
+ separator = OPERATORS[operators.and_]
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in clauses)
+ if s
+ )
+
def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
@@ -1474,6 +1497,12 @@ class SQLCompiler(Compiled):
self, cs, asfrom=False, compound_index=0, **kwargs
):
toplevel = not self.stack
+
+ compile_state = cs._compile_state_factory(cs, self, **kwargs)
+
+ if toplevel:
+ self.compile_state = compile_state
+
entry = self._default_stack_entry if toplevel else self.stack[-1]
need_result_map = toplevel or (
compound_index == 0
@@ -1485,6 +1514,7 @@ class SQLCompiler(Compiled):
"correlate_froms": entry["correlate_froms"],
"asfrom_froms": entry["asfrom_froms"],
"selectable": cs,
+ "compile_state": compile_state,
"need_result_map_for_compound": need_result_map,
}
)
@@ -1666,7 +1696,6 @@ class SQLCompiler(Compiled):
from_linter=None,
**kw
):
-
if from_linter and operators.is_comparison(binary.operator):
from_linter.edges.update(
itertools.product(
@@ -2274,7 +2303,6 @@ class SQLCompiler(Compiled):
need_column_expressions=False,
):
"""produce labeled columns present in a select()."""
-
impl = column.type.dialect_impl(self.dialect)
if impl._has_column_expression and (
@@ -2350,7 +2378,12 @@ class SQLCompiler(Compiled):
or isinstance(column, functions.FunctionElement)
)
):
- result_expr = _CompileLabel(col_expr, column.anon_label)
+ result_expr = _CompileLabel(
+ col_expr,
+ column.anon_label
+ if not column_is_repeated
+ else column._dedupe_label_anon_label,
+ )
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
@@ -2390,7 +2423,9 @@ class SQLCompiler(Compiled):
[("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
)
- def _display_froms_for_select(self, select, asfrom, lateral=False):
+ def _display_froms_for_select(
+ self, select_stmt, asfrom, lateral=False, **kw
+ ):
# utility method to help external dialects
# get the correct from list for a select.
# specifically the oracle dialect needs this feature
@@ -2398,18 +2433,20 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
+ compile_state = select_stmt._compile_state_factory(select_stmt, self)
+
correlate_froms = entry["correlate_froms"]
asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
asfrom_froms
),
implicit_correlate_froms=(),
)
else:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms,
)
@@ -2417,7 +2454,7 @@ class SQLCompiler(Compiled):
def visit_select(
self,
- select,
+ select_stmt,
asfrom=False,
fromhints=None,
compound_index=0,
@@ -2427,7 +2464,16 @@ class SQLCompiler(Compiled):
**kwargs
):
+ compile_state = select_stmt._compile_state_factory(
+ select_stmt, self, **kwargs
+ )
+ select_stmt = compile_state.statement
+
toplevel = not self.stack
+
+ if toplevel:
+ self.compile_state = compile_state
+
entry = self._default_stack_entry if toplevel else self.stack[-1]
populate_result_map = need_column_expressions = (
@@ -2446,7 +2492,7 @@ class SQLCompiler(Compiled):
del kwargs["add_to_result_map"]
froms = self._setup_select_stack(
- select, entry, asfrom, lateral, compound_index
+ select_stmt, compile_state, entry, asfrom, lateral, compound_index
)
column_clause_args = kwargs.copy()
@@ -2456,23 +2502,25 @@ class SQLCompiler(Compiled):
text = "SELECT " # we're off to a good start !
- if select._hints:
- hint_text, byfrom = self._setup_select_hints(select)
+ if select_stmt._hints:
+ hint_text, byfrom = self._setup_select_hints(select_stmt)
if hint_text:
text += hint_text + " "
else:
byfrom = None
- if select._prefixes:
- text += self._generate_prefixes(select, select._prefixes, **kwargs)
+ if select_stmt._prefixes:
+ text += self._generate_prefixes(
+ select_stmt, select_stmt._prefixes, **kwargs
+ )
- text += self.get_select_precolumns(select, **kwargs)
+ text += self.get_select_precolumns(select_stmt, **kwargs)
# the actual list of columns to print in the SELECT column list.
inner_columns = [
c
for c in [
self._label_select_column(
- select,
+ select_stmt,
column,
populate_result_map,
asfrom,
@@ -2481,7 +2529,7 @@ class SQLCompiler(Compiled):
column_is_repeated=repeated,
need_column_expressions=need_column_expressions,
)
- for name, column, repeated in select._columns_plus_names
+ for name, column, repeated in compile_state.columns_plus_names
]
if c is not None
]
@@ -2490,11 +2538,19 @@ class SQLCompiler(Compiled):
# if this select is a compiler-generated wrapper,
# rewrite the targeted columns in the result map
+ compile_state_wraps_for = select_wraps_for._compile_state_factory(
+ select_wraps_for, self, **kwargs
+ )
+
translate = dict(
zip(
[
name
- for (key, name, repeated) in select._columns_plus_names
+ for (
+ key,
+ name,
+ repeated,
+ ) in compile_state.columns_plus_names
],
[
name
@@ -2502,7 +2558,7 @@ class SQLCompiler(Compiled):
key,
name,
repeated,
- ) in select_wraps_for._columns_plus_names
+ ) in compile_state_wraps_for.columns_plus_names
],
)
)
@@ -2513,13 +2569,20 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, toplevel, kwargs
+ text,
+ select_stmt,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
)
- if select._statement_hints:
+ if select_stmt._statement_hints:
per_dialect = [
ht
- for (dialect_name, ht) in select._statement_hints
+ for (dialect_name, ht) in select_stmt._statement_hints
if dialect_name in ("*", self.dialect.name)
]
if per_dialect:
@@ -2528,9 +2591,9 @@ class SQLCompiler(Compiled):
if self.ctes and toplevel:
text = self._render_cte_clause() + text
- if select._suffixes:
+ if select_stmt._suffixes:
text += " " + self._generate_prefixes(
- select, select._suffixes, **kwargs
+ select_stmt, select_stmt._suffixes, **kwargs
)
self.stack.pop(-1)
@@ -2553,7 +2616,7 @@ class SQLCompiler(Compiled):
return hint_text, byfrom
def _setup_select_stack(
- self, select, entry, asfrom, lateral, compound_index
+ self, select, compile_state, entry, asfrom, lateral, compound_index
):
correlate_froms = entry["correlate_froms"]
asfrom_froms = entry["asfrom_froms"]
@@ -2564,8 +2627,8 @@ class SQLCompiler(Compiled):
if select_0._is_select_container:
select_0 = select_0.element
numcols = len(select_0.selected_columns)
- # numcols = len(select_0._columns_plus_names)
- if len(select._columns_plus_names) != numcols:
+
+ if len(compile_state.columns_plus_names) != numcols:
raise exc.CompileError(
"All selectables passed to "
"CompoundSelect must have identical numbers of "
@@ -2580,14 +2643,14 @@ class SQLCompiler(Compiled):
)
if asfrom and not lateral:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
asfrom_froms
),
implicit_correlate_froms=(),
)
else:
- froms = select._get_display_froms(
+ froms = compile_state._get_display_froms(
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms,
)
@@ -2599,13 +2662,22 @@ class SQLCompiler(Compiled):
"asfrom_froms": new_correlate_froms,
"correlate_froms": all_correlate_froms,
"selectable": select,
+ "compile_state": compile_state,
}
self.stack.append(new_entry)
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, toplevel, kwargs
+ self,
+ text,
+ select,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
):
text += ", ".join(inner_columns)
@@ -2647,9 +2719,9 @@ class SQLCompiler(Compiled):
else:
text += self.default_from()
- if select._whereclause is not None:
- t = select._whereclause._compiler_dispatch(
- self, from_linter=from_linter, **kwargs
+ if select._where_criteria:
+ t = self._generate_delimited_and_list(
+ select._where_criteria, from_linter=from_linter, **kwargs
)
if t:
text += " \nWHERE " + t
@@ -2660,15 +2732,17 @@ class SQLCompiler(Compiled):
):
from_linter.warn()
- if select._group_by_clause.clauses:
+ if select._group_by_clauses:
text += self.group_by_clause(select, **kwargs)
- if select._having is not None:
- t = select._having._compiler_dispatch(self, **kwargs)
+ if select._having_criteria:
+ t = self._generate_delimited_and_list(
+ select._having_criteria, **kwargs
+ )
if t:
text += " \nHAVING " + t
- if select._order_by_clause.clauses:
+ if select._order_by_clauses:
text += self.order_by_clause(select, **kwargs)
if (
@@ -2719,7 +2793,9 @@ class SQLCompiler(Compiled):
def group_by_clause(self, select, **kw):
"""allow dialects to customize how GROUP BY is rendered."""
- group_by = select._group_by_clause._compiler_dispatch(self, **kw)
+ group_by = self._generate_delimited_list(
+ select._group_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
if group_by:
return " GROUP BY " + group_by
else:
@@ -2728,7 +2804,10 @@ class SQLCompiler(Compiled):
def order_by_clause(self, select, **kw):
"""allow dialects to customize how ORDER BY is rendered."""
- order_by = select._order_by_clause._compiler_dispatch(self, **kw)
+ order_by = self._generate_delimited_list(
+ select._order_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
+
if order_by:
return " ORDER BY " + order_by
else:
@@ -2827,8 +2906,8 @@ class SQLCompiler(Compiled):
def visit_insert(self, insert_stmt, **kw):
- compile_state = insert_stmt._compile_state_cls(
- insert_stmt, self, isinsert=True, **kw
+ compile_state = insert_stmt._compile_state_factory(
+ insert_stmt, self, **kw
)
insert_stmt = compile_state.statement
@@ -2973,8 +3052,8 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
- compile_state = update_stmt._compile_state_cls(
- update_stmt, self, isupdate=True, **kw
+ compile_state = update_stmt._compile_state_factory(
+ update_stmt, self, **kw
)
update_stmt = compile_state.statement
@@ -3056,8 +3135,8 @@ class SQLCompiler(Compiled):
text += " " + extra_from_text
if update_stmt._where_criteria:
- t = self._generate_delimited_list(
- update_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ t = self._generate_delimited_and_list(
+ update_stmt._where_criteria, **kw
)
if t:
text += " WHERE " + t
@@ -3100,8 +3179,8 @@ class SQLCompiler(Compiled):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, **kw):
- compile_state = delete_stmt._compile_state_cls(
- delete_stmt, self, isdelete=True, **kw
+ compile_state = delete_stmt._compile_state_factory(
+ delete_stmt, self, **kw
)
delete_stmt = compile_state.statement
@@ -3159,8 +3238,8 @@ class SQLCompiler(Compiled):
text += " " + extra_from_text
if delete_stmt._where_criteria:
- t = self._generate_delimited_list(
- delete_stmt._where_criteria, OPERATORS[operators.and_], **kw
+ t = self._generate_delimited_and_list(
+ delete_stmt._where_criteria, **kw
)
if t:
text += " WHERE " + t
@@ -3230,7 +3309,7 @@ class StrSQLCompiler(SQLCompiler):
def returning_clause(self, stmt, returning_cols):
columns = [
self._label_select_column(None, c, True, False, {})
- for c in elements._select_iterables(returning_cols)
+ for c in base._select_iterables(returning_cols)
]
return "RETURNING " + ", ".join(columns)