summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/threadlocal.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/threadlocal.py')
-rw-r--r--lib/sqlalchemy/engine/threadlocal.py230
1 files changed, 57 insertions, 173 deletions
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py
index 27d857623..a9892ae7e 100644
--- a/lib/sqlalchemy/engine/threadlocal.py
+++ b/lib/sqlalchemy/engine/threadlocal.py
@@ -7,211 +7,95 @@ invoked automatically when the threadlocal engine strategy is used.
from sqlalchemy import util
from sqlalchemy.engine import base
-
-
-class TLSession(object):
- def __init__(self, engine):
- self.engine = engine
- self.__tcount = 0
-
- def get_connection(self, close_with_result=False):
- try:
- return self.__transaction._increment_connect()
- except AttributeError:
- return self.engine.TLConnection(self, self.engine.pool.connect(),
- close_with_result=close_with_result)
-
- def reset(self):
- try:
- self.__transaction._force_close()
- del self.__transaction
- del self.__trans
- except AttributeError:
- pass
- self.__tcount = 0
-
- def _conn_closed(self):
- if self.__tcount == 1:
- self.__trans._trans.rollback()
- self.reset()
-
- def in_transaction(self):
- return self.__tcount > 0
-
- def prepare(self):
- if self.__tcount == 1:
- self.__trans._trans.prepare()
-
- def begin_twophase(self, xid=None):
- if self.__tcount == 0:
- self.__transaction = self.get_connection()
- self.__trans = self.__transaction._begin_twophase(xid=xid)
- self.__tcount += 1
- return self.__trans
-
- def begin(self, **kwargs):
- if self.__tcount == 0:
- self.__transaction = self.get_connection()
- self.__trans = self.__transaction._begin(**kwargs)
- self.__tcount += 1
- return self.__trans
-
- def rollback(self):
- if self.__tcount > 0:
- try:
- self.__trans._trans.rollback()
- finally:
- self.reset()
-
- def commit(self):
- if self.__tcount == 1:
- try:
- self.__trans._trans.commit()
- finally:
- self.reset()
- elif self.__tcount > 1:
- self.__tcount -= 1
-
- def close(self):
- if self.__tcount == 1:
- self.rollback()
- elif self.__tcount > 1:
- self.__tcount -= 1
-
- def is_begun(self):
- return self.__tcount > 0
-
+import weakref
class TLConnection(base.Connection):
- def __init__(self, session, connection, **kwargs):
- base.Connection.__init__(self, session.engine, connection, **kwargs)
- self.__session = session
- self.__opencount = 1
-
- def _branch(self):
- return self.engine.Connection(self.engine, self.connection, _branch=True)
-
- def session(self):
- return self.__session
- session = property(session)
-
+ def __init__(self, *arg, **kw):
+ super(TLConnection, self).__init__(*arg, **kw)
+ self.__opencount = 0
+
def _increment_connect(self):
self.__opencount += 1
return self
-
- def _begin(self, **kwargs):
- return TLTransaction(
- super(TLConnection, self).begin(**kwargs), self.__session)
-
- def _begin_twophase(self, xid=None):
- return TLTransaction(
- super(TLConnection, self).begin_twophase(xid=xid), self.__session)
-
- def in_transaction(self):
- return self.session.in_transaction()
-
- def begin(self, **kwargs):
- return self.session.begin(**kwargs)
-
- def begin_twophase(self, xid=None):
- return self.session.begin_twophase(xid=xid)
- def begin_nested(self):
- raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy")
-
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
- self.__session._conn_closed()
self.__opencount -= 1
def _force_close(self):
self.__opencount = 0
base.Connection.close(self)
-
-class TLTransaction(base.Transaction):
- def __init__(self, trans, session):
- self._trans = trans
- self._session = session
-
- def connection(self):
- return self._trans.connection
- connection = property(connection)
-
- def is_active(self):
- return self._trans.is_active
- is_active = property(is_active)
-
- def rollback(self):
- self._session.rollback()
-
- def prepare(self):
- self._session.prepare()
-
- def commit(self):
- self._session.commit()
-
- def close(self):
- self._session.close()
-
- def __enter__(self):
- return self
-
- def __exit__(self, type, value, traceback):
- self._trans.__exit__(type, value, traceback)
-
-
+
class TLEngine(base.Engine):
- """An Engine that includes support for thread-local managed transactions.
+ """An Engine that includes support for thread-local managed transactions."""
- The TLEngine relies upon its Pool having "threadlocal" behavior,
- so that once a connection is checked out for the current thread,
- you get that same connection repeatedly.
- """
def __init__(self, *args, **kwargs):
- """Construct a new TLEngine."""
-
super(TLEngine, self).__init__(*args, **kwargs)
- self.context = util.threading.local()
-
+ self._connections = util.threading.local()
proxy = kwargs.get('proxy')
if proxy:
self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
else:
self.TLConnection = TLConnection
- def session(self):
- "Returns the current thread's TLSession"
- if not hasattr(self.context, 'session'):
- self.context.session = TLSession(self)
- return self.context.session
-
- session = property(session)
-
- def contextual_connect(self, **kwargs):
- """Return a TLConnection which is thread-locally scoped."""
-
- return self.session.get_connection(**kwargs)
-
- def begin_twophase(self, **kwargs):
- return self.session.begin_twophase(**kwargs)
+ def contextual_connect(self, **kw):
+ if not hasattr(self._connections, 'conn'):
+ connection = None
+ else:
+ connection = self._connections.conn()
+
+ if connection is None or connection.closed:
+ # guards against pool-level reapers, if desired.
+ # or not connection.connection.is_valid:
+ connection = self.TLConnection(self, self.pool.connect(), **kw)
+ self._connections.conn = conn = weakref.ref(connection)
+
+ return connection._increment_connect()
+
+ def begin_twophase(self, xid=None):
+ if not hasattr(self._connections, 'trans'):
+ self._connections.trans = []
+ self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid))
def begin_nested(self):
- raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy")
+ if not hasattr(self._connections, 'trans'):
+ self._connections.trans = []
+ self._connections.trans.append(self.contextual_connect().begin_nested())
+
+ def begin(self):
+ if not hasattr(self._connections, 'trans'):
+ self._connections.trans = []
+ self._connections.trans.append(self.contextual_connect().begin())
- def begin(self, **kwargs):
- return self.session.begin(**kwargs)
-
def prepare(self):
- self.session.prepare()
+ self._connections.trans[-1].prepare()
def commit(self):
- self.session.commit()
-
+ trans = self._connections.trans.pop(-1)
+ trans.commit()
+
def rollback(self):
- self.session.rollback()
-
+ trans = self._connections.trans.pop(-1)
+ trans.rollback()
+
+ def dispose(self):
+ self._connections = util.threading.local()
+ super(TLEngine, self).dispose()
+
+ @property
+ def closed(self):
+ return not hasattr(self._connections, 'conn') or \
+ self._connections.conn() is None or \
+ self._connections.conn().closed
+
+ def close(self):
+ if not self.closed:
+ self.contextual_connect().close()
+ del self._connections.conn
+ self._connections.trans = []
+
def __repr__(self):
return 'TLEngine(%s)' % str(self.url)