summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py385
1 files changed, 194 insertions, 191 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 6669efc56..295d4a3d0 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -41,18 +41,18 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
return
states_to_insert, states_to_update = _organize_states_for_save(
- base_mapper,
- states,
- uowtransaction)
+ base_mapper,
+ states,
+ uowtransaction)
cached_connections = _cached_connection_dict(base_mapper)
for table, mapper in base_mapper._sorted_tables.items():
insert = _collect_insert_commands(base_mapper, uowtransaction,
- table, states_to_insert)
+ table, states_to_insert)
update = _collect_update_commands(base_mapper, uowtransaction,
- table, states_to_update)
+ table, states_to_update)
if update:
_emit_update_statements(base_mapper, uowtransaction,
@@ -65,7 +65,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
mapper, table, insert)
_finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update)
+ states_to_insert, states_to_update)
def post_update(base_mapper, states, uowtransaction, post_update_cols):
@@ -76,18 +76,18 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
cached_connections = _cached_connection_dict(base_mapper)
states_to_update = _organize_states_for_post_update(
- base_mapper,
- states, uowtransaction)
+ base_mapper,
+ states, uowtransaction)
for table, mapper in base_mapper._sorted_tables.items():
update = _collect_post_update_commands(base_mapper, uowtransaction,
- table, states_to_update,
- post_update_cols)
+ table, states_to_update,
+ post_update_cols)
if update:
_emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ cached_connections,
+ mapper, table, update)
def delete_obj(base_mapper, states, uowtransaction):
@@ -101,23 +101,23 @@ def delete_obj(base_mapper, states, uowtransaction):
cached_connections = _cached_connection_dict(base_mapper)
states_to_delete = _organize_states_for_delete(
- base_mapper,
- states,
- uowtransaction)
+ base_mapper,
+ states,
+ uowtransaction)
table_to_mapper = base_mapper._sorted_tables
for table in reversed(list(table_to_mapper.keys())):
delete = _collect_delete_commands(base_mapper, uowtransaction,
- table, states_to_delete)
+ table, states_to_delete)
mapper = table_to_mapper[table]
_emit_delete_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, delete)
+ cached_connections, mapper, table, delete)
for state, state_dict, mapper, has_identity, connection \
- in states_to_delete:
+ in states_to_delete:
mapper.dispatch.after_delete(mapper, connection, state)
@@ -137,8 +137,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
states_to_update = []
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction,
+ states):
has_identity = bool(state.key)
instance_key = state.key or mapper._identity_key_from_state(state)
@@ -183,19 +183,19 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
if not has_identity and not row_switch:
states_to_insert.append(
(state, dict_, mapper, connection,
- has_identity, instance_key, row_switch)
+ has_identity, instance_key, row_switch)
)
else:
states_to_update.append(
(state, dict_, mapper, connection,
- has_identity, instance_key, row_switch)
+ has_identity, instance_key, row_switch)
)
return states_to_insert, states_to_update
def _organize_states_for_post_update(base_mapper, states,
- uowtransaction):
+ uowtransaction):
"""Make an initial pass across a set of states for UPDATE
corresponding to post_update.
@@ -205,7 +205,7 @@ def _organize_states_for_post_update(base_mapper, states,
"""
return list(_connections_for_states(base_mapper, uowtransaction,
- states))
+ states))
def _organize_states_for_delete(base_mapper, states, uowtransaction):
@@ -219,25 +219,25 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
states_to_delete = []
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction,
+ states):
mapper.dispatch.before_delete(mapper, connection, state)
states_to_delete.append((state, dict_, mapper,
- bool(state.key), connection))
+ bool(state.key), connection))
return states_to_delete
def _collect_insert_commands(base_mapper, uowtransaction, table,
- states_to_insert):
+ states_to_insert):
"""Identify sets of values to use in INSERT statements for a
list of states.
"""
insert = []
for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_insert:
+ instance_key, row_switch in states_to_insert:
if table not in mapper._pks_by_table:
continue
@@ -250,7 +250,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
has_all_defaults = True
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col and \
- mapper.version_id_generator is not False:
+ mapper.version_id_generator is not False:
val = mapper.version_id_generator(None)
params[col.key] = val
else:
@@ -263,10 +263,10 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
if col in pks:
has_all_pks = False
elif col.default is None and \
- col.server_default is None:
+ col.server_default is None:
params[col.key] = value
elif col.server_default is not None and \
- mapper.base_mapper.eager_defaults:
+ mapper.base_mapper.eager_defaults:
has_all_defaults = False
elif isinstance(value, sql.ClauseElement):
@@ -275,13 +275,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
params[col.key] = value
insert.append((state, state_dict, params, mapper,
- connection, value_params, has_all_pks,
- has_all_defaults))
+ connection, value_params, has_all_pks,
+ has_all_defaults))
return insert
def _collect_update_commands(base_mapper, uowtransaction,
- table, states_to_update):
+ table, states_to_update):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -295,7 +295,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
update = []
for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_update:
+ instance_key, row_switch in states_to_update:
if table not in mapper._pks_by_table:
continue
@@ -309,10 +309,10 @@ def _collect_update_commands(base_mapper, uowtransaction,
if col is mapper.version_id_col:
params[col._label] = \
mapper._get_committed_state_attr_by_column(
- row_switch or state,
- row_switch and row_switch.dict
- or state_dict,
- col)
+ row_switch or state,
+ row_switch and row_switch.dict
+ or state_dict,
+ col)
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
@@ -331,19 +331,20 @@ def _collect_update_commands(base_mapper, uowtransaction,
# in a different table than the one
# where the version_id_col is.
for prop in mapper._columntoproperty.values():
- history = state.manager[prop.key].impl.get_history(
+ history = (
+ state.manager[prop.key].impl.get_history(
state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE)
+ attributes.PASSIVE_NO_INITIALIZE))
if history.added:
hasdata = True
else:
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, state_dict,
+ attributes.PASSIVE_NO_INITIALIZE)
if history.added:
if isinstance(history.added[0],
- sql.ClauseElement):
+ sql.ClauseElement):
value_params[col] = history.added[0]
else:
value = history.added[0]
@@ -351,13 +352,13 @@ def _collect_update_commands(base_mapper, uowtransaction,
if col in pks:
if history.deleted and \
- not row_switch:
+ not row_switch:
# if passive_updates and sync detected
# this was a pk->pk sync, use the new
# value to locate the row, since the
# DB would already have set this
if ("pk_cascaded", state, col) in \
- uowtransaction.attributes:
+ uowtransaction.attributes:
value = history.added[0]
params[col._label] = value
else:
@@ -381,7 +382,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
hasdata = True
elif col in pks:
value = state.manager[prop.key].impl.get(
- state, state_dict)
+ state, state_dict)
if value is None:
hasnull = True
params[col._label] = value
@@ -389,16 +390,16 @@ def _collect_update_commands(base_mapper, uowtransaction,
if hasdata:
if hasnull:
raise orm_exc.FlushError(
- "Can't update table "
- "using NULL for primary "
- "key value")
+ "Can't update table "
+ "using NULL for primary "
+ "key value")
update.append((state, state_dict, params, mapper,
- connection, value_params))
+ connection, value_params))
return update
def _collect_post_update_commands(base_mapper, uowtransaction, table,
- states_to_update, post_update_cols):
+ states_to_update, post_update_cols):
"""Identify sets of values to use in UPDATE statements for a
list of states within a post_update operation.
@@ -415,34 +416,34 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
for col in mapper._cols_by_table[table]:
if col in pks:
params[col._label] = \
- mapper._get_state_attr_by_column(
- state,
- state_dict, col)
+ mapper._get_state_attr_by_column(
+ state,
+ state_dict, col)
elif col in post_update_cols:
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, state_dict,
+ attributes.PASSIVE_NO_INITIALIZE)
if history.added:
value = history.added[0]
params[col.key] = value
hasdata = True
if hasdata:
update.append((state, state_dict, params, mapper,
- connection))
+ connection))
return update
def _collect_delete_commands(base_mapper, uowtransaction, table,
- states_to_delete):
+ states_to_delete):
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
delete = util.defaultdict(list)
for state, state_dict, mapper, has_identity, connection \
- in states_to_delete:
+ in states_to_delete:
if not has_identity or table not in mapper._pks_by_table:
continue
@@ -450,43 +451,44 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
delete[connection].append(params)
for col in mapper._pks_by_table[table]:
params[col.key] = \
- value = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict, col)
+ value = \
+ mapper._get_committed_state_attr_by_column(
+ state, state_dict, col)
if value is None:
raise orm_exc.FlushError(
- "Can't delete from table "
- "using NULL for primary "
- "key value")
+ "Can't delete from table "
+ "using NULL for primary "
+ "key value")
if mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col):
+ table.c.contains_column(mapper.version_id_col):
params[mapper.version_id_col.key] = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict,
- mapper.version_id_col)
+ mapper._get_committed_state_attr_by_column(
+ state, state_dict,
+ mapper.version_id_col)
return delete
def _emit_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update):
+ cached_connections, mapper, table, update):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
needs_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
+ table.c.contains_column(mapper.version_id_col)
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ type_=col.type))
if needs_version_id:
- clause.clauses.append(mapper.version_id_col ==\
- sql.bindparam(mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
+ clause.clauses.append(
+ mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type))
stmt = table.update(clause)
if mapper.base_mapper.eager_defaults:
@@ -500,43 +502,43 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows = 0
for state, state_dict, params, mapper, \
- connection, value_params in update:
+ connection, value_params in update:
if value_params:
c = connection.execute(
- statement.values(value_params),
- params)
+ statement.values(value_params),
+ params)
else:
c = cached_connections[connection].\
- execute(statement, params)
+ execute(statement, params)
_postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params)
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
rows += c.rowcount
if connection.dialect.supports_sane_rowcount:
if rows != len(update):
raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(update), rows))
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched." %
+ (table.description, len(update), rows))
elif needs_version_id:
util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description,
- stacklevel=12)
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert):
+ cached_connections, mapper, table, insert):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
@@ -544,37 +546,37 @@ def _emit_insert_statements(base_mapper, uowtransaction,
for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
records in groupby(insert,
- lambda rec: (rec[4],
- list(rec[2].keys()),
- bool(rec[5]),
- rec[6], rec[7])
- ):
+ lambda rec: (rec[4],
+ list(rec[2].keys()),
+ bool(rec[5]),
+ rec[6], rec[7])
+ ):
if \
- (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not connection.dialect.implicit_returning
- ) and has_all_pks and not hasvalue:
+ (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ ) and has_all_pks and not hasvalue:
records = list(records)
multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
- execute(statement, multiparams)
+ execute(statement, multiparams)
for (state, state_dict, params, mapper_rec,
conn, value_params, has_all_pks, has_all_defaults), \
last_inserted_params in \
zip(records, c.context.compiled_parameters):
_postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params)
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params)
else:
if not has_all_defaults and base_mapper.eager_defaults:
@@ -583,45 +585,45 @@ def _emit_insert_statements(base_mapper, uowtransaction,
statement = statement.return_defaults(mapper.version_id_col)
for state, state_dict, params, mapper_rec, \
- connection, value_params, \
- has_all_pks, has_all_defaults in records:
+ connection, value_params, \
+ has_all_pks, has_all_defaults in records:
if value_params:
result = connection.execute(
- statement.values(value_params),
- params)
+ statement.values(value_params),
+ params)
else:
result = cached_connections[connection].\
- execute(statement, params)
+ execute(statement, params)
primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
for pk, col in zip(primary_key,
- mapper._pks_by_table[table]):
+ mapper._pks_by_table[table]):
prop = mapper_rec._columntoproperty[col]
if state_dict.get(prop.key) is None:
# TODO: would rather say:
- #state_dict[prop.key] = pk
+ # state_dict[prop.key] = pk
mapper_rec._set_state_attr_by_column(
- state,
- state_dict,
- col, pk)
+ state,
+ state_dict,
+ col, pk)
_postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- result,
- result.context.compiled_parameters[0],
- value_params)
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result,
+ result.context.compiled_parameters[0],
+ value_params)
def _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update):
+ cached_connections, mapper, table, update):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
@@ -630,7 +632,7 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ type_=col.type))
return table.update(clause)
@@ -645,13 +647,13 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
):
connection = key[0]
multiparams = [params for state, state_dict,
- params, mapper, conn in grouper]
+ params, mapper, conn in grouper]
cached_connections[connection].\
- execute(statement, multiparams)
+ execute(statement, multiparams)
def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
- mapper, table, delete):
+ mapper, table, delete):
"""Emit DELETE statements corresponding to value lists collected
by _collect_delete_commands()."""
@@ -662,14 +664,14 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(
- col == sql.bindparam(col.key, type_=col.type))
+ col == sql.bindparam(col.key, type_=col.type))
if need_version_id:
clause.clauses.append(
mapper.version_id_col ==
sql.bindparam(
- mapper.version_id_col.key,
- type_=mapper.version_id_col.type
+ mapper.version_id_col.key,
+ type_=mapper.version_id_col.type
)
)
@@ -710,7 +712,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
connection.execute(statement, del_objects)
if base_mapper.confirm_deleted_rows and \
- rows_matched > -1 and expected != rows_matched:
+ rows_matched > -1 and expected != rows_matched:
if only_warn:
util.warn(
"DELETE statement on table '%s' expected to "
@@ -728,15 +730,16 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
(table.description, expected, rows_matched)
)
+
def _finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update):
+ states_to_insert, states_to_update):
"""finalize state on states that have been inserted or updated,
including calling after_insert/after_update events.
"""
for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_insert + \
- states_to_update:
+ instance_key, row_switch in states_to_insert + \
+ states_to_update:
if mapper._readonly_props:
readonly = state.unmodified_intersection(
@@ -754,7 +757,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
if base_mapper.eager_defaults:
toload_now.extend(state._unloaded_non_object)
elif mapper.version_id_col is not None and \
- mapper.version_id_generator is False:
+ mapper.version_id_generator is False:
prop = mapper._columntoproperty[mapper.version_id_col]
if prop.key in state.unloaded:
toload_now.extend([prop.key])
@@ -774,7 +777,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
def _postfetch(mapper, uowtransaction, table,
- state, dict_, result, params, value_params):
+ state, dict_, result, params, value_params):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
@@ -800,19 +803,19 @@ def _postfetch(mapper, uowtransaction, table,
if postfetch_cols:
state._expire_attributes(state.dict,
- [mapper._columntoproperty[c].key
- for c in postfetch_cols if c in
- mapper._columntoproperty]
- )
+ [mapper._columntoproperty[c].key
+ for c in postfetch_cols if c in
+ mapper._columntoproperty]
+ )
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
for m, equated_pairs in mapper._table_to_equated[table]:
sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates)
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates)
def _connections_for_states(base_mapper, uowtransaction, states):
@@ -828,7 +831,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
# to use for update
if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.session.connection_callable
+ uowtransaction.session.connection_callable
else:
connection = None
connection_callable = None
@@ -838,7 +841,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
connection = connection_callable(base_mapper, state.obj())
elif not connection:
connection = uowtransaction.transaction.connection(
- base_mapper)
+ base_mapper)
mapper = _state_mapper(state)
@@ -849,8 +852,8 @@ def _cached_connection_dict(base_mapper):
# dictionary of connection->connection_with_cache_options.
return util.PopulateDict(
lambda conn: conn.execution_options(
- compiled_cache=base_mapper._compiled_cache
- ))
+ compiled_cache=base_mapper._compiled_cache
+ ))
def _sort_states(states):
@@ -858,7 +861,7 @@ def _sort_states(states):
persistent = set(s for s in pending if s.key is not None)
pending.difference_update(persistent)
return sorted(pending, key=operator.attrgetter("insert_order")) + \
- sorted(persistent, key=lambda q: q.key[1])
+ sorted(persistent, key=lambda q: q.key[1])
class BulkUD(object):
@@ -877,9 +880,9 @@ class BulkUD(object):
klass = lookup[synchronize_session]
except KeyError:
raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are %s" % (", ".join(sorted(repr(x)
- for x in lookup))))
+ "Valid strategies for session synchronization "
+ "are %s" % (", ".join(sorted(repr(x)
+ for x in lookup))))
else:
return klass(*arg)
@@ -894,12 +897,12 @@ class BulkUD(object):
query = self.query
self.context = context = query._compile_context()
if len(context.statement.froms) != 1 or \
- not isinstance(context.statement.froms[0], schema.Table):
+ not isinstance(context.statement.froms[0], schema.Table):
self.primary_table = query._only_entity_zero(
- "This operation requires only one Table or "
- "entity be specified as the target."
- ).mapper.local_table
+ "This operation requires only one Table or "
+ "entity be specified as the target."
+ ).mapper.local_table
else:
self.primary_table = context.statement.froms[0]
@@ -929,7 +932,7 @@ class BulkEvaluate(BulkUD):
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
if query.whereclause is not None:
eval_condition = evaluator_compiler.process(
- query.whereclause)
+ query.whereclause)
else:
def eval_condition(obj):
return True
@@ -938,16 +941,16 @@ class BulkEvaluate(BulkUD):
except evaluator.UnevaluatableError:
raise sa_exc.InvalidRequestError(
- "Could not evaluate current criteria in Python. "
- "Specify 'fetch' or False for the "
- "synchronize_session parameter.")
+ "Could not evaluate current criteria in Python. "
+ "Specify 'fetch' or False for the "
+ "synchronize_session parameter.")
- #TODO: detect when the where clause is a trivial primary key match
+ # TODO: detect when the where clause is a trivial primary key match
self.matched_objects = [
- obj for (cls, pk), obj in
- query.session.identity_map.items()
- if issubclass(cls, target_cls) and
- eval_condition(obj)]
+ obj for (cls, pk), obj in
+ query.session.identity_map.items()
+ if issubclass(cls, target_cls) and
+ eval_condition(obj)]
class BulkFetch(BulkUD):
@@ -957,10 +960,10 @@ class BulkFetch(BulkUD):
query = self.query
session = query.session
select_stmt = self.context.statement.with_only_columns(
- self.primary_table.primary_key)
+ self.primary_table.primary_key)
self.matched_rows = session.execute(
- select_stmt,
- params=query._params).fetchall()
+ select_stmt,
+ params=query._params).fetchall()
class BulkUpdate(BulkUD):
@@ -981,10 +984,10 @@ class BulkUpdate(BulkUD):
def _do_exec(self):
update_stmt = sql.update(self.primary_table,
- self.context.whereclause, self.values)
+ self.context.whereclause, self.values)
self.result = self.query.session.execute(
- update_stmt, params=self.query._params)
+ update_stmt, params=self.query._params)
self.rowcount = self.result.rowcount
def _do_post(self):
@@ -1009,10 +1012,10 @@ class BulkDelete(BulkUD):
def _do_exec(self):
delete_stmt = sql.delete(self.primary_table,
- self.context.whereclause)
+ self.context.whereclause)
self.result = self.query.session.execute(delete_stmt,
- params=self.query._params)
+ params=self.query._params)
self.rowcount = self.result.rowcount
def _do_post(self):
@@ -1029,7 +1032,7 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
for key, value in self.values.items():
key = _attr_as_key(key)
self.value_evaluators[key] = evaluator_compiler.process(
- expression._literal_as_binds(value))
+ expression._literal_as_binds(value))
def _do_post_synchronize(self):
session = self.query.session
@@ -1037,11 +1040,11 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
evaluated_keys = list(self.value_evaluators.keys())
for obj in self.matched_objects:
state, dict_ = attributes.instance_state(obj),\
- attributes.instance_dict(obj)
+ attributes.instance_dict(obj)
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(
- evaluated_keys)
+ evaluated_keys)
for key in to_evaluate:
dict_[key] = self.value_evaluators[key](obj)
@@ -1050,8 +1053,8 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
# expire attributes with pending changes
# (there was no autoflush, so they are overwritten)
state._expire_attributes(dict_,
- set(evaluated_keys).
- difference(to_evaluate))
+ set(evaluated_keys).
+ difference(to_evaluate))
states.add(state)
session._register_altered(states)
@@ -1062,8 +1065,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete):
def _do_post_synchronize(self):
self.query.session._remove_newly_deleted(
- [attributes.instance_state(obj)
- for obj in self.matched_objects])
+ [attributes.instance_state(obj)
+ for obj in self.matched_objects])
class BulkUpdateFetch(BulkFetch, BulkUpdate):
@@ -1078,7 +1081,7 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate):
attributes.instance_state(session.identity_map[identity_key])
for identity_key in [
target_mapper.identity_key_from_primary_key(
- list(primary_key))
+ list(primary_key))
for primary_key in self.matched_rows
]
if identity_key in session.identity_map
@@ -1100,7 +1103,7 @@ class BulkDeleteFetch(BulkFetch, BulkDelete):
# TODO: inline this and call remove_newly_deleted
# once
identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key))
+ list(primary_key))
if identity_key in session.identity_map:
session._remove_newly_deleted(
[attributes.instance_state(