diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
| commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
| tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/test/engines.py | |
| parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
| download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz | |
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/test/engines.py')
| -rw-r--r-- | lib/sqlalchemy/test/engines.py | 60 |
1 files changed, 52 insertions, 8 deletions
diff --git a/lib/sqlalchemy/test/engines.py b/lib/sqlalchemy/test/engines.py index f0001978b..187ad2ff0 100644 --- a/lib/sqlalchemy/test/engines.py +++ b/lib/sqlalchemy/test/engines.py @@ -2,6 +2,7 @@ import sys, types, weakref from collections import deque import config from sqlalchemy.util import function_named, callable +import re class ConnectionKiller(object): def __init__(self): @@ -11,7 +12,8 @@ class ConnectionKiller(object): self.proxy_refs[con_proxy] = True def _apply_all(self, methods): - for rec in self.proxy_refs: + # must copy keys atomically + for rec in self.proxy_refs.keys(): if rec is not None and rec.is_valid: try: for name in methods: @@ -38,6 +40,10 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() +def drop_all_tables(metadata): + testing_reaper.close_all() + metadata.drop_all() + def assert_conns_closed(fn): def decorated(*args, **kw): try: @@ -56,6 +62,14 @@ def rollback_open_connections(fn): testing_reaper.rollback_all() return function_named(decorated, fn.__name__) +def close_first(fn): + """Decorator that closes all connections before fn execution.""" + def decorated(*args, **kw): + testing_reaper.close_all() + fn(*args, **kw) + return function_named(decorated, fn.__name__) + + def close_open_connections(fn): """Decorator that closes all connections after fn execution.""" @@ -69,7 +83,10 @@ def close_open_connections(fn): def all_dialects(): import sqlalchemy.databases as d for name in d.__all__: - mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + # TEMPORARY + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() class ReconnectFixture(object): @@ -115,7 +132,11 @@ def testing_engine(url=None, options=None): listeners.append(testing_reaper) engine = create_engine(url, **options) - + + # may want to call this, results + # in first-connect initializers + #engine.connect() + return engine def utf8_engine(url=None, options=None): @@ -123,7 +144,7 @@ def utf8_engine(url=None, options=None): from sqlalchemy.engine import url as engine_url - if config.db.name == 'mysql': + if config.db.driver == 'mysqldb': dbapi_ver = config.db.dialect.dbapi.version_info if (dbapi_ver < (1, 2, 1) or dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), @@ -139,19 +160,35 @@ def utf8_engine(url=None, options=None): return testing_engine(url, options) -def mock_engine(db=None): - """Provides a mocking engine based on the current testing.db.""" +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ from sqlalchemy import create_engine - dbi = db or config.db + if not dialect_name: + dialect_name = config.db.name + buffer = [] def executor(sql, *a, **kw): buffer.append(sql) - engine = create_engine(dbi.name + '://', + def assert_sql(stmts): + recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + assert recv == stmts, recv + + engine = create_engine(dialect_name + '://', strategy='mock', executor=executor) assert not hasattr(engine, 'mock') engine.mock = buffer + engine.assert_sql = assert_sql return engine class ReplayableSession(object): @@ -168,9 +205,16 @@ class ReplayableSession(object): Natives = set([getattr(types, t) for t in dir(types) if not t.startswith('_')]). \ difference([getattr(types, t) + # Py3K + #for t in ('FunctionType', 'BuiltinFunctionType', + # 'MethodType', 'BuiltinMethodType', + # 'LambdaType', )]) + + # Py2K for t in ('FunctionType', 'BuiltinFunctionType', 'MethodType', 'BuiltinMethodType', 'LambdaType', 'UnboundMethodType',)]) + # end Py2K def __init__(self): self.buffer = deque() |
