diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 1186 |
1 files changed, 746 insertions, 440 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 7f9b7db0c..dc86a60e5 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -25,8 +25,13 @@ from . import loading def _bulk_insert( - mapper, mappings, session_transaction, isstates, return_defaults, - render_nulls): + mapper, + mappings, + session_transaction, + isstates, + return_defaults, + render_nulls, +): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -34,7 +39,8 @@ def _bulk_insert( if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " - "not supported in bulk_insert()") + "not supported in bulk_insert()" + ) if isstates: if return_defaults: @@ -51,22 +57,33 @@ def _bulk_insert( continue records = ( - (None, state_dict, params, mapper, - connection, value_params, has_all_pks, has_all_defaults) - for - state, state_dict, params, mp, - conn, value_params, has_all_pks, - has_all_defaults in _collect_insert_commands(table, ( - (None, mapping, mapper, connection) - for mapping in mappings), - bulk=True, return_defaults=return_defaults, - render_nulls=render_nulls + ( + None, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) + for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands( + table, + ((None, mapping, mapper, connection) for mapping in mappings), + bulk=True, + return_defaults=return_defaults, + render_nulls=render_nulls, ) ) - _emit_insert_statements(base_mapper, None, - cached_connections, - super_mapper, table, records, - bookkeeping=return_defaults) + _emit_insert_statements( + base_mapper, + None, + cached_connections, + super_mapper, + table, + records, + bookkeeping=return_defaults, + ) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -74,12 +91,13 @@ def _bulk_insert( for state, dict_ in states: state.key = ( identity_cls, - tuple([dict_[key] for key in identity_props]) + tuple([dict_[key] for key in identity_props]), ) -def _bulk_update(mapper, mappings, session_transaction, - isstates, update_changed_only): +def _bulk_update( + mapper, mappings, session_transaction, isstates, update_changed_only +): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -91,9 +109,8 @@ def _bulk_update(mapper, mappings, session_transaction, def _changed_dict(mapper, state): return dict( (k, v) - for k, v in state.dict.items() if k in state.committed_state or k - in search_keys - + for k, v in state.dict.items() + if k in state.committed_state or k in search_keys ) if isstates: @@ -107,7 +124,8 @@ def _bulk_update(mapper, mappings, session_transaction, if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " - "not supported in bulk_update()") + "not supported in bulk_update()" + ) connection = session_transaction.connection(base_mapper) @@ -115,21 +133,38 @@ def _bulk_update(mapper, mappings, session_transaction, if not mapper.isa(super_mapper): continue - records = _collect_update_commands(None, table, ( - (None, mapping, mapper, connection, - (mapping[mapper._version_id_prop.key] - if mapper._version_id_prop else None)) - for mapping in mappings - ), bulk=True) + records = _collect_update_commands( + None, + table, + ( + ( + None, + mapping, + mapper, + connection, + ( + mapping[mapper._version_id_prop.key] + if mapper._version_id_prop + else None + ), + ) + for mapping in mappings + ), + bulk=True, + ) - _emit_update_statements(base_mapper, None, - cached_connections, - super_mapper, table, records, - bookkeeping=False) + _emit_update_statements( + base_mapper, + None, + cached_connections, + super_mapper, + table, + records, + bookkeeping=False, + ) -def save_obj( - base_mapper, states, uowtransaction, single=False): +def save_obj(base_mapper, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -150,19 +185,21 @@ def save_obj( states_to_insert = [] cached_connections = _cached_connection_dict(base_mapper) - for (state, dict_, mapper, connection, - has_identity, - row_switch, update_version_id) in _organize_states_for_save( - base_mapper, states, uowtransaction - ): + for ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) in _organize_states_for_save(base_mapper, states, uowtransaction): if has_identity or row_switch: states_to_update.append( (state, dict_, mapper, connection, update_version_id) ) else: - states_to_insert.append( - (state, dict_, mapper, connection) - ) + states_to_insert.append((state, dict_, mapper, connection)) for table, mapper in base_mapper._sorted_tables.items(): if table not in mapper._pks_by_table: @@ -170,18 +207,30 @@ def save_obj( insert = _collect_insert_commands(table, states_to_insert) update = _collect_update_commands( - uowtransaction, table, states_to_update) + uowtransaction, table, states_to_update + ) - _emit_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + ) - _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, insert) + _emit_insert_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + insert, + ) _finalize_insert_update_commands( - base_mapper, uowtransaction, + base_mapper, + uowtransaction, chain( ( (state, state_dict, mapper, connection, False) @@ -189,10 +238,9 @@ def save_obj( ), ( (state, state_dict, mapper, connection, True) - for state, state_dict, mapper, connection, - update_version_id in states_to_update - ) - ) + for state, state_dict, mapper, connection, update_version_id in states_to_update + ), + ), ) @@ -203,9 +251,9 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): """ cached_connections = _cached_connection_dict(base_mapper) - states_to_update = list(_organize_states_for_post_update( - base_mapper, - states, uowtransaction)) + states_to_update = list( + _organize_states_for_post_update(base_mapper, states, uowtransaction) + ) for table, mapper in base_mapper._sorted_tables.items(): if table not in mapper._pks_by_table: @@ -213,25 +261,32 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): update = ( ( - state, state_dict, sub_mapper, connection, + state, + state_dict, + sub_mapper, + connection, mapper._get_committed_state_attr_by_column( state, state_dict, mapper.version_id_col - ) if mapper.version_id_col is not None else None + ) + if mapper.version_id_col is not None + else None, ) - for - state, state_dict, sub_mapper, connection in states_to_update + for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table ) update = _collect_post_update_commands( - base_mapper, uowtransaction, - table, update, - post_update_cols + base_mapper, uowtransaction, table, update, post_update_cols ) - _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_post_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + ) def delete_obj(base_mapper, states, uowtransaction): @@ -244,10 +299,9 @@ def delete_obj(base_mapper, states, uowtransaction): cached_connections = _cached_connection_dict(base_mapper) - states_to_delete = list(_organize_states_for_delete( - base_mapper, - states, - uowtransaction)) + states_to_delete = list( + _organize_states_for_delete(base_mapper, states, uowtransaction) + ) table_to_mapper = base_mapper._sorted_tables @@ -258,14 +312,26 @@ def delete_obj(base_mapper, states, uowtransaction): elif mapper.inherits and mapper.passive_deletes: continue - delete = _collect_delete_commands(base_mapper, uowtransaction, - table, states_to_delete) + delete = _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete + ) - _emit_delete_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, delete) + _emit_delete_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + delete, + ) - for state, state_dict, mapper, connection, \ - update_version_id in states_to_delete: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: mapper.dispatch.after_delete(mapper, connection, state) @@ -282,8 +348,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): """ for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, states + ): has_identity = bool(state.key) @@ -304,25 +370,29 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): # no instance_key attached to it), and another instance # with the same identity key already exists as persistent. # convert to an UPDATE if so. - if not has_identity and \ - instance_key in uowtransaction.session.identity_map: - instance = \ - uowtransaction.session.identity_map[instance_key] + if ( + not has_identity + and instance_key in uowtransaction.session.identity_map + ): + instance = uowtransaction.session.identity_map[instance_key] existing = attributes.instance_state(instance) if not uowtransaction.was_already_deleted(existing): if not uowtransaction.is_deleted(existing): raise orm_exc.FlushError( "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, - state_str(existing))) + "with persistent instance %s" + % (state_str(state), instance_key, state_str(existing)) + ) base_mapper._log_debug( "detected row switch for identity %s. " "will update %s, remove %s from " - "transaction", instance_key, - state_str(state), state_str(existing)) + "transaction", + instance_key, + state_str(state), + state_str(existing), + ) # remove the "delete" flag from the existing element uowtransaction.remove_state_actions(existing) @@ -332,14 +402,21 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): update_version_id = mapper._get_committed_state_attr_by_column( row_switch if row_switch else state, row_switch.dict if row_switch else dict_, - mapper.version_id_col) + mapper.version_id_col, + ) - yield (state, dict_, mapper, connection, - has_identity, row_switch, update_version_id) + yield ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) -def _organize_states_for_post_update(base_mapper, states, - uowtransaction): +def _organize_states_for_post_update(base_mapper, states, uowtransaction): """Make an initial pass across a set of states for UPDATE corresponding to post_update. @@ -360,26 +437,28 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): """ for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, states + ): mapper.dispatch.before_delete(mapper, connection, state) if mapper.version_id_col is not None: - update_version_id = \ - mapper._get_committed_state_attr_by_column( - state, dict_, - mapper.version_id_col) + update_version_id = mapper._get_committed_state_attr_by_column( + state, dict_, mapper.version_id_col + ) else: update_version_id = None - yield ( - state, dict_, mapper, connection, update_version_id) + yield (state, dict_, mapper, connection, update_version_id) def _collect_insert_commands( - table, states_to_insert, - bulk=False, return_defaults=False, render_nulls=False): + table, + states_to_insert, + bulk=False, + return_defaults=False, + render_nulls=False, +): """Identify sets of values to use in INSERT statements for a list of states. @@ -400,10 +479,16 @@ def _collect_insert_commands( col = propkey_to_col[propkey] if value is None and col not in eval_none and not render_nulls: continue - elif not bulk and hasattr(value, '__clause_element__') or \ - isinstance(value, sql.ClauseElement): - value_params[col.key] = value.__clause_element__() \ - if hasattr(value, '__clause_element__') else value + elif ( + not bulk + and hasattr(value, "__clause_element__") + or isinstance(value, sql.ClauseElement) + ): + value_params[col.key] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) else: params[col.key] = value @@ -414,8 +499,11 @@ def _collect_insert_commands( # which might be worth removing, as it should not be necessary # and also produces confusion, given that "missing" and None # now have distinct meanings - for colkey in mapper._insert_cols_as_none[table].\ - difference(params).difference(value_params): + for colkey in ( + mapper._insert_cols_as_none[table] + .difference(params) + .difference(value_params) + ): params[colkey] = None if not bulk or return_defaults: @@ -424,28 +512,38 @@ def _collect_insert_commands( has_all_pks = mapper._pk_keys_by_table[table].issubset(params) if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].\ - issubset(params) + has_all_defaults = mapper._server_default_cols[table].issubset( + params + ) else: has_all_defaults = True else: has_all_defaults = has_all_pks = True - if mapper.version_id_generator is not False \ - and mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator(None) + if ( + mapper.version_id_generator is not False + and mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + params[mapper.version_id_col.key] = mapper.version_id_generator( + None + ) yield ( - state, state_dict, params, mapper, - connection, value_params, has_all_pks, - has_all_defaults) + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) def _collect_update_commands( - uowtransaction, table, states_to_update, - bulk=False): + uowtransaction, table, states_to_update, bulk=False +): """Identify sets of values to use in UPDATE statements for a list of states. @@ -457,8 +555,13 @@ def _collect_update_commands( """ - for state, state_dict, mapper, connection, \ - update_version_id in states_to_update: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: if table not in mapper._pks_by_table: continue @@ -474,36 +577,48 @@ def _collect_update_commands( # look at mapper attribute keys for pk params = dict( (propkey_to_col[propkey].key, state_dict[propkey]) - for propkey in - set(propkey_to_col).intersection(state_dict).difference( - mapper._pk_attr_keys_by_table[table]) + for propkey in set(propkey_to_col) + .intersection(state_dict) + .difference(mapper._pk_attr_keys_by_table[table]) ) has_all_defaults = True else: params = {} for propkey in set(propkey_to_col).intersection( - state.committed_state): + state.committed_state + ): value = state_dict[propkey] col = propkey_to_col[propkey] - if hasattr(value, '__clause_element__') or \ - isinstance(value, sql.ClauseElement): - value_params[col] = value.__clause_element__() \ - if hasattr(value, '__clause_element__') else value + if hasattr(value, "__clause_element__") or isinstance( + value, sql.ClauseElement + ): + value_params[col] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) # guard against values that generate non-__nonzero__ # objects for __eq__() - elif state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey]) is not True: + elif ( + state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey] + ) + is not True + ): params[col.key] = value if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_onupdate_default_cols[table].\ - issubset(params) + has_all_defaults = mapper._server_onupdate_default_cols[ + table + ].issubset(params) else: has_all_defaults = True - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): if not bulk and not (params or value_params): # HACK: check for history in other tables, in case the @@ -511,10 +626,9 @@ def _collect_update_commands( # where the version_id_col is. This logic was lost # from 0.9 -> 1.0.0 and restored in 1.0.6. for prop in mapper._columntoproperty.values(): - history = ( - state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE)) + history = state.manager[prop.key].impl.get_history( + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) if history.added: break else: @@ -525,8 +639,9 @@ def _collect_update_commands( no_params = not params and not value_params params[col._label] = update_version_id - if (bulk or col.key not in params) and \ - mapper.version_id_generator is not False: + if ( + bulk or col.key not in params + ) and mapper.version_id_generator is not False: val = mapper.version_id_generator(update_version_id) params[col.key] = val elif mapper.version_id_generator is False and no_params: @@ -545,9 +660,9 @@ def _collect_update_commands( # look at mapper attribute keys for pk pk_params = dict( (propkey_to_col[propkey]._label, state_dict.get(propkey)) - for propkey in - set(propkey_to_col). - intersection(mapper._pk_attr_keys_by_table[table]) + for propkey in set(propkey_to_col).intersection( + mapper._pk_attr_keys_by_table[table] + ) ) else: pk_params = {} @@ -555,12 +670,15 @@ def _collect_update_commands( propkey = mapper._columntoproperty[col].key history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF) + state, state_dict, attributes.PASSIVE_OFF + ) if history.added: - if not history.deleted or \ - ("pk_cascaded", state, col) in \ - uowtransaction.attributes: + if ( + not history.deleted + or ("pk_cascaded", state, col) + in uowtransaction.attributes + ): pk_params[col._label] = history.added[0] params.pop(col.key, None) else: @@ -573,24 +691,38 @@ def _collect_update_commands( if pk_params[col._label] is None: raise orm_exc.FlushError( "Can't update table %s using NULL for primary " - "key value on column %s" % (table, col)) + "key value on column %s" % (table, col) + ) if params or value_params: params.update(pk_params) yield ( - state, state_dict, params, mapper, - connection, value_params, has_all_defaults, has_all_pks) + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) -def _collect_post_update_commands(base_mapper, uowtransaction, table, - states_to_update, post_update_cols): +def _collect_post_update_commands( + base_mapper, uowtransaction, table, states_to_update, post_update_cols +): """Identify sets of values to use in UPDATE statements for a list of states within a post_update operation. """ - for state, state_dict, mapper, connection, \ - update_version_id in states_to_update: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: # assert table in mapper._pks_by_table @@ -600,100 +732,128 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, for col in mapper._cols_by_table[table]: if col in pks: - params[col._label] = \ - mapper._get_state_attr_by_column( - state, - state_dict, col, passive=attributes.PASSIVE_OFF) + params[col._label] = mapper._get_state_attr_by_column( + state, state_dict, col, passive=attributes.PASSIVE_OFF + ) elif col in post_update_cols or col.onupdate is not None: prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE) + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) if history.added: value = history.added[0] params[col.key] = value hasdata = True if hasdata: - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): col = mapper.version_id_col params[col._label] = update_version_id - if bool(state.key) and col.key not in params and \ - mapper.version_id_generator is not False: + if ( + bool(state.key) + and col.key not in params + and mapper.version_id_generator is not False + ): val = mapper.version_id_generator(update_version_id) params[col.key] = val yield state, state_dict, mapper, connection, params -def _collect_delete_commands(base_mapper, uowtransaction, table, - states_to_delete): +def _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete +): """Identify values to use in DELETE statements for a list of states to be deleted.""" - for state, state_dict, mapper, connection, \ - update_version_id in states_to_delete: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: if table not in mapper._pks_by_table: continue params = {} for col in mapper._pks_by_table[table]: - params[col.key] = \ - value = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, col) + params[ + col.key + ] = value = mapper._get_committed_state_attr_by_column( + state, state_dict, col + ) if value is None: raise orm_exc.FlushError( "Can't delete from table %s " "using NULL for primary " - "key value on column %s" % (table, col)) + "key value on column %s" % (table, col) + ) - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): params[mapper.version_id_col.key] = update_version_id yield params, connection -def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update, - bookkeeping=True): +def _emit_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + bookkeeping=True, +): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" - needs_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def update_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + clause.clauses.append( + col == sql.bindparam(col._label, type_=col.type) + ) if needs_version_id: clause.clauses.append( - mapper.version_id_col == sql.bindparam( + mapper.version_id_col + == sql.bindparam( mapper.version_id_col._label, - type_=mapper.version_id_col.type)) + type_=mapper.version_id_col.type, + ) + ) stmt = table.update(clause) return stmt - cached_stmt = base_mapper._memo(('update', table), update_stmt) - - for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \ - records in groupby( - update, - lambda rec: ( - rec[4], # connection - set(rec[2]), # set of parameter keys - bool(rec[5]), # whether or not we have "value" parameters - rec[6], # has_all_defaults - rec[7] # has all pks - ) + cached_stmt = base_mapper._memo(("update", table), update_stmt) + + for ( + (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), + records, + ) in groupby( + update, + lambda rec: ( + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]), # whether or not we have "value" parameters + rec[6], # has_all_defaults + rec[7], # has all pks + ), ): rows = 0 records = list(records) @@ -704,8 +864,11 @@ def _emit_update_statements(base_mapper, uowtransaction, if not has_all_pks: statement = statement.return_defaults() return_defaults = True - elif bookkeeping and not has_all_defaults and \ - mapper.base_mapper.eager_defaults: + elif ( + bookkeeping + and not has_all_defaults + and mapper.base_mapper.eager_defaults + ): statement = statement.return_defaults() return_defaults = True elif mapper.version_id_col is not None: @@ -718,17 +881,24 @@ def _emit_update_statements(base_mapper, uowtransaction, else connection.dialect.supports_sane_rowcount_returning ) - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) allow_multirow = has_all_defaults and not needs_version_id if hasvalue: - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_defaults, has_all_pks in records: - c = connection.execute( - statement.values(value_params), - params) + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = connection.execute(statement.values(value_params), params) if bookkeeping: _postfetch( mapper, @@ -738,17 +908,26 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) rows += c.rowcount check_rowcount = assert_singlerow else: if not allow_multirow: check_rowcount = assert_singlerow - for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults, \ - has_all_pks in records: - c = cached_connections[connection].\ - execute(statement, params) + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = cached_connections[connection].execute( + statement, params + ) # TODO: why with bookkeeping=False? if bookkeeping: @@ -760,24 +939,32 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) rows += c.rowcount else: multiparams = [rec[2] for rec in records] check_rowcount = assert_multirow or ( - assert_singlerow and - len(multiparams) == 1 + assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute( + statement, multiparams + ) rows += c.rowcount - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_defaults, has_all_pks in records: + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: if bookkeeping: _postfetch( mapper, @@ -787,59 +974,85 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(records), rows)) + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description) + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) -def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert, - bookkeeping=True): +def _emit_insert_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + insert, + bookkeeping=True, +): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(('insert', table), table.insert) - - for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ - records in groupby( - insert, - lambda rec: ( - rec[4], # connection - set(rec[2]), # parameter keys - bool(rec[5]), # whether we have "value" parameters - rec[6], - rec[7])): + cached_stmt = base_mapper._memo(("insert", table), table.insert) + + for ( + (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + records, + ) in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7], + ), + ): statement = cached_stmt - if not bookkeeping or \ - ( - has_all_defaults - or not base_mapper.eager_defaults - or not connection.dialect.implicit_returning - ) and has_all_pks and not hasvalue: + if ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not connection.dialect.implicit_returning + ) + and has_all_pks + and not hasvalue + ): records = list(records) multiparams = [rec[2] for rec in records] - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute(statement, multiparams) if bookkeeping: - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + ) in zip(records, c.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -849,7 +1062,8 @@ def _emit_insert_statements(base_mapper, uowtransaction, state_dict, c, last_inserted_params, - value_params) + value_params, + ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -859,24 +1073,33 @@ def _emit_insert_statements(base_mapper, uowtransaction, elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) - for state, state_dict, params, mapper_rec, \ - connection, value_params, \ - has_all_pks, has_all_defaults in records: + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: if value_params: result = connection.execute( - statement.values(value_params), - params) + statement.values(value_params), params + ) else: - result = cached_connections[connection].\ - execute(statement, params) + result = cached_connections[connection].execute( + statement, params + ) primary_key = result.context.inserted_primary_key if primary_key is not None: # set primary key attributes - for pk, col in zip(primary_key, - mapper._pks_by_table[table]): + for pk, col in zip( + primary_key, mapper._pks_by_table[table] + ): prop = mapper_rec._columntoproperty[col] if state_dict.get(prop.key) is None: state_dict[prop.key] = pk @@ -890,31 +1113,39 @@ def _emit_insert_statements(base_mapper, uowtransaction, state_dict, result, result.context.compiled_parameters[0], - value_params) + value_params, + ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) -def _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): +def _emit_post_update_statements( + base_mapper, uowtransaction, cached_connections, mapper, table, update +): """Emit UPDATE statements corresponding to value lists collected by _collect_post_update_commands().""" - needs_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def update_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + clause.clauses.append( + col == sql.bindparam(col._label, type_=col.type) + ) if needs_version_id: clause.clauses.append( - mapper.version_id_col == sql.bindparam( + mapper.version_id_col + == sql.bindparam( mapper.version_id_col._label, - type_=mapper.version_id_col.type)) + type_=mapper.version_id_col.type, + ) + ) stmt = table.update(clause) @@ -923,17 +1154,15 @@ def _emit_post_update_statements(base_mapper, uowtransaction, return stmt - statement = base_mapper._memo(('post_update', table), update_stmt) + statement = base_mapper._memo(("post_update", table), update_stmt) # execute each UPDATE in the order according to the original # list of states to guarantee row access order, but # also group them into common (connection, cols) sets # to support executemany(). for key, records in groupby( - update, lambda rec: ( - rec[3], # connection - set(rec[4]), # parameter keys - ) + update, + lambda rec: (rec[3], set(rec[4])), # connection # parameter keys ): rows = 0 @@ -945,84 +1174,96 @@ def _emit_post_update_statements(base_mapper, uowtransaction, if mapper.version_id_col is None else connection.dialect.supports_sane_rowcount_returning ) - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) allow_multirow = not needs_version_id or assert_multirow - if not allow_multirow: check_rowcount = assert_singlerow - for state, state_dict, mapper_rec, \ - connection, params in records: - c = cached_connections[connection].\ - execute(statement, params) + for state, state_dict, mapper_rec, connection, params in records: + c = cached_connections[connection].execute(statement, params) _postfetch_post_update( - mapper_rec, uowtransaction, table, state, state_dict, - c, c.context.compiled_parameters[0]) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) rows += c.rowcount else: multiparams = [ - params for - state, state_dict, mapper_rec, conn, params in records] + params + for state, state_dict, mapper_rec, conn, params in records + ] check_rowcount = assert_multirow or ( - assert_singlerow and - len(multiparams) == 1 + assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute(statement, multiparams) rows += c.rowcount - for state, state_dict, mapper_rec, \ - connection, params in records: + for state, state_dict, mapper_rec, connection, params in records: _postfetch_post_update( - mapper_rec, uowtransaction, table, state, state_dict, - c, c.context.compiled_parameters[0]) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(records), rows)) + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description) + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) -def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, delete): +def _emit_delete_statements( + base_mapper, uowtransaction, cached_connections, mapper, table, delete +): """Emit DELETE statements corresponding to value lists collected by _collect_delete_commands().""" - need_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + need_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def delete_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append( - col == sql.bindparam(col.key, type_=col.type)) + col == sql.bindparam(col.key, type_=col.type) + ) if need_version_id: clause.clauses.append( - mapper.version_id_col == - sql.bindparam( - mapper.version_id_col.key, - type_=mapper.version_id_col.type + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col.key, type_=mapper.version_id_col.type ) ) return table.delete(clause) - statement = base_mapper._memo(('delete', table), delete_stmt) - for connection, recs in groupby( - delete, - lambda rec: rec[1] # connection - ): + statement = base_mapper._memo(("delete", table), delete_stmt) + for connection, recs in groupby(delete, lambda rec: rec[1]): # connection del_objects = [params for params, connection in recs] connection = cached_connections[connection] @@ -1049,9 +1290,10 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, else: util.warn( "Dialect %s does not support deleted rowcount " - "- versioning cannot be verified." % - connection.dialect.dialect_description, - stacklevel=12) + "- versioning cannot be verified." + % connection.dialect.dialect_description, + stacklevel=12, + ) connection.execute(statement, del_objects) else: c = connection.execute(statement, del_objects) @@ -1061,23 +1303,26 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, rows_matched = c.rowcount - if base_mapper.confirm_deleted_rows and \ - rows_matched > -1 and expected != rows_matched: + if ( + base_mapper.confirm_deleted_rows + and rows_matched > -1 + and expected != rows_matched + ): if only_warn: util.warn( "DELETE statement on table '%s' expected to " "delete %d row(s); %d were matched. Please set " "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." % - (table.description, expected, rows_matched) + "configuration to prevent this warning." + % (table.description, expected, rows_matched) ) else: raise orm_exc.StaleDataError( "DELETE statement on table '%s' expected to " "delete %d row(s); %d were matched. Please set " "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." % - (table.description, expected, rows_matched) + "configuration to prevent this warning." + % (table.description, expected, rows_matched) ) @@ -1091,13 +1336,16 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if mapper._readonly_props: readonly = state.unmodified_intersection( [ - p.key for p in mapper._readonly_props + p.key + for p in mapper._readonly_props if ( - p.expire_on_flush and - (not p.deferred or p.key in state.dict) - ) or ( - not p.expire_on_flush and - not p.deferred and p.key not in state.dict + p.expire_on_flush + and (not p.deferred or p.key in state.dict) + ) + or ( + not p.expire_on_flush + and not p.deferred + and p.key not in state.dict ) ] ) @@ -1112,11 +1360,14 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if base_mapper.eager_defaults: toload_now.extend( state._unloaded_non_object.intersection( - mapper._server_default_plus_onupdate_propkeys) + mapper._server_default_plus_onupdate_propkeys + ) ) - if mapper.version_id_col is not None and \ - mapper.version_id_generator is False: + if ( + mapper.version_id_col is not None + and mapper.version_id_generator is False + ): if mapper._version_id_prop.key in state.unloaded: toload_now.extend([mapper._version_id_prop.key]) @@ -1124,8 +1375,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): state.key = base_mapper._identity_key_from_state(state) loading.load_on_ident( uowtransaction.session.query(mapper), - state.key, refresh_state=state, - only_load_props=toload_now) + state.key, + refresh_state=state, + only_load_props=toload_now, + ) # call after_XXX extensions if not has_identity: @@ -1133,23 +1386,29 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): else: mapper.dispatch.after_update(mapper, connection, state) - if mapper.version_id_generator is False and \ - mapper.version_id_col is not None: + if ( + mapper.version_id_generator is False + and mapper.version_id_col is not None + ): if state_dict[mapper._version_id_prop.key] is None: raise orm_exc.FlushError( - "Instance does not contain a non-NULL version value") + "Instance does not contain a non-NULL version value" + ) -def _postfetch_post_update(mapper, uowtransaction, table, - state, dict_, result, params): +def _postfetch_post_update( + mapper, uowtransaction, table, state, dict_, result, params +): if uowtransaction.is_deleted(state): return prefetch_cols = result.context.compiled.prefetch postfetch_cols = result.context.compiled.postfetch - if mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1164,18 +1423,23 @@ def _postfetch_post_update(mapper, uowtransaction, table, if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs) + state, uowtransaction, load_evt_attrs + ) if postfetch_cols: - state._expire_attributes(state.dict, - [mapper._columntoproperty[c].key - for c in postfetch_cols if c in - mapper._columntoproperty] - ) + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) -def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params): +def _postfetch( + mapper, uowtransaction, table, state, dict_, result, params, value_params +): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -1184,8 +1448,10 @@ def _postfetch(mapper, uowtransaction, table, postfetch_cols = result.context.compiled.postfetch returning_cols = result.context.compiled.returning - if mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1219,23 +1485,32 @@ def _postfetch(mapper, uowtransaction, table, if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs) + state, uowtransaction, load_evt_attrs + ) if postfetch_cols: - state._expire_attributes(state.dict, - [mapper._columntoproperty[c].key - for c in postfetch_cols if c in - mapper._columntoproperty] - ) + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) # synchronize newly inserted ids from one table to the next # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + sync.populate( + state, + m, + state, + m, + equated_pairs, + uowtransaction, + mapper.passive_updates, + ) def _postfetch_bulk_save(mapper, dict_, table): @@ -1255,8 +1530,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): # organize individual states with the connection # to use for update if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable + connection_callable = uowtransaction.session.connection_callable else: connection = uowtransaction.transaction.connection(base_mapper) connection_callable = None @@ -1275,7 +1549,8 @@ def _cached_connection_dict(base_mapper): return util.PopulateDict( lambda conn: conn.execution_options( compiled_cache=base_mapper._compiled_cache - )) + ) + ) def _sort_states(states): @@ -1287,9 +1562,12 @@ def _sort_states(states): except TypeError as err: raise sa_exc.InvalidRequestError( "Could not sort objects by primary key; primary key " - "values must be sortable in Python (was: %s)" % err) - return sorted(pending, key=operator.attrgetter("insert_order")) + \ - persistent_sorted + "values must be sortable in Python (was: %s)" % err + ) + return ( + sorted(pending, key=operator.attrgetter("insert_order")) + + persistent_sorted + ) class BulkUD(object): @@ -1302,21 +1580,22 @@ class BulkUD(object): def _validate_query_state(self): for attr, methname, notset, op in ( - ('_limit', 'limit()', None, operator.is_), - ('_offset', 'offset()', None, operator.is_), - ('_order_by', 'order_by()', False, operator.is_), - ('_group_by', 'group_by()', False, operator.is_), - ('_distinct', 'distinct()', False, operator.is_), + ("_limit", "limit()", None, operator.is_), + ("_offset", "offset()", None, operator.is_), + ("_order_by", "order_by()", False, operator.is_), + ("_group_by", "group_by()", False, operator.is_), + ("_distinct", "distinct()", False, operator.is_), ( - '_from_obj', - 'join(), outerjoin(), select_from(), or from_self()', - (), operator.eq) + "_from_obj", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), ): if not op(getattr(self.query, attr), notset): raise sa_exc.InvalidRequestError( "Can't call Query.update() or Query.delete() " - "when %s has been called" % - (methname, ) + "when %s has been called" % (methname,) ) @property @@ -1330,8 +1609,8 @@ class BulkUD(object): except KeyError: raise sa_exc.ArgumentError( "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) - for x in lookup)))) + "are %s" % (", ".join(sorted(repr(x) for x in lookup))) + ) else: return klass(*arg) @@ -1400,9 +1679,9 @@ class BulkEvaluate(BulkUD): try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) if query.whereclause is not None: - eval_condition = evaluator_compiler.process( - query.whereclause) + eval_condition = evaluator_compiler.process(query.whereclause) else: + def eval_condition(obj): return True @@ -1411,15 +1690,20 @@ class BulkEvaluate(BulkUD): except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( 'Could not evaluate current criteria in Python: "%s". ' - 'Specify \'fetch\' or False for the ' - 'synchronize_session parameter.' % err) + "Specify 'fetch' or False for the " + "synchronize_session parameter." % err + ) # TODO: detect when the where clause is a trivial primary key match self.matched_objects = [ - obj for (cls, pk, identity_token), obj in - query.session.identity_map.items() - if issubclass(cls, target_cls) and - eval_condition(obj)] + obj + for ( + cls, + pk, + identity_token, + ), obj in query.session.identity_map.items() + if issubclass(cls, target_cls) and eval_condition(obj) + ] class BulkFetch(BulkUD): @@ -1430,11 +1714,11 @@ class BulkFetch(BulkUD): session = query.session context = query._compile_context() select_stmt = context.statement.with_only_columns( - self.primary_table.primary_key) + self.primary_table.primary_key + ) self.matched_rows = session.execute( - select_stmt, - mapper=self.mapper, - params=query._params).fetchall() + select_stmt, mapper=self.mapper, params=query._params + ).fetchall() class BulkUpdate(BulkUD): @@ -1447,18 +1731,26 @@ class BulkUpdate(BulkUD): @classmethod def factory(cls, query, synchronize_session, values, update_kwargs): - return BulkUD._factory({ - "evaluate": BulkUpdateEvaluate, - "fetch": BulkUpdateFetch, - False: BulkUpdate - }, synchronize_session, query, values, update_kwargs) + return BulkUD._factory( + { + "evaluate": BulkUpdateEvaluate, + "fetch": BulkUpdateFetch, + False: BulkUpdate, + }, + synchronize_session, + query, + values, + update_kwargs, + ) @property def _resolved_values(self): values = [] for k, v in ( - self.values.items() if hasattr(self.values, 'items') - else self.values): + self.values.items() + if hasattr(self.values, "items") + else self.values + ): if self.mapper: if isinstance(k, util.string_types): desc = _entity_descriptor(self.mapper, k) @@ -1478,7 +1770,7 @@ class BulkUpdate(BulkUD): if isinstance(k, attributes.QueryableAttribute): values.append((k.key, v)) continue - elif hasattr(k, '__clause_element__'): + elif hasattr(k, "__clause_element__"): k = k.__clause_element__() if self.mapper and isinstance(k, expression.ColumnElement): @@ -1490,18 +1782,22 @@ class BulkUpdate(BulkUD): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k) + "Invalid expression type: %r" % k + ) return values def _do_exec(self): values = self._resolved_values - if not self.update_kwargs.get('preserve_parameter_order', False): + if not self.update_kwargs.get("preserve_parameter_order", False): values = dict(values) - update_stmt = sql.update(self.primary_table, - self.context.whereclause, values, - **self.update_kwargs) + update_stmt = sql.update( + self.primary_table, + self.context.whereclause, + values, + **self.update_kwargs + ) self._execute_stmt(update_stmt) @@ -1518,15 +1814,18 @@ class BulkDelete(BulkUD): @classmethod def factory(cls, query, synchronize_session): - return BulkUD._factory({ - "evaluate": BulkDeleteEvaluate, - "fetch": BulkDeleteFetch, - False: BulkDelete - }, synchronize_session, query) + return BulkUD._factory( + { + "evaluate": BulkDeleteEvaluate, + "fetch": BulkDeleteFetch, + False: BulkDelete, + }, + synchronize_session, + query, + ) def _do_exec(self): - delete_stmt = sql.delete(self.primary_table, - self.context.whereclause) + delete_stmt = sql.delete(self.primary_table, self.context.whereclause) self._execute_stmt(delete_stmt) @@ -1544,32 +1843,33 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): values = self._resolved_values_keys_as_propnames for key, value in values: self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value)) + expression._literal_as_binds(value) + ) def _do_post_synchronize(self): session = self.query.session states = set() evaluated_keys = list(self.value_evaluators.keys()) for obj in self.matched_objects: - state, dict_ = attributes.instance_state(obj),\ - attributes.instance_dict(obj) + state, dict_ = ( + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) # only evaluate unmodified attributes - to_evaluate = state.unmodified.intersection( - evaluated_keys) + to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: dict_[key] = self.value_evaluators[key](obj) - state.manager.dispatch.refresh( - state, None, to_evaluate) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) # expire attributes with pending changes # (there was no autoflush, so they are overwritten) - state._expire_attributes(dict_, - set(evaluated_keys). - difference(to_evaluate)) + state._expire_attributes( + dict_, set(evaluated_keys).difference(to_evaluate) + ) states.add(state) session._register_altered(states) @@ -1580,8 +1880,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): def _do_post_synchronize(self): self.query.session._remove_newly_deleted( - [attributes.instance_state(obj) - for obj in self.matched_objects]) + [attributes.instance_state(obj) for obj in self.matched_objects] + ) class BulkUpdateFetch(BulkFetch, BulkUpdate): @@ -1592,15 +1892,18 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): session = self.query.session target_mapper = self.query._mapper_zero() - states = set([ - attributes.instance_state(session.identity_map[identity_key]) - for identity_key in [ - target_mapper.identity_key_from_primary_key( - list(primary_key)) - for primary_key in self.matched_rows + states = set( + [ + attributes.instance_state(session.identity_map[identity_key]) + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key) + ) + for primary_key in self.matched_rows + ] + if identity_key in session.identity_map ] - if identity_key in session.identity_map - ]) + ) values = self._resolved_values_keys_as_propnames attrib = set(k for k, v in values) @@ -1622,10 +1925,13 @@ class BulkDeleteFetch(BulkFetch, BulkDelete): # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key)) + list(primary_key) + ) if identity_key in session.identity_map: session._remove_newly_deleted( - [attributes.instance_state( - session.identity_map[identity_key] - )] + [ + attributes.instance_state( + session.identity_map[identity_key] + ) + ] ) |
