diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 250 |
1 files changed, 190 insertions, 60 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 8393eaf74..bd8efe77f 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -28,6 +28,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import result as _result from ..future import select as future_select from ..sql import coercions from ..sql import expression @@ -1672,8 +1673,17 @@ class BulkUDCompileState(CompileState): @classmethod def orm_pre_session_exec( - cls, session, statement, params, execution_options, bind_arguments + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_reentrant_invoke, ): + if is_reentrant_invoke: + return statement, execution_options + sync = execution_options.get("synchronize_session", None) if sync is None: sync = statement._execution_options.get( @@ -1706,6 +1716,17 @@ class BulkUDCompileState(CompileState): if update_options._autoflush: session._autoflush() + statement = statement._annotate( + {"synchronize_session": update_options._synchronize_session} + ) + + # this stage of the execution is called before the do_orm_execute event + # hook. meaning for an extension like horizontal sharding, this step + # happens before the extension splits out into multiple backends and + # runs only once. if we do pre_sync_fetch, we execute a SELECT + # statement, which the horizontal sharding extension splits amongst the + # shards and combines the results together. + if update_options._synchronize_session == "evaluate": update_options = cls._do_pre_synchronize_evaluate( session, @@ -1725,19 +1746,31 @@ class BulkUDCompileState(CompileState): update_options, ) - return util.immutabledict(execution_options).union( - dict(_sa_orm_update_options=update_options) + return ( + statement, + util.immutabledict(execution_options).union( + dict(_sa_orm_update_options=update_options) + ), ) @classmethod def orm_setup_cursor_result( cls, session, statement, execution_options, bind_arguments, result ): + + # this stage of the execution is called after the + # do_orm_execute event hook. meaning for an extension like + # horizontal sharding, this step happens *within* the horizontal + # sharding event handler which calls session.execute() re-entrantly + # and will occur for each backend individually. + # the sharding extension then returns its own merged result from the + # individual ones we return here. + update_options = execution_options["_sa_orm_update_options"] if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, update_options) + cls._do_post_synchronize_evaluate(session, result, update_options) elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, update_options) + cls._do_post_synchronize_fetch(session, result, update_options) return result @@ -1767,18 +1800,6 @@ class BulkUDCompileState(CompileState): def eval_condition(obj): return True - # TODO: something more robust for this conditional - if statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values(mapper, statement) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - value_evaluators[key] = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError as err: util.raise_( sa_exc.InvalidRequestError( @@ -1789,13 +1810,35 @@ class BulkUDCompileState(CompileState): from_=err, ) - # TODO: detect when the where clause is a trivial primary key match + if statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values(mapper, statement) + value_evaluators = {} + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + # TODO: detect when the where clause is a trivial primary key match. matched_objects = [ obj for (cls, pk, identity_token,), obj in session.identity_map.items() if issubclass(cls, target_cls) and eval_condition(obj) - and identity_token == update_options._refresh_identity_token + and ( + update_options._refresh_identity_token is None + # TODO: coverage for the case where horiziontal sharding + # invokes an update() or delete() given an explicit identity + # token up front + or identity_token == update_options._refresh_identity_token + ) ] return update_options + { "_matched_objects": matched_objects, @@ -1868,29 +1911,56 @@ class BulkUDCompileState(CompileState): ): mapper = update_options._subject_mapper - if mapper: - primary_table = mapper.local_table - else: - primary_table = statement._raw_columns[0] - - # note this creates a Select() *without* the ORM plugin. - # we don't want that here. - select_stmt = future_select(*primary_table.primary_key) + select_stmt = future_select( + *(mapper.primary_key + (mapper.select_identity_token,)) + ) select_stmt._where_criteria = statement._where_criteria - matched_rows = session.execute( - select_stmt, params, execution_options, bind_arguments - ).fetchall() + def skip_for_full_returning(orm_context): + bind = orm_context.session.get_bind(**orm_context.bind_arguments) + if bind.dialect.full_returning: + return _result.null_result() + else: + return None + + result = session.execute( + select_stmt, + params, + execution_options, + bind_arguments, + _add_event=skip_for_full_returning, + ) + matched_rows = result.fetchall() + + value_evaluators = _EMPTY_DICT if statement.__visit_name__ == "update": + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) resolved_values = cls._get_resolved_values(mapper, statement) resolved_keys_as_propnames = cls._resolved_keys_as_propnames( mapper, resolved_values ) + + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + else: resolved_keys_as_propnames = _EMPTY_DICT return update_options + { + "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, "_resolved_keys_as_propnames": resolved_keys_as_propnames, } @@ -1925,15 +1995,23 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): elif statement._values: new_stmt._values = self._resolved_values + if ( + statement._annotations.get("synchronize_session", None) == "fetch" + and compiler.dialect.full_returning + ): + new_stmt = new_stmt.returning(*mapper.primary_key) + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) return self @classmethod - def _do_post_synchronize_evaluate(cls, session, update_options): + def _do_post_synchronize_evaluate(cls, session, result, update_options): states = set() evaluated_keys = list(update_options._value_evaluators.keys()) + values = update_options._resolved_keys_as_propnames + attrib = set(k for k, v in values) for obj in update_options._matched_objects: state, dict_ = ( @@ -1941,9 +2019,15 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): attributes.instance_dict(obj), ) - assert ( - state.identity_token == update_options._refresh_identity_token - ) + # the evaluated states were gathered across all identity tokens. + # however the post_sync events are called per identity token, + # so filter. + if ( + update_options._refresh_identity_token is not None + and state.identity_token + != update_options._refresh_identity_token + ): + continue # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) @@ -1954,38 +2038,64 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): state._commit(dict_, list(to_evaluate)) - # expire attributes with pending changes - # (there was no autoflush, so they are overwritten) - state._expire_attributes( - dict_, set(evaluated_keys).difference(to_evaluate) - ) + to_expire = attrib.intersection(dict_).difference(to_evaluate) + if to_expire: + state._expire_attributes(dict_, to_expire) + states.add(state) session._register_altered(states) @classmethod - def _do_post_synchronize_fetch(cls, session, update_options): + def _do_post_synchronize_fetch(cls, session, result, update_options): target_mapper = update_options._subject_mapper - states = set( - [ - attributes.instance_state(session.identity_map[identity_key]) - for identity_key in [ - target_mapper.identity_key_from_primary_key( - list(primary_key), - identity_token=update_options._refresh_identity_token, - ) - for primary_key in update_options._matched_rows + states = set() + evaluated_keys = list(update_options._value_evaluators.keys()) + + if result.returns_rows: + matched_rows = [ + tuple(row) + (update_options._refresh_identity_token,) + for row in result.all() + ] + else: + matched_rows = update_options._matched_rows + + objs = [ + session.identity_map[identity_key] + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key), identity_token=identity_token, + ) + for primary_key, identity_token in [ + (row[0:-1], row[-1]) for row in matched_rows ] - if identity_key in session.identity_map + if update_options._refresh_identity_token is None + or identity_token == update_options._refresh_identity_token ] - ) + if identity_key in session.identity_map + ] values = update_options._resolved_keys_as_propnames attrib = set(k for k, v in values) - for state in states: - to_expire = attrib.intersection(state.dict) + + for obj in objs: + state, dict_ = ( + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + + to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: + dict_[key] = update_options._value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) + + state._commit(dict_, list(to_evaluate)) + + to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: - session._expire_state(state, to_expire) + state._expire_attributes(dict_, to_expire) + + states.add(state) session._register_altered(states) @@ -1995,14 +2105,24 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) - self.mapper = statement.table._annotations.get("parentmapper", None) + self.mapper = mapper = statement.table._annotations.get( + "parentmapper", None + ) + + if ( + mapper + and statement._annotations.get("synchronize_session", None) + == "fetch" + and compiler.dialect.full_returning + ): + statement = statement.returning(*mapper.primary_key) DeleteDMLState.__init__(self, statement, compiler, **kw) return self @classmethod - def _do_post_synchronize_evaluate(cls, session, update_options): + def _do_post_synchronize_evaluate(cls, session, result, update_options): session._remove_newly_deleted( [ @@ -2012,15 +2132,25 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): ) @classmethod - def _do_post_synchronize_fetch(cls, session, update_options): + def _do_post_synchronize_fetch(cls, session, result, update_options): target_mapper = update_options._subject_mapper - for primary_key in update_options._matched_rows: + if result.returns_rows: + matched_rows = [ + tuple(row) + (update_options._refresh_identity_token,) + for row in result.all() + ] + else: + matched_rows = update_options._matched_rows + + for row in matched_rows: + primary_key = row[0:-1] + identity_token = row[-1] + # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key), - identity_token=update_options._refresh_identity_token, + list(primary_key), identity_token=identity_token, ) if identity_key in session.identity_map: session._remove_newly_deleted( |
