diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/orm/persistence.py | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 1186 |
1 files changed, 746 insertions, 440 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 7f9b7db0c..dc86a60e5 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -25,8 +25,13 @@ from . import loading def _bulk_insert( - mapper, mappings, session_transaction, isstates, return_defaults, - render_nulls): + mapper, + mappings, + session_transaction, + isstates, + return_defaults, + render_nulls, +): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -34,7 +39,8 @@ def _bulk_insert( if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " - "not supported in bulk_insert()") + "not supported in bulk_insert()" + ) if isstates: if return_defaults: @@ -51,22 +57,33 @@ def _bulk_insert( continue records = ( - (None, state_dict, params, mapper, - connection, value_params, has_all_pks, has_all_defaults) - for - state, state_dict, params, mp, - conn, value_params, has_all_pks, - has_all_defaults in _collect_insert_commands(table, ( - (None, mapping, mapper, connection) - for mapping in mappings), - bulk=True, return_defaults=return_defaults, - render_nulls=render_nulls + ( + None, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) + for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands( + table, + ((None, mapping, mapper, connection) for mapping in mappings), + bulk=True, + return_defaults=return_defaults, + render_nulls=render_nulls, ) ) - _emit_insert_statements(base_mapper, None, - cached_connections, - super_mapper, table, records, - bookkeeping=return_defaults) + _emit_insert_statements( + base_mapper, + None, + cached_connections, + super_mapper, + table, + records, + bookkeeping=return_defaults, + ) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -74,12 +91,13 @@ def _bulk_insert( for state, dict_ in states: state.key = ( identity_cls, - tuple([dict_[key] for key in identity_props]) + tuple([dict_[key] for key in identity_props]), ) -def _bulk_update(mapper, mappings, session_transaction, - isstates, update_changed_only): +def _bulk_update( + mapper, mappings, session_transaction, isstates, update_changed_only +): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -91,9 +109,8 @@ def _bulk_update(mapper, mappings, session_transaction, def _changed_dict(mapper, state): return dict( (k, v) - for k, v in state.dict.items() if k in state.committed_state or k - in search_keys - + for k, v in state.dict.items() + if k in state.committed_state or k in search_keys ) if isstates: @@ -107,7 +124,8 @@ def _bulk_update(mapper, mappings, session_transaction, if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " - "not supported in bulk_update()") + "not supported in bulk_update()" + ) connection = session_transaction.connection(base_mapper) @@ -115,21 +133,38 @@ def _bulk_update(mapper, mappings, session_transaction, if not mapper.isa(super_mapper): continue - records = _collect_update_commands(None, table, ( - (None, mapping, mapper, connection, - (mapping[mapper._version_id_prop.key] - if mapper._version_id_prop else None)) - for mapping in mappings - ), bulk=True) + records = _collect_update_commands( + None, + table, + ( + ( + None, + mapping, + mapper, + connection, + ( + mapping[mapper._version_id_prop.key] + if mapper._version_id_prop + else None + ), + ) + for mapping in mappings + ), + bulk=True, + ) - _emit_update_statements(base_mapper, None, - cached_connections, - super_mapper, table, records, - bookkeeping=False) + _emit_update_statements( + base_mapper, + None, + cached_connections, + super_mapper, + table, + records, + bookkeeping=False, + ) -def save_obj( - base_mapper, states, uowtransaction, single=False): +def save_obj(base_mapper, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -150,19 +185,21 @@ def save_obj( states_to_insert = [] cached_connections = _cached_connection_dict(base_mapper) - for (state, dict_, mapper, connection, - has_identity, - row_switch, update_version_id) in _organize_states_for_save( - base_mapper, states, uowtransaction - ): + for ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) in _organize_states_for_save(base_mapper, states, uowtransaction): if has_identity or row_switch: states_to_update.append( (state, dict_, mapper, connection, update_version_id) ) else: - states_to_insert.append( - (state, dict_, mapper, connection) - ) + states_to_insert.append((state, dict_, mapper, connection)) for table, mapper in base_mapper._sorted_tables.items(): if table not in mapper._pks_by_table: @@ -170,18 +207,30 @@ def save_obj( insert = _collect_insert_commands(table, states_to_insert) update = _collect_update_commands( - uowtransaction, table, states_to_update) + uowtransaction, table, states_to_update + ) - _emit_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + ) - _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, insert) + _emit_insert_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + insert, + ) _finalize_insert_update_commands( - base_mapper, uowtransaction, + base_mapper, + uowtransaction, chain( ( (state, state_dict, mapper, connection, False) @@ -189,10 +238,9 @@ def save_obj( ), ( (state, state_dict, mapper, connection, True) - for state, state_dict, mapper, connection, - update_version_id in states_to_update - ) - ) + for state, state_dict, mapper, connection, update_version_id in states_to_update + ), + ), ) @@ -203,9 +251,9 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): """ cached_connections = _cached_connection_dict(base_mapper) - states_to_update = list(_organize_states_for_post_update( - base_mapper, - states, uowtransaction)) + states_to_update = list( + _organize_states_for_post_update(base_mapper, states, uowtransaction) + ) for table, mapper in base_mapper._sorted_tables.items(): if table not in mapper._pks_by_table: @@ -213,25 +261,32 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): update = ( ( - state, state_dict, sub_mapper, connection, + state, + state_dict, + sub_mapper, + connection, mapper._get_committed_state_attr_by_column( state, state_dict, mapper.version_id_col - ) if mapper.version_id_col is not None else None + ) + if mapper.version_id_col is not None + else None, ) - for - state, state_dict, sub_mapper, connection in states_to_update + for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table ) update = _collect_post_update_commands( - base_mapper, uowtransaction, - table, update, - post_update_cols + base_mapper, uowtransaction, table, update, post_update_cols ) - _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_post_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + ) def delete_obj(base_mapper, states, uowtransaction): @@ -244,10 +299,9 @@ def delete_obj(base_mapper, states, uowtransaction): cached_connections = _cached_connection_dict(base_mapper) - states_to_delete = list(_organize_states_for_delete( - base_mapper, - states, - uowtransaction)) + states_to_delete = list( + _organize_states_for_delete(base_mapper, states, uowtransaction) + ) table_to_mapper = base_mapper._sorted_tables @@ -258,14 +312,26 @@ def delete_obj(base_mapper, states, uowtransaction): elif mapper.inherits and mapper.passive_deletes: continue - delete = _collect_delete_commands(base_mapper, uowtransaction, - table, states_to_delete) + delete = _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete + ) - _emit_delete_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, delete) + _emit_delete_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + delete, + ) - for state, state_dict, mapper, connection, \ - update_version_id in states_to_delete: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: mapper.dispatch.after_delete(mapper, connection, state) @@ -282,8 +348,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): """ for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, states + ): has_identity = bool(state.key) @@ -304,25 +370,29 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): # no instance_key attached to it), and another instance # with the same identity key already exists as persistent. # convert to an UPDATE if so. - if not has_identity and \ - instance_key in uowtransaction.session.identity_map: - instance = \ - uowtransaction.session.identity_map[instance_key] + if ( + not has_identity + and instance_key in uowtransaction.session.identity_map + ): + instance = uowtransaction.session.identity_map[instance_key] existing = attributes.instance_state(instance) if not uowtransaction.was_already_deleted(existing): if not uowtransaction.is_deleted(existing): raise orm_exc.FlushError( "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, - state_str(existing))) + "with persistent instance %s" + % (state_str(state), instance_key, state_str(existing)) + ) base_mapper._log_debug( "detected row switch for identity %s. " "will update %s, remove %s from " - "transaction", instance_key, - state_str(state), state_str(existing)) + "transaction", + instance_key, + state_str(state), + state_str(existing), + ) # remove the "delete" flag from the existing element uowtransaction.remove_state_actions(existing) @@ -332,14 +402,21 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): update_version_id = mapper._get_committed_state_attr_by_column( row_switch if row_switch else state, row_switch.dict if row_switch else dict_, - mapper.version_id_col) + mapper.version_id_col, + ) - yield (state, dict_, mapper, connection, - has_identity, row_switch, update_version_id) + yield ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) -def _organize_states_for_post_update(base_mapper, states, - uowtransaction): +def _organize_states_for_post_update(base_mapper, states, uowtransaction): """Make an initial pass across a set of states for UPDATE corresponding to post_update. @@ -360,26 +437,28 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): """ for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, - states): + base_mapper, uowtransaction, states + ): mapper.dispatch.before_delete(mapper, connection, state) if mapper.version_id_col is not None: - update_version_id = \ - mapper._get_committed_state_attr_by_column( - state, dict_, - mapper.version_id_col) + update_version_id = mapper._get_committed_state_attr_by_column( + state, dict_, mapper.version_id_col + ) else: update_version_id = None - yield ( - state, dict_, mapper, connection, update_version_id) + yield (state, dict_, mapper, connection, update_version_id) def _collect_insert_commands( - table, states_to_insert, - bulk=False, return_defaults=False, render_nulls=False): + table, + states_to_insert, + bulk=False, + return_defaults=False, + render_nulls=False, +): """Identify sets of values to use in INSERT statements for a list of states. @@ -400,10 +479,16 @@ def _collect_insert_commands( col = propkey_to_col[propkey] if value is None and col not in eval_none and not render_nulls: continue - elif not bulk and hasattr(value, '__clause_element__') or \ - isinstance(value, sql.ClauseElement): - value_params[col.key] = value.__clause_element__() \ - if hasattr(value, '__clause_element__') else value + elif ( + not bulk + and hasattr(value, "__clause_element__") + or isinstance(value, sql.ClauseElement) + ): + value_params[col.key] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) else: params[col.key] = value @@ -414,8 +499,11 @@ def _collect_insert_commands( # which might be worth removing, as it should not be necessary # and also produces confusion, given that "missing" and None # now have distinct meanings - for colkey in mapper._insert_cols_as_none[table].\ - difference(params).difference(value_params): + for colkey in ( + mapper._insert_cols_as_none[table] + .difference(params) + .difference(value_params) + ): params[colkey] = None if not bulk or return_defaults: @@ -424,28 +512,38 @@ def _collect_insert_commands( has_all_pks = mapper._pk_keys_by_table[table].issubset(params) if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].\ - issubset(params) + has_all_defaults = mapper._server_default_cols[table].issubset( + params + ) else: has_all_defaults = True else: has_all_defaults = has_all_pks = True - if mapper.version_id_generator is not False \ - and mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator(None) + if ( + mapper.version_id_generator is not False + and mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + params[mapper.version_id_col.key] = mapper.version_id_generator( + None + ) yield ( - state, state_dict, params, mapper, - connection, value_params, has_all_pks, - has_all_defaults) + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) def _collect_update_commands( - uowtransaction, table, states_to_update, - bulk=False): + uowtransaction, table, states_to_update, bulk=False +): """Identify sets of values to use in UPDATE statements for a list of states. @@ -457,8 +555,13 @@ def _collect_update_commands( """ - for state, state_dict, mapper, connection, \ - update_version_id in states_to_update: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: if table not in mapper._pks_by_table: continue @@ -474,36 +577,48 @@ def _collect_update_commands( # look at mapper attribute keys for pk params = dict( (propkey_to_col[propkey].key, state_dict[propkey]) - for propkey in - set(propkey_to_col).intersection(state_dict).difference( - mapper._pk_attr_keys_by_table[table]) + for propkey in set(propkey_to_col) + .intersection(state_dict) + .difference(mapper._pk_attr_keys_by_table[table]) ) has_all_defaults = True else: params = {} for propkey in set(propkey_to_col).intersection( - state.committed_state): + state.committed_state + ): value = state_dict[propkey] col = propkey_to_col[propkey] - if hasattr(value, '__clause_element__') or \ - isinstance(value, sql.ClauseElement): - value_params[col] = value.__clause_element__() \ - if hasattr(value, '__clause_element__') else value + if hasattr(value, "__clause_element__") or isinstance( + value, sql.ClauseElement + ): + value_params[col] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) # guard against values that generate non-__nonzero__ # objects for __eq__() - elif state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey]) is not True: + elif ( + state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey] + ) + is not True + ): params[col.key] = value if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_onupdate_default_cols[table].\ - issubset(params) + has_all_defaults = mapper._server_onupdate_default_cols[ + table + ].issubset(params) else: has_all_defaults = True - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): if not bulk and not (params or value_params): # HACK: check for history in other tables, in case the @@ -511,10 +626,9 @@ def _collect_update_commands( # where the version_id_col is. This logic was lost # from 0.9 -> 1.0.0 and restored in 1.0.6. for prop in mapper._columntoproperty.values(): - history = ( - state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE)) + history = state.manager[prop.key].impl.get_history( + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) if history.added: break else: @@ -525,8 +639,9 @@ def _collect_update_commands( no_params = not params and not value_params params[col._label] = update_version_id - if (bulk or col.key not in params) and \ - mapper.version_id_generator is not False: + if ( + bulk or col.key not in params + ) and mapper.version_id_generator is not False: val = mapper.version_id_generator(update_version_id) params[col.key] = val elif mapper.version_id_generator is False and no_params: @@ -545,9 +660,9 @@ def _collect_update_commands( # look at mapper attribute keys for pk pk_params = dict( (propkey_to_col[propkey]._label, state_dict.get(propkey)) - for propkey in - set(propkey_to_col). - intersection(mapper._pk_attr_keys_by_table[table]) + for propkey in set(propkey_to_col).intersection( + mapper._pk_attr_keys_by_table[table] + ) ) else: pk_params = {} @@ -555,12 +670,15 @@ def _collect_update_commands( propkey = mapper._columntoproperty[col].key history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF) + state, state_dict, attributes.PASSIVE_OFF + ) if history.added: - if not history.deleted or \ - ("pk_cascaded", state, col) in \ - uowtransaction.attributes: + if ( + not history.deleted + or ("pk_cascaded", state, col) + in uowtransaction.attributes + ): pk_params[col._label] = history.added[0] params.pop(col.key, None) else: @@ -573,24 +691,38 @@ def _collect_update_commands( if pk_params[col._label] is None: raise orm_exc.FlushError( "Can't update table %s using NULL for primary " - "key value on column %s" % (table, col)) + "key value on column %s" % (table, col) + ) if params or value_params: params.update(pk_params) yield ( - state, state_dict, params, mapper, - connection, value_params, has_all_defaults, has_all_pks) + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) -def _collect_post_update_commands(base_mapper, uowtransaction, table, - states_to_update, post_update_cols): +def _collect_post_update_commands( + base_mapper, uowtransaction, table, states_to_update, post_update_cols +): """Identify sets of values to use in UPDATE statements for a list of states within a post_update operation. """ - for state, state_dict, mapper, connection, \ - update_version_id in states_to_update: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: # assert table in mapper._pks_by_table @@ -600,100 +732,128 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, for col in mapper._cols_by_table[table]: if col in pks: - params[col._label] = \ - mapper._get_state_attr_by_column( - state, - state_dict, col, passive=attributes.PASSIVE_OFF) + params[col._label] = mapper._get_state_attr_by_column( + state, state_dict, col, passive=attributes.PASSIVE_OFF + ) elif col in post_update_cols or col.onupdate is not None: prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE) + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) if history.added: value = history.added[0] params[col.key] = value hasdata = True if hasdata: - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): col = mapper.version_id_col params[col._label] = update_version_id - if bool(state.key) and col.key not in params and \ - mapper.version_id_generator is not False: + if ( + bool(state.key) + and col.key not in params + and mapper.version_id_generator is not False + ): val = mapper.version_id_generator(update_version_id) params[col.key] = val yield state, state_dict, mapper, connection, params -def _collect_delete_commands(base_mapper, uowtransaction, table, - states_to_delete): +def _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete +): """Identify values to use in DELETE statements for a list of states to be deleted.""" - for state, state_dict, mapper, connection, \ - update_version_id in states_to_delete: + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: if table not in mapper._pks_by_table: continue params = {} for col in mapper._pks_by_table[table]: - params[col.key] = \ - value = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, col) + params[ + col.key + ] = value = mapper._get_committed_state_attr_by_column( + state, state_dict, col + ) if value is None: raise orm_exc.FlushError( "Can't delete from table %s " "using NULL for primary " - "key value on column %s" % (table, col)) + "key value on column %s" % (table, col) + ) - if update_version_id is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): params[mapper.version_id_col.key] = update_version_id yield params, connection -def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update, - bookkeeping=True): +def _emit_update_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + update, + bookkeeping=True, +): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" - needs_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def update_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + clause.clauses.append( + col == sql.bindparam(col._label, type_=col.type) + ) if needs_version_id: clause.clauses.append( - mapper.version_id_col == sql.bindparam( + mapper.version_id_col + == sql.bindparam( mapper.version_id_col._label, - type_=mapper.version_id_col.type)) + type_=mapper.version_id_col.type, + ) + ) stmt = table.update(clause) return stmt - cached_stmt = base_mapper._memo(('update', table), update_stmt) - - for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \ - records in groupby( - update, - lambda rec: ( - rec[4], # connection - set(rec[2]), # set of parameter keys - bool(rec[5]), # whether or not we have "value" parameters - rec[6], # has_all_defaults - rec[7] # has all pks - ) + cached_stmt = base_mapper._memo(("update", table), update_stmt) + + for ( + (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), + records, + ) in groupby( + update, + lambda rec: ( + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]), # whether or not we have "value" parameters + rec[6], # has_all_defaults + rec[7], # has all pks + ), ): rows = 0 records = list(records) @@ -704,8 +864,11 @@ def _emit_update_statements(base_mapper, uowtransaction, if not has_all_pks: statement = statement.return_defaults() return_defaults = True - elif bookkeeping and not has_all_defaults and \ - mapper.base_mapper.eager_defaults: + elif ( + bookkeeping + and not has_all_defaults + and mapper.base_mapper.eager_defaults + ): statement = statement.return_defaults() return_defaults = True elif mapper.version_id_col is not None: @@ -718,17 +881,24 @@ def _emit_update_statements(base_mapper, uowtransaction, else connection.dialect.supports_sane_rowcount_returning ) - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) allow_multirow = has_all_defaults and not needs_version_id if hasvalue: - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_defaults, has_all_pks in records: - c = connection.execute( - statement.values(value_params), - params) + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = connection.execute(statement.values(value_params), params) if bookkeeping: _postfetch( mapper, @@ -738,17 +908,26 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) rows += c.rowcount check_rowcount = assert_singlerow else: if not allow_multirow: check_rowcount = assert_singlerow - for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults, \ - has_all_pks in records: - c = cached_connections[connection].\ - execute(statement, params) + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = cached_connections[connection].execute( + statement, params + ) # TODO: why with bookkeeping=False? if bookkeeping: @@ -760,24 +939,32 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) rows += c.rowcount else: multiparams = [rec[2] for rec in records] check_rowcount = assert_multirow or ( - assert_singlerow and - len(multiparams) == 1 + assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute( + statement, multiparams + ) rows += c.rowcount - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_defaults, has_all_pks in records: + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: if bookkeeping: _postfetch( mapper, @@ -787,59 +974,85 @@ def _emit_update_statements(base_mapper, uowtransaction, state_dict, c, c.context.compiled_parameters[0], - value_params) + value_params, + ) if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(records), rows)) + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description) + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) -def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert, - bookkeeping=True): +def _emit_insert_statements( + base_mapper, + uowtransaction, + cached_connections, + mapper, + table, + insert, + bookkeeping=True, +): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(('insert', table), table.insert) - - for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ - records in groupby( - insert, - lambda rec: ( - rec[4], # connection - set(rec[2]), # parameter keys - bool(rec[5]), # whether we have "value" parameters - rec[6], - rec[7])): + cached_stmt = base_mapper._memo(("insert", table), table.insert) + + for ( + (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + records, + ) in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7], + ), + ): statement = cached_stmt - if not bookkeeping or \ - ( - has_all_defaults - or not base_mapper.eager_defaults - or not connection.dialect.implicit_returning - ) and has_all_pks and not hasvalue: + if ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not connection.dialect.implicit_returning + ) + and has_all_pks + and not hasvalue + ): records = list(records) multiparams = [rec[2] for rec in records] - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute(statement, multiparams) if bookkeeping: - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + ) in zip(records, c.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -849,7 +1062,8 @@ def _emit_insert_statements(base_mapper, uowtransaction, state_dict, c, last_inserted_params, - value_params) + value_params, + ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -859,24 +1073,33 @@ def _emit_insert_statements(base_mapper, uowtransaction, elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) - for state, state_dict, params, mapper_rec, \ - connection, value_params, \ - has_all_pks, has_all_defaults in records: + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: if value_params: result = connection.execute( - statement.values(value_params), - params) + statement.values(value_params), params + ) else: - result = cached_connections[connection].\ - execute(statement, params) + result = cached_connections[connection].execute( + statement, params + ) primary_key = result.context.inserted_primary_key if primary_key is not None: # set primary key attributes - for pk, col in zip(primary_key, - mapper._pks_by_table[table]): + for pk, col in zip( + primary_key, mapper._pks_by_table[table] + ): prop = mapper_rec._columntoproperty[col] if state_dict.get(prop.key) is None: state_dict[prop.key] = pk @@ -890,31 +1113,39 @@ def _emit_insert_statements(base_mapper, uowtransaction, state_dict, result, result.context.compiled_parameters[0], - value_params) + value_params, + ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) -def _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): +def _emit_post_update_statements( + base_mapper, uowtransaction, cached_connections, mapper, table, update +): """Emit UPDATE statements corresponding to value lists collected by _collect_post_update_commands().""" - needs_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def update_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) + clause.clauses.append( + col == sql.bindparam(col._label, type_=col.type) + ) if needs_version_id: clause.clauses.append( - mapper.version_id_col == sql.bindparam( + mapper.version_id_col + == sql.bindparam( mapper.version_id_col._label, - type_=mapper.version_id_col.type)) + type_=mapper.version_id_col.type, + ) + ) stmt = table.update(clause) @@ -923,17 +1154,15 @@ def _emit_post_update_statements(base_mapper, uowtransaction, return stmt - statement = base_mapper._memo(('post_update', table), update_stmt) + statement = base_mapper._memo(("post_update", table), update_stmt) # execute each UPDATE in the order according to the original # list of states to guarantee row access order, but # also group them into common (connection, cols) sets # to support executemany(). for key, records in groupby( - update, lambda rec: ( - rec[3], # connection - set(rec[4]), # parameter keys - ) + update, + lambda rec: (rec[3], set(rec[4])), # connection # parameter keys ): rows = 0 @@ -945,84 +1174,96 @@ def _emit_post_update_statements(base_mapper, uowtransaction, if mapper.version_id_col is None else connection.dialect.supports_sane_rowcount_returning ) - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) allow_multirow = not needs_version_id or assert_multirow - if not allow_multirow: check_rowcount = assert_singlerow - for state, state_dict, mapper_rec, \ - connection, params in records: - c = cached_connections[connection].\ - execute(statement, params) + for state, state_dict, mapper_rec, connection, params in records: + c = cached_connections[connection].execute(statement, params) _postfetch_post_update( - mapper_rec, uowtransaction, table, state, state_dict, - c, c.context.compiled_parameters[0]) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) rows += c.rowcount else: multiparams = [ - params for - state, state_dict, mapper_rec, conn, params in records] + params + for state, state_dict, mapper_rec, conn, params in records + ] check_rowcount = assert_multirow or ( - assert_singlerow and - len(multiparams) == 1 + assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].\ - execute(statement, multiparams) + c = cached_connections[connection].execute(statement, multiparams) rows += c.rowcount - for state, state_dict, mapper_rec, \ - connection, params in records: + for state, state_dict, mapper_rec, connection, params in records: _postfetch_post_update( - mapper_rec, uowtransaction, table, state, state_dict, - c, c.context.compiled_parameters[0]) + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) if check_rowcount: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(records), rows)) + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description) + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) -def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, delete): +def _emit_delete_statements( + base_mapper, uowtransaction, cached_connections, mapper, table, delete +): """Emit DELETE statements corresponding to value lists collected by _collect_delete_commands().""" - need_version_id = mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table] + need_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) def delete_stmt(): clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append( - col == sql.bindparam(col.key, type_=col.type)) + col == sql.bindparam(col.key, type_=col.type) + ) if need_version_id: clause.clauses.append( - mapper.version_id_col == - sql.bindparam( - mapper.version_id_col.key, - type_=mapper.version_id_col.type + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col.key, type_=mapper.version_id_col.type ) ) return table.delete(clause) - statement = base_mapper._memo(('delete', table), delete_stmt) - for connection, recs in groupby( - delete, - lambda rec: rec[1] # connection - ): + statement = base_mapper._memo(("delete", table), delete_stmt) + for connection, recs in groupby(delete, lambda rec: rec[1]): # connection del_objects = [params for params, connection in recs] connection = cached_connections[connection] @@ -1049,9 +1290,10 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, else: util.warn( "Dialect %s does not support deleted rowcount " - "- versioning cannot be verified." % - connection.dialect.dialect_description, - stacklevel=12) + "- versioning cannot be verified." + % connection.dialect.dialect_description, + stacklevel=12, + ) connection.execute(statement, del_objects) else: c = connection.execute(statement, del_objects) @@ -1061,23 +1303,26 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, rows_matched = c.rowcount - if base_mapper.confirm_deleted_rows and \ - rows_matched > -1 and expected != rows_matched: + if ( + base_mapper.confirm_deleted_rows + and rows_matched > -1 + and expected != rows_matched + ): if only_warn: util.warn( "DELETE statement on table '%s' expected to " "delete %d row(s); %d were matched. Please set " "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." % - (table.description, expected, rows_matched) + "configuration to prevent this warning." + % (table.description, expected, rows_matched) ) else: raise orm_exc.StaleDataError( "DELETE statement on table '%s' expected to " "delete %d row(s); %d were matched. Please set " "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." % - (table.description, expected, rows_matched) + "configuration to prevent this warning." + % (table.description, expected, rows_matched) ) @@ -1091,13 +1336,16 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if mapper._readonly_props: readonly = state.unmodified_intersection( [ - p.key for p in mapper._readonly_props + p.key + for p in mapper._readonly_props if ( - p.expire_on_flush and - (not p.deferred or p.key in state.dict) - ) or ( - not p.expire_on_flush and - not p.deferred and p.key not in state.dict + p.expire_on_flush + and (not p.deferred or p.key in state.dict) + ) + or ( + not p.expire_on_flush + and not p.deferred + and p.key not in state.dict ) ] ) @@ -1112,11 +1360,14 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): if base_mapper.eager_defaults: toload_now.extend( state._unloaded_non_object.intersection( - mapper._server_default_plus_onupdate_propkeys) + mapper._server_default_plus_onupdate_propkeys + ) ) - if mapper.version_id_col is not None and \ - mapper.version_id_generator is False: + if ( + mapper.version_id_col is not None + and mapper.version_id_generator is False + ): if mapper._version_id_prop.key in state.unloaded: toload_now.extend([mapper._version_id_prop.key]) @@ -1124,8 +1375,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): state.key = base_mapper._identity_key_from_state(state) loading.load_on_ident( uowtransaction.session.query(mapper), - state.key, refresh_state=state, - only_load_props=toload_now) + state.key, + refresh_state=state, + only_load_props=toload_now, + ) # call after_XXX extensions if not has_identity: @@ -1133,23 +1386,29 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): else: mapper.dispatch.after_update(mapper, connection, state) - if mapper.version_id_generator is False and \ - mapper.version_id_col is not None: + if ( + mapper.version_id_generator is False + and mapper.version_id_col is not None + ): if state_dict[mapper._version_id_prop.key] is None: raise orm_exc.FlushError( - "Instance does not contain a non-NULL version value") + "Instance does not contain a non-NULL version value" + ) -def _postfetch_post_update(mapper, uowtransaction, table, - state, dict_, result, params): +def _postfetch_post_update( + mapper, uowtransaction, table, state, dict_, result, params +): if uowtransaction.is_deleted(state): return prefetch_cols = result.context.compiled.prefetch postfetch_cols = result.context.compiled.postfetch - if mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1164,18 +1423,23 @@ def _postfetch_post_update(mapper, uowtransaction, table, if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs) + state, uowtransaction, load_evt_attrs + ) if postfetch_cols: - state._expire_attributes(state.dict, - [mapper._columntoproperty[c].key - for c in postfetch_cols if c in - mapper._columntoproperty] - ) + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) -def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params): +def _postfetch( + mapper, uowtransaction, table, state, dict_, result, params, value_params +): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -1184,8 +1448,10 @@ def _postfetch(mapper, uowtransaction, table, postfetch_cols = result.context.compiled.postfetch returning_cols = result.context.compiled.returning - if mapper.version_id_col is not None and \ - mapper.version_id_col in mapper._cols_by_table[table]: + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1219,23 +1485,32 @@ def _postfetch(mapper, uowtransaction, table, if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs) + state, uowtransaction, load_evt_attrs + ) if postfetch_cols: - state._expire_attributes(state.dict, - [mapper._columntoproperty[c].key - for c in postfetch_cols if c in - mapper._columntoproperty] - ) + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) # synchronize newly inserted ids from one table to the next # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + sync.populate( + state, + m, + state, + m, + equated_pairs, + uowtransaction, + mapper.passive_updates, + ) def _postfetch_bulk_save(mapper, dict_, table): @@ -1255,8 +1530,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): # organize individual states with the connection # to use for update if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable + connection_callable = uowtransaction.session.connection_callable else: connection = uowtransaction.transaction.connection(base_mapper) connection_callable = None @@ -1275,7 +1549,8 @@ def _cached_connection_dict(base_mapper): return util.PopulateDict( lambda conn: conn.execution_options( compiled_cache=base_mapper._compiled_cache - )) + ) + ) def _sort_states(states): @@ -1287,9 +1562,12 @@ def _sort_states(states): except TypeError as err: raise sa_exc.InvalidRequestError( "Could not sort objects by primary key; primary key " - "values must be sortable in Python (was: %s)" % err) - return sorted(pending, key=operator.attrgetter("insert_order")) + \ - persistent_sorted + "values must be sortable in Python (was: %s)" % err + ) + return ( + sorted(pending, key=operator.attrgetter("insert_order")) + + persistent_sorted + ) class BulkUD(object): @@ -1302,21 +1580,22 @@ class BulkUD(object): def _validate_query_state(self): for attr, methname, notset, op in ( - ('_limit', 'limit()', None, operator.is_), - ('_offset', 'offset()', None, operator.is_), - ('_order_by', 'order_by()', False, operator.is_), - ('_group_by', 'group_by()', False, operator.is_), - ('_distinct', 'distinct()', False, operator.is_), + ("_limit", "limit()", None, operator.is_), + ("_offset", "offset()", None, operator.is_), + ("_order_by", "order_by()", False, operator.is_), + ("_group_by", "group_by()", False, operator.is_), + ("_distinct", "distinct()", False, operator.is_), ( - '_from_obj', - 'join(), outerjoin(), select_from(), or from_self()', - (), operator.eq) + "_from_obj", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), ): if not op(getattr(self.query, attr), notset): raise sa_exc.InvalidRequestError( "Can't call Query.update() or Query.delete() " - "when %s has been called" % - (methname, ) + "when %s has been called" % (methname,) ) @property @@ -1330,8 +1609,8 @@ class BulkUD(object): except KeyError: raise sa_exc.ArgumentError( "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) - for x in lookup)))) + "are %s" % (", ".join(sorted(repr(x) for x in lookup))) + ) else: return klass(*arg) @@ -1400,9 +1679,9 @@ class BulkEvaluate(BulkUD): try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) if query.whereclause is not None: - eval_condition = evaluator_compiler.process( - query.whereclause) + eval_condition = evaluator_compiler.process(query.whereclause) else: + def eval_condition(obj): return True @@ -1411,15 +1690,20 @@ class BulkEvaluate(BulkUD): except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( 'Could not evaluate current criteria in Python: "%s". ' - 'Specify \'fetch\' or False for the ' - 'synchronize_session parameter.' % err) + "Specify 'fetch' or False for the " + "synchronize_session parameter." % err + ) # TODO: detect when the where clause is a trivial primary key match self.matched_objects = [ - obj for (cls, pk, identity_token), obj in - query.session.identity_map.items() - if issubclass(cls, target_cls) and - eval_condition(obj)] + obj + for ( + cls, + pk, + identity_token, + ), obj in query.session.identity_map.items() + if issubclass(cls, target_cls) and eval_condition(obj) + ] class BulkFetch(BulkUD): @@ -1430,11 +1714,11 @@ class BulkFetch(BulkUD): session = query.session context = query._compile_context() select_stmt = context.statement.with_only_columns( - self.primary_table.primary_key) + self.primary_table.primary_key + ) self.matched_rows = session.execute( - select_stmt, - mapper=self.mapper, - params=query._params).fetchall() + select_stmt, mapper=self.mapper, params=query._params + ).fetchall() class BulkUpdate(BulkUD): @@ -1447,18 +1731,26 @@ class BulkUpdate(BulkUD): @classmethod def factory(cls, query, synchronize_session, values, update_kwargs): - return BulkUD._factory({ - "evaluate": BulkUpdateEvaluate, - "fetch": BulkUpdateFetch, - False: BulkUpdate - }, synchronize_session, query, values, update_kwargs) + return BulkUD._factory( + { + "evaluate": BulkUpdateEvaluate, + "fetch": BulkUpdateFetch, + False: BulkUpdate, + }, + synchronize_session, + query, + values, + update_kwargs, + ) @property def _resolved_values(self): values = [] for k, v in ( - self.values.items() if hasattr(self.values, 'items') - else self.values): + self.values.items() + if hasattr(self.values, "items") + else self.values + ): if self.mapper: if isinstance(k, util.string_types): desc = _entity_descriptor(self.mapper, k) @@ -1478,7 +1770,7 @@ class BulkUpdate(BulkUD): if isinstance(k, attributes.QueryableAttribute): values.append((k.key, v)) continue - elif hasattr(k, '__clause_element__'): + elif hasattr(k, "__clause_element__"): k = k.__clause_element__() if self.mapper and isinstance(k, expression.ColumnElement): @@ -1490,18 +1782,22 @@ class BulkUpdate(BulkUD): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k) + "Invalid expression type: %r" % k + ) return values def _do_exec(self): values = self._resolved_values - if not self.update_kwargs.get('preserve_parameter_order', False): + if not self.update_kwargs.get("preserve_parameter_order", False): values = dict(values) - update_stmt = sql.update(self.primary_table, - self.context.whereclause, values, - **self.update_kwargs) + update_stmt = sql.update( + self.primary_table, + self.context.whereclause, + values, + **self.update_kwargs + ) self._execute_stmt(update_stmt) @@ -1518,15 +1814,18 @@ class BulkDelete(BulkUD): @classmethod def factory(cls, query, synchronize_session): - return BulkUD._factory({ - "evaluate": BulkDeleteEvaluate, - "fetch": BulkDeleteFetch, - False: BulkDelete - }, synchronize_session, query) + return BulkUD._factory( + { + "evaluate": BulkDeleteEvaluate, + "fetch": BulkDeleteFetch, + False: BulkDelete, + }, + synchronize_session, + query, + ) def _do_exec(self): - delete_stmt = sql.delete(self.primary_table, - self.context.whereclause) + delete_stmt = sql.delete(self.primary_table, self.context.whereclause) self._execute_stmt(delete_stmt) @@ -1544,32 +1843,33 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): values = self._resolved_values_keys_as_propnames for key, value in values: self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value)) + expression._literal_as_binds(value) + ) def _do_post_synchronize(self): session = self.query.session states = set() evaluated_keys = list(self.value_evaluators.keys()) for obj in self.matched_objects: - state, dict_ = attributes.instance_state(obj),\ - attributes.instance_dict(obj) + state, dict_ = ( + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) # only evaluate unmodified attributes - to_evaluate = state.unmodified.intersection( - evaluated_keys) + to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: dict_[key] = self.value_evaluators[key](obj) - state.manager.dispatch.refresh( - state, None, to_evaluate) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) # expire attributes with pending changes # (there was no autoflush, so they are overwritten) - state._expire_attributes(dict_, - set(evaluated_keys). - difference(to_evaluate)) + state._expire_attributes( + dict_, set(evaluated_keys).difference(to_evaluate) + ) states.add(state) session._register_altered(states) @@ -1580,8 +1880,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): def _do_post_synchronize(self): self.query.session._remove_newly_deleted( - [attributes.instance_state(obj) - for obj in self.matched_objects]) + [attributes.instance_state(obj) for obj in self.matched_objects] + ) class BulkUpdateFetch(BulkFetch, BulkUpdate): @@ -1592,15 +1892,18 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): session = self.query.session target_mapper = self.query._mapper_zero() - states = set([ - attributes.instance_state(session.identity_map[identity_key]) - for identity_key in [ - target_mapper.identity_key_from_primary_key( - list(primary_key)) - for primary_key in self.matched_rows + states = set( + [ + attributes.instance_state(session.identity_map[identity_key]) + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key) + ) + for primary_key in self.matched_rows + ] + if identity_key in session.identity_map ] - if identity_key in session.identity_map - ]) + ) values = self._resolved_values_keys_as_propnames attrib = set(k for k, v in values) @@ -1622,10 +1925,13 @@ class BulkDeleteFetch(BulkFetch, BulkDelete): # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key)) + list(primary_key) + ) if identity_key in session.identity_map: session._remove_newly_deleted( - [attributes.instance_state( - session.identity_map[identity_key] - )] + [ + attributes.instance_state( + session.identity_map[identity_key] + ) + ] ) |
