diff options
Diffstat (limited to 'lib/sqlalchemy/engine/threadlocal.py')
-rw-r--r-- | lib/sqlalchemy/engine/threadlocal.py | 50 |
1 files changed, 30 insertions, 20 deletions
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 0ec1f9613..5b2bdabc0 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -19,7 +19,6 @@ import weakref class TLConnection(base.Connection): - def __init__(self, *arg, **kw): super(TLConnection, self).__init__(*arg, **kw) self.__opencount = 0 @@ -43,6 +42,7 @@ class TLEngine(base.Engine): transactions. """ + _tl_connection_cls = TLConnection def __init__(self, *args, **kwargs): @@ -50,7 +50,7 @@ class TLEngine(base.Engine): self._connections = util.threading.local() def contextual_connect(self, **kw): - if not hasattr(self._connections, 'conn'): + if not hasattr(self._connections, "conn"): connection = None else: connection = self._connections.conn() @@ -60,29 +60,31 @@ class TLEngine(base.Engine): # or not connection.connection.is_valid: connection = self._tl_connection_cls( self, - self._wrap_pool_connect( - self.pool.connect, connection), - **kw) + self._wrap_pool_connect(self.pool.connect, connection), + **kw + ) self._connections.conn = weakref.ref(connection) return connection._increment_connect() def begin_twophase(self, xid=None): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_twophase(xid=xid)) + self.contextual_connect().begin_twophase(xid=xid) + ) return self def begin_nested(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append( - self.contextual_connect().begin_nested()) + self.contextual_connect().begin_nested() + ) return self def begin(self): - if not hasattr(self._connections, 'trans'): + if not hasattr(self._connections, "trans"): self._connections.trans = [] self._connections.trans.append(self.contextual_connect().begin()) return self @@ -97,21 +99,27 @@ class TLEngine(base.Engine): self.rollback() def prepare(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return self._connections.trans[-1].prepare() def commit(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.commit() def rollback(self): - if not hasattr(self._connections, 'trans') or \ - not self._connections.trans: + if ( + not hasattr(self._connections, "trans") + or not self._connections.trans + ): return trans = self._connections.trans.pop(-1) trans.rollback() @@ -122,9 +130,11 @@ class TLEngine(base.Engine): @property def closed(self): - return not hasattr(self._connections, 'conn') or \ - self._connections.conn() is None or \ - self._connections.conn().closed + return ( + not hasattr(self._connections, "conn") + or self._connections.conn() is None + or self._connections.conn().closed + ) def close(self): if not self.closed: @@ -135,4 +145,4 @@ class TLEngine(base.Engine): self._connections.trans = [] def __repr__(self): - return 'TLEngine(%r)' % self.url + return "TLEngine(%r)" % self.url |