diff options
Diffstat (limited to 'test/testlib/testing.py')
| -rw-r--r-- | test/testlib/testing.py | 134 |
1 files changed, 100 insertions, 34 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py index cf0936e92..1e2ca62e9 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -2,15 +2,27 @@ # monkeypatches unittest.TestLoader.suiteClass at import time -import itertools, os, operator, re, sys, unittest, warnings +import itertools +import operator +import re +import sys +import types +import unittest +import warnings from cStringIO import StringIO + import testlib.config as config -from testlib.compat import * +from testlib.compat import set, _function_named, reversed -sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None -sa_exceptions = None +# Delayed imports +MetaData = None +Session = None +clear_mappers = None +sa_exc = None +schema = None +sqltypes = None +util = None -__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL') _ops = { '<': operator.lt, '>': operator.gt, @@ -25,6 +37,9 @@ _ops = { '<': operator.lt, # sugar ('testing.db'); set here by config() at runtime db = None +# more sugar, installed by __init__ +requires = None + def fails_if(callable_): """Mark a test as expected to fail if callable_ returns True. @@ -224,17 +239,17 @@ def emits_warning(*messages): # - update: jython looks ok, it uses cpython's module def decorate(fn): def safe(*args, **kw): - global sa_exceptions - if sa_exceptions is None: - import sqlalchemy.exceptions as sa_exceptions + global sa_exc + if sa_exc is None: + import sqlalchemy.exc as sa_exc if not messages: filters = [dict(action='ignore', - category=sa_exceptions.SAWarning)] + category=sa_exc.SAWarning)] else: filters = [dict(action='ignore', message=message, - category=sa_exceptions.SAWarning) + category=sa_exc.SAWarning) for message in messages ] for f in filters: warnings.filterwarnings(**f) @@ -259,17 +274,17 @@ def uses_deprecated(*messages): def decorate(fn): def safe(*args, **kw): - global sa_exceptions - if sa_exceptions is None: - import sqlalchemy.exceptions as sa_exceptions + global sa_exc + if sa_exc is None: + import sqlalchemy.exc as sa_exc if not messages: filters = [dict(action='ignore', - category=sa_exceptions.SADeprecationWarning)] + category=sa_exc.SADeprecationWarning)] else: filters = [dict(action='ignore', message=message, - category=sa_exceptions.SADeprecationWarning) + category=sa_exc.SADeprecationWarning) for message in [ (m.startswith('//') and ('Call to deprecated function ' + m[2:]) or m) @@ -287,13 +302,13 @@ def uses_deprecated(*messages): def resetwarnings(): """Reset warning behavior to testing defaults.""" - global sa_exceptions - if sa_exceptions is None: - import sqlalchemy.exceptions as sa_exceptions + global sa_exc + if sa_exc is None: + import sqlalchemy.exc as sa_exc warnings.resetwarnings() - warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning) - warnings.filterwarnings('error', category=sa_exceptions.SAWarning) + warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SAWarning) if sys.version_info < (2, 4): warnings.filterwarnings('ignore', category=FutureWarning) @@ -338,6 +353,23 @@ def rowset(results): return set([tuple(row) for row in results]) +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + +def is_not_(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + + class TestData(object): """Tracks SQL expressions as they are executed via an instrumented ExecutionContext.""" @@ -360,10 +392,6 @@ class ExecutionContextWrapper(object): can be tracked.""" def __init__(self, ctx): - global sql - if sql is None: - from sqlalchemy import sql - self.__dict__['ctx'] = ctx def __getattr__(self, key): return getattr(self.ctx, key) @@ -414,7 +442,7 @@ class ExecutionContextWrapper(object): query = self.convert_statement(query) equivalent = ( (statement == query) - or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) ) + or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) ) ) \ and \ ( (params is None) or (params == parameters) @@ -422,7 +450,7 @@ class ExecutionContextWrapper(object): for (k, v) in p.items()]) for p in parameters] ) - testdata.unittest.assert_(equivalent, + testdata.unittest.assert_(equivalent, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) testdata.sql_count += 1 self.ctx.post_execution() @@ -445,6 +473,44 @@ class ExecutionContextWrapper(object): query = re.sub(r':([\w_]+)', repl, query) return query + +def _import_by_name(name): + submodule = name.split('.')[-1] + return __import__(name, globals(), locals(), [submodule]) + +class CompositeModule(types.ModuleType): + """Merged attribute access for multiple modules.""" + + # break the habit + __all__ = () + + def __init__(self, name, *modules, **overrides): + """Construct a new lazy composite of modules. + + Modules may be string names or module-like instances. Individual + attribute overrides may be specified as keyword arguments for + convenience. + + The constructed module will resolve attribute access in reverse order: + overrides, then each member of reversed(modules). Modules specified + by name will be loaded lazily when encountered in attribute + resolution. + + """ + types.ModuleType.__init__(self, name) + self.__modules = list(reversed(modules)) + for key, value in overrides.iteritems(): + setattr(self, key, value) + + def __getattr__(self, key): + for idx, mod in enumerate(self.__modules): + if isinstance(mod, basestring): + self.__modules[idx] = mod = _import_by_name(mod) + if hasattr(mod, key): + return getattr(mod, key) + raise AttributeError(key) + + class TestBase(unittest.TestCase): # A sequence of dialect names to exclude from the test class. __unsupported_on__ = () @@ -469,14 +535,14 @@ class TestBase(unittest.TestCase): def shortDescription(self): """overridden to not return docstrings""" return None - + def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs): try: callable_(*args, **kwargs) assert False, "Callable did not raise expected exception" except except_cls, e: assert re.search(msg, str(e)), "Exception message did not match: '%s'" % str(e) - + if not hasattr(unittest.TestCase, 'assertTrue'): assertTrue = unittest.TestCase.failUnless if not hasattr(unittest.TestCase, 'assertFalse'): @@ -522,7 +588,7 @@ class ComparesTables(object): set(type(c.type).__mro__).difference(base_mro) ) ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) - + if isinstance(c.type, sqltypes.String): self.assertEquals(c.type.length, reflected_c.type.length) @@ -535,18 +601,18 @@ class ComparesTables(object): elif not c.primary_key or not against('postgres'): print repr(c) assert reflected_c.default is None, reflected_c.default - + assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] - + class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): result = list(result) print repr(result) self.assert_list(result, class_, objects) - + def assert_list(self, result, class_, list): self.assert_(len(result) == len(list), "result list is not the same size as test list, " + @@ -675,10 +741,10 @@ class ORMTest(TestBase, AssertsExecutionResults): def define_tables(self, _otest_metadata): raise NotImplementedError() - + def setup_mappers(self): pass - + def insert_data(self): pass |
