diff options
Diffstat (limited to 'lib/sqlalchemy/orm/mapper.py')
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 144 |
1 files changed, 99 insertions, 45 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8af6153d6..87c4c8100 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -23,7 +23,6 @@ deque = __import__('collections').deque from sqlalchemy import sql, util, log, exc as sa_exc from sqlalchemy.sql import expression, visitors, operators, util as sqlutil from sqlalchemy.orm import attributes, exc, sync -from sqlalchemy.orm.identity import IdentityManagedState from sqlalchemy.orm.interfaces import ( MapperProperty, EXT_CONTINUE, PropComparator ) @@ -255,7 +254,8 @@ class Mapper(object): for mapper in self.iterate_to_root(): util.reset_memoized(mapper, '_equivalent_columns') - + util.reset_memoized(mapper, '_sorted_tables') + if self.order_by is False and not self.concrete and self.inherits.order_by is not False: self.order_by = self.inherits.order_by @@ -357,7 +357,6 @@ class Mapper(object): if manager is None: manager = attributes.register_class(self.class_, - instance_state_factory = IdentityManagedState, deferred_scalar_loader = _load_scalar_attributes ) @@ -372,6 +371,8 @@ class Mapper(object): event_registry = manager.events event_registry.add_listener('on_init', _event_on_init) event_registry.add_listener('on_init_failure', _event_on_init_failure) + event_registry.add_listener('on_resurrect', _event_on_resurrect) + for key, method in util.iterate_attributes(self.class_): if isinstance(method, types.FunctionType): if hasattr(method, '__sa_reconstructor__'): @@ -1173,6 +1174,19 @@ class Mapper(object): # persistence + @util.memoized_property + def _sorted_tables(self): + table_to_mapper = {} + for mapper in self.base_mapper.polymorphic_iterator(): + for t in mapper.tables: + table_to_mapper[t] = mapper + + sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys()) + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1198,16 +1212,37 @@ class Mapper(object): # if session has a connection callable, # organize individual states with the connection to use for insert/update + tups = [] if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection_callable(self, state.obj()), + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection, + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) if not postupdate: # call before_XXX extensions - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if not has_identity: if 'before_insert' in mapper.extension: mapper.extension.before_insert(mapper, connection, state.obj()) @@ -1215,39 +1250,44 @@ class Mapper(object): if 'before_update' in mapper.extension: mapper.extension.before_update(mapper, connection, state.obj()) - for state, mapper, connection, has_identity in tups: - # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. convert to an - # UPDATE if so. - instance_key = mapper._identity_key_from_state(state) - if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map: - instance = uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) - if self._should_log_debug: - self._log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) - uowtransaction.set_row_switch(existing) - - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + row_switches = set() + if not postupdate: + for state, mapper, connection, has_identity, instance_key in tups: + # detect if we have a "pending" instance (i.e. has no instance_key attached to it), + # and another instance with the same identity key already exists as persistent. convert to an + # UPDATE if so. + if not has_identity and instance_key in uowtransaction.session.identity_map: + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise exc.FlushError( + "New instance %s with identity key %s conflicts with persistent instance %s" % + (state_str(state), instance_key, state_str(existing))) + if self._should_log_debug: + self._log_debug( + "detected row switch for identity %s. will update %s, remove %s from transaction", + instance_key, state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.set_row_switch(existing) + row_switches.add(state) + + table_to_mapper = self._sorted_tables - for table in sqlutil.sort_tables(table_to_mapper.iterkeys()): + for table in table_to_mapper.iterkeys(): insert = [] update = [] - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if table not in mapper._pks_by_table: continue + pks = mapper._pks_by_table[table] - instance_key = mapper._identity_key_from_state(state) - + if self._should_log_debug: self._log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity + isinsert = not has_identity and not postupdate and state not in row_switches params = {} value_params = {} @@ -1364,7 +1404,7 @@ class Mapper(object): sync.populate(state, m, state, m, m._inherits_equated_pairs) if not postupdate: - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: # expire readonly attributes readonly = state.unmodified.intersection( @@ -1434,12 +1474,9 @@ class Mapper(object): if 'before_delete' in mapper.extension: mapper.extension.before_delete(mapper, connection, state.obj()) - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + table_to_mapper = self._sorted_tables - for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())): + for table in reversed(table_to_mapper.keys()): delete = {} for state, mapper, connection in tups: if table not in mapper._pks_by_table: @@ -1485,6 +1522,10 @@ class Mapper(object): for dep in self._props.values() + self._dependency_processors: dep.register_dependencies(uowcommit) + def _register_processors(self, uowcommit): + for dep in self._props.values() + self._dependency_processors: + dep.register_processors(uowcommit) + # result set conversion def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None): @@ -1514,7 +1555,7 @@ class Mapper(object): new_populators = [] existing_populators = [] - def populate_state(state, row, isnew, only_load_props, **flags): + def populate_state(state, dict_, row, isnew, only_load_props, **flags): if isnew: if context.options: state.load_options = context.options @@ -1533,7 +1574,7 @@ class Mapper(object): populators = [p for p in populators if p[0] in only_load_props] for key, populator in populators: - populator(state, row, isnew=isnew, **flags) + populator(state, dict_, row, isnew=isnew, **flags) session_identity_map = context.session.identity_map @@ -1573,9 +1614,11 @@ class Mapper(object): if identitykey in session_identity_map: instance = session_identity_map[identitykey] state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) if self._should_log_debug: - self._log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), identitykey)) + self._log_debug("_instance(): using existing instance %s identity %s", + instance_str(instance), identitykey) isnew = state.runid != context.runid currentload = not isnew @@ -1592,12 +1635,13 @@ class Mapper(object): # when eager_defaults is True. state = refresh_state instance = state.obj() + dict_ = attributes.instance_dict(instance) isnew = state.runid != context.runid currentload = True loaded_instance = False else: if self._should_log_debug: - self._log_debug("_instance(): identity key %s not in session" % str(identitykey)) + self._log_debug("_instance(): identity key %s not in session", identitykey) if self.allow_null_pks: for x in identitykey[1]: @@ -1625,8 +1669,10 @@ class Mapper(object): instance = self.class_manager.new_instance() if self._should_log_debug: - self._log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) + self._log_debug("_instance(): created new instance %s identity %s", + instance_str(instance), identitykey) + dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) state.key = identitykey @@ -1638,12 +1684,12 @@ class Mapper(object): if currentload or populate_existing: if isnew: state.runid = context.runid - context.progress.add(state) + context.progress[state] = dict_ if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, only_load_props) + populate_state(state, dict_, row, isnew, only_load_props) else: # populate attributes on non-loading instances which have been expired @@ -1652,16 +1698,16 @@ class Mapper(object): if state in context.partials: isnew = False - attrs = context.partials[state] + (d_, attrs) = context.partials[state] else: isnew = True attrs = state.unloaded - context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs + context.partials[state] = (dict_, attrs) #<-- allow query.instances to commit the subset of attrs if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, attrs, instancekey=identitykey) + populate_state(state, dict_, row, isnew, attrs, instancekey=identitykey) if loaded_instance: state._run_on_load(instance) @@ -1759,6 +1805,14 @@ def _event_on_init_failure(state, instance, args, kwargs): instrumenting_mapper, instrumenting_mapper.class_, state.manager.events.original_init, instance, args, kwargs) +def _event_on_resurrect(state, instance): + # re-populate the primary key elements + # of the dict based on the mapping. + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): + instrumenting_mapper._set_state_attr_by_column(state, col, val) + + def _sort_states(states): return sorted(states, key=operator.attrgetter('sort_key')) |
