diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 385 |
1 files changed, 194 insertions, 191 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 6669efc56..295d4a3d0 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -41,18 +41,18 @@ def save_obj(base_mapper, states, uowtransaction, single=False): return states_to_insert, states_to_update = _organize_states_for_save( - base_mapper, - states, - uowtransaction) + base_mapper, + states, + uowtransaction) cached_connections = _cached_connection_dict(base_mapper) for table, mapper in base_mapper._sorted_tables.items(): insert = _collect_insert_commands(base_mapper, uowtransaction, - table, states_to_insert) + table, states_to_insert) update = _collect_update_commands(base_mapper, uowtransaction, - table, states_to_update) + table, states_to_update) if update: _emit_update_statements(base_mapper, uowtransaction, @@ -65,7 +65,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False): mapper, table, insert) _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update) + states_to_insert, states_to_update) def post_update(base_mapper, states, uowtransaction, post_update_cols): @@ -76,18 +76,18 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): cached_connections = _cached_connection_dict(base_mapper) states_to_update = _organize_states_for_post_update( - base_mapper, - states, uowtransaction) + base_mapper, + states, uowtransaction) for table, mapper in base_mapper._sorted_tables.items(): update = _collect_post_update_commands(base_mapper, uowtransaction, - table, states_to_update, - post_update_cols) + table, states_to_update, + post_update_cols) if update: _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + cached_connections, + mapper, table, update) def delete_obj(base_mapper, states, uowtransaction): @@ -101,23 +101,23 @@ def delete_obj(base_mapper, states, uowtransaction): cached_connections = _cached_connection_dict(base_mapper) states_to_delete = _organize_states_for_delete( - base_mapper, - states, - uowtransaction) + base_mapper, + states, + uowtransaction) table_to_mapper = base_mapper._sorted_tables for table in reversed(list(table_to_mapper.keys())): delete = _collect_delete_commands(base_mapper, uowtransaction, - table, states_to_delete) + table, states_to_delete) mapper = table_to_mapper[table] _emit_delete_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, delete) + cached_connections, mapper, table, delete) for state, state_dict, mapper, has_identity, connection \ - in states_to_delete: + in states_to_delete: mapper.dispatch.after_delete(mapper, connection, state) @@ -137,8 +137,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): states_to_update = [] for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, + states): has_identity = bool(state.key) instance_key = state.key or mapper._identity_key_from_state(state) @@ -183,19 +183,19 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): if not has_identity and not row_switch: states_to_insert.append( (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) + has_identity, instance_key, row_switch) ) else: states_to_update.append( (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) + has_identity, instance_key, row_switch) ) return states_to_insert, states_to_update def _organize_states_for_post_update(base_mapper, states, - uowtransaction): + uowtransaction): """Make an initial pass across a set of states for UPDATE corresponding to post_update. @@ -205,7 +205,7 @@ def _organize_states_for_post_update(base_mapper, states, """ return list(_connections_for_states(base_mapper, uowtransaction, - states)) + states)) def _organize_states_for_delete(base_mapper, states, uowtransaction): @@ -219,25 +219,25 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): states_to_delete = [] for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, + states): mapper.dispatch.before_delete(mapper, connection, state) states_to_delete.append((state, dict_, mapper, - bool(state.key), connection)) + bool(state.key), connection)) return states_to_delete def _collect_insert_commands(base_mapper, uowtransaction, table, - states_to_insert): + states_to_insert): """Identify sets of values to use in INSERT statements for a list of states. """ insert = [] for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert: + instance_key, row_switch in states_to_insert: if table not in mapper._pks_by_table: continue @@ -250,7 +250,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, has_all_defaults = True for col in mapper._cols_by_table[table]: if col is mapper.version_id_col and \ - mapper.version_id_generator is not False: + mapper.version_id_generator is not False: val = mapper.version_id_generator(None) params[col.key] = val else: @@ -263,10 +263,10 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, if col in pks: has_all_pks = False elif col.default is None and \ - col.server_default is None: + col.server_default is None: params[col.key] = value elif col.server_default is not None and \ - mapper.base_mapper.eager_defaults: + mapper.base_mapper.eager_defaults: has_all_defaults = False elif isinstance(value, sql.ClauseElement): @@ -275,13 +275,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, params[col.key] = value insert.append((state, state_dict, params, mapper, - connection, value_params, has_all_pks, - has_all_defaults)) + connection, value_params, has_all_pks, + has_all_defaults)) return insert def _collect_update_commands(base_mapper, uowtransaction, - table, states_to_update): + table, states_to_update): """Identify sets of values to use in UPDATE statements for a list of states. @@ -295,7 +295,7 @@ def _collect_update_commands(base_mapper, uowtransaction, update = [] for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_update: + instance_key, row_switch in states_to_update: if table not in mapper._pks_by_table: continue @@ -309,10 +309,10 @@ def _collect_update_commands(base_mapper, uowtransaction, if col is mapper.version_id_col: params[col._label] = \ mapper._get_committed_state_attr_by_column( - row_switch or state, - row_switch and row_switch.dict - or state_dict, - col) + row_switch or state, + row_switch and row_switch.dict + or state_dict, + col) prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( @@ -331,19 +331,20 @@ def _collect_update_commands(base_mapper, uowtransaction, # in a different table than the one # where the version_id_col is. for prop in mapper._columntoproperty.values(): - history = state.manager[prop.key].impl.get_history( + history = ( + state.manager[prop.key].impl.get_history( state, state_dict, - attributes.PASSIVE_NO_INITIALIZE) + attributes.PASSIVE_NO_INITIALIZE)) if history.added: hasdata = True else: prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE) + state, state_dict, + attributes.PASSIVE_NO_INITIALIZE) if history.added: if isinstance(history.added[0], - sql.ClauseElement): + sql.ClauseElement): value_params[col] = history.added[0] else: value = history.added[0] @@ -351,13 +352,13 @@ def _collect_update_commands(base_mapper, uowtransaction, if col in pks: if history.deleted and \ - not row_switch: + not row_switch: # if passive_updates and sync detected # this was a pk->pk sync, use the new # value to locate the row, since the # DB would already have set this if ("pk_cascaded", state, col) in \ - uowtransaction.attributes: + uowtransaction.attributes: value = history.added[0] params[col._label] = value else: @@ -381,7 +382,7 @@ def _collect_update_commands(base_mapper, uowtransaction, hasdata = True elif col in pks: value = state.manager[prop.key].impl.get( - state, state_dict) + state, state_dict) if value is None: hasnull = True params[col._label] = value @@ -389,16 +390,16 @@ def _collect_update_commands(base_mapper, uowtransaction, if hasdata: if hasnull: raise orm_exc.FlushError( - "Can't update table " - "using NULL for primary " - "key value") + "Can't update table " + "using NULL for primary " + "key value") update.append((state, state_dict, params, mapper, - connection, value_params)) + connection, value_params)) return update def _collect_post_update_commands(base_mapper, uowtransaction, table, - states_to_update, post_update_cols): + states_to_update, post_update_cols): """Identify sets of values to use in UPDATE statements for a list of states within a post_update operation. @@ -415,34 +416,34 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, for col in mapper._cols_by_table[table]: if col in pks: params[col._label] = \ - mapper._get_state_attr_by_column( - state, - state_dict, col) + mapper._get_state_attr_by_column( + state, + state_dict, col) elif col in post_update_cols: prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE) + state, state_dict, + attributes.PASSIVE_NO_INITIALIZE) if history.added: value = history.added[0] params[col.key] = value hasdata = True if hasdata: update.append((state, state_dict, params, mapper, - connection)) + connection)) return update def _collect_delete_commands(base_mapper, uowtransaction, table, - states_to_delete): + states_to_delete): """Identify values to use in DELETE statements for a list of states to be deleted.""" delete = util.defaultdict(list) for state, state_dict, mapper, has_identity, connection \ - in states_to_delete: + in states_to_delete: if not has_identity or table not in mapper._pks_by_table: continue @@ -450,43 +451,44 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, delete[connection].append(params) for col in mapper._pks_by_table[table]: params[col.key] = \ - value = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, col) + value = \ + mapper._get_committed_state_attr_by_column( + state, state_dict, col) if value is None: raise orm_exc.FlushError( - "Can't delete from table " - "using NULL for primary " - "key value") + "Can't delete from table " + "using NULL for primary " + "key value") if mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col): + table.c.contains_column(mapper.version_id_col): params[mapper.version_id_col.key] = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, - mapper.version_id_col) + mapper._get_committed_state_attr_by_column( + state, state_dict, + mapper.version_id_col) return delete def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): + cached_connections, mapper, table, update): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + table.c.contains_column(mapper.version_id_col) def update_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + type_=col.type)) if needs_version_id: - clause.clauses.append(mapper.version_id_col ==\ - sql.bindparam(mapper.version_id_col._label, - type_=mapper.version_id_col.type)) + clause.clauses.append( + mapper.version_id_col == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type)) stmt = table.update(clause) if mapper.base_mapper.eager_defaults: @@ -500,43 +502,43 @@ def _emit_update_statements(base_mapper, uowtransaction, rows = 0 for state, state_dict, params, mapper, \ - connection, value_params in update: + connection, value_params in update: if value_params: c = connection.execute( - statement.values(value_params), - params) + statement.values(value_params), + params) else: c = cached_connections[connection].\ - execute(statement, params) + execute(statement, params) _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) rows += c.rowcount if connection.dialect.supports_sane_rowcount: if rows != len(update): raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(update), rows)) + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(update), rows)) elif needs_version_id: util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert): + cached_connections, mapper, table, insert): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" @@ -544,37 +546,37 @@ def _emit_insert_statements(base_mapper, uowtransaction, for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ records in groupby(insert, - lambda rec: (rec[4], - list(rec[2].keys()), - bool(rec[5]), - rec[6], rec[7]) - ): + lambda rec: (rec[4], + list(rec[2].keys()), + bool(rec[5]), + rec[6], rec[7]) + ): if \ - ( - has_all_defaults - or not base_mapper.eager_defaults - or not connection.dialect.implicit_returning - ) and has_all_pks and not hasvalue: + ( + has_all_defaults + or not base_mapper.eager_defaults + or not connection.dialect.implicit_returning + ) and has_all_pks and not hasvalue: records = list(records) multiparams = [rec[2] for rec in records] c = cached_connections[connection].\ - execute(statement, multiparams) + execute(statement, multiparams) for (state, state_dict, params, mapper_rec, conn, value_params, has_all_pks, has_all_defaults), \ last_inserted_params in \ zip(records, c.context.compiled_parameters): _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -583,45 +585,45 @@ def _emit_insert_statements(base_mapper, uowtransaction, statement = statement.return_defaults(mapper.version_id_col) for state, state_dict, params, mapper_rec, \ - connection, value_params, \ - has_all_pks, has_all_defaults in records: + connection, value_params, \ + has_all_pks, has_all_defaults in records: if value_params: result = connection.execute( - statement.values(value_params), - params) + statement.values(value_params), + params) else: result = cached_connections[connection].\ - execute(statement, params) + execute(statement, params) primary_key = result.context.inserted_primary_key if primary_key is not None: # set primary key attributes for pk, col in zip(primary_key, - mapper._pks_by_table[table]): + mapper._pks_by_table[table]): prop = mapper_rec._columntoproperty[col] if state_dict.get(prop.key) is None: # TODO: would rather say: - #state_dict[prop.key] = pk + # state_dict[prop.key] = pk mapper_rec._set_state_attr_by_column( - state, - state_dict, - col, pk) + state, + state_dict, + col, pk) _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - result.context.compiled_parameters[0], - value_params) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params) def _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): + cached_connections, mapper, table, update): """Emit UPDATE statements corresponding to value lists collected by _collect_post_update_commands().""" @@ -630,7 +632,7 @@ def _emit_post_update_statements(base_mapper, uowtransaction, for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + type_=col.type)) return table.update(clause) @@ -645,13 +647,13 @@ def _emit_post_update_statements(base_mapper, uowtransaction, ): connection = key[0] multiparams = [params for state, state_dict, - params, mapper, conn in grouper] + params, mapper, conn in grouper] cached_connections[connection].\ - execute(statement, multiparams) + execute(statement, multiparams) def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, delete): + mapper, table, delete): """Emit DELETE statements corresponding to value lists collected by _collect_delete_commands().""" @@ -662,14 +664,14 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append( - col == sql.bindparam(col.key, type_=col.type)) + col == sql.bindparam(col.key, type_=col.type)) if need_version_id: clause.clauses.append( mapper.version_id_col == sql.bindparam( - mapper.version_id_col.key, - type_=mapper.version_id_col.type + mapper.version_id_col.key, + type_=mapper.version_id_col.type ) ) @@ -710,7 +712,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, connection.execute(statement, del_objects) if base_mapper.confirm_deleted_rows and \ - rows_matched > -1 and expected != rows_matched: + rows_matched > -1 and expected != rows_matched: if only_warn: util.warn( "DELETE statement on table '%s' expected to " @@ -728,15 +730,16 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, (table.description, expected, rows_matched) ) + def _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update): + states_to_insert, states_to_update): """finalize state on states that have been inserted or updated, including calling after_insert/after_update events. """ for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert + \ - states_to_update: + instance_key, row_switch in states_to_insert + \ + states_to_update: if mapper._readonly_props: readonly = state.unmodified_intersection( @@ -754,7 +757,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, if base_mapper.eager_defaults: toload_now.extend(state._unloaded_non_object) elif mapper.version_id_col is not None and \ - mapper.version_id_generator is False: + mapper.version_id_generator is False: prop = mapper._columntoproperty[mapper.version_id_col] if prop.key in state.unloaded: toload_now.extend([prop.key]) @@ -774,7 +777,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params): + state, dict_, result, params, value_params): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -800,19 +803,19 @@ def _postfetch(mapper, uowtransaction, table, if postfetch_cols: state._expire_attributes(state.dict, - [mapper._columntoproperty[c].key - for c in postfetch_cols if c in - mapper._columntoproperty] - ) + [mapper._columntoproperty[c].key + for c in postfetch_cols if c in + mapper._columntoproperty] + ) # synchronize newly inserted ids from one table to the next # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + equated_pairs, + uowtransaction, + mapper.passive_updates) def _connections_for_states(base_mapper, uowtransaction, states): @@ -828,7 +831,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): # to use for update if uowtransaction.session.connection_callable: connection_callable = \ - uowtransaction.session.connection_callable + uowtransaction.session.connection_callable else: connection = None connection_callable = None @@ -838,7 +841,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): connection = connection_callable(base_mapper, state.obj()) elif not connection: connection = uowtransaction.transaction.connection( - base_mapper) + base_mapper) mapper = _state_mapper(state) @@ -849,8 +852,8 @@ def _cached_connection_dict(base_mapper): # dictionary of connection->connection_with_cache_options. return util.PopulateDict( lambda conn: conn.execution_options( - compiled_cache=base_mapper._compiled_cache - )) + compiled_cache=base_mapper._compiled_cache + )) def _sort_states(states): @@ -858,7 +861,7 @@ def _sort_states(states): persistent = set(s for s in pending if s.key is not None) pending.difference_update(persistent) return sorted(pending, key=operator.attrgetter("insert_order")) + \ - sorted(persistent, key=lambda q: q.key[1]) + sorted(persistent, key=lambda q: q.key[1]) class BulkUD(object): @@ -877,9 +880,9 @@ class BulkUD(object): klass = lookup[synchronize_session] except KeyError: raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) - for x in lookup)))) + "Valid strategies for session synchronization " + "are %s" % (", ".join(sorted(repr(x) + for x in lookup)))) else: return klass(*arg) @@ -894,12 +897,12 @@ class BulkUD(object): query = self.query self.context = context = query._compile_context() if len(context.statement.froms) != 1 or \ - not isinstance(context.statement.froms[0], schema.Table): + not isinstance(context.statement.froms[0], schema.Table): self.primary_table = query._only_entity_zero( - "This operation requires only one Table or " - "entity be specified as the target." - ).mapper.local_table + "This operation requires only one Table or " + "entity be specified as the target." + ).mapper.local_table else: self.primary_table = context.statement.froms[0] @@ -929,7 +932,7 @@ class BulkEvaluate(BulkUD): evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) if query.whereclause is not None: eval_condition = evaluator_compiler.process( - query.whereclause) + query.whereclause) else: def eval_condition(obj): return True @@ -938,16 +941,16 @@ class BulkEvaluate(BulkUD): except evaluator.UnevaluatableError: raise sa_exc.InvalidRequestError( - "Could not evaluate current criteria in Python. " - "Specify 'fetch' or False for the " - "synchronize_session parameter.") + "Could not evaluate current criteria in Python. " + "Specify 'fetch' or False for the " + "synchronize_session parameter.") - #TODO: detect when the where clause is a trivial primary key match + # TODO: detect when the where clause is a trivial primary key match self.matched_objects = [ - obj for (cls, pk), obj in - query.session.identity_map.items() - if issubclass(cls, target_cls) and - eval_condition(obj)] + obj for (cls, pk), obj in + query.session.identity_map.items() + if issubclass(cls, target_cls) and + eval_condition(obj)] class BulkFetch(BulkUD): @@ -957,10 +960,10 @@ class BulkFetch(BulkUD): query = self.query session = query.session select_stmt = self.context.statement.with_only_columns( - self.primary_table.primary_key) + self.primary_table.primary_key) self.matched_rows = session.execute( - select_stmt, - params=query._params).fetchall() + select_stmt, + params=query._params).fetchall() class BulkUpdate(BulkUD): @@ -981,10 +984,10 @@ class BulkUpdate(BulkUD): def _do_exec(self): update_stmt = sql.update(self.primary_table, - self.context.whereclause, self.values) + self.context.whereclause, self.values) self.result = self.query.session.execute( - update_stmt, params=self.query._params) + update_stmt, params=self.query._params) self.rowcount = self.result.rowcount def _do_post(self): @@ -1009,10 +1012,10 @@ class BulkDelete(BulkUD): def _do_exec(self): delete_stmt = sql.delete(self.primary_table, - self.context.whereclause) + self.context.whereclause) self.result = self.query.session.execute(delete_stmt, - params=self.query._params) + params=self.query._params) self.rowcount = self.result.rowcount def _do_post(self): @@ -1029,7 +1032,7 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): for key, value in self.values.items(): key = _attr_as_key(key) self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value)) + expression._literal_as_binds(value)) def _do_post_synchronize(self): session = self.query.session @@ -1037,11 +1040,11 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): evaluated_keys = list(self.value_evaluators.keys()) for obj in self.matched_objects: state, dict_ = attributes.instance_state(obj),\ - attributes.instance_dict(obj) + attributes.instance_dict(obj) # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection( - evaluated_keys) + evaluated_keys) for key in to_evaluate: dict_[key] = self.value_evaluators[key](obj) @@ -1050,8 +1053,8 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): # expire attributes with pending changes # (there was no autoflush, so they are overwritten) state._expire_attributes(dict_, - set(evaluated_keys). - difference(to_evaluate)) + set(evaluated_keys). + difference(to_evaluate)) states.add(state) session._register_altered(states) @@ -1062,8 +1065,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): def _do_post_synchronize(self): self.query.session._remove_newly_deleted( - [attributes.instance_state(obj) - for obj in self.matched_objects]) + [attributes.instance_state(obj) + for obj in self.matched_objects]) class BulkUpdateFetch(BulkFetch, BulkUpdate): @@ -1078,7 +1081,7 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): attributes.instance_state(session.identity_map[identity_key]) for identity_key in [ target_mapper.identity_key_from_primary_key( - list(primary_key)) + list(primary_key)) for primary_key in self.matched_rows ] if identity_key in session.identity_map @@ -1100,7 +1103,7 @@ class BulkDeleteFetch(BulkFetch, BulkDelete): # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key)) + list(primary_key)) if identity_key in session.identity_map: session._remove_newly_deleted( [attributes.instance_state( |
