summaryrefslogtreecommitdiff
path: root/test/dialect/mysql/test_for_update.py
blob: 3537c3220dcef1edf9dd93b285a7306ff97c3d8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Test MySQL FOR UPDATE behavior.

See #4246

"""
import contextlib

from sqlalchemy import Column
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import testing
from sqlalchemy import update
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.testing import fixtures


class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest):
    __backend__ = True
    __only_on__ = "mysql"
    __requires__ = ("mysql_for_update",)

    @classmethod
    def setup_classes(cls):
        Base = cls.DeclarativeBasic

        class A(Base):
            __tablename__ = "a"
            id = Column(Integer, primary_key=True)
            x = Column(Integer)
            y = Column(Integer)
            bs = relationship("B")
            __table_args__ = {"mysql_engine": "InnoDB"}

        class B(Base):
            __tablename__ = "b"
            id = Column(Integer, primary_key=True)
            a_id = Column(ForeignKey("a.id"))
            x = Column(Integer)
            y = Column(Integer)
            __table_args__ = {"mysql_engine": "InnoDB"}

    @classmethod
    def insert_data(cls):
        A = cls.classes.A
        B = cls.classes.B

        # all the x/y are < 10
        s = Session()
        s.add_all(
            [
                A(x=5, y=5, bs=[B(x=4, y=4), B(x=2, y=8), B(x=7, y=1)]),
                A(x=7, y=5, bs=[B(x=4, y=4), B(x=5, y=8)]),
            ]
        )
        s.commit()

    @contextlib.contextmanager
    def run_test(self):
        connection = testing.db.connect()
        connection.execute("set innodb_lock_wait_timeout=1")
        main_trans = connection.begin()
        try:
            yield Session(bind=connection)
        finally:
            main_trans.rollback()
            connection.close()

    def _assert_a_is_locked(self, should_be_locked):
        A = self.classes.A
        with testing.db.begin() as alt_trans:
            alt_trans.execute("set innodb_lock_wait_timeout=1")
            # set x/y > 10
            try:
                alt_trans.execute(update(A).values(x=15, y=19))
            except (exc.InternalError, exc.OperationalError) as err:
                assert "Lock wait timeout exceeded" in str(err)
                assert should_be_locked
            else:
                assert not should_be_locked

    def _assert_b_is_locked(self, should_be_locked):
        B = self.classes.B
        with testing.db.begin() as alt_trans:
            alt_trans.execute("set innodb_lock_wait_timeout=1")
            # set x/y > 10
            try:
                alt_trans.execute(update(B).values(x=15, y=19))
            except (exc.InternalError, exc.OperationalError) as err:
                assert "Lock wait timeout exceeded" in str(err)
                assert should_be_locked
            else:
                assert not should_be_locked

    def test_basic_lock(self):
        A = self.classes.A
        with self.run_test() as s:
            s.query(A).with_for_update().all()
            # test our fixture
            self._assert_a_is_locked(True)

    def test_basic_not_lock(self):
        A = self.classes.A
        with self.run_test() as s:
            s.query(A).all()
            # test our fixture
            self._assert_a_is_locked(False)

    def test_joined_lock_subquery(self):
        A = self.classes.A
        with self.run_test() as s:
            s.query(A).options(joinedload(A.bs)).with_for_update().first()

            # test for issue #4246, should be locked
            self._assert_a_is_locked(True)
            self._assert_b_is_locked(True)

    def test_joined_lock_subquery_inner_for_update(self):
        A = self.classes.A
        B = self.classes.B
        with self.run_test() as s:
            q = s.query(A).with_for_update().subquery()
            s.query(q).join(B).all()

            # FOR UPDATE is inside the subquery, should be locked
            self._assert_a_is_locked(True)

            # FOR UPDATE is inside the subquery, B is not locked
            self._assert_b_is_locked(False)

    def test_joined_lock_subquery_inner_for_update_outer(self):
        A = self.classes.A
        B = self.classes.B
        with self.run_test() as s:
            q = s.query(A).with_for_update().subquery()
            s.query(q).join(B).with_for_update().all()

            # FOR UPDATE is inside the subquery, should be locked
            self._assert_a_is_locked(True)

            # FOR UPDATE is also outside the subquery, B is locked
            self._assert_b_is_locked(True)

    def test_joined_lock_subquery_order_for_update_outer(self):
        A = self.classes.A
        B = self.classes.B
        with self.run_test() as s:
            q = s.query(A).order_by(A.id).subquery()
            s.query(q).join(B).with_for_update().all()
            # FOR UPDATE is inside the subquery, should not be locked
            self._assert_a_is_locked(False)
            self._assert_b_is_locked(True)

    def test_joined_lock_no_subquery(self):
        A = self.classes.A
        with self.run_test() as s:
            s.query(A).options(joinedload(A.bs)).with_for_update().all()
            # no subquery, should be locked
            self._assert_a_is_locked(True)
            self._assert_b_is_locked(True)