summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2016-02-10 15:33:10 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2016-02-10 15:33:10 -0500
commit630d2a4e284780e100a2f9e1acd9c25412bfe85a (patch)
treeac5b4d610dec755dff8cd981f116a547ef6e5ef6
parent2d7938c3864a75f056ade70db803c021e631827a (diff)
downloadsqlalchemy-630d2a4e284780e100a2f9e1acd9c25412bfe85a.tar.gz
- move out isinsert/isupdate/isdelete state
-rw-r--r--lib/sqlalchemy/sql/compiler.py70
-rw-r--r--lib/sqlalchemy/sql/crud.py45
-rw-r--r--lib/sqlalchemy/testing/assertions.py2
-rw-r--r--test/sql/test_cte.py2
4 files changed, 79 insertions, 40 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 43276ac4f..2f962be0b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -418,6 +418,11 @@ class SQLCompiler(Compiled):
self.truncated_names = {}
Compiled.__init__(self, dialect, statement, **kwargs)
+ if (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and statement._returning:
+ self.returning = statement._returning
+
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes and self._is_toplevel_select(select):
+ if self.ctes and toplevel:
text = self._render_cte_clause() + text
if select._suffixes:
@@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled):
else:
return text
- def _is_toplevel_select(self, select):
- """Return True if the stack is placed at the given select, and
- is also the outermost SELECT, meaning there is either no stack
- before this one, or the enclosing stack is a topmost INSERT.
-
- """
- return (
- self.stack[-1]['selectable'] is select and
- (
- len(self.stack) == 1 or self.isinsert and len(self.stack) == 2
- and self.statement is self.stack[0]['selectable']
- )
- )
-
def _setup_select_hints(self, select):
byfrom = dict([
(from_, hinttext % {
@@ -1877,13 +1868,15 @@ class SQLCompiler(Compiled):
return dialect_hints, table_text
def visit_insert(self, insert_stmt, **kw):
+ toplevel = not self.stack
+
self.stack.append(
{'correlate_froms': set(),
"asfrom_froms": set(),
"selectable": insert_stmt})
- self.isinsert = True
- crud_params = crud._get_crud_params(self, insert_stmt, **kw)
+ crud_params = crud._setup_crud_params(
+ self, insert_stmt, crud.ISINSERT, **kw)
if not crud_params and \
not self.dialect.supports_default_values and \
@@ -1928,10 +1921,9 @@ class SQLCompiler(Compiled):
text += " (%s)" % ', '.join([preparer.format_column(c[0])
for c in crud_params_single])
- if self.returning or insert_stmt._returning:
- self.returning = self.returning or insert_stmt._returning
+ if insert_stmt._returning:
returning_clause = self.returning_clause(
- insert_stmt, self.returning)
+ insert_stmt, insert_stmt._returning)
if self.returning_precedes_values:
text += " " + returning_clause
@@ -1953,9 +1945,12 @@ class SQLCompiler(Compiled):
text += " VALUES (%s)" % \
', '.join([c[1] for c in crud_params])
- if self.returning and not self.returning_precedes_values:
+ if insert_stmt._returning and not self.returning_precedes_values:
text += " " + returning_clause
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
+
self.stack.pop(-1)
return text
@@ -1991,13 +1986,13 @@ class SQLCompiler(Compiled):
for t in extra_froms)
def visit_update(self, update_stmt, **kw):
+ toplevel = not self.stack
+
self.stack.append(
{'correlate_froms': set([update_stmt.table]),
"asfrom_froms": set([update_stmt.table]),
"selectable": update_stmt})
- self.isupdate = True
-
extra_froms = update_stmt._extra_froms
text = "UPDATE "
@@ -2009,7 +2004,8 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- crud_params = crud._get_crud_params(self, update_stmt, **kw)
+ crud_params = crud._setup_crud_params(
+ self, update_stmt, crud.ISUPDATE, **kw)
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
@@ -2028,12 +2024,10 @@ class SQLCompiler(Compiled):
'=' + c[1] for c in crud_params
)
- if self.returning or update_stmt._returning:
- if not self.returning:
- self.returning = update_stmt._returning
+ if update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, update_stmt._returning)
if extra_froms:
extra_from_text = self.update_from_clause(
@@ -2053,9 +2047,12 @@ class SQLCompiler(Compiled):
if limit_clause:
text += " " + limit_clause
- if self.returning and not self.returning_precedes_values:
+ if update_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, update_stmt._returning)
+
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -2063,13 +2060,16 @@ class SQLCompiler(Compiled):
@util.memoized_property
def _key_getters_for_crud_column(self):
- return crud._key_getters_for_crud_column(self)
+ return crud._key_getters_for_crud_column(self, self.statement)
def visit_delete(self, delete_stmt, **kw):
+ toplevel = not self.stack
+
self.stack.append({'correlate_froms': set([delete_stmt.table]),
"asfrom_froms": set([delete_stmt.table]),
"selectable": delete_stmt})
- self.isdelete = True
+
+ crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
text = "DELETE "
@@ -2088,7 +2088,6 @@ class SQLCompiler(Compiled):
text += table_text
if delete_stmt._returning:
- self.returning = delete_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
@@ -2098,10 +2097,13 @@ class SQLCompiler(Compiled):
if t:
text += " WHERE " + t
- if self.returning and not self.returning_precedes_values:
+ if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
+
self.stack.pop(-1)
return text
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index a01b72e61..9110ae6f2 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -25,6 +25,39 @@ values present.
""")
+ISINSERT = util.symbol('ISINSERT')
+ISUPDATE = util.symbol('ISUPDATE')
+ISDELETE = util.symbol('ISDELETE')
+
+
+def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
+ restore_isinsert = compiler.isinsert
+ restore_isupdate = compiler.isupdate
+ restore_isdelete = compiler.isdelete
+
+ should_restore = restore_isinsert or restore_isupdate or restore_isdelete
+
+ if local_stmt_type is ISINSERT:
+ compiler.isupdate = False
+ compiler.isinsert = True
+ elif local_stmt_type is ISUPDATE:
+ compiler.isupdate = True
+ compiler.isinsert = False
+ elif local_stmt_type is ISDELETE:
+ if not should_restore:
+ compiler.isdelete = True
+ else:
+ assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
+
+ try:
+ if compiler.isupdate or compiler.isinsert:
+ return _get_crud_params(compiler, stmt, **kw)
+ finally:
+ if should_restore:
+ compiler.isinsert = restore_isinsert
+ compiler.isupdate = restore_isupdate
+ compiler.isdelete = restore_isdelete
+
def _get_crud_params(compiler, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
@@ -59,7 +92,7 @@ def _get_crud_params(compiler, stmt, **kw):
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
_column_as_key, _getattr_col_key, _col_bind_name = \
- _key_getters_for_crud_column(compiler)
+ _key_getters_for_crud_column(compiler, stmt)
# if we have statement parameters - set defaults in the
# compiled params
@@ -128,15 +161,15 @@ def _create_bind_param(
return bindparam
-def _key_getters_for_crud_column(compiler):
- if compiler.isupdate and compiler.statement._extra_froms:
+def _key_getters_for_crud_column(compiler, stmt):
+ if compiler.isupdate and stmt._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
# the "main" table of the statement remains unqualified,
# allowing the most compatibility with a non-multi-table
# statement.
- _et = set(compiler.statement._extra_froms)
+ _et = set(stmt._extra_froms)
def _column_as_key(key):
str_key = elements._column_as_key(key)
@@ -609,7 +642,9 @@ def _get_returning_modifiers(compiler, stmt):
stmt.table.implicit_returning and
stmt._return_defaults)
else:
- implicit_return_defaults = False
+ # this line is unused, currently we are always
+ # isinsert or isupdate
+ implicit_return_defaults = False # pragma: no cover
if implicit_return_defaults:
if stmt._return_defaults is True:
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index bb5a96256..21f9f68fb 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -296,6 +296,8 @@ class AssertsCompiledSQL(object):
dialect = config.db.dialect
elif dialect == 'default':
dialect = default.DefaultDialect()
+ elif dialect == 'default_enhanced':
+ dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index 885b15599..e6bc7c0f0 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -8,7 +8,7 @@ from sqlalchemy.exc import CompileError
class CTETest(fixtures.TestBase, AssertsCompiledSQL):
- __dialect__ = 'default'
+ __dialect__ = 'default_enhanced'
def test_nonrecursive(self):
orders = table('orders',