summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py43
1 files changed, 31 insertions, 12 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 48ff09b87..dc5de7ac6 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -235,7 +235,7 @@ class SessionTransaction(object):
return SessionTransaction(
self.session, self, nested=nested)
- def _iterate_parents(self, upto=None):
+ def _iterate_self_and_parents(self, upto=None):
current = self
result = ()
@@ -269,6 +269,11 @@ class SessionTransaction(object):
self._key_switches = weakref.WeakKeyDictionary()
def _restore_snapshot(self, dirty_only=False):
+ """Restore the restoration state taken before a transaction began.
+
+ Corresponds to a rollback.
+
+ """
assert self._is_transaction_boundary
self.session._expunge_states(
@@ -290,6 +295,11 @@ class SessionTransaction(object):
s._expire(s.dict, self.session.identity_map._modified)
def _remove_snapshot(self):
+ """Remove the restoration state taken before a transaction began.
+
+ Corresponds to a commit.
+
+ """
assert self._is_transaction_boundary
if not self.nested and self.session.expire_on_commit:
@@ -358,7 +368,7 @@ class SessionTransaction(object):
stx = self.session.transaction
if stx is not self:
- for subtransaction in stx._iterate_parents(upto=self):
+ for subtransaction in stx._iterate_self_and_parents(upto=self):
subtransaction.commit()
if not self.session._flushing:
@@ -405,14 +415,18 @@ class SessionTransaction(object):
stx = self.session.transaction
if stx is not self:
- for subtransaction in stx._iterate_parents(upto=self):
+ for subtransaction in stx._iterate_self_and_parents(upto=self):
subtransaction.close()
boundary = self
+ rollback_err = None
if self._state in (ACTIVE, PREPARED):
- for transaction in self._iterate_parents():
+ for transaction in self._iterate_self_and_parents():
if transaction._parent is None or transaction.nested:
- transaction._rollback_impl()
+ try:
+ transaction._rollback_impl()
+ except:
+ rollback_err = sys.exc_info()
transaction._state = DEACTIVE
boundary = transaction
break
@@ -421,7 +435,7 @@ class SessionTransaction(object):
sess = self.session
- if sess._enable_transaction_accounting and \
+ if not rollback_err and sess._enable_transaction_accounting and \
not sess._is_clean():
# if items were added, deleted, or mutated
@@ -433,19 +447,24 @@ class SessionTransaction(object):
boundary._restore_snapshot(dirty_only=boundary.nested)
self.close()
+
if self._parent and _capture_exception:
self._parent._rollback_exception = sys.exc_info()[1]
+ if rollback_err:
+ util.reraise(*rollback_err)
+
sess.dispatch.after_soft_rollback(sess, self)
return self._parent
def _rollback_impl(self):
- for t in set(self._connections.values()):
- t[1].rollback()
-
- if self.session._enable_transaction_accounting:
- self._restore_snapshot(dirty_only=self.nested)
+ try:
+ for t in set(self._connections.values()):
+ t[1].rollback()
+ finally:
+ if self.session._enable_transaction_accounting:
+ self._restore_snapshot(dirty_only=self.nested)
self.session.dispatch.after_rollback(self.session)
@@ -1078,7 +1097,7 @@ class Session(_SessionClassMethods):
def _close_impl(self, invalidate):
self.expunge_all()
if self.transaction is not None:
- for transaction in self.transaction._iterate_parents():
+ for transaction in self.transaction._iterate_self_and_parents():
transaction.close(invalidate)
def expunge_all(self):