summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/mapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/mapper.py')
-rw-r--r--lib/sqlalchemy/orm/mapper.py629
1 files changed, 6 insertions, 623 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index bf277664a..42acb4928 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1875,500 +1875,6 @@ class Mapper(object):
self._memoized_values[key] = value = callable_()
return value
- def _post_update(self, states, uowtransaction, post_update_cols):
- """Issue UPDATE statements on behalf of a relationship() which
- specifies post_update.
-
- """
- cached_connections = util.PopulateDict(
- lambda conn:conn.execution_options(
- compiled_cache=self._compiled_cache
- ))
-
- # if session has a connection callable,
- # organize individual states with the connection
- # to use for update
- if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(self)
- connection_callable = None
-
- tups = []
- for state in _sort_states(states):
- if connection_callable:
- conn = connection_callable(self, state.obj())
- else:
- conn = connection
-
- mapper = _state_mapper(state)
-
- tups.append((state, state.dict, mapper, conn))
-
- table_to_mapper = self._sorted_tables
-
- for table in table_to_mapper:
- update = []
-
- for state, state_dict, mapper, connection in tups:
- if table not in mapper._pks_by_table:
- continue
-
- pks = mapper._pks_by_table[table]
- params = {}
- hasdata = False
-
- for col in mapper._cols_by_table[table]:
- if col in pks:
- params[col._label] = \
- mapper._get_state_attr_by_column(
- state,
- state_dict, col)
- elif col in post_update_cols:
- prop = mapper._columntoproperty[col]
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- value = history.added[0]
- params[col.key] = value
- hasdata = True
- if hasdata:
- update.append((state, state_dict, params, mapper,
- connection))
-
- if update:
- mapper = table_to_mapper[table]
-
- def update_stmt():
- clause = sql.and_()
-
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
-
- return table.update(clause)
-
- statement = self._memo(('post_update', table), update_stmt)
-
- # execute each UPDATE in the order according to the original
- # list of states to guarantee row access order, but
- # also group them into common (connection, cols) sets
- # to support executemany().
- for key, grouper in groupby(
- update, lambda rec: (rec[4], rec[2].keys())
- ):
- multiparams = [params for state, state_dict,
- params, mapper, conn in grouper]
- cached_connections[connection].\
- execute(statement, multiparams)
-
- def _save_obj(self, states, uowtransaction, single=False):
- """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
- of objects.
-
- This is called within the context of a UOWTransaction during a
- flush operation, given a list of states to be flushed. The
- base mapper in an inheritance hierarchy handles the inserts/
- updates for all descendant mappers.
-
- """
-
- # if batch=false, call _save_obj separately for each object
- if not single and not self.batch:
- for state in _sort_states(states):
- self._save_obj([state],
- uowtransaction,
- single=True)
- return
-
- # if session has a connection callable,
- # organize individual states with the connection
- # to use for insert/update
- if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(self)
- connection_callable = None
-
- tups = []
-
- for state in _sort_states(states):
- if connection_callable:
- conn = connection_callable(self, state.obj())
- else:
- conn = connection
-
- has_identity = bool(state.key)
- mapper = _state_mapper(state)
- instance_key = state.key or mapper._identity_key_from_state(state)
-
- row_switch = None
-
- # call before_XXX extensions
- if not has_identity:
- mapper.dispatch.before_insert(mapper, conn, state)
- else:
- mapper.dispatch.before_update(mapper, conn, state)
-
- # detect if we have a "pending" instance (i.e. has
- # no instance_key attached to it), and another instance
- # with the same identity key already exists as persistent.
- # convert to an UPDATE if so.
- if not has_identity and \
- instance_key in uowtransaction.session.identity_map:
- instance = \
- uowtransaction.session.identity_map[instance_key]
- existing = attributes.instance_state(instance)
- if not uowtransaction.is_deleted(existing):
- raise orm_exc.FlushError(
- "New instance %s with identity key %s conflicts "
- "with persistent instance %s" %
- (state_str(state), instance_key,
- state_str(existing)))
-
- self._log_debug(
- "detected row switch for identity %s. "
- "will update %s, remove %s from "
- "transaction", instance_key,
- state_str(state), state_str(existing))
-
- # remove the "delete" flag from the existing element
- uowtransaction.remove_state_actions(existing)
- row_switch = existing
-
- tups.append(
- (state, state.dict, mapper, conn,
- has_identity, instance_key, row_switch)
- )
-
- # dictionary of connection->connection_with_cache_options.
- cached_connections = util.PopulateDict(
- lambda conn:conn.execution_options(
- compiled_cache=self._compiled_cache
- ))
-
- table_to_mapper = self._sorted_tables
-
- for table in table_to_mapper:
- insert = []
- update = []
-
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in tups:
- if table not in mapper._pks_by_table:
- continue
-
- pks = mapper._pks_by_table[table]
-
- isinsert = not has_identity and not row_switch
-
- params = {}
- value_params = {}
-
- if isinsert:
- has_all_pks = True
- for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col:
- params[col.key] = \
- mapper.version_id_generator(None)
- else:
- # pull straight from the dict for
- # pending objects
- prop = mapper._columntoproperty[col]
- value = state_dict.get(prop.key, None)
-
- if value is None:
- if col in pks:
- has_all_pks = False
- elif col.default is None and \
- col.server_default is None:
- params[col.key] = value
-
- elif isinstance(value, sql.ClauseElement):
- value_params[col] = value
- else:
- params[col.key] = value
-
- insert.append((state, state_dict, params, mapper,
- connection, value_params, has_all_pks))
- else:
- hasdata = hasnull = False
- for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col:
- params[col._label] = \
- mapper._get_committed_state_attr_by_column(
- row_switch or state,
- row_switch and row_switch.dict
- or state_dict,
- col)
-
- prop = mapper._columntoproperty[col]
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE
- )
- if history.added:
- params[col.key] = history.added[0]
- hasdata = True
- else:
- params[col.key] = \
- mapper.version_id_generator(
- params[col._label])
-
- # HACK: check for history, in case the
- # history is only
- # in a different table than the one
- # where the version_id_col is.
- for prop in mapper._columntoproperty.\
- itervalues():
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- hasdata = True
- else:
- prop = mapper._columntoproperty[col]
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- if isinstance(history.added[0],
- sql.ClauseElement):
- value_params[col] = history.added[0]
- else:
- value = history.added[0]
- params[col.key] = value
-
- if col in pks:
- if history.deleted and \
- not row_switch:
- # if passive_updates and sync detected
- # this was a pk->pk sync, use the new
- # value to locate the row, since the
- # DB would already have set this
- if ("pk_cascaded", state, col) in \
- uowtransaction.\
- attributes:
- value = history.added[0]
- params[col._label] = value
- else:
- # use the old value to
- # locate the row
- value = history.deleted[0]
- params[col._label] = value
- hasdata = True
- else:
- # row switch logic can reach us here
- # remove the pk from the update params
- # so the update doesn't
- # attempt to include the pk in the
- # update statement
- del params[col.key]
- value = history.added[0]
- params[col._label] = value
- if value is None:
- hasnull = True
- else:
- hasdata = True
- elif col in pks:
- value = state.manager[prop.key].\
- impl.get(state, state_dict)
- if value is None:
- hasnull = True
- params[col._label] = value
- if hasdata:
- if hasnull:
- raise sa_exc.FlushError(
- "Can't update table "
- "using NULL for primary "
- "key value")
- update.append((state, state_dict, params, mapper,
- connection, value_params))
-
- if update:
- mapper = table_to_mapper[table]
-
- needs_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
-
- def update_stmt():
- clause = sql.and_()
-
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
-
- if needs_version_id:
- clause.clauses.append(mapper.version_id_col ==\
- sql.bindparam(mapper.version_id_col._label,
- type_=col.type))
-
- return table.update(clause)
-
- statement = self._memo(('update', table), update_stmt)
-
- rows = 0
- for state, state_dict, params, mapper, \
- connection, value_params in update:
-
- if value_params:
- c = connection.execute(
- statement.values(value_params),
- params)
- else:
- c = cached_connections[connection].\
- execute(statement, params)
-
- mapper._postfetch(
- uowtransaction,
- table,
- state,
- state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
- c.context.compiled_parameters[0],
- value_params)
- rows += c.rowcount
-
- if connection.dialect.supports_sane_rowcount:
- if rows != len(update):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to update %d row(s); "
- "%d were matched." %
- (table.description, len(update), rows))
-
- elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description,
- stacklevel=12)
-
- if insert:
- statement = self._memo(('insert', table), table.insert)
-
- for (connection, pkeys, hasvalue, has_all_pks), \
- records in groupby(insert,
- lambda rec: (rec[4],
- rec[2].keys(),
- bool(rec[5]),
- rec[6])
- ):
- if has_all_pks and not hasvalue:
- records = list(records)
- multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].\
- execute(statement, multiparams)
-
- for (state, state_dict, params, mapper,
- conn, value_params, has_all_pks), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
- mapper._postfetch(
- uowtransaction,
- table,
- state,
- state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
- last_inserted_params,
- value_params)
-
- else:
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_pks in records:
-
- if value_params:
- result = connection.execute(
- statement.values(value_params),
- params)
- else:
- result = cached_connections[connection].\
- execute(statement, params)
-
- primary_key = result.context.inserted_primary_key
-
- if primary_key is not None:
- # set primary key attributes
- for pk, col in zip(primary_key,
- mapper._pks_by_table[table]):
- prop = mapper._columntoproperty[col]
- if state_dict.get(prop.key) is None:
- # TODO: would rather say:
- #state_dict[prop.key] = pk
- mapper._set_state_attr_by_column(
- state,
- state_dict,
- col, pk)
-
- mapper._postfetch(
- uowtransaction,
- table,
- state,
- state_dict,
- result.context.prefetch_cols,
- result.context.postfetch_cols,
- result.context.compiled_parameters[0],
- value_params)
-
-
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in tups:
-
- if mapper._readonly_props:
- readonly = state.unmodified_intersection(
- [p.key for p in mapper._readonly_props
- if p.expire_on_flush or p.key not in state.dict]
- )
- if readonly:
- state.expire_attributes(state.dict, readonly)
-
- # if eager_defaults option is enabled,
- # refresh whatever has been expired.
- if self.eager_defaults and state.unloaded:
- state.key = self._identity_key_from_state(state)
- uowtransaction.session.query(self)._load_on_ident(
- state.key, refresh_state=state,
- only_load_props=state.unloaded)
-
- # call after_XXX extensions
- if not has_identity:
- mapper.dispatch.after_insert(mapper, connection, state)
- else:
- mapper.dispatch.after_update(mapper, connection, state)
-
- def _postfetch(self, uowtransaction, table,
- state, dict_, prefetch_cols, postfetch_cols,
- params, value_params):
- """During a flush, expire attributes in need of newly
- persisted database state."""
-
- if self.version_id_col is not None:
- prefetch_cols = list(prefetch_cols) + [self.version_id_col]
-
- for c in prefetch_cols:
- if c.key in params and c in self._columntoproperty:
- self._set_state_attr_by_column(state, dict_, c, params[c.key])
-
- if postfetch_cols:
- state.expire_attributes(state.dict,
- [self._columntoproperty[c].key
- for c in postfetch_cols if c in
- self._columntoproperty]
- )
-
- # synchronize newly inserted ids from one table to the next
- # TODO: this still goes a little too often. would be nice to
- # have definitive list of "columns that changed" here
- for m, equated_pairs in self._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- self.passive_updates)
-
@util.memoized_property
def _table_to_equated(self):
"""memoized map of tables to collections of columns to be
@@ -2386,128 +1892,6 @@ class Mapper(object):
return result
- def _delete_obj(self, states, uowtransaction):
- """Issue ``DELETE`` statements for a list of objects.
-
- This is called within the context of a UOWTransaction during a
- flush operation.
-
- """
- if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(self)
- connection_callable = None
-
- tups = []
- cached_connections = util.PopulateDict(
- lambda conn:conn.execution_options(
- compiled_cache=self._compiled_cache
- ))
-
- for state in _sort_states(states):
- mapper = _state_mapper(state)
-
- if connection_callable:
- conn = connection_callable(self, state.obj())
- else:
- conn = connection
-
- mapper.dispatch.before_delete(mapper, conn, state)
-
- tups.append((state,
- state.dict,
- _state_mapper(state),
- bool(state.key),
- conn))
-
- table_to_mapper = self._sorted_tables
-
- for table in reversed(table_to_mapper.keys()):
- delete = util.defaultdict(list)
- for state, state_dict, mapper, has_identity, connection in tups:
- if not has_identity or table not in mapper._pks_by_table:
- continue
-
- params = {}
- delete[connection].append(params)
- for col in mapper._pks_by_table[table]:
- params[col.key] = \
- value = \
- mapper._get_state_attr_by_column(
- state, state_dict, col)
- if value is None:
- raise sa_exc.FlushError(
- "Can't delete from table "
- "using NULL for primary "
- "key value")
-
- if mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col):
- params[mapper.version_id_col.key] = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict,
- mapper.version_id_col)
-
- mapper = table_to_mapper[table]
- need_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
-
- def delete_stmt():
- clause = sql.and_()
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(
- col == sql.bindparam(col.key, type_=col.type))
-
- if need_version_id:
- clause.clauses.append(
- mapper.version_id_col ==
- sql.bindparam(
- mapper.version_id_col.key,
- type_=mapper.version_id_col.type
- )
- )
-
- return table.delete(clause)
-
- for connection, del_objects in delete.iteritems():
- statement = self._memo(('delete', table), delete_stmt)
- rows = -1
-
- connection = cached_connections[connection]
-
- if need_version_id and \
- not connection.dialect.supports_sane_multi_rowcount:
- # TODO: need test coverage for this [ticket:1761]
- if connection.dialect.supports_sane_rowcount:
- rows = 0
- # execute deletes individually so that versioned
- # rows can be verified
- for params in del_objects:
- c = connection.execute(statement, params)
- rows += c.rowcount
- else:
- util.warn(
- "Dialect %s does not support deleted rowcount "
- "- versioning cannot be verified." %
- connection.dialect.dialect_description,
- stacklevel=12)
- connection.execute(statement, del_objects)
- else:
- c = connection.execute(statement, del_objects)
- if connection.dialect.supports_sane_multi_rowcount:
- rows = c.rowcount
-
- if rows != -1 and rows != len(del_objects):
- raise orm_exc.StaleDataError(
- "DELETE statement on table '%s' expected to delete %d row(s); "
- "%d were matched." %
- (table.description, len(del_objects), c.rowcount)
- )
-
- for state, state_dict, mapper, has_identity, connection in tups:
- mapper.dispatch.after_delete(mapper, connection, state)
def _instance_processor(self, context, path, reduced_path, adapter,
polymorphic_from=None,
@@ -2517,6 +1901,12 @@ class Mapper(object):
"""Produce a mapper level row processor callable
which processes rows into mapped instances."""
+ # note that this method, most of which exists in a closure
+ # called _instance(), resists being broken out, as
+ # attempts to do so tend to add significant function
+ # call overhead. _instance() is the most
+ # performance-critical section in the whole ORM.
+
pk_cols = self.primary_key
if polymorphic_from or refresh_state:
@@ -2960,13 +2350,6 @@ def _event_on_resurrect(state):
state, state.dict, col, val)
-def _sort_states(states):
- pending = set(states)
- persistent = set(s for s in pending if s.key is not None)
- pending.difference_update(persistent)
- return sorted(pending, key=operator.attrgetter("insert_order")) + \
- sorted(persistent, key=lambda q:q.key[1])
-
class _ColumnMapping(util.py25_dict):
"""Error reporting helper for mapper._columntoproperty."""