diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-12-02 09:18:11 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-12-27 13:13:02 -0500 |
| commit | cc46ea711df77540d5d658e9c7b3ab1e88288929 (patch) | |
| tree | 401208da0f9f783517c059fb5ccae52a2b43b8dd /lib/sqlalchemy | |
| parent | fd99a4aa808f91f87d0a678708dd9c2b131fda04 (diff) | |
| download | sqlalchemy-cc46ea711df77540d5d658e9c7b3ab1e88288929.tar.gz | |
propose concurrency check for SessionTransaction
the discussion at #7387 refers to a condition that seems
to happen in the wild also, such as [1] [2] [3], it's not
entirely clear why this specific spot is how this occurs,
however it's maybe that when the connection is being acquired
from the pool, under load there might be a wait on the connection
pool, leading to more time for another errant thread to be
calling .close(), just a theory.
in this patch we propose using decorators and context managers
along with declarative state declarations to block reentrant
or concurrent calls to methods that conflict with expected
state changes.
The :class:`_orm.Session` (and by extension :class:`.AsyncSession`) now has
new state-tracking functionality that will proactively trap any unexpected
state changes which occur as a particular transactional method proceeds.
This is to allow situations where the :class:`_orm.Session` is being used
in a thread-unsafe manner, where event hooks or similar may be calling
unexpected methods within operations, as well as potentially under other
concurrency situations such as asyncio or gevent to raise an informative
message when the illegal access first occurs, rather than passing silently
leading to secondary failures due to the :class:`_orm.Session` being in an
invalid state.
[1] https://stackoverflow.com/questions/25768428/sqlalchemy-connection-errors
[2] https://groups.google.com/g/sqlalchemy/c/n5oVX3v4WOw
[3] https://github.com/cosmicpython/code/issues/23
Fixes: #7433
Change-Id: I699b935c0ec4e5a63f12cf878af6f7a92a30a3aa
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/exc.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 161 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/state_changes.py | 179 |
3 files changed, 282 insertions, 67 deletions
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index e35c41836..e51214fd9 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -249,6 +249,15 @@ class InvalidRequestError(SQLAlchemyError): """ +class IllegalStateChangeError(InvalidRequestError): + """An object that tracks state encountered an illegal state change + of some kind. + + .. versionadded:: 2.0 + + """ + + class NoInspectionAvailable(InvalidRequestError): """A subject passed to :func:`sqlalchemy.inspection.inspect` produced no context for inspection.""" diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e921bb8f0..13fc7f22e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -26,6 +26,9 @@ from .base import instance_str from .base import object_mapper from .base import object_state from .base import state_str +from .state_changes import _StateChange +from .state_changes import _StateChangeState +from .state_changes import _StateChangeStates from .unitofwork import UOWTransaction from .. import engine from .. import exc as sa_exc @@ -101,11 +104,16 @@ class _SessionClassMethods: return object_session(instance) -ACTIVE = util.symbol("ACTIVE") -PREPARED = util.symbol("PREPARED") -COMMITTED = util.symbol("COMMITTED") -DEACTIVE = util.symbol("DEACTIVE") -CLOSED = util.symbol("CLOSED") +class SessionTransactionState(_StateChangeState): + ACTIVE = 1 + PREPARED = 2 + COMMITTED = 3 + DEACTIVE = 4 + CLOSED = 5 + + +# backwards compatibility +ACTIVE, PREPARED, COMMITTED, DEACTIVE, CLOSED = tuple(SessionTransactionState) class ORMExecuteState(util.MemoizedSlots): @@ -476,7 +484,7 @@ class ORMExecuteState(util.MemoizedSlots): ] -class SessionTransaction(TransactionalContext): +class SessionTransaction(_StateChange, TransactionalContext): """A :class:`.Session`-level transaction. :class:`.SessionTransaction` is produced from the @@ -532,7 +540,7 @@ class SessionTransaction(TransactionalContext): self.nested = nested if nested: self._previous_nested_transaction = session._nested_transaction - self._state = ACTIVE + self._state = SessionTransactionState.ACTIVE if not parent and nested: raise sa_exc.InvalidRequestError( "Can't start a SAVEPOINT transaction when no existing " @@ -547,6 +555,31 @@ class SessionTransaction(TransactionalContext): self.session.dispatch.after_transaction_create(self.session, self) + def _raise_for_prerequisite_state(self, operation_name, state): + if state is SessionTransactionState.DEACTIVE: + if self._rollback_exception: + raise sa_exc.PendingRollbackError( + "This Session's transaction has been rolled back " + "due to a previous exception during flush." + " To begin a new transaction with this Session, " + "first issue Session.rollback()." + f" Original exception was: {self._rollback_exception}", + code="7s2a", + ) + else: + raise sa_exc.InvalidRequestError( + "This session is in 'inactive' state, due to the " + "SQL transaction being rolled back; no further SQL " + "can be emitted within this transaction." + ) + elif state is SessionTransactionState.CLOSED: + raise sa_exc.ResourceClosedError("This transaction is closed") + else: + raise sa_exc.InvalidRequestError( + f"This session is in '{state.name.lower()}' state; no " + "further SQL can be emitted within this transaction." + ) + @property def parent(self): """The parent :class:`.SessionTransaction` of this @@ -576,58 +609,26 @@ class SessionTransaction(TransactionalContext): @property def is_active(self): - return self.session is not None and self._state is ACTIVE - - def _assert_active( - self, - prepared_ok=False, - rollback_ok=False, - deactive_ok=False, - closed_msg="This transaction is closed", - ): - if self._state is COMMITTED: - raise sa_exc.InvalidRequestError( - "This session is in 'committed' state; no further " - "SQL can be emitted within this transaction." - ) - elif self._state is PREPARED: - if not prepared_ok: - raise sa_exc.InvalidRequestError( - "This session is in 'prepared' state; no further " - "SQL can be emitted within this transaction." - ) - elif self._state is DEACTIVE: - if not deactive_ok and not rollback_ok: - if self._rollback_exception: - raise sa_exc.PendingRollbackError( - "This Session's transaction has been rolled back " - "due to a previous exception during flush." - " To begin a new transaction with this Session, " - "first issue Session.rollback()." - " Original exception was: %s" - % self._rollback_exception, - code="7s2a", - ) - elif not deactive_ok: - raise sa_exc.InvalidRequestError( - "This session is in 'inactive' state, due to the " - "SQL transaction being rolled back; no further " - "SQL can be emitted within this transaction." - ) - elif self._state is CLOSED: - raise sa_exc.ResourceClosedError(closed_msg) + return ( + self.session is not None + and self._state is SessionTransactionState.ACTIVE + ) @property def _is_transaction_boundary(self): return self.nested or not self._parent + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) def connection(self, bindkey, execution_options=None, **kwargs): - self._assert_active() bind = self.session.get_bind(bindkey, **kwargs) return self._connection_for_bind(bind, execution_options) + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) def _begin(self, nested=False): - self._assert_active() return SessionTransaction(self.session, self, nested=nested) def _iterate_self_and_parents(self, upto=None): @@ -718,8 +719,10 @@ class SessionTransaction(TransactionalContext): self._parent._deleted.update(self._deleted) self._parent._key_switches.update(self._key_switches) + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) def _connection_for_bind(self, bind, execution_options): - self._assert_active() if bind in self._connections: if execution_options: @@ -792,8 +795,11 @@ class SessionTransaction(TransactionalContext): ) self._prepare_impl() + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED + ) def _prepare_impl(self): - self._assert_active() + if self._parent is None or self.nested: self.session.dispatch.before_commit(self.session) @@ -822,12 +828,16 @@ class SessionTransaction(TransactionalContext): with util.safe_reraise(): self.rollback() - self._state = PREPARED + self._state = SessionTransactionState.PREPARED + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED), + SessionTransactionState.CLOSED, + ) def commit(self, _to_root=False): - self._assert_active(prepared_ok=True) - if self._state is not PREPARED: - self._prepare_impl() + if self._state is not SessionTransactionState.PREPARED: + with self._expect_state(SessionTransactionState.PREPARED): + self._prepare_impl() if self._parent is None or self.nested: for conn, trans, should_commit, autoclose in set( @@ -836,20 +846,28 @@ class SessionTransaction(TransactionalContext): if should_commit: trans.commit() - self._state = COMMITTED + self._state = SessionTransactionState.COMMITTED self.session.dispatch.after_commit(self.session) self._remove_snapshot() - self.close() + with self._expect_state(SessionTransactionState.CLOSED): + self.close() if _to_root and self._parent: return self._parent.commit(_to_root=True) return self._parent + @_StateChange.declare_states( + ( + SessionTransactionState.ACTIVE, + SessionTransactionState.DEACTIVE, + SessionTransactionState.PREPARED, + ), + SessionTransactionState.CLOSED, + ) def rollback(self, _capture_exception=False, _to_root=False): - self._assert_active(prepared_ok=True, rollback_ok=True) stx = self.session._transaction if stx is not self: @@ -858,26 +876,29 @@ class SessionTransaction(TransactionalContext): boundary = self rollback_err = None - if self._state in (ACTIVE, PREPARED): + if self._state in ( + SessionTransactionState.ACTIVE, + SessionTransactionState.PREPARED, + ): for transaction in self._iterate_self_and_parents(): if transaction._parent is None or transaction.nested: try: for t in set(transaction._connections.values()): t[1].rollback() - transaction._state = DEACTIVE + transaction._state = SessionTransactionState.DEACTIVE self.session.dispatch.after_rollback(self.session) except: rollback_err = sys.exc_info() finally: - transaction._state = DEACTIVE + transaction._state = SessionTransactionState.DEACTIVE transaction._restore_snapshot( dirty_only=transaction.nested ) boundary = transaction break else: - transaction._state = DEACTIVE + transaction._state = SessionTransactionState.DEACTIVE sess = self.session @@ -892,7 +913,8 @@ class SessionTransaction(TransactionalContext): ) boundary._restore_snapshot(dirty_only=boundary.nested) - self.close() + with self._expect_state(SessionTransactionState.CLOSED): + self.close() if self._parent and _capture_exception: self._parent._rollback_exception = sys.exc_info()[1] @@ -906,6 +928,9 @@ class SessionTransaction(TransactionalContext): return self._parent.rollback(_to_root=True) return self._parent + @_StateChange.declare_states( + _StateChangeStates.ANY, SessionTransactionState.CLOSED + ) def close(self, invalidate=False): if self.nested: self.session._nested_transaction = ( @@ -925,20 +950,22 @@ class SessionTransaction(TransactionalContext): if autoclose: connection.close() - self._state = CLOSED - self.session.dispatch.after_transaction_end(self.session, self) + self._state = SessionTransactionState.CLOSED + sess = self.session self.session = None self._connections = None + sess.dispatch.after_transaction_end(sess, self) + def _get_subject(self): return self.session def _transaction_is_active(self): - return self._state is ACTIVE + return self._state is SessionTransactionState.ACTIVE def _transaction_is_closed(self): - return self._state is CLOSED + return self._state is SessionTransactionState.CLOSED def _rollback_can_be_called(self): return self._state not in (COMMITTED, CLOSED) diff --git a/lib/sqlalchemy/orm/state_changes.py b/lib/sqlalchemy/orm/state_changes.py new file mode 100644 index 000000000..7d2c3e056 --- /dev/null +++ b/lib/sqlalchemy/orm/state_changes.py @@ -0,0 +1,179 @@ +# orm/state_changes.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""State tracking utilities used by :class:`_orm.Session`. + +""" + +import contextlib +from enum import Enum +from typing import Any +from typing import Callable +from typing import Optional +from typing import Tuple +from typing import Union + +from .. import exc as sa_exc +from .. import util +from ..util.typing import Literal + + +class _StateChangeState(Enum): + pass + + +class _StateChangeStates(_StateChangeState): + ANY = 1 + NO_CHANGE = 2 + CHANGE_IN_PROGRESS = 3 + + +class _StateChange: + """Supplies state assertion decorators. + + The current use case is for the :class:`_orm.SessionTransaction` class. The + :class:`_StateChange` class itself is agnostic of the + :class:`_orm.SessionTransaction` class so could in theory be generalized + for other systems as well. + + """ + + _next_state: _StateChangeState = _StateChangeStates.ANY + _state: _StateChangeState = _StateChangeStates.NO_CHANGE + _current_fn: Optional[Callable] = None + + def _raise_for_prerequisite_state(self, operation_name, state): + raise sa_exc.IllegalStateChangeError( + f"Can't run operation '{operation_name}()' when Session " + f"is in state {state!r}" + ) + + @classmethod + def declare_states( + cls, + prerequisite_states: Union[ + Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...] + ], + moves_to: _StateChangeState, + ) -> Callable[..., Any]: + """Method decorator declaring valid states. + + :param prerequisite_states: sequence of acceptable prerequisite + states. Can be the single constant _State.ANY to indicate no + prerequisite state + + :param moves_to: the expected state at the end of the method, assuming + no exceptions raised. Can be the constant _State.NO_CHANGE to + indicate state should not change at the end of the method. + + """ + assert prerequisite_states, "no prequisite states sent" + has_prerequisite_states = ( + prerequisite_states is not _StateChangeStates.ANY + ) + + expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE + + @util.decorator + def _go(fn, self, *arg, **kw): + + current_state = self._state + + if ( + has_prerequisite_states + and current_state not in prerequisite_states + ): + self._raise_for_prerequisite_state(fn.__name__, current_state) + + next_state = self._next_state + existing_fn = self._current_fn + expect_state = moves_to if expect_state_change else current_state + + if ( + # destination states are restricted + next_state is not _StateChangeStates.ANY + # method seeks to change state + and expect_state_change + # destination state incorrect + and next_state is not expect_state + ): + if existing_fn and next_state in ( + _StateChangeStates.NO_CHANGE, + _StateChangeStates.CHANGE_IN_PROGRESS, + ): + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' can't be called here; " + f"method '{existing_fn.__name__}()' is already " + f"in progress and this would cause an unexpected " + f"state change to {moves_to!r}" + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Cant run operation '{fn.__name__}()' here; " + f"will move to state {moves_to!r} where we are " + f"expecting {next_state!r}" + ) + + self._current_fn = fn + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS + try: + ret_value = fn(self, *arg, **kw) + except: + raise + else: + if self._state is expect_state: + return ret_value + + if self._state is current_state: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' failed to " + "change state " + f"to {moves_to!r} as expected" + ) + elif existing_fn: + raise sa_exc.IllegalStateChangeError( + f"While method '{existing_fn.__name__}()' was " + "running, " + f"method '{fn.__name__}()' caused an " + "unexpected " + f"state change to {self._state!r}" + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' caused an unexpected " + f"state change to {self._state!r}" + ) + + finally: + self._next_state = next_state + self._current_fn = existing_fn + + return _go + + @contextlib.contextmanager + def _expect_state(self, expected: _StateChangeState): + """called within a method that changes states. + + method must also use the ``@declare_states()`` decorator. + + """ + assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, ( + "Unexpected call to _expect_state outside of " + "state-changing method" + ) + + self._next_state = expected + try: + yield + except: + raise + else: + if self._state is not expected: + raise sa_exc.IllegalStateChangeError( + f"Unexpected state change to {self._state!r}" + ) + finally: + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS |
