diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-01-24 18:13:21 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-01-24 18:13:21 +0000 |
commit | 9806d81675ef62363753a028ada43bc460728cf5 (patch) | |
tree | a563783652ea5e6dde90a17aa095be1529bbc00f /lib/sqlalchemy/engine/threadlocal.py | |
parent | 72d1cbadde4619264cc795800af3859a52c2794c (diff) | |
download | sqlalchemy-9806d81675ef62363753a028ada43bc460728cf5.tar.gz |
- the "threadlocal" engine has been rewritten and simplified
and now supports SAVEPOINT operations.
Diffstat (limited to 'lib/sqlalchemy/engine/threadlocal.py')
-rw-r--r-- | lib/sqlalchemy/engine/threadlocal.py | 230 |
1 files changed, 57 insertions, 173 deletions
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 27d857623..a9892ae7e 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -7,211 +7,95 @@ invoked automatically when the threadlocal engine strategy is used. from sqlalchemy import util from sqlalchemy.engine import base - - -class TLSession(object): - def __init__(self, engine): - self.engine = engine - self.__tcount = 0 - - def get_connection(self, close_with_result=False): - try: - return self.__transaction._increment_connect() - except AttributeError: - return self.engine.TLConnection(self, self.engine.pool.connect(), - close_with_result=close_with_result) - - def reset(self): - try: - self.__transaction._force_close() - del self.__transaction - del self.__trans - except AttributeError: - pass - self.__tcount = 0 - - def _conn_closed(self): - if self.__tcount == 1: - self.__trans._trans.rollback() - self.reset() - - def in_transaction(self): - return self.__tcount > 0 - - def prepare(self): - if self.__tcount == 1: - self.__trans._trans.prepare() - - def begin_twophase(self, xid=None): - if self.__tcount == 0: - self.__transaction = self.get_connection() - self.__trans = self.__transaction._begin_twophase(xid=xid) - self.__tcount += 1 - return self.__trans - - def begin(self, **kwargs): - if self.__tcount == 0: - self.__transaction = self.get_connection() - self.__trans = self.__transaction._begin(**kwargs) - self.__tcount += 1 - return self.__trans - - def rollback(self): - if self.__tcount > 0: - try: - self.__trans._trans.rollback() - finally: - self.reset() - - def commit(self): - if self.__tcount == 1: - try: - self.__trans._trans.commit() - finally: - self.reset() - elif self.__tcount > 1: - self.__tcount -= 1 - - def close(self): - if self.__tcount == 1: - self.rollback() - elif self.__tcount > 1: - self.__tcount -= 1 - - def is_begun(self): - return self.__tcount > 0 - +import weakref class TLConnection(base.Connection): - def __init__(self, session, connection, **kwargs): - base.Connection.__init__(self, session.engine, connection, **kwargs) - self.__session = session - self.__opencount = 1 - - def _branch(self): - return self.engine.Connection(self.engine, self.connection, _branch=True) - - def session(self): - return self.__session - session = property(session) - + def __init__(self, *arg, **kw): + super(TLConnection, self).__init__(*arg, **kw) + self.__opencount = 0 + def _increment_connect(self): self.__opencount += 1 return self - - def _begin(self, **kwargs): - return TLTransaction( - super(TLConnection, self).begin(**kwargs), self.__session) - - def _begin_twophase(self, xid=None): - return TLTransaction( - super(TLConnection, self).begin_twophase(xid=xid), self.__session) - - def in_transaction(self): - return self.session.in_transaction() - - def begin(self, **kwargs): - return self.session.begin(**kwargs) - - def begin_twophase(self, xid=None): - return self.session.begin_twophase(xid=xid) - def begin_nested(self): - raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy") - def close(self): if self.__opencount == 1: base.Connection.close(self) - self.__session._conn_closed() self.__opencount -= 1 def _force_close(self): self.__opencount = 0 base.Connection.close(self) - -class TLTransaction(base.Transaction): - def __init__(self, trans, session): - self._trans = trans - self._session = session - - def connection(self): - return self._trans.connection - connection = property(connection) - - def is_active(self): - return self._trans.is_active - is_active = property(is_active) - - def rollback(self): - self._session.rollback() - - def prepare(self): - self._session.prepare() - - def commit(self): - self._session.commit() - - def close(self): - self._session.close() - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - self._trans.__exit__(type, value, traceback) - - + class TLEngine(base.Engine): - """An Engine that includes support for thread-local managed transactions. + """An Engine that includes support for thread-local managed transactions.""" - The TLEngine relies upon its Pool having "threadlocal" behavior, - so that once a connection is checked out for the current thread, - you get that same connection repeatedly. - """ def __init__(self, *args, **kwargs): - """Construct a new TLEngine.""" - super(TLEngine, self).__init__(*args, **kwargs) - self.context = util.threading.local() - + self._connections = util.threading.local() proxy = kwargs.get('proxy') if proxy: self.TLConnection = base._proxy_connection_cls(TLConnection, proxy) else: self.TLConnection = TLConnection - def session(self): - "Returns the current thread's TLSession" - if not hasattr(self.context, 'session'): - self.context.session = TLSession(self) - return self.context.session - - session = property(session) - - def contextual_connect(self, **kwargs): - """Return a TLConnection which is thread-locally scoped.""" - - return self.session.get_connection(**kwargs) - - def begin_twophase(self, **kwargs): - return self.session.begin_twophase(**kwargs) + def contextual_connect(self, **kw): + if not hasattr(self._connections, 'conn'): + connection = None + else: + connection = self._connections.conn() + + if connection is None or connection.closed: + # guards against pool-level reapers, if desired. + # or not connection.connection.is_valid: + connection = self.TLConnection(self, self.pool.connect(), **kw) + self._connections.conn = conn = weakref.ref(connection) + + return connection._increment_connect() + + def begin_twophase(self, xid=None): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid)) def begin_nested(self): - raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy") + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin_nested()) + + def begin(self): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin()) - def begin(self, **kwargs): - return self.session.begin(**kwargs) - def prepare(self): - self.session.prepare() + self._connections.trans[-1].prepare() def commit(self): - self.session.commit() - + trans = self._connections.trans.pop(-1) + trans.commit() + def rollback(self): - self.session.rollback() - + trans = self._connections.trans.pop(-1) + trans.rollback() + + def dispose(self): + self._connections = util.threading.local() + super(TLEngine, self).dispose() + + @property + def closed(self): + return not hasattr(self._connections, 'conn') or \ + self._connections.conn() is None or \ + self._connections.conn().closed + + def close(self): + if not self.closed: + self.contextual_connect().close() + del self._connections.conn + self._connections.trans = [] + def __repr__(self): return 'TLEngine(%s)' % str(self.url) |