summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/bulk_persistence.py66
1 files changed, 40 insertions, 26 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
index 257d71db4..b75285ebd 100644
--- a/lib/sqlalchemy/orm/bulk_persistence.py
+++ b/lib/sqlalchemy/orm/bulk_persistence.py
@@ -1397,9 +1397,11 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
"dml_strategy", "unspecified"
)
- if dml_strategy == "bulk":
+ toplevel = not compiler.stack
+
+ if toplevel and dml_strategy == "bulk":
self._setup_for_bulk_update(statement, compiler)
- elif dml_strategy in ("orm", "unspecified"):
+ elif not toplevel or dml_strategy in ("orm", "unspecified"):
self._setup_for_orm_update(statement, compiler)
return self
@@ -1407,6 +1409,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
def _setup_for_orm_update(self, statement, compiler, **kw):
orm_level_statement = statement
+ toplevel = not compiler.stack
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
@@ -1416,8 +1420,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
self._init_global_attributes(
statement,
compiler,
- toplevel=True,
- process_criteria_for_toplevel=True,
+ toplevel=toplevel,
+ process_criteria_for_toplevel=toplevel,
)
if statement._values:
@@ -1451,9 +1455,12 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
use_supplemental_cols = False
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
+ if not toplevel:
+ synchronize_session = None
+ else:
+ synchronize_session = compiler._annotations.get(
+ "synchronize_session", None
+ )
can_use_returning = compiler._annotations.get(
"can_use_returning", None
)
@@ -1486,13 +1493,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
*(list(mapper.local_table.primary_key))
)
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
+ if toplevel:
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ dml_mapper=mapper,
+ use_supplemental_cols=use_supplemental_cols,
+ )
self.statement = new_stmt
@@ -1814,6 +1822,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ toplevel = not compiler.stack
+
orm_level_statement = statement
ext_info = statement.table._annotations["parententity"]
@@ -1822,8 +1832,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
self._init_global_attributes(
statement,
compiler,
- toplevel=True,
- process_criteria_for_toplevel=True,
+ toplevel=toplevel,
+ process_criteria_for_toplevel=toplevel,
)
new_stmt = statement._clone()
@@ -1841,9 +1851,12 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
use_supplemental_cols = False
- synchronize_session = compiler._annotations.get(
- "synchronize_session", None
- )
+ if not toplevel:
+ synchronize_session = None
+ else:
+ synchronize_session = compiler._annotations.get(
+ "synchronize_session", None
+ )
can_use_returning = compiler._annotations.get(
"can_use_returning", None
)
@@ -1870,13 +1883,14 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
- new_stmt = self._setup_orm_returning(
- compiler,
- orm_level_statement,
- new_stmt,
- dml_mapper=mapper,
- use_supplemental_cols=use_supplemental_cols,
- )
+ if toplevel:
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ dml_mapper=mapper,
+ use_supplemental_cols=use_supplemental_cols,
+ )
self.statement = new_stmt