summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/test/engines.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
commit8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch)
treeae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/test/engines.py
parent7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff)
downloadsqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/test/engines.py')
-rw-r--r--lib/sqlalchemy/test/engines.py60
1 files changed, 52 insertions, 8 deletions
diff --git a/lib/sqlalchemy/test/engines.py b/lib/sqlalchemy/test/engines.py
index f0001978b..187ad2ff0 100644
--- a/lib/sqlalchemy/test/engines.py
+++ b/lib/sqlalchemy/test/engines.py
@@ -2,6 +2,7 @@ import sys, types, weakref
from collections import deque
import config
from sqlalchemy.util import function_named, callable
+import re
class ConnectionKiller(object):
def __init__(self):
@@ -11,7 +12,8 @@ class ConnectionKiller(object):
self.proxy_refs[con_proxy] = True
def _apply_all(self, methods):
- for rec in self.proxy_refs:
+ # must copy keys atomically
+ for rec in self.proxy_refs.keys():
if rec is not None and rec.is_valid:
try:
for name in methods:
@@ -38,6 +40,10 @@ class ConnectionKiller(object):
testing_reaper = ConnectionKiller()
+def drop_all_tables(metadata):
+ testing_reaper.close_all()
+ metadata.drop_all()
+
def assert_conns_closed(fn):
def decorated(*args, **kw):
try:
@@ -56,6 +62,14 @@ def rollback_open_connections(fn):
testing_reaper.rollback_all()
return function_named(decorated, fn.__name__)
+def close_first(fn):
+ """Decorator that closes all connections before fn execution."""
+ def decorated(*args, **kw):
+ testing_reaper.close_all()
+ fn(*args, **kw)
+ return function_named(decorated, fn.__name__)
+
+
def close_open_connections(fn):
"""Decorator that closes all connections after fn execution."""
@@ -69,7 +83,10 @@ def close_open_connections(fn):
def all_dialects():
import sqlalchemy.databases as d
for name in d.__all__:
- mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+ # TEMPORARY
+ mod = getattr(d, name, None)
+ if not mod:
+ mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
class ReconnectFixture(object):
@@ -115,7 +132,11 @@ def testing_engine(url=None, options=None):
listeners.append(testing_reaper)
engine = create_engine(url, **options)
-
+
+ # may want to call this, results
+ # in first-connect initializers
+ #engine.connect()
+
return engine
def utf8_engine(url=None, options=None):
@@ -123,7 +144,7 @@ def utf8_engine(url=None, options=None):
from sqlalchemy.engine import url as engine_url
- if config.db.name == 'mysql':
+ if config.db.driver == 'mysqldb':
dbapi_ver = config.db.dialect.dbapi.version_info
if (dbapi_ver < (1, 2, 1) or
dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
@@ -139,19 +160,35 @@ def utf8_engine(url=None, options=None):
return testing_engine(url, options)
-def mock_engine(db=None):
- """Provides a mocking engine based on the current testing.db."""
+def mock_engine(dialect_name=None):
+ """Provides a mocking engine based on the current testing.db.
+
+ This is normally used to test DDL generation flow as emitted
+ by an Engine.
+
+ It should not be used in other cases, as assert_compile() and
+ assert_sql_execution() are much better choices with fewer
+ moving parts.
+
+ """
from sqlalchemy import create_engine
- dbi = db or config.db
+ if not dialect_name:
+ dialect_name = config.db.name
+
buffer = []
def executor(sql, *a, **kw):
buffer.append(sql)
- engine = create_engine(dbi.name + '://',
+ def assert_sql(stmts):
+ recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ assert recv == stmts, recv
+
+ engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
engine.mock = buffer
+ engine.assert_sql = assert_sql
return engine
class ReplayableSession(object):
@@ -168,9 +205,16 @@ class ReplayableSession(object):
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]). \
difference([getattr(types, t)
+ # Py3K
+ #for t in ('FunctionType', 'BuiltinFunctionType',
+ # 'MethodType', 'BuiltinMethodType',
+ # 'LambdaType', )])
+
+ # Py2K
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', 'UnboundMethodType',)])
+ # end Py2K
def __init__(self):
self.buffer = deque()