diff options
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 1e3a750d9..00a7d55e5 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc from sqlalchemy import util, sql, engine, log from sqlalchemy.sql import util as sql_util, expression from sqlalchemy.orm import ( - SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, + SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state ) from sqlalchemy.orm.util import object_mapper as _object_mapper from sqlalchemy.orm.util import class_mapper as _class_mapper @@ -899,8 +899,8 @@ class Session(object): self.flush() def _finalize_loaded(self, states): - for state in states: - state.commit_all() + for state, dict_ in states.items(): + state.commit_all(dict_) def refresh(self, instance, attribute_names=None): """Refresh the attributes on the given instance. @@ -1020,11 +1020,9 @@ class Session(object): # primary key switch self.identity_map.remove(state) state.key = instance_key - - if state.key in self.identity_map and not self.identity_map.contains_state(state): - self.identity_map.remove_key(state.key) - self.identity_map.add(state) - state.commit_all() + + self.identity_map.replace(state) + state.commit_all(state.dict) # remove from new last, might be the last strong ref if state in self._new: @@ -1213,7 +1211,7 @@ class Session(object): prop.merge(self, instance, merged, dont_load, _recursive) if dont_load: - attributes.instance_state(merged).commit_all() # remove any history + attributes.instance_state(merged).commit_all(attributes.instance_dict(merged)) # remove any history if new_instance: merged_state._run_on_load(merged) @@ -1368,7 +1366,7 @@ class Session(object): self.identity_map.modified = False return - flush_context = UOWTransaction(self) + flush_context = UOWTransaction(self) if self.extensions: for ext in self.extensions: @@ -1489,7 +1487,7 @@ class Session(object): return util.IdentitySet( [state for state in self.identity_map.all_states() - if state.check_modified()]) + if state.modified]) @property def dirty(self): @@ -1528,7 +1526,7 @@ class Session(object): return util.IdentitySet(self._new.values()) -_expire_state = attributes.InstanceState.expire_attributes +_expire_state = state.InstanceState.expire_attributes UOWEventHandler = unitofwork.UOWEventHandler @@ -1548,16 +1546,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs): yield _state_for_unknown_persistence_instance(o), m def _state_for_unsaved_instance(instance, create=False): - manager = attributes.manager_of_class(instance.__class__) - if manager is None: + try: + state = attributes.instance_state(instance) + except AttributeError: raise exc.UnmappedInstanceError(instance) - if manager.has_state(instance): - state = manager.state_of(instance) + if state: if state.key is not None: raise sa_exc.InvalidRequestError( "Instance '%s' is already persistent" % mapperutil.state_str(state)) elif create: + manager = attributes.manager_of_class(instance.__class__) + if manager is None: + raise exc.UnmappedInstanceError(instance) state = manager.setup_instance(instance) else: raise exc.UnmappedInstanceError(instance) |
