summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py33
1 files changed, 17 insertions, 16 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 1e3a750d9..00a7d55e5 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc
from sqlalchemy import util, sql, engine, log
from sqlalchemy.sql import util as sql_util, expression
from sqlalchemy.orm import (
- SessionExtension, attributes, exc, query, unitofwork, util as mapperutil,
+ SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state
)
from sqlalchemy.orm.util import object_mapper as _object_mapper
from sqlalchemy.orm.util import class_mapper as _class_mapper
@@ -899,8 +899,8 @@ class Session(object):
self.flush()
def _finalize_loaded(self, states):
- for state in states:
- state.commit_all()
+ for state, dict_ in states.items():
+ state.commit_all(dict_)
def refresh(self, instance, attribute_names=None):
"""Refresh the attributes on the given instance.
@@ -1020,11 +1020,9 @@ class Session(object):
# primary key switch
self.identity_map.remove(state)
state.key = instance_key
-
- if state.key in self.identity_map and not self.identity_map.contains_state(state):
- self.identity_map.remove_key(state.key)
- self.identity_map.add(state)
- state.commit_all()
+
+ self.identity_map.replace(state)
+ state.commit_all(state.dict)
# remove from new last, might be the last strong ref
if state in self._new:
@@ -1213,7 +1211,7 @@ class Session(object):
prop.merge(self, instance, merged, dont_load, _recursive)
if dont_load:
- attributes.instance_state(merged).commit_all() # remove any history
+ attributes.instance_state(merged).commit_all(attributes.instance_dict(merged)) # remove any history
if new_instance:
merged_state._run_on_load(merged)
@@ -1368,7 +1366,7 @@ class Session(object):
self.identity_map.modified = False
return
- flush_context = UOWTransaction(self)
+ flush_context = UOWTransaction(self)
if self.extensions:
for ext in self.extensions:
@@ -1489,7 +1487,7 @@ class Session(object):
return util.IdentitySet(
[state
for state in self.identity_map.all_states()
- if state.check_modified()])
+ if state.modified])
@property
def dirty(self):
@@ -1528,7 +1526,7 @@ class Session(object):
return util.IdentitySet(self._new.values())
-_expire_state = attributes.InstanceState.expire_attributes
+_expire_state = state.InstanceState.expire_attributes
UOWEventHandler = unitofwork.UOWEventHandler
@@ -1548,16 +1546,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs):
yield _state_for_unknown_persistence_instance(o), m
def _state_for_unsaved_instance(instance, create=False):
- manager = attributes.manager_of_class(instance.__class__)
- if manager is None:
+ try:
+ state = attributes.instance_state(instance)
+ except AttributeError:
raise exc.UnmappedInstanceError(instance)
- if manager.has_state(instance):
- state = manager.state_of(instance)
+ if state:
if state.key is not None:
raise sa_exc.InvalidRequestError(
"Instance '%s' is already persistent" %
mapperutil.state_str(state))
elif create:
+ manager = attributes.manager_of_class(instance.__class__)
+ if manager is None:
+ raise exc.UnmappedInstanceError(instance)
state = manager.setup_instance(instance)
else:
raise exc.UnmappedInstanceError(instance)