diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 189 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 1 |
7 files changed, 197 insertions, 21 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 73e35d4bb..c3e69584a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2800,6 +2800,8 @@ class MSDialect(default.DefaultDialect): insert_returning = True update_returning = True delete_returning = True + update_returning_multifrom = True + delete_returning_multifrom = True colspecs = { sqltypes.DateTime: _MSDateTime, diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index dcd03e625..ae2f12d3f 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2820,6 +2820,8 @@ class PGDialect(default.DefaultDialect): update_returning = True delete_returning = True insert_returning = True + update_returning_multifrom = True + delete_returning_multifrom = True connection_characteristics = ( default.DefaultDialect.connection_characteristics diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4b312dceb..80e687c32 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -143,6 +143,8 @@ class DefaultDialect(Dialect): insert_null_pk_still_autoincrements = False update_returning = False delete_returning = False + update_returning_multifrom = False + delete_returning_multifrom = False insert_returning = False insert_executemany_returning = False diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 208c4f6b0..778c07592 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -781,6 +781,13 @@ class Dialect(EventTarget): """ + update_returning_multifrom: bool + """if the dialect supports RETURNING with UPDATE..FROM + + .. versionadded:: 2.0 + + """ + delete_returning: bool """if the dialect supports RETURNING with DELETE @@ -788,6 +795,13 @@ class Dialect(EventTarget): """ + delete_returning_multifrom: bool + """if the dialect supports RETURNING with DELETE..FROM + + .. versionadded:: 2.0 + + """ + favor_returning_over_lastrowid: bool """for backends that support both a lastrowid and a RETURNING insert strategy, favor RETURNING for simple single-int pk inserts. diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 7cd66513b..59a0a3d81 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1805,6 +1805,8 @@ _EMPTY_DICT = util.immutabledict() class BulkUDCompileState(CompileState): class default_update_options(Options): _synchronize_session = "evaluate" + _is_delete_using = False + _is_update_from = False _autoflush = True _subject_mapper = None _resolved_values = _EMPTY_DICT @@ -1815,7 +1817,15 @@ class BulkUDCompileState(CompileState): _refresh_identity_token = None @classmethod - def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: raise NotImplementedError() @classmethod @@ -1836,7 +1846,7 @@ class BulkUDCompileState(CompileState): execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", - {"synchronize_session"}, + {"synchronize_session", "is_delete_using", "is_update_from"}, execution_options, statement._execution_options, ) @@ -1863,7 +1873,11 @@ class BulkUDCompileState(CompileState): session._autoflush() statement = statement._annotate( - {"synchronize_session": update_options._synchronize_session} + { + "synchronize_session": update_options._synchronize_session, + "is_delete_using": update_options._is_delete_using, + "is_update_from": update_options._is_update_from, + } ) # this stage of the execution is called before the do_orm_execute event @@ -1964,6 +1978,56 @@ class BulkUDCompileState(CompileState): return return_crit @classmethod + def _interpret_returning_rows(cls, mapper, rows): + """translate from local inherited table columns to base mapper + primary key columns. + + Joined inheritance mappers always establish the primary key in terms of + the base table. When we UPDATE a sub-table, we can only get + RETURNING for the sub-table's columns. + + Here, we create a lookup from the local sub table's primary key + columns to the base table PK columns so that we can get identity + key values from RETURNING that's against the joined inheritance + sub-table. + + the complexity here is to support more than one level deep of + inheritance, where we have to link columns to each other across + the inheritance hierarchy. + + """ + + if mapper.local_table is not mapper.base_mapper.local_table: + return rows + + # this starts as a mapping of + # local_pk_col: local_pk_col. + # we will then iteratively rewrite the "value" of the dict with + # each successive superclass column + local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} + + for mp in mapper.iterate_to_root(): + if mp.inherits is None: + break + elif mp.local_table is mp.inherits.local_table: + continue + + t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) + col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} + for pk, super_ in local_pk_to_base_pk.items(): + local_pk_to_base_pk[pk] = col_to_col[super_] + + lookup = { + local_pk_to_base_pk[lpk]: idx + for idx, lpk in enumerate(mapper.local_table.primary_key) + } + primary_key_convert = [ + lookup[bpk] for bpk in mapper.base_mapper.primary_key + ] + + return [tuple(row[idx] for idx in primary_key_convert) for row in rows] + + @classmethod def _do_pre_synchronize_evaluate( cls, session, @@ -2111,8 +2175,12 @@ class BulkUDCompileState(CompileState): def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - - if cls.can_use_returning(bind.dialect, mapper): + if cls.can_use_returning( + bind.dialect, + mapper, + is_update_from=update_options._is_update_from, + is_delete_using=update_options._is_delete_using, + ): return _result.null_result() else: return None @@ -2300,25 +2368,60 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState): # if we are against a lambda statement we might not be the # topmost object that received per-execute annotations + # do this first as we need to determine if there is + # UPDATE..FROM + + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) + if compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning(compiler.dialect, mapper): + ) == "fetch" and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable + ): if new_stmt._returning: raise sa_exc.InvalidRequestError( "Can't use synchronize_session='fetch' " "with explicit returning()" ) - new_stmt = new_stmt.returning(*mapper.primary_key) - - UpdateDMLState.__init__(self, new_stmt, compiler, **kw) + self.statement = self.statement.returning( + *mapper.local_table.primary_key + ) return self @classmethod - def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: - return ( + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: + + # normal answer for "should we use RETURNING" at all. + normal_answer = ( dialect.update_returning and mapper.local_table.implicit_returning ) + if not normal_answer: + return False + + # these workarounds are currently hypothetical for UPDATE, + # unlike DELETE where they impact MariaDB + if is_update_from: + return dialect.update_returning_multifrom + + elif is_multitable and not dialect.update_returning_multifrom: + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with UPDATE..FROM; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_update_from=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True @classmethod def _get_crud_kv_pairs(cls, statement, kv_iterator): @@ -2429,9 +2532,11 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState): evaluated_keys = list(update_options._value_evaluators.keys()) if result.returns_rows: + rows = cls._interpret_returning_rows(target_mapper, result.all()) + matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in result.all() + for row in rows ] else: matched_rows = update_options._matched_rows @@ -2500,20 +2605,64 @@ class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState): if new_crit: statement = statement.where(*new_crit) + # do this first as we need to determine if there is + # DELETE..FROM + DeleteDMLState.__init__(self, statement, compiler, **kw) + if compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning(compiler.dialect, mapper): - statement = statement.returning(*mapper.primary_key) - - DeleteDMLState.__init__(self, statement, compiler, **kw) + ) == "fetch" and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ): + self.statement = statement.returning(*statement.table.primary_key) return self @classmethod - def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: - return ( + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + ) -> bool: + + # normal answer for "should we use RETURNING" at all. + normal_answer = ( dialect.delete_returning and mapper.local_table.implicit_returning ) + if not normal_answer: + return False + + # now get into special workarounds because MariaDB supports + # DELETE...RETURNING but not DELETE...USING...RETURNING. + if is_delete_using: + # is_delete_using hint was passed. use + # additional dialect feature (True for PG, False for MariaDB) + return dialect.delete_returning_multifrom + + elif is_multitable and not dialect.delete_returning_multifrom: + # is_delete_using hint was not passed, but we determined + # at compile time that this is in fact a DELETE..USING. + # it's too late to continue since we did not pre-SELECT. + # raise that we need that hint up front. + + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with DELETE..USING; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_delete_using=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True @classmethod def _do_post_synchronize_evaluate(cls, session, result, update_options): @@ -2530,9 +2679,11 @@ class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState): target_mapper = update_options._subject_mapper if result.returns_rows: + rows = cls._interpret_returning_rows(target_mapper, result.all()) + matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in result.all() + for row in rows ] else: matched_rows = update_options._matched_rows diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 374ae5cc8..b17b05371 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3028,7 +3028,9 @@ class Query( self.session.execute( delete_, self._params, - execution_options={"synchronize_session": synchronize_session}, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), ), ) bulk_del.result = result # type: ignore @@ -3120,7 +3122,9 @@ class Query( self.session.execute( upd, self._params, - execution_options={"synchronize_session": synchronize_session}, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), ), ) bulk_ud.result = result # type: ignore diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 9d489ed98..eb612f394 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -359,6 +359,7 @@ class DeleteDMLState(DMLState): t, ef = self._make_extra_froms(statement) self._primary_table = t self._extra_froms = ef + self.is_multitable = ef SelfUpdateBase = typing.TypeVar("SelfUpdateBase", bound="UpdateBase") |
