summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py250
1 files changed, 190 insertions, 60 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 8393eaf74..bd8efe77f 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -28,6 +28,7 @@ from .. import exc as sa_exc
from .. import future
from .. import sql
from .. import util
+from ..engine import result as _result
from ..future import select as future_select
from ..sql import coercions
from ..sql import expression
@@ -1672,8 +1673,17 @@ class BulkUDCompileState(CompileState):
@classmethod
def orm_pre_session_exec(
- cls, session, statement, params, execution_options, bind_arguments
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_reentrant_invoke,
):
+ if is_reentrant_invoke:
+ return statement, execution_options
+
sync = execution_options.get("synchronize_session", None)
if sync is None:
sync = statement._execution_options.get(
@@ -1706,6 +1716,17 @@ class BulkUDCompileState(CompileState):
if update_options._autoflush:
session._autoflush()
+ statement = statement._annotate(
+ {"synchronize_session": update_options._synchronize_session}
+ )
+
+ # this stage of the execution is called before the do_orm_execute event
+ # hook. meaning for an extension like horizontal sharding, this step
+ # happens before the extension splits out into multiple backends and
+ # runs only once. if we do pre_sync_fetch, we execute a SELECT
+ # statement, which the horizontal sharding extension splits amongst the
+ # shards and combines the results together.
+
if update_options._synchronize_session == "evaluate":
update_options = cls._do_pre_synchronize_evaluate(
session,
@@ -1725,19 +1746,31 @@ class BulkUDCompileState(CompileState):
update_options,
)
- return util.immutabledict(execution_options).union(
- dict(_sa_orm_update_options=update_options)
+ return (
+ statement,
+ util.immutabledict(execution_options).union(
+ dict(_sa_orm_update_options=update_options)
+ ),
)
@classmethod
def orm_setup_cursor_result(
cls, session, statement, execution_options, bind_arguments, result
):
+
+ # this stage of the execution is called after the
+ # do_orm_execute event hook. meaning for an extension like
+ # horizontal sharding, this step happens *within* the horizontal
+ # sharding event handler which calls session.execute() re-entrantly
+ # and will occur for each backend individually.
+ # the sharding extension then returns its own merged result from the
+ # individual ones we return here.
+
update_options = execution_options["_sa_orm_update_options"]
if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(session, update_options)
+ cls._do_post_synchronize_evaluate(session, result, update_options)
elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(session, update_options)
+ cls._do_post_synchronize_fetch(session, result, update_options)
return result
@@ -1767,18 +1800,6 @@ class BulkUDCompileState(CompileState):
def eval_condition(obj):
return True
- # TODO: something more robust for this conditional
- if statement.__visit_name__ == "update":
- resolved_values = cls._get_resolved_values(mapper, statement)
- value_evaluators = {}
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- for key, value in resolved_keys_as_propnames:
- value_evaluators[key] = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
-
except evaluator.UnevaluatableError as err:
util.raise_(
sa_exc.InvalidRequestError(
@@ -1789,13 +1810,35 @@ class BulkUDCompileState(CompileState):
from_=err,
)
- # TODO: detect when the where clause is a trivial primary key match
+ if statement.__visit_name__ == "update":
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ value_evaluators = {}
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ # TODO: detect when the where clause is a trivial primary key match.
matched_objects = [
obj
for (cls, pk, identity_token,), obj in session.identity_map.items()
if issubclass(cls, target_cls)
and eval_condition(obj)
- and identity_token == update_options._refresh_identity_token
+ and (
+ update_options._refresh_identity_token is None
+ # TODO: coverage for the case where horiziontal sharding
+ # invokes an update() or delete() given an explicit identity
+ # token up front
+ or identity_token == update_options._refresh_identity_token
+ )
]
return update_options + {
"_matched_objects": matched_objects,
@@ -1868,29 +1911,56 @@ class BulkUDCompileState(CompileState):
):
mapper = update_options._subject_mapper
- if mapper:
- primary_table = mapper.local_table
- else:
- primary_table = statement._raw_columns[0]
-
- # note this creates a Select() *without* the ORM plugin.
- # we don't want that here.
- select_stmt = future_select(*primary_table.primary_key)
+ select_stmt = future_select(
+ *(mapper.primary_key + (mapper.select_identity_token,))
+ )
select_stmt._where_criteria = statement._where_criteria
- matched_rows = session.execute(
- select_stmt, params, execution_options, bind_arguments
- ).fetchall()
+ def skip_for_full_returning(orm_context):
+ bind = orm_context.session.get_bind(**orm_context.bind_arguments)
+ if bind.dialect.full_returning:
+ return _result.null_result()
+ else:
+ return None
+
+ result = session.execute(
+ select_stmt,
+ params,
+ execution_options,
+ bind_arguments,
+ _add_event=skip_for_full_returning,
+ )
+ matched_rows = result.fetchall()
+
+ value_evaluators = _EMPTY_DICT
if statement.__visit_name__ == "update":
+ target_cls = mapper.class_
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
resolved_values = cls._get_resolved_values(mapper, statement)
resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
mapper, resolved_values
)
+
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ value_evaluators = {}
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
else:
resolved_keys_as_propnames = _EMPTY_DICT
return update_options + {
+ "_value_evaluators": value_evaluators,
"_matched_rows": matched_rows,
"_resolved_keys_as_propnames": resolved_keys_as_propnames,
}
@@ -1925,15 +1995,23 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
elif statement._values:
new_stmt._values = self._resolved_values
+ if (
+ statement._annotations.get("synchronize_session", None) == "fetch"
+ and compiler.dialect.full_returning
+ ):
+ new_stmt = new_stmt.returning(*mapper.primary_key)
+
UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
return self
@classmethod
- def _do_post_synchronize_evaluate(cls, session, update_options):
+ def _do_post_synchronize_evaluate(cls, session, result, update_options):
states = set()
evaluated_keys = list(update_options._value_evaluators.keys())
+ values = update_options._resolved_keys_as_propnames
+ attrib = set(k for k, v in values)
for obj in update_options._matched_objects:
state, dict_ = (
@@ -1941,9 +2019,15 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
attributes.instance_dict(obj),
)
- assert (
- state.identity_token == update_options._refresh_identity_token
- )
+ # the evaluated states were gathered across all identity tokens.
+ # however the post_sync events are called per identity token,
+ # so filter.
+ if (
+ update_options._refresh_identity_token is not None
+ and state.identity_token
+ != update_options._refresh_identity_token
+ ):
+ continue
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
@@ -1954,38 +2038,64 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
state._commit(dict_, list(to_evaluate))
- # expire attributes with pending changes
- # (there was no autoflush, so they are overwritten)
- state._expire_attributes(
- dict_, set(evaluated_keys).difference(to_evaluate)
- )
+ to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ if to_expire:
+ state._expire_attributes(dict_, to_expire)
+
states.add(state)
session._register_altered(states)
@classmethod
- def _do_post_synchronize_fetch(cls, session, update_options):
+ def _do_post_synchronize_fetch(cls, session, result, update_options):
target_mapper = update_options._subject_mapper
- states = set(
- [
- attributes.instance_state(session.identity_map[identity_key])
- for identity_key in [
- target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=update_options._refresh_identity_token,
- )
- for primary_key in update_options._matched_rows
+ states = set()
+ evaluated_keys = list(update_options._value_evaluators.keys())
+
+ if result.returns_rows:
+ matched_rows = [
+ tuple(row) + (update_options._refresh_identity_token,)
+ for row in result.all()
+ ]
+ else:
+ matched_rows = update_options._matched_rows
+
+ objs = [
+ session.identity_map[identity_key]
+ for identity_key in [
+ target_mapper.identity_key_from_primary_key(
+ list(primary_key), identity_token=identity_token,
+ )
+ for primary_key, identity_token in [
+ (row[0:-1], row[-1]) for row in matched_rows
]
- if identity_key in session.identity_map
+ if update_options._refresh_identity_token is None
+ or identity_token == update_options._refresh_identity_token
]
- )
+ if identity_key in session.identity_map
+ ]
values = update_options._resolved_keys_as_propnames
attrib = set(k for k, v in values)
- for state in states:
- to_expire = attrib.intersection(state.dict)
+
+ for obj in objs:
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
+ for key in to_evaluate:
+ dict_[key] = update_options._value_evaluators[key](obj)
+ state.manager.dispatch.refresh(state, None, to_evaluate)
+
+ state._commit(dict_, list(to_evaluate))
+
+ to_expire = attrib.intersection(dict_).difference(to_evaluate)
if to_expire:
- session._expire_state(state, to_expire)
+ state._expire_attributes(dict_, to_expire)
+
+ states.add(state)
session._register_altered(states)
@@ -1995,14 +2105,24 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
- self.mapper = statement.table._annotations.get("parentmapper", None)
+ self.mapper = mapper = statement.table._annotations.get(
+ "parentmapper", None
+ )
+
+ if (
+ mapper
+ and statement._annotations.get("synchronize_session", None)
+ == "fetch"
+ and compiler.dialect.full_returning
+ ):
+ statement = statement.returning(*mapper.primary_key)
DeleteDMLState.__init__(self, statement, compiler, **kw)
return self
@classmethod
- def _do_post_synchronize_evaluate(cls, session, update_options):
+ def _do_post_synchronize_evaluate(cls, session, result, update_options):
session._remove_newly_deleted(
[
@@ -2012,15 +2132,25 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
)
@classmethod
- def _do_post_synchronize_fetch(cls, session, update_options):
+ def _do_post_synchronize_fetch(cls, session, result, update_options):
target_mapper = update_options._subject_mapper
- for primary_key in update_options._matched_rows:
+ if result.returns_rows:
+ matched_rows = [
+ tuple(row) + (update_options._refresh_identity_token,)
+ for row in result.all()
+ ]
+ else:
+ matched_rows = update_options._matched_rows
+
+ for row in matched_rows:
+ primary_key = row[0:-1]
+ identity_token = row[-1]
+
# TODO: inline this and call remove_newly_deleted
# once
identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=update_options._refresh_identity_token,
+ list(primary_key), identity_token=identity_token,
)
if identity_key in session.identity_map:
session._remove_newly_deleted(