summaryrefslogtreecommitdiff
path: root/test/testlib/testing.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/testlib/testing.py')
-rw-r--r--test/testlib/testing.py134
1 files changed, 100 insertions, 34 deletions
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
index cf0936e92..1e2ca62e9 100644
--- a/test/testlib/testing.py
+++ b/test/testlib/testing.py
@@ -2,15 +2,27 @@
# monkeypatches unittest.TestLoader.suiteClass at import time
-import itertools, os, operator, re, sys, unittest, warnings
+import itertools
+import operator
+import re
+import sys
+import types
+import unittest
+import warnings
from cStringIO import StringIO
+
import testlib.config as config
-from testlib.compat import *
+from testlib.compat import set, _function_named, reversed
-sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None
-sa_exceptions = None
+# Delayed imports
+MetaData = None
+Session = None
+clear_mappers = None
+sa_exc = None
+schema = None
+sqltypes = None
+util = None
-__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL')
_ops = { '<': operator.lt,
'>': operator.gt,
@@ -25,6 +37,9 @@ _ops = { '<': operator.lt,
# sugar ('testing.db'); set here by config() at runtime
db = None
+# more sugar, installed by __init__
+requires = None
+
def fails_if(callable_):
"""Mark a test as expected to fail if callable_ returns True.
@@ -224,17 +239,17 @@ def emits_warning(*messages):
# - update: jython looks ok, it uses cpython's module
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SAWarning)]
+ category=sa_exc.SAWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SAWarning)
+ category=sa_exc.SAWarning)
for message in messages ]
for f in filters:
warnings.filterwarnings(**f)
@@ -259,17 +274,17 @@ def uses_deprecated(*messages):
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SADeprecationWarning)]
+ category=sa_exc.SADeprecationWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SADeprecationWarning)
+ category=sa_exc.SADeprecationWarning)
for message in
[ (m.startswith('//') and
('Call to deprecated function ' + m[2:]) or m)
@@ -287,13 +302,13 @@ def uses_deprecated(*messages):
def resetwarnings():
"""Reset warning behavior to testing defaults."""
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
warnings.resetwarnings()
- warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exceptions.SAWarning)
+ warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SAWarning)
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
@@ -338,6 +353,23 @@ def rowset(results):
return set([tuple(row) for row in results])
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+
class TestData(object):
"""Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
@@ -360,10 +392,6 @@ class ExecutionContextWrapper(object):
can be tracked."""
def __init__(self, ctx):
- global sql
- if sql is None:
- from sqlalchemy import sql
-
self.__dict__['ctx'] = ctx
def __getattr__(self, key):
return getattr(self.ctx, key)
@@ -414,7 +442,7 @@ class ExecutionContextWrapper(object):
query = self.convert_statement(query)
equivalent = ( (statement == query)
- or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
+ or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
) \
and \
( (params is None) or (params == parameters)
@@ -422,7 +450,7 @@ class ExecutionContextWrapper(object):
for (k, v) in p.items()])
for p in parameters]
)
- testdata.unittest.assert_(equivalent,
+ testdata.unittest.assert_(equivalent,
"Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
testdata.sql_count += 1
self.ctx.post_execution()
@@ -445,6 +473,44 @@ class ExecutionContextWrapper(object):
query = re.sub(r':([\w_]+)', repl, query)
return query
+
+def _import_by_name(name):
+ submodule = name.split('.')[-1]
+ return __import__(name, globals(), locals(), [submodule])
+
+class CompositeModule(types.ModuleType):
+ """Merged attribute access for multiple modules."""
+
+ # break the habit
+ __all__ = ()
+
+ def __init__(self, name, *modules, **overrides):
+ """Construct a new lazy composite of modules.
+
+ Modules may be string names or module-like instances. Individual
+ attribute overrides may be specified as keyword arguments for
+ convenience.
+
+ The constructed module will resolve attribute access in reverse order:
+ overrides, then each member of reversed(modules). Modules specified
+ by name will be loaded lazily when encountered in attribute
+ resolution.
+
+ """
+ types.ModuleType.__init__(self, name)
+ self.__modules = list(reversed(modules))
+ for key, value in overrides.iteritems():
+ setattr(self, key, value)
+
+ def __getattr__(self, key):
+ for idx, mod in enumerate(self.__modules):
+ if isinstance(mod, basestring):
+ self.__modules[idx] = mod = _import_by_name(mod)
+ if hasattr(mod, key):
+ return getattr(mod, key)
+ raise AttributeError(key)
+
+
class TestBase(unittest.TestCase):
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
@@ -469,14 +535,14 @@ class TestBase(unittest.TestCase):
def shortDescription(self):
"""overridden to not return docstrings"""
return None
-
+
def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
assert False, "Callable did not raise expected exception"
except except_cls, e:
assert re.search(msg, str(e)), "Exception message did not match: '%s'" % str(e)
-
+
if not hasattr(unittest.TestCase, 'assertTrue'):
assertTrue = unittest.TestCase.failUnless
if not hasattr(unittest.TestCase, 'assertFalse'):
@@ -522,7 +588,7 @@ class ComparesTables(object):
set(type(c.type).__mro__).difference(base_mro)
)
) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
-
+
if isinstance(c.type, sqltypes.String):
self.assertEquals(c.type.length, reflected_c.type.length)
@@ -535,18 +601,18 @@ class ComparesTables(object):
elif not c.primary_key or not against('postgres'):
print repr(c)
assert reflected_c.default is None, reflected_c.default
-
+
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name]
-
+
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print repr(result)
self.assert_list(result, class_, objects)
-
+
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list),
"result list is not the same size as test list, " +
@@ -675,10 +741,10 @@ class ORMTest(TestBase, AssertsExecutionResults):
def define_tables(self, _otest_metadata):
raise NotImplementedError()
-
+
def setup_mappers(self):
pass
-
+
def insert_data(self):
pass