summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/threadlocal.py
blob: a04e8edbe99ac424bfb6345b7fb147b95b133496 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from sqlalchemy import schema, exceptions, util, sql, types
import StringIO, sys, re
import base, default

"""provides a thread-local transactional wrapper around the basic ComposedSQLEngine.  multiple calls to engine.connect()
will return the same connection for the same thread. also provides begin/commit methods on the engine itself
which correspond to a thread-local transaction."""

class TLSession(object):
    def __init__(self, engine):
        self.engine = engine
        self.__tcount = 0
    def get_connection(self, close_with_result=False):
        try:
            return self.__transaction._increment_connect()
        except AttributeError:
            return TLConnection(self, close_with_result=close_with_result)
    def reset(self):
        try:
            self.__transaction._force_close()
            del self.__transaction
            del self.__trans
        except AttributeError:
            pass
        self.__tcount = 0
    def in_transaction(self):
        return self.__tcount > 0
    def begin(self):
        if self.__tcount == 0:
            self.__transaction = self.get_connection()
            self.__trans = self.__transaction._begin()
        self.__tcount += 1
        return self.__trans
    def rollback(self):
        if self.__tcount > 0:
            try:
                self.__trans._rollback_impl()
            finally:
                self.reset()
    def commit(self):
        if self.__tcount == 1:
            try:
                self.__trans._commit_impl()
            finally:
                self.reset()
        elif self.__tcount > 1:
            self.__tcount -= 1
    def is_begun(self):
        return self.__tcount > 0

class TLConnection(base.Connection):
    def __init__(self, session, close_with_result):
        base.Connection.__init__(self, session.engine, close_with_result=close_with_result)
        self.__session = session
        self.__opencount = 1
    session = property(lambda s:s.__session)
    def _increment_connect(self):
        self.__opencount += 1
        return self
    def _create_transaction(self, parent):
        return TLTransaction(self, parent)
    def _begin(self):
        return base.Connection.begin(self)
    def in_transaction(self):
        return self.session.in_transaction()
    def begin(self):
        return self.session.begin()
    def close(self):
        if self.__opencount == 1:
            base.Connection.close(self)
        self.__opencount -= 1
    def _force_close(self):
        self.__opencount = 0
        base.Connection.close(self)
            
class TLTransaction(base.Transaction):
    def _commit_impl(self):
        base.Transaction.commit(self)
    def _rollback_impl(self):
        base.Transaction.rollback(self)
    def commit(self):
        self.connection.session.commit()
    def rollback(self):
        self.connection.session.rollback()
            
class TLEngine(base.Engine):
    """an Engine that includes support for thread-local managed transactions.  This engine
    is better suited to be used with threadlocal Pool object."""
    def __init__(self, *args, **kwargs):
        """the TLEngine relies upon the ConnectionProvider having "threadlocal" behavior,
        so that once a connection is checked out for the current thread, you get that same connection
        repeatedly."""
        super(TLEngine, self).__init__(*args, **kwargs)
        self.context = util.ThreadLocal()
    def raw_connection(self):
        """returns a DBAPI connection."""
        return self.connection_provider.get_connection()
    def connect(self, **kwargs):
        """returns a Connection that is not thread-locally scoped.  this is the equilvalent to calling
        "connect()" on a ComposedSQLEngine."""
        return base.Connection(self, self.connection_provider.unique_connection())

    def _session(self):
        if not hasattr(self.context, 'session'):
            self.context.session = TLSession(self)
        return self.context.session
    session = property(_session, doc="returns the current thread's TLSession")

    def contextual_connect(self, **kwargs):
        """returns a TLConnection which is thread-locally scoped."""
        return self.session.get_connection(**kwargs)
        
    def begin(self):
        return self.session.begin()
    def commit(self):
        self.session.commit()
    def rollback(self):
        self.session.rollback()

class TLocalConnectionProvider(default.PoolConnectionProvider):
    def unique_connection(self):
        return self._pool.unique_connection()