diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 547 |
1 files changed, 281 insertions, 266 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 295d4a3d0..511a9cef0 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -18,7 +18,7 @@ import operator from itertools import groupby from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator -from .base import _state_mapper, state_str, _attr_as_key +from .base import state_str, _attr_as_key from ..sql import expression from . import loading @@ -40,32 +40,55 @@ def save_obj(base_mapper, states, uowtransaction, single=False): save_obj(base_mapper, [state], uowtransaction, single=True) return - states_to_insert, states_to_update = _organize_states_for_save( - base_mapper, - states, - uowtransaction) - + states_to_update = [] + states_to_insert = [] cached_connections = _cached_connection_dict(base_mapper) - for table, mapper in base_mapper._sorted_tables.items(): - insert = _collect_insert_commands(base_mapper, uowtransaction, - table, states_to_insert) - - update = _collect_update_commands(base_mapper, uowtransaction, - table, states_to_update) - - if update: - _emit_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) - - if insert: - _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, insert) + for (state, dict_, mapper, connection, + has_identity, + row_switch, update_version_id) in _organize_states_for_save( + base_mapper, states, uowtransaction + ): + if has_identity or row_switch: + states_to_update.append( + (state, dict_, mapper, connection, update_version_id) + ) + else: + states_to_insert.append( + (state, dict_, mapper, connection) + ) - _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update) + for table, mapper in base_mapper._sorted_tables.items(): + if table not in mapper._pks_by_table: + continue + insert = _collect_insert_commands(table, states_to_insert) + + update = _collect_update_commands( + uowtransaction, table, states_to_update) + + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, insert) + + _finalize_insert_update_commands( + base_mapper, uowtransaction, + ( + (state, state_dict, mapper, connection, False) + for state, state_dict, mapper, connection in states_to_insert + ) + ) + _finalize_insert_update_commands( + base_mapper, uowtransaction, + ( + (state, state_dict, mapper, connection, True) + for state, state_dict, mapper, connection, + update_version_id in states_to_update + ) + ) def post_update(base_mapper, states, uowtransaction, post_update_cols): @@ -75,19 +98,28 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): """ cached_connections = _cached_connection_dict(base_mapper) - states_to_update = _organize_states_for_post_update( + states_to_update = list(_organize_states_for_post_update( base_mapper, - states, uowtransaction) + states, uowtransaction)) for table, mapper in base_mapper._sorted_tables.items(): + if table not in mapper._pks_by_table: + continue + + update = ( + (state, state_dict, sub_mapper, connection) + for + state, state_dict, sub_mapper, connection in states_to_update + if table in sub_mapper._pks_by_table + ) + update = _collect_post_update_commands(base_mapper, uowtransaction, - table, states_to_update, + table, update, post_update_cols) - if update: - _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) def delete_obj(base_mapper, states, uowtransaction): @@ -100,24 +132,26 @@ def delete_obj(base_mapper, states, uowtransaction): cached_connections = _cached_connection_dict(base_mapper) - states_to_delete = _organize_states_for_delete( + states_to_delete = list(_organize_states_for_delete( base_mapper, states, - uowtransaction) + uowtransaction)) table_to_mapper = base_mapper._sorted_tables for table in reversed(list(table_to_mapper.keys())): + mapper = table_to_mapper[table] + if table not in mapper._pks_by_table: + continue + delete = _collect_delete_commands(base_mapper, uowtransaction, table, states_to_delete) - mapper = table_to_mapper[table] - _emit_delete_statements(base_mapper, uowtransaction, cached_connections, mapper, table, delete) - for state, state_dict, mapper, has_identity, connection \ - in states_to_delete: + for state, state_dict, mapper, connection, \ + update_version_id in states_to_delete: mapper.dispatch.after_delete(mapper, connection, state) @@ -133,17 +167,15 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): """ - states_to_insert = [] - states_to_update = [] - for state, dict_, mapper, connection in _connections_for_states( base_mapper, uowtransaction, states): has_identity = bool(state.key) + instance_key = state.key or mapper._identity_key_from_state(state) - row_switch = None + row_switch = update_version_id = None # call before_XXX extensions if not has_identity: @@ -180,18 +212,14 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): uowtransaction.remove_state_actions(existing) row_switch = existing - if not has_identity and not row_switch: - states_to_insert.append( - (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) - ) - else: - states_to_update.append( - (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) - ) + if (has_identity or row_switch) and mapper.version_id_col is not None: + update_version_id = mapper._get_committed_state_attr_by_column( + row_switch if row_switch else state, + row_switch.dict if row_switch else dict_, + mapper.version_id_col) - return states_to_insert, states_to_update + yield (state, dict_, mapper, connection, + has_identity, row_switch, update_version_id) def _organize_states_for_post_update(base_mapper, states, @@ -204,8 +232,7 @@ def _organize_states_for_post_update(base_mapper, states, the execution per state. """ - return list(_connections_for_states(base_mapper, uowtransaction, - states)) + return _connections_for_states(base_mapper, uowtransaction, states) def _organize_states_for_delete(base_mapper, states, uowtransaction): @@ -216,72 +243,73 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): mapper, the connection to use for the execution per state. """ - states_to_delete = [] - for state, dict_, mapper, connection in _connections_for_states( base_mapper, uowtransaction, states): mapper.dispatch.before_delete(mapper, connection, state) - states_to_delete.append((state, dict_, mapper, - bool(state.key), connection)) - return states_to_delete + if mapper.version_id_col is not None: + update_version_id = \ + mapper._get_committed_state_attr_by_column( + state, dict_, + mapper.version_id_col) + else: + update_version_id = None + + yield ( + state, dict_, mapper, connection, update_version_id) -def _collect_insert_commands(base_mapper, uowtransaction, table, - states_to_insert): +def _collect_insert_commands(table, states_to_insert): """Identify sets of values to use in INSERT statements for a list of states. """ - insert = [] - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert: + for state, state_dict, mapper, connection in states_to_insert: if table not in mapper._pks_by_table: continue - pks = mapper._pks_by_table[table] - params = {} value_params = {} - has_all_pks = True - has_all_defaults = True - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col and \ - mapper.version_id_generator is not False: - val = mapper.version_id_generator(None) - params[col.key] = val + propkey_to_col = mapper._propkey_to_col[table] + + for propkey in set(propkey_to_col).intersection(state_dict): + value = state_dict[propkey] + col = propkey_to_col[propkey] + if value is None: + continue + elif isinstance(value, sql.ClauseElement): + value_params[col.key] = value else: - # pull straight from the dict for - # pending objects - prop = mapper._columntoproperty[col] - value = state_dict.get(prop.key, None) - - if value is None: - if col in pks: - has_all_pks = False - elif col.default is None and \ - col.server_default is None: - params[col.key] = value - elif col.server_default is not None and \ - mapper.base_mapper.eager_defaults: - has_all_defaults = False - - elif isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value + params[col.key] = value + + for colkey in mapper._insert_cols_as_none[table].\ + difference(params).difference(value_params): + params[colkey] = None + + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) + + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True + + if mapper.version_id_generator is not False \ + and mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator(None) - insert.append((state, state_dict, params, mapper, - connection, value_params, has_all_pks, - has_all_defaults)) - return insert + yield ( + state, state_dict, params, mapper, + connection, value_params, has_all_pks, + has_all_defaults) -def _collect_update_commands(base_mapper, uowtransaction, - table, states_to_update): +def _collect_update_commands(uowtransaction, table, states_to_update): """Identify sets of values to use in UPDATE statements for a list of states. @@ -293,9 +321,9 @@ def _collect_update_commands(base_mapper, uowtransaction, """ - update = [] - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_update: + for state, state_dict, mapper, connection, \ + update_version_id in states_to_update: + if table not in mapper._pks_by_table: continue @@ -304,98 +332,59 @@ def _collect_update_commands(base_mapper, uowtransaction, params = {} value_params = {} - hasdata = hasnull = False - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col._label] = \ - mapper._get_committed_state_attr_by_column( - row_switch or state, - row_switch and row_switch.dict - or state_dict, - col) + propkey_to_col = mapper._propkey_to_col[table] - prop = mapper._columntoproperty[col] - history = state.manager[prop.key].impl.get_history( - state, state_dict, attributes.PASSIVE_NO_INITIALIZE - ) - if history.added: - params[col.key] = history.added[0] - hasdata = True + for propkey in set(propkey_to_col).intersection(state.committed_state): + value = state_dict[propkey] + col = propkey_to_col[propkey] + + if not state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey]): + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + + if update_version_id is not None: + col = mapper.version_id_col + params[col._label] = update_version_id + + if col.key not in params and \ + mapper.version_id_generator is not False: + val = mapper.version_id_generator(update_version_id) + params[col.key] = val + + if not (params or value_params): + continue + + pk_params = {} + for col in pks: + propkey = mapper._columntoproperty[col].key + history = state.manager[propkey].impl.get_history( + state, state_dict, attributes.PASSIVE_OFF) + + if history.added: + if not history.deleted or \ + ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + pk_params[col._label] = history.added[0] + params.pop(col.key, None) else: - if mapper.version_id_generator is not False: - val = mapper.version_id_generator(params[col._label]) - params[col.key] = val - - # HACK: check for history, in case the - # history is only - # in a different table than the one - # where the version_id_col is. - for prop in mapper._columntoproperty.values(): - history = ( - state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE)) - if history.added: - hasdata = True + # else, use the old value to locate the row + pk_params[col._label] = history.deleted[0] + params[col.key] = history.added[0] else: - prop = mapper._columntoproperty[col] - history = state.manager[prop.key].impl.get_history( - state, state_dict, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - if isinstance(history.added[0], - sql.ClauseElement): - value_params[col] = history.added[0] - else: - value = history.added[0] - params[col.key] = value - - if col in pks: - if history.deleted and \ - not row_switch: - # if passive_updates and sync detected - # this was a pk->pk sync, use the new - # value to locate the row, since the - # DB would already have set this - if ("pk_cascaded", state, col) in \ - uowtransaction.attributes: - value = history.added[0] - params[col._label] = value - else: - # use the old value to - # locate the row - value = history.deleted[0] - params[col._label] = value - hasdata = True - else: - # row switch logic can reach us here - # remove the pk from the update params - # so the update doesn't - # attempt to include the pk in the - # update statement - del params[col.key] - value = history.added[0] - params[col._label] = value - if value is None: - hasnull = True - else: - hasdata = True - elif col in pks: - value = state.manager[prop.key].impl.get( - state, state_dict) - if value is None: - hasnull = True - params[col._label] = value + pk_params[col._label] = history.unchanged[0] - if hasdata: - if hasnull: + if params or value_params: + if None in pk_params.values(): raise orm_exc.FlushError( - "Can't update table " - "using NULL for primary " + "Can't update table using NULL for primary " "key value") - update.append((state, state_dict, params, mapper, - connection, value_params)) - return update + params.update(pk_params) + yield ( + state, state_dict, params, mapper, + connection, value_params) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -405,10 +394,10 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, """ - update = [] for state, state_dict, mapper, connection in states_to_update: - if table not in mapper._pks_by_table: - continue + + # assert table in mapper._pks_by_table + pks = mapper._pks_by_table[table] params = {} hasdata = False @@ -417,8 +406,8 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, if col in pks: params[col._label] = \ mapper._get_state_attr_by_column( - state, - state_dict, col) + state, + state_dict, col) elif col in post_update_cols: prop = mapper._columntoproperty[col] @@ -430,9 +419,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, params[col.key] = value hasdata = True if hasdata: - update.append((state, state_dict, params, mapper, - connection)) - return update + yield params, connection def _collect_delete_commands(base_mapper, uowtransaction, table, @@ -440,33 +427,28 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, """Identify values to use in DELETE statements for a list of states to be deleted.""" - delete = util.defaultdict(list) + for state, state_dict, mapper, connection, \ + update_version_id in states_to_delete: - for state, state_dict, mapper, has_identity, connection \ - in states_to_delete: - if not has_identity or table not in mapper._pks_by_table: + if table not in mapper._pks_by_table: continue params = {} - delete[connection].append(params) for col in mapper._pks_by_table[table]: params[col.key] = \ value = \ mapper._get_committed_state_attr_by_column( - state, state_dict, col) + state, state_dict, col) if value is None: raise orm_exc.FlushError( "Can't delete from table " "using NULL for primary " "key value") - if mapper.version_id_col is not None and \ + if update_version_id is not None and \ table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, - mapper.version_id_col) - return delete + params[mapper.version_id_col.key] = update_version_id + yield params, connection def _emit_update_statements(base_mapper, uowtransaction, @@ -500,41 +482,80 @@ def _emit_update_statements(base_mapper, uowtransaction, statement = base_mapper._memo(('update', table), update_stmt) - rows = 0 - for state, state_dict, params, mapper, \ - connection, value_params in update: - - if value_params: - c = connection.execute( - statement.values(value_params), - params) + for (connection, paramkeys, hasvalue), \ + records in groupby( + update, + lambda rec: ( + rec[4], + tuple(sorted(rec[2])), + bool(rec[5]))): + + rows = 0 + records = list(records) + if hasvalue: + for state, state_dict, params, mapper, \ + connection, value_params in records: + c = connection.execute( + statement.values(value_params), + params) + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount else: - c = cached_connections[connection].\ - execute(statement, params) - - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) - rows += c.rowcount - - if connection.dialect.supports_sane_rowcount: - if rows != len(update): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(update), rows)) - - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) + if needs_version_id and \ + not connection.dialect.supports_sane_multi_rowcount and \ + connection.dialect.supports_sane_rowcount: + for state, state_dict, params, mapper, \ + connection, value_params in records: + c = cached_connections[connection].\ + execute(statement, params) + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount + else: + multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ + execute(statement, multiparams) + + rows += c.rowcount + for state, state_dict, params, mapper, \ + connection, value_params in records: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + + if connection.dialect.supports_sane_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(records), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) def _emit_insert_statements(base_mapper, uowtransaction, @@ -547,7 +568,7 @@ def _emit_insert_statements(base_mapper, uowtransaction, for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ records in groupby(insert, lambda rec: (rec[4], - list(rec[2].keys()), + tuple(sorted(rec[2].keys())), bool(rec[5]), rec[6], rec[7]) ): @@ -604,13 +625,7 @@ def _emit_insert_statements(base_mapper, uowtransaction, mapper._pks_by_table[table]): prop = mapper_rec._columntoproperty[col] if state_dict.get(prop.key) is None: - # TODO: would rather say: - # state_dict[prop.key] = pk - mapper_rec._set_state_attr_by_column( - state, - state_dict, - col, pk) - + state_dict[prop.key] = pk _postfetch( mapper_rec, uowtransaction, @@ -643,11 +658,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # also group them into common (connection, cols) sets # to support executemany(). for key, grouper in groupby( - update, lambda rec: (rec[4], list(rec[2].keys())) + update, lambda rec: (rec[1], sorted(rec[0])) ): connection = key[0] - multiparams = [params for state, state_dict, - params, mapper, conn in grouper] + multiparams = [params for params, conn in grouper] cached_connections[connection].\ execute(statement, multiparams) @@ -677,8 +691,15 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, return table.delete(clause) - for connection, del_objects in delete.items(): - statement = base_mapper._memo(('delete', table), delete_stmt) + statement = base_mapper._memo(('delete', table), delete_stmt) + for connection, recs in groupby( + delete, + lambda rec: rec[1] + ): + del_objects = [ + params + for params, connection in recs + ] connection = cached_connections[connection] @@ -731,15 +752,12 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, ) -def _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update): +def _finalize_insert_update_commands(base_mapper, uowtransaction, states): """finalize state on states that have been inserted or updated, including calling after_insert/after_update events. """ - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert + \ - states_to_update: + for state, state_dict, mapper, connection, has_identity in states: if mapper._readonly_props: readonly = state.unmodified_intersection( @@ -795,11 +813,11 @@ def _postfetch(mapper, uowtransaction, table, for col in returning_cols: if col.primary_key: continue - mapper._set_state_attr_by_column(state, dict_, col, row[col]) + dict_[mapper._columntoproperty[col].key] = row[col] for c in prefetch_cols: if c.key in params and c in mapper._columntoproperty: - mapper._set_state_attr_by_column(state, dict_, c, params[c.key]) + dict_[mapper._columntoproperty[c].key] = params[c.key] if postfetch_cols: state._expire_attributes(state.dict, @@ -833,17 +851,14 @@ def _connections_for_states(base_mapper, uowtransaction, states): connection_callable = \ uowtransaction.session.connection_callable else: - connection = None + connection = uowtransaction.transaction.connection(base_mapper) connection_callable = None for state in _sort_states(states): if connection_callable: connection = connection_callable(base_mapper, state.obj()) - elif not connection: - connection = uowtransaction.transaction.connection( - base_mapper) - mapper = _state_mapper(state) + mapper = state.manager.mapper yield state, state.dict, mapper, connection |