summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2018-06-14 22:17:00 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2018-06-18 09:12:19 -0400
commit3619edcb8aa3ceef2a44925b85315fc0e90c5982 (patch)
treeef01478a8de04a145675ad442006ff6cb52dfafe /lib/sqlalchemy
parentda5323c2fae39aab45d305f723a73483563b2307 (diff)
downloadsqlalchemy-3619edcb8aa3ceef2a44925b85315fc0e90c5982.tar.gz
render WITH clause after INSERT for INSERT..SELECT on Oracle, MySQL
Fixed INSERT FROM SELECT with CTEs for the Oracle and MySQL dialects, where the CTE was being placed above the entire statement as is typical with other databases, however Oracle and MariaDB 10.2 wants the CTE underneath the "INSERT" segment. Note that the Oracle and MySQL dialects don't yet work when a CTE is applied to a subquery inside of an UPDATE or DELETE statement, as the CTE is still applied to the top rather than inside the subquery. Also adds test suite support CTEs against backends. Change-Id: I8ac337104d5c546dd4f0cd305632ffb56ac8bf90 Fixes: #4275 Fixes: #4230
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py2
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py1
-rw-r--r--lib/sqlalchemy/engine/default.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py9
-rw-r--r--lib/sqlalchemy/testing/requirements.py11
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py193
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py2
8 files changed, 217 insertions, 3 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index c8a3d3322..62753e1a5 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1684,6 +1684,8 @@ class MySQLDialect(default.DefaultDialect):
default_paramstyle = 'format'
colspecs = colspecs
+ cte_follows_insert = True
+
statement_compiler = MySQLCompiler
ddl_compiler = MySQLDDLCompiler
type_compiler = MySQLTypeCompiler
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 39acbf28d..356c2a2bf 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -1030,6 +1030,7 @@ class OracleDialect(default.DefaultDialect):
max_identifier_length = 30
supports_simple_order_by_label = False
+ cte_follows_insert = True
supports_sequences = True
sequences_optional = False
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 4d5f338bf..54fb25c16 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -60,6 +60,7 @@ class DefaultDialect(interfaces.Dialect):
implicit_returning = False
supports_right_nested_joins = True
+ cte_follows_insert = False
supports_native_enum = False
supports_native_boolean = False
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index a442c65fd..0b98dc51c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -2105,7 +2105,12 @@ class SQLCompiler(Compiled):
returning_clause = None
if insert_stmt.select is not None:
- text += " %s" % self.process(self._insert_from_select, **kw)
+ select_text = self.process(self._insert_from_select, **kw)
+
+ if self.ctes and toplevel and self.dialect.cte_follows_insert:
+ text += " %s%s" % (self._render_cte_clause(), select_text)
+ else:
+ text += " %s" % select_text
elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
@@ -2130,7 +2135,7 @@ class SQLCompiler(Compiled):
if returning_clause and not self.returning_precedes_values:
text += " " + returning_clause
- if self.ctes and toplevel:
+ if self.ctes and toplevel and not self.dialect.cte_follows_insert:
text = self._render_cte_clause() + text
self.stack.pop(-1)
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index b509c94d6..19d80e028 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -180,9 +180,18 @@ class SuiteRequirements(Requirements):
return exclusions.closed()
@property
+ def ctes_with_update_delete(self):
+ """target database supports CTES that ride on top of a normal UPDATE
+ or DELETE statement which refers to the CTE in a correlated subquery.
+
+ """
+
+ return exclusions.closed()
+
+ @property
def ctes_on_dml(self):
"""target database supports CTES which consist of INSERT, UPDATE
- or DELETE"""
+ or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)"""
return exclusions.closed()
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
index 9eeffd4cb..748d9722d 100644
--- a/lib/sqlalchemy/testing/suite/__init__.py
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -1,4 +1,5 @@
+from sqlalchemy.testing.suite.test_cte import *
from sqlalchemy.testing.suite.test_dialect import *
from sqlalchemy.testing.suite.test_ddl import *
from sqlalchemy.testing.suite.test_insert import *
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
new file mode 100644
index 000000000..cc72278e6
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -0,0 +1,193 @@
+from .. import fixtures, config
+from ..assertions import eq_
+
+from sqlalchemy import Integer, String, select
+from sqlalchemy import ForeignKey
+from sqlalchemy import testing
+
+from ..schema import Table, Column
+
+
+class CTETest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = 'ctes',
+
+ run_inserts = 'each'
+ run_deletes = 'each'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("some_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column("parent_id", ForeignKey("some_table.id")))
+
+ Table("some_other_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column("parent_id", Integer))
+
+ @classmethod
+ def insert_data(cls):
+ config.db.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "d1", "parent_id": None},
+ {"id": 2, "data": "d2", "parent_id": 1},
+ {"id": 3, "data": "d3", "parent_id": 1},
+ {"id": 4, "data": "d4", "parent_id": 3},
+ {"id": 5, "data": "d5", "parent_id": 3}
+ ]
+ )
+
+ def test_select_nonrecursive_round_trip(self):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
+ result = conn.execute(
+ select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
+ )
+ eq_(result.fetchall(), [("d4", )])
+
+ def test_select_recursive_round_trip(self):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])).cte(
+ "some_cte", recursive=True)
+
+ cte_alias = cte.alias("c1")
+ st1 = some_table.alias()
+ # note that SQL Server requires this to be UNION ALL,
+ # can't be UNION
+ cte = cte.union_all(
+ select([st1]).where(st1.c.id == cte_alias.c.parent_id)
+ )
+ result = conn.execute(
+ select([cte.c.data]).where(
+ cte.c.data != "d2").order_by(cte.c.data.desc())
+ )
+ eq_(
+ result.fetchall(),
+ [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
+ )
+
+ def test_insert_from_select_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"],
+ select([cte])
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.update_from
+ def test_update_from_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.update().values(parent_id=5).where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (2, "d2", 5),
+ (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
+ ]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.delete_from
+ def test_delete_from_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (5, "d5", 3)
+ ]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ def test_delete_scalar_subq_round_trip(self):
+
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data ==
+ select([cte.c.data]).where(
+ cte.c.id == some_other_table.c.id)
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (5, "d5", 3)
+ ]
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index d9755c8f9..05b9162de 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -511,3 +511,5 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
+
+