diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-02-10 15:33:10 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-02-10 15:33:10 -0500 |
commit | 630d2a4e284780e100a2f9e1acd9c25412bfe85a (patch) | |
tree | ac5b4d610dec755dff8cd981f116a547ef6e5ef6 | |
parent | 2d7938c3864a75f056ade70db803c021e631827a (diff) | |
download | sqlalchemy-630d2a4e284780e100a2f9e1acd9c25412bfe85a.tar.gz |
- move out isinsert/isupdate/isdelete state
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 70 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 45 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 2 | ||||
-rw-r--r-- | test/sql/test_cte.py | 2 |
4 files changed, 79 insertions, 40 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 43276ac4f..2f962be0b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -418,6 +418,11 @@ class SQLCompiler(Compiled): self.truncated_names = {} Compiled.__init__(self, dialect, statement, **kwargs) + if ( + self.isinsert or self.isupdate or self.isdelete + ) and statement._returning: + self.returning = statement._returning + if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() @@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and self._is_toplevel_select(select): + if self.ctes and toplevel: text = self._render_cte_clause() + text if select._suffixes: @@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled): else: return text - def _is_toplevel_select(self, select): - """Return True if the stack is placed at the given select, and - is also the outermost SELECT, meaning there is either no stack - before this one, or the enclosing stack is a topmost INSERT. - - """ - return ( - self.stack[-1]['selectable'] is select and - ( - len(self.stack) == 1 or self.isinsert and len(self.stack) == 2 - and self.statement is self.stack[0]['selectable'] - ) - ) - def _setup_select_hints(self, select): byfrom = dict([ (from_, hinttext % { @@ -1877,13 +1868,15 @@ class SQLCompiler(Compiled): return dialect_hints, table_text def visit_insert(self, insert_stmt, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set(), "asfrom_froms": set(), "selectable": insert_stmt}) - self.isinsert = True - crud_params = crud._get_crud_params(self, insert_stmt, **kw) + crud_params = crud._setup_crud_params( + self, insert_stmt, crud.ISINSERT, **kw) if not crud_params and \ not self.dialect.supports_default_values and \ @@ -1928,10 +1921,9 @@ class SQLCompiler(Compiled): text += " (%s)" % ', '.join([preparer.format_column(c[0]) for c in crud_params_single]) - if self.returning or insert_stmt._returning: - self.returning = self.returning or insert_stmt._returning + if insert_stmt._returning: returning_clause = self.returning_clause( - insert_stmt, self.returning) + insert_stmt, insert_stmt._returning) if self.returning_precedes_values: text += " " + returning_clause @@ -1953,9 +1945,12 @@ class SQLCompiler(Compiled): text += " VALUES (%s)" % \ ', '.join([c[1] for c in crud_params]) - if self.returning and not self.returning_precedes_values: + if insert_stmt._returning and not self.returning_precedes_values: text += " " + returning_clause + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) return text @@ -1991,13 +1986,13 @@ class SQLCompiler(Compiled): for t in extra_froms) def visit_update(self, update_stmt, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set([update_stmt.table]), "asfrom_froms": set([update_stmt.table]), "selectable": update_stmt}) - self.isupdate = True - extra_froms = update_stmt._extra_froms text = "UPDATE " @@ -2009,7 +2004,8 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - crud_params = crud._get_crud_params(self, update_stmt, **kw) + crud_params = crud._setup_crud_params( + self, update_stmt, crud.ISUPDATE, **kw) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( @@ -2028,12 +2024,10 @@ class SQLCompiler(Compiled): '=' + c[1] for c in crud_params ) - if self.returning or update_stmt._returning: - if not self.returning: - self.returning = update_stmt._returning + if update_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, update_stmt._returning) if extra_froms: extra_from_text = self.update_from_clause( @@ -2053,9 +2047,12 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if self.returning and not self.returning_precedes_values: + if update_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, update_stmt._returning) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text self.stack.pop(-1) @@ -2063,13 +2060,16 @@ class SQLCompiler(Compiled): @util.memoized_property def _key_getters_for_crud_column(self): - return crud._key_getters_for_crud_column(self) + return crud._key_getters_for_crud_column(self, self.statement) def visit_delete(self, delete_stmt, **kw): + toplevel = not self.stack + self.stack.append({'correlate_froms': set([delete_stmt.table]), "asfrom_froms": set([delete_stmt.table]), "selectable": delete_stmt}) - self.isdelete = True + + crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) text = "DELETE " @@ -2088,7 +2088,6 @@ class SQLCompiler(Compiled): text += table_text if delete_stmt._returning: - self.returning = delete_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) @@ -2098,10 +2097,13 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - if self.returning and not self.returning_precedes_values: + if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) return text diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index a01b72e61..9110ae6f2 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -25,6 +25,39 @@ values present. """) +ISINSERT = util.symbol('ISINSERT') +ISUPDATE = util.symbol('ISUPDATE') +ISDELETE = util.symbol('ISDELETE') + + +def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): + restore_isinsert = compiler.isinsert + restore_isupdate = compiler.isupdate + restore_isdelete = compiler.isdelete + + should_restore = restore_isinsert or restore_isupdate or restore_isdelete + + if local_stmt_type is ISINSERT: + compiler.isupdate = False + compiler.isinsert = True + elif local_stmt_type is ISUPDATE: + compiler.isupdate = True + compiler.isinsert = False + elif local_stmt_type is ISDELETE: + if not should_restore: + compiler.isdelete = True + else: + assert False, "ISINSERT, ISUPDATE, or ISDELETE expected" + + try: + if compiler.isupdate or compiler.isinsert: + return _get_crud_params(compiler, stmt, **kw) + finally: + if should_restore: + compiler.isinsert = restore_isinsert + compiler.isupdate = restore_isupdate + compiler.isdelete = restore_isdelete + def _get_crud_params(compiler, stmt, **kw): """create a set of tuples representing column/string pairs for use @@ -59,7 +92,7 @@ def _get_crud_params(compiler, stmt, **kw): # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account _column_as_key, _getattr_col_key, _col_bind_name = \ - _key_getters_for_crud_column(compiler) + _key_getters_for_crud_column(compiler, stmt) # if we have statement parameters - set defaults in the # compiled params @@ -128,15 +161,15 @@ def _create_bind_param( return bindparam -def _key_getters_for_crud_column(compiler): - if compiler.isupdate and compiler.statement._extra_froms: +def _key_getters_for_crud_column(compiler, stmt): + if compiler.isupdate and stmt._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. # the "main" table of the statement remains unqualified, # allowing the most compatibility with a non-multi-table # statement. - _et = set(compiler.statement._extra_froms) + _et = set(stmt._extra_froms) def _column_as_key(key): str_key = elements._column_as_key(key) @@ -609,7 +642,9 @@ def _get_returning_modifiers(compiler, stmt): stmt.table.implicit_returning and stmt._return_defaults) else: - implicit_return_defaults = False + # this line is unused, currently we are always + # isinsert or isupdate + implicit_return_defaults = False # pragma: no cover if implicit_return_defaults: if stmt._return_defaults is True: diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bb5a96256..21f9f68fb 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -296,6 +296,8 @@ class AssertsCompiledSQL(object): dialect = config.db.dialect elif dialect == 'default': dialect = default.DefaultDialect() + elif dialect == 'default_enhanced': + dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 885b15599..e6bc7c0f0 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -8,7 +8,7 @@ from sqlalchemy.exc import CompileError class CTETest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = 'default_enhanced' def test_nonrecursive(self): orders = table('orders', |