diff options
Diffstat (limited to 'test/engine/test_transaction.py')
| -rw-r--r-- | test/engine/test_transaction.py | 218 |
1 files changed, 85 insertions, 133 deletions
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 8a5303642..b3b17e75a 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -133,6 +133,91 @@ class TransactionTest(fixtures.TestBase): finally: connection.close() + def test_branch_nested_rollback(self): + connection = testing.db.connect() + try: + connection.begin() + branched = connection.connect() + assert branched.in_transaction() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + assert not connection.in_transaction() + eq_(connection.scalar("select count(*) from query_users"), 0) + + finally: + connection.close() + + def test_branch_autorollback(self): + connection = testing.db.connect() + try: + branched = connection.connect() + branched.execute(users.insert(), user_id=1, user_name='user1') + try: + branched.execute(users.insert(), user_id=1, user_name='user1') + except exc.DBAPIError: + pass + finally: + connection.close() + + def test_branch_orig_rollback(self): + connection = testing.db.connect() + try: + branched = connection.connect() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin() + assert branched.in_transaction() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + eq_(connection.scalar("select count(*) from query_users"), 1) + + finally: + connection.close() + + def test_branch_autocommit(self): + connection = testing.db.connect() + try: + branched = connection.connect() + branched.execute(users.insert(), user_id=1, user_name='user1') + finally: + connection.close() + eq_(testing.db.scalar("select count(*) from query_users"), 1) + + @testing.requires.savepoints + def test_branch_savepoint_rollback(self): + connection = testing.db.connect() + try: + trans = connection.begin() + branched = connection.connect() + assert branched.in_transaction() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin_nested() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + assert connection.in_transaction() + trans.commit() + eq_(connection.scalar("select count(*) from query_users"), 1) + + finally: + connection.close() + + @testing.requires.two_phase_transactions + def test_branch_twophase_rollback(self): + connection = testing.db.connect() + try: + branched = connection.connect() + assert not branched.in_transaction() + branched.execute(users.insert(), user_id=1, user_name='user1') + nested = branched.begin_twophase() + branched.execute(users.insert(), user_id=2, user_name='user2') + nested.rollback() + assert not connection.in_transaction() + eq_(connection.scalar("select count(*) from query_users"), 1) + + finally: + connection.close() + def test_retains_through_options(self): connection = testing.db.connect() try: @@ -1126,139 +1211,6 @@ class TLTransactionTest(fixtures.TestBase): order_by(users.c.user_id)).fetchall(), [(1, ), (2, )]) -counters = None - - -class ForUpdateTest(fixtures.TestBase): - __requires__ = 'ad_hoc_engines', - __backend__ = True - - @classmethod - def setup_class(cls): - global counters, metadata - metadata = MetaData() - counters = Table('forupdate_counters', metadata, - Column('counter_id', INT, primary_key=True), - Column('counter_value', INT), - test_needs_acid=True) - counters.create(testing.db) - - def teardown(self): - testing.db.execute(counters.delete()).close() - - @classmethod - def teardown_class(cls): - counters.drop(testing.db) - - def increment(self, count, errors, update_style=True, delay=0.005): - con = testing.db.connect() - sel = counters.select(for_update=update_style, - whereclause=counters.c.counter_id == 1) - for i in range(count): - trans = con.begin() - try: - existing = con.execute(sel).first() - incr = existing['counter_value'] + 1 - time.sleep(delay) - con.execute(counters.update(counters.c.counter_id == 1, - values={'counter_value': incr})) - time.sleep(delay) - readback = con.execute(sel).first() - if readback['counter_value'] != incr: - raise AssertionError('Got %s post-update, expected ' - '%s' % (readback['counter_value'], incr)) - trans.commit() - except Exception as e: - trans.rollback() - errors.append(e) - break - con.close() - - @testing.crashes('mssql', 'FIXME: unknown') - @testing.crashes('firebird', 'FIXME: unknown') - @testing.crashes('sybase', 'FIXME: unknown') - @testing.requires.independent_connections - def test_queued_update(self): - """Test SELECT FOR UPDATE with concurrent modifications. - - Runs concurrent modifications on a single row in the users - table, with each mutator trying to increment a value stored in - user_name. - - """ - - db = testing.db - db.execute(counters.insert(), counter_id=1, counter_value=0) - iterations, thread_count = 10, 5 - threads, errors = [], [] - for i in range(thread_count): - thrd = threading.Thread(target=self.increment, - args=(iterations, ), - kwargs={'errors': errors, - 'update_style': True}) - thrd.start() - threads.append(thrd) - for thrd in threads: - thrd.join() - assert not errors - sel = counters.select(whereclause=counters.c.counter_id == 1) - final = db.execute(sel).first() - eq_(final['counter_value'], iterations * thread_count) - - def overlap(self, ids, errors, update_style): - - sel = counters.select(for_update=update_style, - whereclause=counters.c.counter_id.in_(ids)) - con = testing.db.connect() - trans = con.begin() - try: - rows = con.execute(sel).fetchall() - time.sleep(0.50) - trans.commit() - except Exception as e: - trans.rollback() - errors.append(e) - con.close() - - def _threaded_overlap(self, thread_count, groups, update_style=True, pool=5): - db = testing.db - for cid in range(pool - 1): - db.execute(counters.insert(), counter_id=cid + 1, - counter_value=0) - errors, threads = [], [] - for i in range(thread_count): - thrd = threading.Thread(target=self.overlap, - args=(groups.pop(0), errors, - update_style)) - time.sleep(0.20) # give the previous thread a chance to start - # to ensure it gets a lock - thrd.start() - threads.append(thrd) - for thrd in threads: - thrd.join() - return errors - - @testing.crashes('mssql', 'FIXME: unknown') - @testing.crashes('firebird', 'FIXME: unknown') - @testing.crashes('sybase', 'FIXME: unknown') - @testing.requires.independent_connections - def test_queued_select(self): - """Simple SELECT FOR UPDATE conflict test""" - - errors = self._threaded_overlap(2, [(1, 2, 3), (3, 4, 5)]) - assert not errors - - @testing.crashes('mssql', 'FIXME: unknown') - @testing.fails_on('mysql', 'No support for NOWAIT') - @testing.crashes('firebird', 'FIXME: unknown') - @testing.crashes('sybase', 'FIXME: unknown') - @testing.requires.independent_connections - def test_nowait_select(self): - """Simple SELECT FOR UPDATE NOWAIT conflict test""" - - errors = self._threaded_overlap(2, [(1, 2, 3), (3, 4, 5)], - update_style='nowait') - assert errors class IsolationLevelTest(fixtures.TestBase): __requires__ = ('isolation_level', 'ad_hoc_engines') |
