summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/9767.rst8
-rw-r--r--lib/sqlalchemy/orm/bulk_persistence.py66
-rw-r--r--test/orm/test_core_compilation.py34
3 files changed, 82 insertions, 26 deletions
diff --git a/doc/build/changelog/unreleased_20/9767.rst b/doc/build/changelog/unreleased_20/9767.rst
new file mode 100644
index 000000000..857d34987
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9767.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: bug, orm, regression
+ :tickets: 9767
+
+ Fixed regression where use of :func:`_dml.update` or :func:`_dml_delete`
+ within a :class:`_sql.CTE` construct, then used in a :func:`_sql.select`,
+ would raise a :class:`.CompileError` as a result of ORM related rules for
+ performing ORM-level update/delete statements.
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
index 257d71db4..b75285ebd 100644
--- a/lib/sqlalchemy/orm/bulk_persistence.py
+++ b/lib/sqlalchemy/orm/bulk_persistence.py
@@ -1397,9 +1397,11 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
"dml_strategy", "unspecified"
)
- if dml_strategy == "bulk":
+ toplevel = not compiler.stack
+
+ if toplevel and dml_strategy == "bulk":
self._setup_for_bulk_update(statement, compiler)
- elif dml_strategy in ("orm", "unspecified"):
+ elif not toplevel or dml_strategy in ("orm", "unspecified"):
self._setup_for_orm_update(statement, compiler)
return self
@@ -1407,6 +1409,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
def _setup_for_orm_update(self, statement, compiler, **kw):
orm_level_statement = statement
+ toplevel = not compiler.stack
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
@@ -1416,8 +1420,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
self._init_global_attributes(
statement,
compiler,
- toplevel=True,
- process_criteria_for_toplevel=True,
+ toplevel=toplevel,
+ process_criteria_for_toplevel=toplevel,
)
if statement._values:
@@ -1451,9 +1455,12 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
use_supplemental_cols = False
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
+ if not toplevel:
+ synchronize_session = None
+ else:
+ synchronize_session = compiler._annotations.get(
+ "synchronize_session", None
+ )
can_use_returning = compiler._annotations.get(
"can_use_returning", None
)
@@ -1486,13 +1493,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
*(list(mapper.local_table.primary_key))
)
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
+ if toplevel:
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ dml_mapper=mapper,
+ use_supplemental_cols=use_supplemental_cols,
+ )
self.statement = new_stmt
@@ -1814,6 +1822,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ toplevel = not compiler.stack
+
orm_level_statement = statement
ext_info = statement.table._annotations["parententity"]
@@ -1822,8 +1832,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
self._init_global_attributes(
statement,
compiler,
- toplevel=True,
- process_criteria_for_toplevel=True,
+ toplevel=toplevel,
+ process_criteria_for_toplevel=toplevel,
)
new_stmt = statement._clone()
@@ -1841,9 +1851,12 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
use_supplemental_cols = False
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
+ if not toplevel:
+ synchronize_session = None
+ else:
+ synchronize_session = compiler._annotations.get(
+ "synchronize_session", None
+ )
can_use_returning = compiler._annotations.get(
"can_use_returning", None
)
@@ -1870,13 +1883,14 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
+ if toplevel:
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ dml_mapper=mapper,
+ use_supplemental_cols=use_supplemental_cols,
+ )
self.statement = new_stmt
diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py
index 6736d5589..8b28de591 100644
--- a/test/orm/test_core_compilation.py
+++ b/test/orm/test_core_compilation.py
@@ -360,6 +360,40 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
)
+class DMLTest(QueryTest, AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ @testing.variation("stmt_type", ["update", "delete"])
+ def test_dml_ctes(self, stmt_type: testing.Variation):
+ User = self.classes.User
+
+ if stmt_type.update:
+ fn = update
+ elif stmt_type.delete:
+ fn = delete
+ else:
+ stmt_type.fail()
+
+ inner_cte = fn(User).returning(User.id).cte("uid")
+
+ stmt = select(inner_cte)
+
+ if stmt_type.update:
+ self.assert_compile(
+ stmt,
+ "WITH uid AS (UPDATE users SET id=:id, name=:name "
+ "RETURNING users.id) SELECT uid.id FROM uid",
+ )
+ elif stmt_type.delete:
+ self.assert_compile(
+ stmt,
+ "WITH uid AS (DELETE FROM users "
+ "RETURNING users.id) SELECT uid.id FROM uid",
+ )
+ else:
+ stmt_type.fail()
+
+
class ColumnsClauseFromsTest(QueryTest, AssertsCompiledSQL):
__dialect__ = "default"