summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-03-04 02:32:21 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-03-04 02:32:21 +0000
commit7768a89d6e373639b63ecb35639dd3cec15e1407 (patch)
tree3adc3008a115f4a877a6acb1b3155513a03f82a9
parent3c682d225d21faf751fe6c4d5bcb1efc0c5bf5f8 (diff)
parent7fe400f54632835695f7b98f0c1a54424953dfad (diff)
downloadsqlalchemy-7768a89d6e373639b63ecb35639dd3cec15e1407.tar.gz
Merge "Restore crud flags if visiting_cte is set"
-rw-r--r--doc/build/changelog/unreleased_13/5181.rst9
-rw-r--r--lib/sqlalchemy/sql/compiler.py1
-rw-r--r--lib/sqlalchemy/sql/crud.py6
-rw-r--r--test/sql/test_cte.py35
4 files changed, 49 insertions, 2 deletions
diff --git a/doc/build/changelog/unreleased_13/5181.rst b/doc/build/changelog/unreleased_13/5181.rst
new file mode 100644
index 000000000..046dc4f38
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/5181.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: bug, sql, postgresql
+ :tickets: 5181
+
+ Fixed bug where a CTE of an INSERT/UPDATE/DELETE that also uses RETURNING
+ could then not be SELECTed from directly, as the internal state of the
+ compiler would try to treat the outer SELECT as a DELETE statement itself
+ and access nonexistent state.
+
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d31cf67f8..424282951 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -723,6 +723,7 @@ class SQLCompiler(Compiled):
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
self.truncated_names = {}
+
Compiled.__init__(self, dialect, statement, **kwargs)
if (
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 433a5fdfa..e474952ce 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -44,8 +44,10 @@ def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
restore_isdelete = compiler.isdelete
should_restore = (
- restore_isinsert or restore_isupdate or restore_isdelete
- ) or len(compiler.stack) > 1
+ (restore_isinsert or restore_isupdate or restore_isdelete)
+ or len(compiler.stack) > 1
+ or "visiting_cte" in kw
+ )
if local_stmt_type is ISINSERT:
compiler.isupdate = False
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index 4a7a80e77..c9178d580 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -999,6 +999,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"upsert.quantity FROM upsert))",
)
+ eq_(insert.compile().isinsert, True)
+
def test_anon_update_cte(self):
orders = table("orders", column("region"))
stmt = (
@@ -1016,6 +1018,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT anon_1.region FROM anon_1",
)
+ eq_(stmt.select().compile().isupdate, False)
+
def test_anon_insert_cte(self):
orders = table("orders", column("region"))
stmt = (
@@ -1028,6 +1032,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"VALUES (:region) RETURNING orders.region) "
"SELECT anon_1.region FROM anon_1",
)
+ eq_(stmt.select().compile().isinsert, False)
def test_pg_example_one(self):
products = table("products", column("id"), column("date"))
@@ -1054,6 +1059,33 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"INSERT INTO products_log (id, date) "
"SELECT moved_rows.id, moved_rows.date FROM moved_rows",
)
+ eq_(stmt.compile().isinsert, True)
+ eq_(stmt.compile().isdelete, False)
+
+ def test_pg_example_one_select_only(self):
+ products = table("products", column("id"), column("date"))
+
+ moved_rows = (
+ products.delete()
+ .where(
+ and_(products.c.date >= "dateone", products.c.date < "datetwo")
+ )
+ .returning(*products.c)
+ .cte("moved_rows")
+ )
+
+ stmt = moved_rows.select()
+
+ self.assert_compile(
+ stmt,
+ "WITH moved_rows AS "
+ "(DELETE FROM products WHERE products.date >= :date_1 "
+ "AND products.date < :date_2 "
+ "RETURNING products.id, products.date) "
+ "SELECT moved_rows.id, moved_rows.date FROM moved_rows",
+ )
+
+ eq_(stmt.compile().isdelete, False)
def test_pg_example_two(self):
products = table("products", column("id"), column("price"))
@@ -1076,6 +1108,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT t.id, t.price "
"FROM t",
)
+ eq_(stmt.compile().isupdate, False)
def test_pg_example_three(self):
@@ -1136,6 +1169,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT pd.id, pd.price "
"FROM pd",
)
+ eq_(stmt.compile().isinsert, False)
def test_update_pulls_from_cte(self):
products = table("products", column("id"), column("price"))
@@ -1154,6 +1188,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"UPDATE products SET id=:id, price=:price FROM pd "
"WHERE products.price = pd.price",
)
+ eq_(stmt.compile().isupdate, True)
def test_standalone_function(self):
a = table("a", column("x"))