summaryrefslogtreecommitdiff
path: root/test/engine
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-06-05 17:25:51 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-06-05 17:25:51 +0000
commit120dcee5a71187d4bebfe50aedbbefb09184cac1 (patch)
treef2a090a510c8df405d0b1bef2936bafa3511be07 /test/engine
parentf8314ef9ff08af5f104731de402d6e6bd8c043f3 (diff)
downloadsqlalchemy-120dcee5a71187d4bebfe50aedbbefb09184cac1.tar.gz
reorganized unit tests into subdirectories
Diffstat (limited to 'test/engine')
-rw-r--r--test/engine/__init__.py0
-rw-r--r--test/engine/alltests.py28
-rw-r--r--test/engine/autoconnect_engine.py90
-rw-r--r--test/engine/parseconnect.py33
-rw-r--r--test/engine/pool.py164
-rw-r--r--test/engine/proxy_engine.py198
-rw-r--r--test/engine/reflection.py222
-rw-r--r--test/engine/transaction.py210
8 files changed, 945 insertions, 0 deletions
diff --git a/test/engine/__init__.py b/test/engine/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/test/engine/__init__.py
diff --git a/test/engine/alltests.py b/test/engine/alltests.py
new file mode 100644
index 000000000..c63cb2861
--- /dev/null
+++ b/test/engine/alltests.py
@@ -0,0 +1,28 @@
+import testbase
+import unittest
+
+
+def suite():
+ modules_to_test = (
+ # connectivity, execution
+ 'engine.parseconnect',
+ 'engine.pool',
+ 'engine.transaction',
+
+ # schema/tables
+ 'engine.reflection',
+
+ 'engine.proxy_engine'
+ )
+ alltests = unittest.TestSuite()
+ for name in modules_to_test:
+ mod = __import__(name)
+ for token in name.split('.')[1:]:
+ mod = getattr(mod, token)
+ alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
+ return alltests
+
+
+
+if __name__ == '__main__':
+ testbase.runTests(suite())
diff --git a/test/engine/autoconnect_engine.py b/test/engine/autoconnect_engine.py
new file mode 100644
index 000000000..39bcf3e53
--- /dev/null
+++ b/test/engine/autoconnect_engine.py
@@ -0,0 +1,90 @@
+from sqlalchemy import *
+from sqlalchemy.ext.proxy import AutoConnectEngine
+
+from testbase import PersistTest
+import testbase
+import os
+
+#
+# Define an engine, table and mapper at the module level, to show that the
+# table and mapper can be used with different real engines in multiple threads
+#
+
+
+module_engine = AutoConnectEngine( testbase.db_uri )
+users = Table('users', module_engine,
+ Column('user_id', Integer, primary_key=True),
+ Column('user_name', String(16)),
+ Column('password', String(20))
+ )
+
+class User(object):
+ pass
+
+
+class AutoConnectEngineTest1(PersistTest):
+
+ def setUp(self):
+ clear_mappers()
+ objectstore.clear()
+
+ def test_engine_connect(self):
+ users.create()
+ assign_mapper(User, users)
+ try:
+ trans = objectstore.begin()
+
+ user = User()
+ user.user_name='fred'
+ user.password='*'
+ trans.commit()
+
+ # select
+ sqluser = User.select_by(user_name='fred')[0]
+ assert sqluser.user_name == 'fred'
+
+ # modify
+ sqluser.user_name = 'fred jones'
+
+ # commit - saves everything that changed
+ objectstore.commit()
+
+ allusers = [ user.user_name for user in User.select() ]
+ assert allusers == [ 'fred jones' ]
+ finally:
+ users.drop()
+
+
+
+
+if __name__ == "__main__":
+ testbase.main()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py
new file mode 100644
index 000000000..43389c272
--- /dev/null
+++ b/test/engine/parseconnect.py
@@ -0,0 +1,33 @@
+from testbase import PersistTest
+import sqlalchemy.engine.url as url
+import unittest
+
+class ParseConnectTest(PersistTest):
+ def testrfc1738(self):
+ for text in (
+ 'dbtype://username:password@hostspec:110//usr/db_file.db',
+ 'dbtype://username:password@hostspec/database',
+ 'dbtype://username:password@hostspec',
+ 'dbtype://username:password@/database',
+ 'dbtype://username@hostspec',
+ 'dbtype://username:password@127.0.0.1:1521',
+ 'dbtype://hostspec/database',
+ 'dbtype://hostspec',
+ 'dbtype://hostspec/?arg1=val1&arg2=val2',
+ 'dbtype:///database',
+ 'dbtype:///:memory:',
+ 'dbtype:///foo/bar/im/a/file',
+ 'dbtype:///E:/work/src/LEM/db/hello.db',
+ 'dbtype:///E:/work/src/LEM/db/hello.db?foo=bar&hoho=lala',
+ 'dbtype://',
+ 'dbtype://username:password@/db'
+ ):
+ u = url.make_url(text)
+ print u, text
+ print "username=", u.username, "password=", u.password, "database=", u.database, "host=", u.host
+ assert str(u) == text
+
+
+if __name__ == "__main__":
+ unittest.main()
+ \ No newline at end of file
diff --git a/test/engine/pool.py b/test/engine/pool.py
new file mode 100644
index 000000000..9a8c7cffd
--- /dev/null
+++ b/test/engine/pool.py
@@ -0,0 +1,164 @@
+from testbase import PersistTest
+import unittest, sys, os, time
+
+from pysqlite2 import dbapi2 as sqlite
+import sqlalchemy.pool as pool
+
+class PoolTest(PersistTest):
+
+ def setUp(self):
+ pool.clear_managers()
+
+ def testmanager(self):
+ manager = pool.manage(sqlite)
+
+ connection = manager.connect('foo.db')
+ connection2 = manager.connect('foo.db')
+ connection3 = manager.connect('bar.db')
+
+ self.echo( "connection " + repr(connection))
+ self.assert_(connection.cursor() is not None)
+ self.assert_(connection is connection2)
+ self.assert_(connection2 is not connection3)
+
+ def testbadargs(self):
+ manager = pool.manage(sqlite)
+
+ try:
+ connection = manager.connect(None)
+ except:
+ pass
+
+ def testnonthreadlocalmanager(self):
+ manager = pool.manage(sqlite, use_threadlocal = False)
+
+ connection = manager.connect('foo.db')
+ connection2 = manager.connect('foo.db')
+
+ self.echo( "connection " + repr(connection))
+
+ self.assert_(connection.cursor() is not None)
+ self.assert_(connection is not connection2)
+
+ def testqueuepool_del(self):
+ self._do_testqueuepool(useclose=False)
+
+ def testqueuepool_close(self):
+ self._do_testqueuepool(useclose=True)
+
+ def _do_testqueuepool(self, useclose=False):
+
+ p = pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = False, echo = False)
+
+ def status(pool):
+ tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
+ self.echo( "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup)
+ return tup
+
+ c1 = p.connect()
+ self.assert_(status(p) == (3,0,-2,1))
+ c2 = p.connect()
+ self.assert_(status(p) == (3,0,-1,2))
+ c3 = p.connect()
+ self.assert_(status(p) == (3,0,0,3))
+ c4 = p.connect()
+ self.assert_(status(p) == (3,0,1,4))
+ c5 = p.connect()
+ self.assert_(status(p) == (3,0,2,5))
+ c6 = p.connect()
+ self.assert_(status(p) == (3,0,3,6))
+ if useclose:
+ c4.close()
+ c3.close()
+ c2.close()
+ else:
+ c4 = c3 = c2 = None
+ self.assert_(status(p) == (3,3,3,3))
+ if useclose:
+ c1.close()
+ c5.close()
+ c6.close()
+ else:
+ c1 = c5 = c6 = None
+ self.assert_(status(p) == (3,3,0,0))
+ c1 = p.connect()
+ c2 = p.connect()
+ self.assert_(status(p) == (3, 1, 0, 2))
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+ self.assert_(status(p) == (3, 2, 0, 1))
+
+ def test_timeout(self):
+ p = pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = 0, use_threadlocal = False, echo = False, timeout=2)
+ c1 = p.get()
+ c2 = p.get()
+ c3 = p.get()
+ now = time.time()
+ c4 = p.get()
+ assert int(time.time() - now) == 2
+
+ def testthreadlocal_del(self):
+ self._do_testthreadlocal(useclose=False)
+
+ def testthreadlocal_close(self):
+ self._do_testthreadlocal(useclose=True)
+
+ def _do_testthreadlocal(self, useclose=False):
+ for p in (
+ pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True, echo = False),
+ pool.SingletonThreadPool(creator = lambda: sqlite.connect('foo.db'), use_threadlocal = True)
+ ):
+ c1 = p.connect()
+ c2 = p.connect()
+ self.assert_(c1 is c2)
+ c3 = p.unique_connection()
+ self.assert_(c3 is not c1)
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+ c2 = p.connect()
+ self.assert_(c1 is c2)
+ self.assert_(c3 is not c1)
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+
+ c3 = None
+
+ if useclose:
+ c1 = p.connect()
+ c2 = p.connect()
+ c3 = p.connect()
+ c3.close()
+ c2.close()
+ self.assert_(c1.connection is not None)
+ c1.close()
+ else:
+ c1 = c2 = c3 = None
+
+ # extra tests with QueuePool to insure connections get __del__()ed when dereferenced
+ if isinstance(p, pool.QueuePool):
+ self.assert_(p.checkedout() == 0)
+ c1 = p.connect()
+ c2 = p.connect()
+ if useclose:
+ c2.close()
+ c1.close()
+ else:
+ c2 = None
+ c1 = None
+ self.assert_(p.checkedout() == 0)
+
+ def tearDown(self):
+ pool.clear_managers()
+ for file in ('foo.db', 'bar.db'):
+ if os.access(file, os.F_OK):
+ os.remove(file)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/engine/proxy_engine.py b/test/engine/proxy_engine.py
new file mode 100644
index 000000000..df0c64398
--- /dev/null
+++ b/test/engine/proxy_engine.py
@@ -0,0 +1,198 @@
+import os
+
+from sqlalchemy import *
+from sqlalchemy.ext.proxy import ProxyEngine
+
+from testbase import PersistTest
+import testbase
+
+#
+# Define an engine, table and mapper at the module level, to show that the
+# table and mapper can be used with different real engines in multiple threads
+#
+
+
+class ProxyTestBase(PersistTest):
+ def setUpAll(self):
+
+ global users, User, module_engine, module_metadata
+
+ module_engine = ProxyEngine(echo=testbase.echo)
+ module_metadata = MetaData()
+
+ users = Table('users', module_metadata,
+ Column('user_id', Integer, primary_key=True),
+ Column('user_name', String(16)),
+ Column('password', String(20))
+ )
+
+ class User(object):
+ pass
+
+ User.mapper = mapper(User, users)
+ def tearDownAll(self):
+ clear_mappers()
+
+class ConstructTest(ProxyTestBase):
+ """tests that we can build SQL constructs without engine-specific parameters, particulary
+ oid_column, being needed, as the proxy engine is usually not connected yet."""
+
+ def test_join(self):
+ engine = ProxyEngine()
+ t = Table('table1', engine,
+ Column('col1', Integer, primary_key=True))
+ t2 = Table('table2', engine,
+ Column('col2', Integer, ForeignKey('table1.col1')))
+ j = join(t, t2)
+
+
+class ProxyEngineTest1(ProxyTestBase):
+
+ def test_engine_connect(self):
+ # connect to a real engine
+ module_engine.connect(testbase.db_uri)
+ module_metadata.create_all(module_engine)
+
+ session = create_session(bind_to=module_engine)
+ try:
+
+ user = User()
+ user.user_name='fred'
+ user.password='*'
+
+ session.save(user)
+ session.flush()
+
+ query = session.query(User)
+
+ # select
+ sqluser = query.select_by(user_name='fred')[0]
+ assert sqluser.user_name == 'fred'
+
+ # modify
+ sqluser.user_name = 'fred jones'
+
+ # flush - saves everything that changed
+ session.flush()
+
+ allusers = [ user.user_name for user in query.select() ]
+ assert allusers == ['fred jones']
+
+ finally:
+ module_metadata.drop_all(module_engine)
+
+
+class ThreadProxyTest(ProxyTestBase):
+
+ def tearDownAll(self):
+ os.remove('threadtesta.db')
+ os.remove('threadtestb.db')
+
+ def test_multi_thread(self):
+
+ from threading import Thread
+ from Queue import Queue
+
+ # start 2 threads with different connection params
+ # and perform simultaneous operations, showing that the
+ # 2 threads don't share a connection
+ qa = Queue()
+ qb = Queue()
+ def run(db_uri, uname, queue):
+ def test():
+
+ try:
+ module_engine.connect(db_uri)
+ module_metadata.create_all(module_engine)
+ try:
+ session = create_session(bind_to=module_engine)
+
+ query = session.query(User)
+
+ all = list(query.select())
+ assert all == []
+
+ u = User()
+ u.user_name = uname
+ u.password = 'whatever'
+
+ session.save(u)
+ session.flush()
+
+ names = [u.user_name for u in query.select()]
+ assert names == [uname]
+ finally:
+ module_metadata.drop_all(module_engine)
+ except Exception, e:
+ import traceback
+ traceback.print_exc()
+ queue.put(e)
+ else:
+ queue.put(False)
+ return test
+
+ # NOTE: I'm not sure how to give the test runner the option to
+ # override these uris, or how to safely clear them after test runs
+ a = Thread(target=run('sqlite:///threadtesta.db', 'jim', qa))
+ b = Thread(target=run('sqlite:///threadtestb.db', 'joe', qb))
+
+ a.start()
+ b.start()
+
+ # block and wait for the threads to push their results
+ res = qa.get(True)
+ if res != False:
+ raise res
+
+ res = qb.get(True)
+ if res != False:
+ raise res
+
+
+class ProxyEngineTest2(ProxyTestBase):
+
+ def test_table_singleton_a(self):
+ """set up for table singleton check
+ """
+ #
+ # For this 'test', create a proxy engine instance, connect it
+ # to a real engine, and make it do some work
+ #
+ engine = ProxyEngine()
+ cats = Table('cats', engine,
+ Column('cat_id', Integer, primary_key=True),
+ Column('cat_name', String))
+
+ engine.connect(testbase.db_uri)
+
+ cats.create(engine)
+ cats.drop(engine)
+
+ ProxyEngineTest2.cats_table_a = cats
+ assert isinstance(cats, Table)
+
+ def test_table_singleton_b(self):
+ """check that a table on a 2nd proxy engine instance gets 2nd table
+ instance
+ """
+ #
+ # Now create a new proxy engine instance and attach the same
+ # table as the first test. This should result in 2 table instances,
+ # since different proxy engine instances can't attach to the
+ # same table instance
+ #
+ engine = ProxyEngine()
+ cats = Table('cats', engine,
+ Column('cat_id', Integer, primary_key=True),
+ Column('cat_name', String))
+ assert id(cats) != id(ProxyEngineTest2.cats_table_a)
+
+ # the real test -- if we're still using the old engine reference,
+ # this will fail because the old reference's local storage will
+ # not have the default attributes
+ engine.connect(testbase.db_uri)
+ cats.create(engine)
+ cats.drop(engine)
+
+if __name__ == "__main__":
+ testbase.main()
diff --git a/test/engine/reflection.py b/test/engine/reflection.py
new file mode 100644
index 000000000..85c97d704
--- /dev/null
+++ b/test/engine/reflection.py
@@ -0,0 +1,222 @@
+
+import sqlalchemy.ansisql as ansisql
+import sqlalchemy.databases.postgres as postgres
+
+from sqlalchemy import *
+
+from testbase import PersistTest
+import testbase
+import unittest, re
+
+class ReflectionTest(PersistTest):
+ def testbasic(self):
+ # really trip it up with a circular reference
+
+ use_function_defaults = testbase.db.engine.name == 'postgres' or testbase.db.engine.name == 'oracle'
+
+ use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite')
+
+ if use_function_defaults:
+ defval = func.current_date()
+ deftype = Date
+ else:
+ defval = "3"
+ deftype = Integer
+
+ if use_string_defaults:
+ deftype2 = String
+ defval2 = "im a default"
+ else:
+ deftype2 = Integer
+ defval2 = "15"
+
+ users = Table('engine_users', testbase.db,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20), nullable = False),
+ Column('test1', CHAR(5), nullable = False),
+ Column('test2', FLOAT(5), nullable = False),
+ Column('test3', TEXT),
+ Column('test4', DECIMAL, nullable = False),
+ Column('test5', TIMESTAMP),
+ Column('parent_user_id', Integer, ForeignKey('engine_users.user_id')),
+ Column('test6', DateTime, nullable = False),
+ Column('test7', String),
+ Column('test8', Binary),
+ Column('test_passivedefault', deftype, PassiveDefault(defval)),
+ Column('test_passivedefault2', Integer, PassiveDefault("5")),
+ Column('test_passivedefault3', deftype2, PassiveDefault(defval2)),
+ Column('test9', Binary(100)),
+ mysql_engine='InnoDB'
+ )
+
+ addresses = Table('engine_email_addresses', testbase.db,
+ Column('address_id', Integer, primary_key = True),
+ Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
+ Column('email_address', String(20)),
+ mysql_engine='InnoDB'
+ )
+
+
+# users.c.parent_user_id.set_foreign_key(ForeignKey(users.c.user_id))
+
+ users.create()
+ addresses.create()
+
+ # clear out table registry
+ users.deregister()
+ addresses.deregister()
+
+ try:
+ users = Table('engine_users', testbase.db, autoload = True)
+ addresses = Table('engine_email_addresses', testbase.db, autoload = True)
+ finally:
+ addresses.drop()
+ users.drop()
+
+ users.create()
+ addresses.create()
+ try:
+ # create a join from the two tables, this insures that
+ # theres a foreign key set up
+ # previously, we couldnt get foreign keys out of mysql. seems like
+ # we can now as long as we use InnoDB
+# if testbase.db.engine.__module__.endswith('mysql'):
+ # addresses.c.remote_user_id.append_item(ForeignKey('engine_users.user_id'))
+ print users
+ print addresses
+ j = join(users, addresses)
+ print str(j.onclause)
+ self.assert_((users.c.user_id==addresses.c.remote_user_id).compare(j.onclause))
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def testmultipk(self):
+ table = Table(
+ 'engine_multi', testbase.db,
+ Column('multi_id', Integer, primary_key=True),
+ Column('multi_rev', Integer, primary_key=True),
+ Column('name', String(50), nullable=False),
+ Column('val', String(100))
+ )
+ table.create()
+ # clear out table registry
+ table.deregister()
+
+ try:
+ table = Table('engine_multi', testbase.db, autoload=True)
+ finally:
+ table.drop()
+
+ print repr(
+ [table.c['multi_id'].primary_key,
+ table.c['multi_rev'].primary_key
+ ]
+ )
+ table.create()
+ table.insert().execute({'multi_id':1,'multi_rev':1,'name':'row1', 'val':'value1'})
+ table.insert().execute({'multi_id':2,'multi_rev':18,'name':'row2', 'val':'value2'})
+ table.insert().execute({'multi_id':3,'multi_rev':3,'name':'row3', 'val':'value3'})
+ table.select().execute().fetchall()
+ table.drop()
+
+ def testtoengine(self):
+ meta = MetaData('md1')
+ meta2 = MetaData('md2')
+
+ table = Table('mytable', meta,
+ Column('myid', Integer, key = 'id'),
+ Column('name', String, key = 'name', nullable=False),
+ Column('description', String, key = 'description'),
+ )
+
+ print repr(table)
+
+ table2 = table.tometadata(meta2)
+
+ print repr(table2)
+
+ assert table is not table2
+ assert table2.c.id.nullable
+ assert not table2.c.name.nullable
+ assert table2.c.description.nullable
+
+ def testoverride(self):
+ table = Table(
+ 'override_test', testbase.db,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(20)),
+ Column('col3', Numeric)
+ )
+ table.create()
+ # clear out table registry
+ table.deregister()
+
+ try:
+ table = Table(
+ 'override_test', testbase.db,
+ Column('col2', Unicode()),
+ Column('col4', String(30)), autoload=True)
+
+ print repr(table)
+ self.assert_(isinstance(table.c.col1.type, Integer))
+ self.assert_(isinstance(table.c.col2.type, Unicode))
+ self.assert_(isinstance(table.c.col4.type, String))
+ finally:
+ table.drop()
+
+class CreateDropTest(PersistTest):
+ def setUpAll(self):
+ global metadata
+ metadata = MetaData()
+ users = Table('users', metadata,
+ Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+ Column('user_name', String(40)),
+ )
+
+ addresses = Table('email_addresses', metadata,
+ Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
+ Column('user_id', Integer, ForeignKey(users.c.user_id)),
+ Column('email_address', String(40)),
+
+ )
+
+ orders = Table('orders', metadata,
+ Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True),
+ Column('user_id', Integer, ForeignKey(users.c.user_id)),
+ Column('description', String(50)),
+ Column('isopen', Integer),
+
+ )
+
+ orderitems = Table('items', metadata,
+ Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
+ Column('order_id', INT, ForeignKey("orders")),
+ Column('item_name', VARCHAR(50)),
+
+ )
+
+ def test_sorter( self ):
+ tables = metadata._sort_tables(metadata.tables.values())
+ table_names = [t.name for t in tables]
+ self.assert_( table_names == ['users', 'orders', 'items', 'email_addresses'] or table_names == ['users', 'email_addresses', 'orders', 'items'])
+
+
+ def test_createdrop(self):
+ metadata.create_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), True )
+ self.assertEqual( testbase.db.has_table('email_addresses'), True )
+ metadata.create_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), True )
+
+ metadata.drop_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), False )
+ self.assertEqual( testbase.db.has_table('email_addresses'), False )
+ metadata.drop_all(engine=testbase.db)
+ self.assertEqual( testbase.db.has_table('items'), False )
+
+
+
+if __name__ == "__main__":
+ testbase.main()
+
diff --git a/test/engine/transaction.py b/test/engine/transaction.py
new file mode 100644
index 000000000..408c9dc99
--- /dev/null
+++ b/test/engine/transaction.py
@@ -0,0 +1,210 @@
+
+import testbase
+import unittest, sys, datetime
+import tables
+db = testbase.db
+from sqlalchemy import *
+
+
+class TransactionTest(testbase.PersistTest):
+ def setUpAll(self):
+ global users, metadata
+ metadata = MetaData()
+ users = Table('query_users', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ users.create(testbase.db)
+
+ def tearDown(self):
+ testbase.db.connect().execute(users.delete())
+ def tearDownAll(self):
+ users.drop(testbase.db)
+
+ @testbase.unsupported('mysql')
+ def testrollback(self):
+ """test a basic rollback"""
+ connection = testbase.db.connect()
+ transaction = connection.begin()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ transaction.rollback()
+
+ result = connection.execute("select * from query_users")
+ assert len(result.fetchall()) == 0
+ connection.close()
+
+ @testbase.unsupported('mysql')
+ def testnesting(self):
+ connection = testbase.db.connect()
+ transaction = connection.begin()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ trans2 = connection.begin()
+ connection.execute(users.insert(), user_id=4, user_name='user4')
+ connection.execute(users.insert(), user_id=5, user_name='user5')
+ trans2.commit()
+ transaction.rollback()
+ self.assert_(connection.scalar("select count(1) from query_users") == 0)
+
+ result = connection.execute("select * from query_users")
+ assert len(result.fetchall()) == 0
+ connection.close()
+
+class AutoRollbackTest(testbase.PersistTest):
+ def setUpAll(self):
+ global metadata
+ metadata = MetaData()
+
+ def tearDownAll(self):
+ metadata.drop_all(testbase.db)
+
+ @testbase.unsupported('sqlite')
+ def testrollback_deadlock(self):
+ """test that returning connections to the pool clears any object locks."""
+ conn1 = testbase.db.connect()
+ conn2 = testbase.db.connect()
+ users = Table('deadlock_users', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ users.create(conn1)
+ conn1.execute("select * from deadlock_users")
+ conn1.close()
+ # without auto-rollback in the connection pool's return() logic, this deadlocks in Postgres,
+ # because conn1 is returned to the pool but still has a lock on "deadlock_users"
+ # comment out the rollback in pool/ConnectionFairy._close() to see !
+ users.drop(conn2)
+ conn2.close()
+
+class TLTransactionTest(testbase.PersistTest):
+ def setUpAll(self):
+ global users, metadata, tlengine
+ tlengine = create_engine(testbase.db_uri, strategy='threadlocal', echo=True)
+ metadata = MetaData()
+ users = Table('query_users', metadata,
+ Column('user_id', INT, primary_key = True),
+ Column('user_name', VARCHAR(20)),
+ )
+ users.create(tlengine)
+ def tearDown(self):
+ tlengine.execute(users.delete())
+ def tearDownAll(self):
+ users.drop(tlengine)
+ tlengine.dispose()
+
+ @testbase.unsupported('mysql')
+ def testrollback(self):
+ """test a basic rollback"""
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.rollback()
+
+ external_connection = tlengine.connect()
+ result = external_connection.execute("select * from query_users")
+ try:
+ assert len(result.fetchall()) == 0
+ finally:
+ external_connection.close()
+
+ @testbase.unsupported('mysql')
+ def testcommit(self):
+ """test a basic commit"""
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.commit()
+
+ external_connection = tlengine.connect()
+ result = external_connection.execute("select * from query_users")
+ try:
+ assert len(result.fetchall()) == 3
+ finally:
+ external_connection.close()
+
+ @testbase.unsupported('mysql', 'sqlite')
+ def testnesting(self):
+ """tests nesting of tranacstions"""
+ external_connection = tlengine.connect()
+ self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=4, user_name='user4')
+ tlengine.execute(users.insert(), user_id=5, user_name='user5')
+ tlengine.commit()
+ tlengine.rollback()
+ try:
+ self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
+ finally:
+ external_connection.close()
+
+ @testbase.unsupported('mysql')
+ def testmixednesting(self):
+ """tests nesting of transactions off the TLEngine directly inside of
+ tranasctions off the connection from the TLEngine"""
+ external_connection = tlengine.connect()
+ self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
+ conn = tlengine.contextual_connect()
+ trans = conn.begin()
+ trans2 = conn.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=4, user_name='user4')
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=5, user_name='user5')
+ tlengine.execute(users.insert(), user_id=6, user_name='user6')
+ tlengine.execute(users.insert(), user_id=7, user_name='user7')
+ tlengine.commit()
+ tlengine.execute(users.insert(), user_id=8, user_name='user8')
+ tlengine.commit()
+ trans2.commit()
+ trans.rollback()
+ conn.close()
+ try:
+ self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
+ finally:
+ external_connection.close()
+
+ @testbase.unsupported('mysql')
+ def testmoremixednesting(self):
+ """tests nesting of transactions off the connection from the TLEngine
+ inside of tranasctions off thbe TLEngine directly."""
+ external_connection = tlengine.connect()
+ self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
+ tlengine.begin()
+ connection = tlengine.contextual_connect()
+ connection.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.begin()
+ connection.execute(users.insert(), user_id=2, user_name='user2')
+ connection.execute(users.insert(), user_id=3, user_name='user3')
+ trans = connection.begin()
+ connection.execute(users.insert(), user_id=4, user_name='user4')
+ connection.execute(users.insert(), user_id=5, user_name='user5')
+ trans.commit()
+ tlengine.commit()
+ tlengine.rollback()
+ connection.close()
+ try:
+ self.assert_(external_connection.scalar("select count(1) from query_users") == 0)
+ finally:
+ external_connection.close()
+
+ def testconnections(self):
+ """tests that contextual_connect is threadlocal"""
+ c1 = tlengine.contextual_connect()
+ c2 = tlengine.contextual_connect()
+ assert c1.connection is c2.connection
+ c1.close()
+
+if __name__ == "__main__":
+ testbase.main()