diff options
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 48ff09b87..dc5de7ac6 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -235,7 +235,7 @@ class SessionTransaction(object): return SessionTransaction( self.session, self, nested=nested) - def _iterate_parents(self, upto=None): + def _iterate_self_and_parents(self, upto=None): current = self result = () @@ -269,6 +269,11 @@ class SessionTransaction(object): self._key_switches = weakref.WeakKeyDictionary() def _restore_snapshot(self, dirty_only=False): + """Restore the restoration state taken before a transaction began. + + Corresponds to a rollback. + + """ assert self._is_transaction_boundary self.session._expunge_states( @@ -290,6 +295,11 @@ class SessionTransaction(object): s._expire(s.dict, self.session.identity_map._modified) def _remove_snapshot(self): + """Remove the restoration state taken before a transaction began. + + Corresponds to a commit. + + """ assert self._is_transaction_boundary if not self.nested and self.session.expire_on_commit: @@ -358,7 +368,7 @@ class SessionTransaction(object): stx = self.session.transaction if stx is not self: - for subtransaction in stx._iterate_parents(upto=self): + for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.commit() if not self.session._flushing: @@ -405,14 +415,18 @@ class SessionTransaction(object): stx = self.session.transaction if stx is not self: - for subtransaction in stx._iterate_parents(upto=self): + for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.close() boundary = self + rollback_err = None if self._state in (ACTIVE, PREPARED): - for transaction in self._iterate_parents(): + for transaction in self._iterate_self_and_parents(): if transaction._parent is None or transaction.nested: - transaction._rollback_impl() + try: + transaction._rollback_impl() + except: + rollback_err = sys.exc_info() transaction._state = DEACTIVE boundary = transaction break @@ -421,7 +435,7 @@ class SessionTransaction(object): sess = self.session - if sess._enable_transaction_accounting and \ + if not rollback_err and sess._enable_transaction_accounting and \ not sess._is_clean(): # if items were added, deleted, or mutated @@ -433,19 +447,24 @@ class SessionTransaction(object): boundary._restore_snapshot(dirty_only=boundary.nested) self.close() + if self._parent and _capture_exception: self._parent._rollback_exception = sys.exc_info()[1] + if rollback_err: + util.reraise(*rollback_err) + sess.dispatch.after_soft_rollback(sess, self) return self._parent def _rollback_impl(self): - for t in set(self._connections.values()): - t[1].rollback() - - if self.session._enable_transaction_accounting: - self._restore_snapshot(dirty_only=self.nested) + try: + for t in set(self._connections.values()): + t[1].rollback() + finally: + if self.session._enable_transaction_accounting: + self._restore_snapshot(dirty_only=self.nested) self.session.dispatch.after_rollback(self.session) @@ -1078,7 +1097,7 @@ class Session(_SessionClassMethods): def _close_impl(self, invalidate): self.expunge_all() if self.transaction is not None: - for transaction in self.transaction._iterate_parents(): + for transaction in self.transaction._iterate_self_and_parents(): transaction.close(invalidate) def expunge_all(self): |