summaryrefslogtreecommitdiff
path: root/test/engine/test_transaction.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/engine/test_transaction.py')
-rw-r--r--test/engine/test_transaction.py218
1 files changed, 85 insertions, 133 deletions
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py
index 8a5303642..b3b17e75a 100644
--- a/test/engine/test_transaction.py
+++ b/test/engine/test_transaction.py
@@ -133,6 +133,91 @@ class TransactionTest(fixtures.TestBase):
finally:
connection.close()
+ def test_branch_nested_rollback(self):
+ connection = testing.db.connect()
+ try:
+ connection.begin()
+ branched = connection.connect()
+ assert branched.in_transaction()
+ branched.execute(users.insert(), user_id=1, user_name='user1')
+ nested = branched.begin()
+ branched.execute(users.insert(), user_id=2, user_name='user2')
+ nested.rollback()
+ assert not connection.in_transaction()
+ eq_(connection.scalar("select count(*) from query_users"), 0)
+
+ finally:
+ connection.close()
+
+ def test_branch_autorollback(self):
+ connection = testing.db.connect()
+ try:
+ branched = connection.connect()
+ branched.execute(users.insert(), user_id=1, user_name='user1')
+ try:
+ branched.execute(users.insert(), user_id=1, user_name='user1')
+ except exc.DBAPIError:
+ pass
+ finally:
+ connection.close()
+
+ def test_branch_orig_rollback(self):
+ connection = testing.db.connect()
+ try:
+ branched = connection.connect()
+ branched.execute(users.insert(), user_id=1, user_name='user1')
+ nested = branched.begin()
+ assert branched.in_transaction()
+ branched.execute(users.insert(), user_id=2, user_name='user2')
+ nested.rollback()
+ eq_(connection.scalar("select count(*) from query_users"), 1)
+
+ finally:
+ connection.close()
+
+ def test_branch_autocommit(self):
+ connection = testing.db.connect()
+ try:
+ branched = connection.connect()
+ branched.execute(users.insert(), user_id=1, user_name='user1')
+ finally:
+ connection.close()
+ eq_(testing.db.scalar("select count(*) from query_users"), 1)
+
+ @testing.requires.savepoints
+ def test_branch_savepoint_rollback(self):
+ connection = testing.db.connect()
+ try:
+ trans = connection.begin()
+ branched = connection.connect()
+ assert branched.in_transaction()
+ branched.execute(users.insert(), user_id=1, user_name='user1')
+ nested = branched.begin_nested()
+ branched.execute(users.insert(), user_id=2, user_name='user2')
+ nested.rollback()
+ assert connection.in_transaction()
+ trans.commit()
+ eq_(connection.scalar("select count(*) from query_users"), 1)
+
+ finally:
+ connection.close()
+
+ @testing.requires.two_phase_transactions
+ def test_branch_twophase_rollback(self):
+ connection = testing.db.connect()
+ try:
+ branched = connection.connect()
+ assert not branched.in_transaction()
+ branched.execute(users.insert(), user_id=1, user_name='user1')
+ nested = branched.begin_twophase()
+ branched.execute(users.insert(), user_id=2, user_name='user2')
+ nested.rollback()
+ assert not connection.in_transaction()
+ eq_(connection.scalar("select count(*) from query_users"), 1)
+
+ finally:
+ connection.close()
+
def test_retains_through_options(self):
connection = testing.db.connect()
try:
@@ -1126,139 +1211,6 @@ class TLTransactionTest(fixtures.TestBase):
order_by(users.c.user_id)).fetchall(),
[(1, ), (2, )])
-counters = None
-
-
-class ForUpdateTest(fixtures.TestBase):
- __requires__ = 'ad_hoc_engines',
- __backend__ = True
-
- @classmethod
- def setup_class(cls):
- global counters, metadata
- metadata = MetaData()
- counters = Table('forupdate_counters', metadata,
- Column('counter_id', INT, primary_key=True),
- Column('counter_value', INT),
- test_needs_acid=True)
- counters.create(testing.db)
-
- def teardown(self):
- testing.db.execute(counters.delete()).close()
-
- @classmethod
- def teardown_class(cls):
- counters.drop(testing.db)
-
- def increment(self, count, errors, update_style=True, delay=0.005):
- con = testing.db.connect()
- sel = counters.select(for_update=update_style,
- whereclause=counters.c.counter_id == 1)
- for i in range(count):
- trans = con.begin()
- try:
- existing = con.execute(sel).first()
- incr = existing['counter_value'] + 1
- time.sleep(delay)
- con.execute(counters.update(counters.c.counter_id == 1,
- values={'counter_value': incr}))
- time.sleep(delay)
- readback = con.execute(sel).first()
- if readback['counter_value'] != incr:
- raise AssertionError('Got %s post-update, expected '
- '%s' % (readback['counter_value'], incr))
- trans.commit()
- except Exception as e:
- trans.rollback()
- errors.append(e)
- break
- con.close()
-
- @testing.crashes('mssql', 'FIXME: unknown')
- @testing.crashes('firebird', 'FIXME: unknown')
- @testing.crashes('sybase', 'FIXME: unknown')
- @testing.requires.independent_connections
- def test_queued_update(self):
- """Test SELECT FOR UPDATE with concurrent modifications.
-
- Runs concurrent modifications on a single row in the users
- table, with each mutator trying to increment a value stored in
- user_name.
-
- """
-
- db = testing.db
- db.execute(counters.insert(), counter_id=1, counter_value=0)
- iterations, thread_count = 10, 5
- threads, errors = [], []
- for i in range(thread_count):
- thrd = threading.Thread(target=self.increment,
- args=(iterations, ),
- kwargs={'errors': errors,
- 'update_style': True})
- thrd.start()
- threads.append(thrd)
- for thrd in threads:
- thrd.join()
- assert not errors
- sel = counters.select(whereclause=counters.c.counter_id == 1)
- final = db.execute(sel).first()
- eq_(final['counter_value'], iterations * thread_count)
-
- def overlap(self, ids, errors, update_style):
-
- sel = counters.select(for_update=update_style,
- whereclause=counters.c.counter_id.in_(ids))
- con = testing.db.connect()
- trans = con.begin()
- try:
- rows = con.execute(sel).fetchall()
- time.sleep(0.50)
- trans.commit()
- except Exception as e:
- trans.rollback()
- errors.append(e)
- con.close()
-
- def _threaded_overlap(self, thread_count, groups, update_style=True, pool=5):
- db = testing.db
- for cid in range(pool - 1):
- db.execute(counters.insert(), counter_id=cid + 1,
- counter_value=0)
- errors, threads = [], []
- for i in range(thread_count):
- thrd = threading.Thread(target=self.overlap,
- args=(groups.pop(0), errors,
- update_style))
- time.sleep(0.20) # give the previous thread a chance to start
- # to ensure it gets a lock
- thrd.start()
- threads.append(thrd)
- for thrd in threads:
- thrd.join()
- return errors
-
- @testing.crashes('mssql', 'FIXME: unknown')
- @testing.crashes('firebird', 'FIXME: unknown')
- @testing.crashes('sybase', 'FIXME: unknown')
- @testing.requires.independent_connections
- def test_queued_select(self):
- """Simple SELECT FOR UPDATE conflict test"""
-
- errors = self._threaded_overlap(2, [(1, 2, 3), (3, 4, 5)])
- assert not errors
-
- @testing.crashes('mssql', 'FIXME: unknown')
- @testing.fails_on('mysql', 'No support for NOWAIT')
- @testing.crashes('firebird', 'FIXME: unknown')
- @testing.crashes('sybase', 'FIXME: unknown')
- @testing.requires.independent_connections
- def test_nowait_select(self):
- """Simple SELECT FOR UPDATE NOWAIT conflict test"""
-
- errors = self._threaded_overlap(2, [(1, 2, 3), (3, 4, 5)],
- update_style='nowait')
- assert errors
class IsolationLevelTest(fixtures.TestBase):
__requires__ = ('isolation_level', 'ad_hoc_engines')