diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/bulk_persistence.py | 66 |
1 files changed, 40 insertions, 26 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 257d71db4..b75285ebd 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -1397,9 +1397,11 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): "dml_strategy", "unspecified" ) - if dml_strategy == "bulk": + toplevel = not compiler.stack + + if toplevel and dml_strategy == "bulk": self._setup_for_bulk_update(statement, compiler) - elif dml_strategy in ("orm", "unspecified"): + elif not toplevel or dml_strategy in ("orm", "unspecified"): self._setup_for_orm_update(statement, compiler) return self @@ -1407,6 +1409,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): def _setup_for_orm_update(self, statement, compiler, **kw): orm_level_statement = statement + toplevel = not compiler.stack + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper @@ -1416,8 +1420,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self._init_global_attributes( statement, compiler, - toplevel=True, - process_criteria_for_toplevel=True, + toplevel=toplevel, + process_criteria_for_toplevel=toplevel, ) if statement._values: @@ -1451,9 +1455,12 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): use_supplemental_cols = False - synchronize_session = compiler._annotations.get( - "synchronize_session", None - ) + if not toplevel: + synchronize_session = None + else: + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) can_use_returning = compiler._annotations.get( "can_use_returning", None ) @@ -1486,13 +1493,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): *(list(mapper.local_table.primary_key)) ) - new_stmt = self._setup_orm_returning( - compiler, - orm_level_statement, - new_stmt, - dml_mapper=mapper, - use_supplemental_cols=use_supplemental_cols, - ) + if toplevel: + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + dml_mapper=mapper, + use_supplemental_cols=use_supplemental_cols, + ) self.statement = new_stmt @@ -1814,6 +1822,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + toplevel = not compiler.stack + orm_level_statement = statement ext_info = statement.table._annotations["parententity"] @@ -1822,8 +1832,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): self._init_global_attributes( statement, compiler, - toplevel=True, - process_criteria_for_toplevel=True, + toplevel=toplevel, + process_criteria_for_toplevel=toplevel, ) new_stmt = statement._clone() @@ -1841,9 +1851,12 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): use_supplemental_cols = False - synchronize_session = compiler._annotations.get( - "synchronize_session", None - ) + if not toplevel: + synchronize_session = None + else: + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) can_use_returning = compiler._annotations.get( "can_use_returning", None ) @@ -1870,13 +1883,14 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) - new_stmt = self._setup_orm_returning( - compiler, - orm_level_statement, - new_stmt, - dml_mapper=mapper, - use_supplemental_cols=use_supplemental_cols, - ) + if toplevel: + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + dml_mapper=mapper, + use_supplemental_cols=use_supplemental_cols, + ) self.statement = new_stmt |
