diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 780 |
1 files changed, 780 insertions, 0 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py new file mode 100644 index 000000000..746395919 --- /dev/null +++ b/lib/sqlalchemy/orm/persistence.py @@ -0,0 +1,780 @@ +# orm/persistence.py +# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" + +import operator +from itertools import groupby + +from sqlalchemy import sql, util, exc as sa_exc +from sqlalchemy.orm import attributes, sync, \ + exc as orm_exc + +from sqlalchemy.orm.util import _state_mapper, state_str + +def save_obj(base_mapper, states, uowtransaction, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list + of objects. + + This is called within the context of a UOWTransaction during a + flush operation, given a list of states to be flushed. The + base mapper in an inheritance hierarchy handles the inserts/ + updates for all descendant mappers. + + """ + + # if batch=false, call _save_obj separately for each object + if not single and not base_mapper.batch: + for state in _sort_states(states): + save_obj(base_mapper, [state], uowtransaction, single=True) + return + + states_to_insert, states_to_update = _organize_states_for_save( + base_mapper, + states, + uowtransaction) + + cached_connections = _cached_connection_dict(base_mapper) + + for table, mapper in base_mapper._sorted_tables.iteritems(): + insert = _collect_insert_commands(base_mapper, uowtransaction, + table, states_to_insert) + + update = _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update) + + if update: + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + + if insert: + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + table, insert) + + _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update) + +def post_update(base_mapper, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + cached_connections = _cached_connection_dict(base_mapper) + + states_to_update = _organize_states_for_post_update( + base_mapper, + states, uowtransaction) + + + for table, mapper in base_mapper._sorted_tables.iteritems(): + update = _collect_post_update_commands(base_mapper, uowtransaction, + table, states_to_update, + post_update_cols) + + if update: + _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + +def delete_obj(base_mapper, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + + cached_connections = _cached_connection_dict(base_mapper) + + states_to_delete = _organize_states_for_delete( + base_mapper, + states, + uowtransaction) + + table_to_mapper = base_mapper._sorted_tables + + for table in reversed(table_to_mapper.keys()): + delete = _collect_delete_commands(base_mapper, uowtransaction, + table, states_to_delete) + + mapper = table_to_mapper[table] + + _emit_delete_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, delete) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + mapper.dispatch.after_delete(mapper, connection, state) + +def _organize_states_for_save(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for INSERT or + UPDATE. + + This includes splitting out into distinct lists for + each, calling before_insert/before_update, obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state, + and the identity flag. + + """ + + states_to_insert = [] + states_to_update = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + has_identity = bool(state.key) + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = None + + # call before_XXX extensions + if not has_identity: + mapper.dispatch.before_insert(mapper, connection, state) + else: + mapper.dispatch.before_update(mapper, connection, state) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if not has_identity and \ + instance_key in uowtransaction.session.identity_map: + instance = \ + uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise orm_exc.FlushError( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" % + (state_str(state), instance_key, + state_str(existing))) + + base_mapper._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing + + if not has_identity and not row_switch: + states_to_insert.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + else: + states_to_update.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + + return states_to_insert, states_to_update + +def _organize_states_for_post_update(base_mapper, states, + uowtransaction): + """Make an initial pass across a set of states for UPDATE + corresponding to post_update. + + This includes obtaining key information for each state + including its dictionary, mapper, the connection to use for + the execution per state. + + """ + return list(_connections_for_states(base_mapper, uowtransaction, + states)) + +def _organize_states_for_delete(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for DELETE. + + This includes calling out before_delete and obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state. + + """ + states_to_delete = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + mapper.dispatch.before_delete(mapper, connection, state) + + states_to_delete.append((state, dict_, mapper, + bool(state.key), connection)) + return states_to_delete + +def _collect_insert_commands(base_mapper, uowtransaction, table, + states_to_insert): + """Identify sets of values to use in INSERT statements for a + list of states. + + """ + insert = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + has_all_pks = True + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col.key] = mapper.version_id_generator(None) + else: + # pull straight from the dict for + # pending objects + prop = mapper._columntoproperty[col] + value = state_dict.get(prop.key, None) + + if value is None: + if col in pks: + has_all_pks = False + elif col.default is None and \ + col.server_default is None: + params[col.key] = value + + elif isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + + insert.append((state, state_dict, params, mapper, + connection, value_params, has_all_pks)) + return insert + +def _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update): + """Identify sets of values to use in UPDATE statements for a + list of states. + + This function works intricately with the history system + to determine exactly what values should be updated + as well as how the row should be matched within an UPDATE + statement. Includes some tricky scenarios where the primary + key of an object might have been changed. + + """ + + update = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_update: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + hasdata = hasnull = False + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col._label] = \ + mapper._get_committed_state_attr_by_column( + row_switch or state, + row_switch and row_switch.dict + or state_dict, + col) + + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + params[col.key] = history.added[0] + hasdata = True + else: + params[col.key] = mapper.version_id_generator( + params[col._label]) + + # HACK: check for history, in case the + # history is only + # in a different table than the one + # where the version_id_col is. + for prop in mapper._columntoproperty.itervalues(): + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + hasdata = True + else: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + if isinstance(history.added[0], + sql.ClauseElement): + value_params[col] = history.added[0] + else: + value = history.added[0] + params[col.key] = value + + if col in pks: + if history.deleted and \ + not row_switch: + # if passive_updates and sync detected + # this was a pk->pk sync, use the new + # value to locate the row, since the + # DB would already have set this + if ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + value = history.added[0] + params[col._label] = value + else: + # use the old value to + # locate the row + value = history.deleted[0] + params[col._label] = value + hasdata = True + else: + # row switch logic can reach us here + # remove the pk from the update params + # so the update doesn't + # attempt to include the pk in the + # update statement + del params[col.key] + value = history.added[0] + params[col._label] = value + if value is None: + hasnull = True + else: + hasdata = True + elif col in pks: + value = state.manager[prop.key].impl.get( + state, state_dict) + if value is None: + hasnull = True + params[col._label] = value + if hasdata: + if hasnull: + raise sa_exc.FlushError( + "Can't update table " + "using NULL for primary " + "key value") + update.append((state, state_dict, params, mapper, + connection, value_params)) + return update + + +def _collect_post_update_commands(base_mapper, uowtransaction, table, + states_to_update, post_update_cols): + """Identify sets of values to use in UPDATE statements for a + list of states within a post_update operation. + + """ + + update = [] + for state, state_dict, mapper, connection in states_to_update: + if table not in mapper._pks_by_table: + continue + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = \ + mapper._get_state_attr_by_column( + state, + state_dict, col) + elif col in post_update_cols: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + value = history.added[0] + params[col.key] = value + hasdata = True + if hasdata: + update.append((state, state_dict, params, mapper, + connection)) + return update + +def _collect_delete_commands(base_mapper, uowtransaction, table, + states_to_delete): + """Identify values to use in DELETE statements for a list of + states to be deleted.""" + + delete = util.defaultdict(list) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + if not has_identity or table not in mapper._pks_by_table: + continue + + params = {} + delete[connection].append(params) + for col in mapper._pks_by_table[table]: + params[col.key] = \ + value = \ + mapper._get_state_attr_by_column( + state, state_dict, col) + if value is None: + raise sa_exc.FlushError( + "Can't delete from table " + "using NULL for primary " + "key value") + + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = \ + mapper._get_committed_state_attr_by_column( + state, state_dict, + mapper.version_id_col) + return delete + + +def _emit_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_update_commands().""" + + needs_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append(mapper.version_id_col ==\ + sql.bindparam(mapper.version_id_col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('update', table), update_stmt) + + rows = 0 + for state, state_dict, params, mapper, \ + connection, value_params in update: + + if value_params: + c = connection.execute( + statement.values(value_params), + params) + else: + c = cached_connections[connection].\ + execute(statement, params) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(update): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(update), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + +def _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, table, insert): + """Emit INSERT statements corresponding to value lists collected + by _collect_insert_commands().""" + + statement = base_mapper._memo(('insert', table), table.insert) + + for (connection, pkeys, hasvalue, has_all_pks), \ + records in groupby(insert, + lambda rec: (rec[4], + rec[2].keys(), + bool(rec[5]), + rec[6]) + ): + if has_all_pks and not hasvalue: + records = list(records) + multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ + execute(statement, multiparams) + + for (state, state_dict, params, mapper, + conn, value_params, has_all_pks), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + last_inserted_params, + value_params) + + else: + for state, state_dict, params, mapper, \ + connection, value_params, \ + has_all_pks in records: + + if value_params: + result = connection.execute( + statement.values(value_params), + params) + else: + result = cached_connections[connection].\ + execute(statement, params) + + primary_key = result.context.inserted_primary_key + + if primary_key is not None: + # set primary key attributes + for pk, col in zip(primary_key, + mapper._pks_by_table[table]): + prop = mapper._columntoproperty[col] + if state_dict.get(prop.key) is None: + # TODO: would rather say: + #state_dict[prop.key] = pk + mapper._set_state_attr_by_column( + state, + state_dict, + col, pk) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + result.context.prefetch_cols, + result.context.postfetch_cols, + result.context.compiled_parameters[0], + value_params) + + + +def _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_post_update_commands().""" + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('post_update', table), update_stmt) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, grouper in groupby( + update, lambda rec: (rec[4], rec[2].keys()) + ): + multiparams = [params for state, state_dict, + params, mapper, conn in grouper] + cached_connections[conn].\ + execute(statement, multiparams) + + +def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, + mapper, table, delete): + """Emit DELETE statements corresponding to value lists collected + by _collect_delete_commands().""" + + need_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def delete_stmt(): + clause = sql.and_() + for col in mapper._pks_by_table[table]: + clause.clauses.append( + col == sql.bindparam(col.key, type_=col.type)) + + if need_version_id: + clause.clauses.append( + mapper.version_id_col == + sql.bindparam( + mapper.version_id_col.key, + type_=mapper.version_id_col.type + ) + ) + + return table.delete(clause) + + for connection, del_objects in delete.iteritems(): + statement = base_mapper._memo(('delete', table), delete_stmt) + rows = -1 + + connection = cached_connections[connection] + + if need_version_id and \ + not connection.dialect.supports_sane_multi_rowcount: + # TODO: need test coverage for this [ticket:1761] + if connection.dialect.supports_sane_rowcount: + rows = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute(statement, params) + rows += c.rowcount + else: + util.warn( + "Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." % + connection.dialect.dialect_description, + stacklevel=12) + connection.execute(statement, del_objects) + else: + c = connection.execute(statement, del_objects) + if connection.dialect.supports_sane_multi_rowcount: + rows = c.rowcount + + if rows != -1 and rows != len(del_objects): + raise orm_exc.StaleDataError( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched." % + (table.description, len(del_objects), c.rowcount) + ) + +def _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update): + """finalize state on states that have been inserted or updated, + including calling after_insert/after_update events. + + """ + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert + \ + states_to_update: + + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state.expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, + # refresh whatever has been expired. + if base_mapper.eager_defaults and state.unloaded: + state.key = base_mapper._identity_key_from_state(state) + uowtransaction.session.query(base_mapper)._load_on_ident( + state.key, refresh_state=state, + only_load_props=state.unloaded) + + # call after_XXX extensions + if not has_identity: + mapper.dispatch.after_insert(mapper, connection, state) + else: + mapper.dispatch.after_update(mapper, connection, state) + +def _postfetch(mapper, uowtransaction, table, + state, dict_, prefetch_cols, postfetch_cols, + params, value_params): + """Expire attributes in need of newly persisted database state, + after an INSERT or UPDATE statement has proceeded for that + state.""" + + if mapper.version_id_col is not None: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + mapper._set_state_attr_by_column(state, dict_, c, params[c.key]) + + if postfetch_cols: + state.expire_attributes(state.dict, + [mapper._columntoproperty[c].key + for c in postfetch_cols if c in + mapper._columntoproperty] + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) + +def _connections_for_states(base_mapper, uowtransaction, states): + """Return an iterator of (state, state.dict, mapper, connection). + + The states are sorted according to _sort_states, then paired + with the connection they should be using for the given + unit of work transaction. + + """ + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if uowtransaction.session.connection_callable: + connection_callable = \ + uowtransaction.session.connection_callable + else: + connection = uowtransaction.transaction.connection( + base_mapper) + connection_callable = None + + for state in _sort_states(states): + if connection_callable: + connection = connection_callable(base_mapper, state.obj()) + + mapper = _state_mapper(state) + + yield state, state.dict, mapper, connection + +def _cached_connection_dict(base_mapper): + # dictionary of connection->connection_with_cache_options. + return util.PopulateDict( + lambda conn:conn.execution_options( + compiled_cache=base_mapper._compiled_cache + )) + +def _sort_states(states): + pending = set(states) + persistent = set(s for s in pending if s.key is not None) + pending.difference_update(persistent) + return sorted(pending, key=operator.attrgetter("insert_order")) + \ + sorted(persistent, key=lambda q:q.key[1]) + + |
