diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-03-11 18:08:03 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-03-11 18:08:03 +0000 |
commit | 2aa9a8043b4982d4d7b53e8b11371ea27fccd09c (patch) | |
tree | 0fc1f0ddd3a6defdda5888ee48bd4e69bd162c4c /lib/sqlalchemy/sql/compiler.py | |
parent | 59ca6e5fcc6974ea1fac82d05157aa58e550b332 (diff) | |
parent | 693938dd6fb2f3ee3e031aed4c62355ac97f3ceb (diff) | |
download | sqlalchemy-2aa9a8043b4982d4d7b53e8b11371ea27fccd09c.tar.gz |
Merge "Rework select(), CompoundSelect() in terms of CompileState"
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 183 |
1 files changed, 131 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b37c46216..1f183b5c1 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ import contextlib import itertools import re +from . import base from . import coercions from . import crud from . import elements @@ -1046,9 +1047,13 @@ class SQLCompiler(Compiled): self, element, within_columns_clause=False, **kwargs ): if self.stack and self.dialect.supports_simple_order_by_label: - selectable = self.stack[-1]["selectable"] + compile_state = self.stack[-1]["compile_state"] - with_cols, only_froms, only_cols = selectable._label_resolve_dict + ( + with_cols, + only_froms, + only_cols, + ) = compile_state._label_resolve_dict if within_columns_clause: resolve_dict = only_froms else: @@ -1083,8 +1088,8 @@ class SQLCompiler(Compiled): # compiling the element outside of the context of a SELECT return self.process(element._text_clause) - selectable = self.stack[-1]["selectable"] - with_cols, only_froms, only_cols = selectable._label_resolve_dict + compile_state = self.stack[-1]["compile_state"] + with_cols, only_froms, only_cols = compile_state._label_resolve_dict try: if within_columns_clause: col = only_froms[element.element] @@ -1314,6 +1319,24 @@ class SQLCompiler(Compiled): if s ) + def _generate_delimited_and_list(self, clauses, **kw): + + lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean( + operators.and_, + elements.True_._singleton, + elements.False_._singleton, + clauses, + ) + if lcc == 1: + return clauses[0]._compiler_dispatch(self, **kw) + else: + separator = OPERATORS[operators.and_] + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in clauses) + if s + ) + def visit_clauselist(self, clauselist, **kw): sep = clauselist.operator if sep is None: @@ -1474,6 +1497,12 @@ class SQLCompiler(Compiled): self, cs, asfrom=False, compound_index=0, **kwargs ): toplevel = not self.stack + + compile_state = cs._compile_state_factory(cs, self, **kwargs) + + if toplevel: + self.compile_state = compile_state + entry = self._default_stack_entry if toplevel else self.stack[-1] need_result_map = toplevel or ( compound_index == 0 @@ -1485,6 +1514,7 @@ class SQLCompiler(Compiled): "correlate_froms": entry["correlate_froms"], "asfrom_froms": entry["asfrom_froms"], "selectable": cs, + "compile_state": compile_state, "need_result_map_for_compound": need_result_map, } ) @@ -1666,7 +1696,6 @@ class SQLCompiler(Compiled): from_linter=None, **kw ): - if from_linter and operators.is_comparison(binary.operator): from_linter.edges.update( itertools.product( @@ -2274,7 +2303,6 @@ class SQLCompiler(Compiled): need_column_expressions=False, ): """produce labeled columns present in a select().""" - impl = column.type.dialect_impl(self.dialect) if impl._has_column_expression and ( @@ -2350,7 +2378,12 @@ class SQLCompiler(Compiled): or isinstance(column, functions.FunctionElement) ) ): - result_expr = _CompileLabel(col_expr, column.anon_label) + result_expr = _CompileLabel( + col_expr, + column.anon_label + if not column_is_repeated + else column._dedupe_label_anon_label, + ) elif col_expr is not column: # TODO: are we sure "column" has a .name and .key here ? # assert isinstance(column, elements.ColumnClause) @@ -2390,7 +2423,9 @@ class SQLCompiler(Compiled): [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] ) - def _display_froms_for_select(self, select, asfrom, lateral=False): + def _display_froms_for_select( + self, select_stmt, asfrom, lateral=False, **kw + ): # utility method to help external dialects # get the correct from list for a select. # specifically the oracle dialect needs this feature @@ -2398,18 +2433,20 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] + compile_state = select_stmt._compile_state_factory(select_stmt, self) + correlate_froms = entry["correlate_froms"] asfrom_froms = entry["asfrom_froms"] if asfrom and not lateral: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms.difference( asfrom_froms ), implicit_correlate_froms=(), ) else: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms, ) @@ -2417,7 +2454,7 @@ class SQLCompiler(Compiled): def visit_select( self, - select, + select_stmt, asfrom=False, fromhints=None, compound_index=0, @@ -2427,7 +2464,16 @@ class SQLCompiler(Compiled): **kwargs ): + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + toplevel = not self.stack + + if toplevel: + self.compile_state = compile_state + entry = self._default_stack_entry if toplevel else self.stack[-1] populate_result_map = need_column_expressions = ( @@ -2446,7 +2492,7 @@ class SQLCompiler(Compiled): del kwargs["add_to_result_map"] froms = self._setup_select_stack( - select, entry, asfrom, lateral, compound_index + select_stmt, compile_state, entry, asfrom, lateral, compound_index ) column_clause_args = kwargs.copy() @@ -2456,23 +2502,25 @@ class SQLCompiler(Compiled): text = "SELECT " # we're off to a good start ! - if select._hints: - hint_text, byfrom = self._setup_select_hints(select) + if select_stmt._hints: + hint_text, byfrom = self._setup_select_hints(select_stmt) if hint_text: text += hint_text + " " else: byfrom = None - if select._prefixes: - text += self._generate_prefixes(select, select._prefixes, **kwargs) + if select_stmt._prefixes: + text += self._generate_prefixes( + select_stmt, select_stmt._prefixes, **kwargs + ) - text += self.get_select_precolumns(select, **kwargs) + text += self.get_select_precolumns(select_stmt, **kwargs) # the actual list of columns to print in the SELECT column list. inner_columns = [ c for c in [ self._label_select_column( - select, + select_stmt, column, populate_result_map, asfrom, @@ -2481,7 +2529,7 @@ class SQLCompiler(Compiled): column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column, repeated in select._columns_plus_names + for name, column, repeated in compile_state.columns_plus_names ] if c is not None ] @@ -2490,11 +2538,19 @@ class SQLCompiler(Compiled): # if this select is a compiler-generated wrapper, # rewrite the targeted columns in the result map + compile_state_wraps_for = select_wraps_for._compile_state_factory( + select_wraps_for, self, **kwargs + ) + translate = dict( zip( [ name - for (key, name, repeated) in select._columns_plus_names + for ( + key, + name, + repeated, + ) in compile_state.columns_plus_names ], [ name @@ -2502,7 +2558,7 @@ class SQLCompiler(Compiled): key, name, repeated, - ) in select_wraps_for._columns_plus_names + ) in compile_state_wraps_for.columns_plus_names ], ) ) @@ -2513,13 +2569,20 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, toplevel, kwargs + text, + select_stmt, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, ) - if select._statement_hints: + if select_stmt._statement_hints: per_dialect = [ ht - for (dialect_name, ht) in select._statement_hints + for (dialect_name, ht) in select_stmt._statement_hints if dialect_name in ("*", self.dialect.name) ] if per_dialect: @@ -2528,9 +2591,9 @@ class SQLCompiler(Compiled): if self.ctes and toplevel: text = self._render_cte_clause() + text - if select._suffixes: + if select_stmt._suffixes: text += " " + self._generate_prefixes( - select, select._suffixes, **kwargs + select_stmt, select_stmt._suffixes, **kwargs ) self.stack.pop(-1) @@ -2553,7 +2616,7 @@ class SQLCompiler(Compiled): return hint_text, byfrom def _setup_select_stack( - self, select, entry, asfrom, lateral, compound_index + self, select, compile_state, entry, asfrom, lateral, compound_index ): correlate_froms = entry["correlate_froms"] asfrom_froms = entry["asfrom_froms"] @@ -2564,8 +2627,8 @@ class SQLCompiler(Compiled): if select_0._is_select_container: select_0 = select_0.element numcols = len(select_0.selected_columns) - # numcols = len(select_0._columns_plus_names) - if len(select._columns_plus_names) != numcols: + + if len(compile_state.columns_plus_names) != numcols: raise exc.CompileError( "All selectables passed to " "CompoundSelect must have identical numbers of " @@ -2580,14 +2643,14 @@ class SQLCompiler(Compiled): ) if asfrom and not lateral: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms.difference( asfrom_froms ), implicit_correlate_froms=(), ) else: - froms = select._get_display_froms( + froms = compile_state._get_display_froms( explicit_correlate_froms=correlate_froms, implicit_correlate_froms=asfrom_froms, ) @@ -2599,13 +2662,22 @@ class SQLCompiler(Compiled): "asfrom_froms": new_correlate_froms, "correlate_froms": all_correlate_froms, "selectable": select, + "compile_state": compile_state, } self.stack.append(new_entry) return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, toplevel, kwargs + self, + text, + select, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, ): text += ", ".join(inner_columns) @@ -2647,9 +2719,9 @@ class SQLCompiler(Compiled): else: text += self.default_from() - if select._whereclause is not None: - t = select._whereclause._compiler_dispatch( - self, from_linter=from_linter, **kwargs + if select._where_criteria: + t = self._generate_delimited_and_list( + select._where_criteria, from_linter=from_linter, **kwargs ) if t: text += " \nWHERE " + t @@ -2660,15 +2732,17 @@ class SQLCompiler(Compiled): ): from_linter.warn() - if select._group_by_clause.clauses: + if select._group_by_clauses: text += self.group_by_clause(select, **kwargs) - if select._having is not None: - t = select._having._compiler_dispatch(self, **kwargs) + if select._having_criteria: + t = self._generate_delimited_and_list( + select._having_criteria, **kwargs + ) if t: text += " \nHAVING " + t - if select._order_by_clause.clauses: + if select._order_by_clauses: text += self.order_by_clause(select, **kwargs) if ( @@ -2719,7 +2793,9 @@ class SQLCompiler(Compiled): def group_by_clause(self, select, **kw): """allow dialects to customize how GROUP BY is rendered.""" - group_by = select._group_by_clause._compiler_dispatch(self, **kw) + group_by = self._generate_delimited_list( + select._group_by_clauses, OPERATORS[operators.comma_op], **kw + ) if group_by: return " GROUP BY " + group_by else: @@ -2728,7 +2804,10 @@ class SQLCompiler(Compiled): def order_by_clause(self, select, **kw): """allow dialects to customize how ORDER BY is rendered.""" - order_by = select._order_by_clause._compiler_dispatch(self, **kw) + order_by = self._generate_delimited_list( + select._order_by_clauses, OPERATORS[operators.comma_op], **kw + ) + if order_by: return " ORDER BY " + order_by else: @@ -2827,8 +2906,8 @@ class SQLCompiler(Compiled): def visit_insert(self, insert_stmt, **kw): - compile_state = insert_stmt._compile_state_cls( - insert_stmt, self, isinsert=True, **kw + compile_state = insert_stmt._compile_state_factory( + insert_stmt, self, **kw ) insert_stmt = compile_state.statement @@ -2973,8 +3052,8 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): - compile_state = update_stmt._compile_state_cls( - update_stmt, self, isupdate=True, **kw + compile_state = update_stmt._compile_state_factory( + update_stmt, self, **kw ) update_stmt = compile_state.statement @@ -3056,8 +3135,8 @@ class SQLCompiler(Compiled): text += " " + extra_from_text if update_stmt._where_criteria: - t = self._generate_delimited_list( - update_stmt._where_criteria, OPERATORS[operators.and_], **kw + t = self._generate_delimited_and_list( + update_stmt._where_criteria, **kw ) if t: text += " WHERE " + t @@ -3100,8 +3179,8 @@ class SQLCompiler(Compiled): return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) def visit_delete(self, delete_stmt, **kw): - compile_state = delete_stmt._compile_state_cls( - delete_stmt, self, isdelete=True, **kw + compile_state = delete_stmt._compile_state_factory( + delete_stmt, self, **kw ) delete_stmt = compile_state.statement @@ -3159,8 +3238,8 @@ class SQLCompiler(Compiled): text += " " + extra_from_text if delete_stmt._where_criteria: - t = self._generate_delimited_list( - delete_stmt._where_criteria, OPERATORS[operators.and_], **kw + t = self._generate_delimited_and_list( + delete_stmt._where_criteria, **kw ) if t: text += " WHERE " + t @@ -3230,7 +3309,7 @@ class StrSQLCompiler(SQLCompiler): def returning_clause(self, stmt, returning_cols): columns = [ self._label_select_column(None, c, True, False, {}) - for c in elements._select_iterables(returning_cols) + for c in base._select_iterables(returning_cols) ] return "RETURNING " + ", ".join(columns) |