summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertions.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-09-27 02:37:33 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-09-27 02:37:33 -0400
commit20cdc64588b0f6ae52f8380c11d0ed848005377b (patch)
tree08f6cc8f82263f1e402c1c05c83b66a1f4b016ac /lib/sqlalchemy/testing/assertions.py
parent21cac5b598a83ef0e24423dc523629b475aa3af0 (diff)
downloadsqlalchemy-20cdc64588b0f6ae52f8380c11d0ed848005377b.tar.gz
trying different approaches to test layout. in this one, the testing modules
become an externally usable package but still remains within the main sqlalchemy parent package. in this system, we use kind of an ugly hack to get the noseplugin imported outside of the "sqlalchemy" package, while still making it available within sqlalchemy for usage by third party libraries.
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r--lib/sqlalchemy/testing/assertions.py349
1 files changed, 349 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
new file mode 100644
index 000000000..1e8559c1a
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -0,0 +1,349 @@
+from __future__ import absolute_import
+
+from . import util as testutil
+from sqlalchemy import pool, orm, util
+from sqlalchemy.engine import default
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.util import decorator
+from sqlalchemy import types as sqltypes, schema
+import warnings
+import re
+from .warnings import resetwarnings
+from .exclusions import db_spec, _is_excluded
+from . import assertsql
+from . import config
+import itertools
+from .util import fail
+
+def emits_warning(*messages):
+ """Mark a test as emitting a warning.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+ # TODO: it would be nice to assert that a named warning was
+ # emitted. should work with some monkeypatching of warnings,
+ # and may work on non-CPython if they keep to the spirit of
+ # warnings.showwarning's docstring.
+ # - update: jython looks ok, it uses cpython's module
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SAWarning))
+ else:
+ filters.extend(dict(action='ignore',
+ message=message,
+ category=sa_exc.SAWarning)
+ for message in messages)
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return decorate
+
+def emits_warning_on(db, *warnings):
+ """Mark a test as emitting a warning on a specific dialect.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+ spec = db_spec(db)
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ if isinstance(db, basestring):
+ if not spec(config.db):
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ else:
+ if not _is_excluded(*db):
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ return decorate
+
+
+def uses_deprecated(*messages):
+ """Mark a test as immune from fatal deprecation warnings.
+
+ With no arguments, squelches all SADeprecationWarning failures.
+ Or pass one or more strings; these will be matched to the root
+ of the warning description by warnings.filterwarnings().
+
+ As a special case, you may pass a function name prefixed with //
+ and it will be re-written as needed to match the standard warning
+ verbiage emitted by the sqlalchemy.util.deprecated decorator.
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SADeprecationWarning))
+ else:
+ filters.extend(
+ [dict(action='ignore',
+ message=message,
+ category=sa_exc.SADeprecationWarning)
+ for message in
+ [(m.startswith('//') and
+ ('Call to deprecated function ' + m[2:]) or m)
+ for m in messages]])
+
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return decorate
+
+
+
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+
+ testutil.lazy_gc()
+ assert not pool._refs, str(pool._refs)
+
+
+
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+def startswith_(a, fragment, msg=None):
+ """Assert a.startswith(fragment), with repr messaging on failure."""
+ assert a.startswith(fragment), msg or "%r does not start with %r" % (
+ a, fragment)
+
+def assert_raises(except_cls, callable_, *args, **kw):
+ try:
+ callable_(*args, **kw)
+ success = False
+ except except_cls:
+ success = True
+
+ # assert outside the block so it works for AssertionError too !
+ assert success, "Callable did not raise an exception"
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ try:
+ callable_(*args, **kwargs)
+ assert False, "Callable did not raise an exception"
+ except except_cls, e:
+ assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
+ print unicode(e).encode('utf-8')
+
+
+class AssertsCompiledSQL(object):
+ def assert_compile(self, clause, result, params=None,
+ checkparams=None, dialect=None,
+ checkpositional=None,
+ use_default_dialect=False,
+ allow_dialect_select=False):
+ if use_default_dialect:
+ dialect = default.DefaultDialect()
+ elif dialect == None and not allow_dialect_select:
+ dialect = getattr(self, '__dialect__', None)
+ if dialect == 'default':
+ dialect = default.DefaultDialect()
+ elif dialect is None:
+ dialect = config.db.dialect
+
+ kw = {}
+ if params is not None:
+ kw['column_keys'] = params.keys()
+
+ if isinstance(clause, orm.Query):
+ context = clause._compile_context()
+ context.statement.use_labels = True
+ clause = context.statement
+
+ c = clause.compile(dialect=dialect, **kw)
+
+ param_str = repr(getattr(c, 'params', {}))
+ # Py3K
+ #param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+
+ print "\nSQL String:\n" + str(c) + param_str
+
+ cc = re.sub(r'[\n\t]', '', str(c))
+
+ eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+ if checkparams is not None:
+ eq_(c.construct_params(params), checkparams)
+ if checkpositional is not None:
+ p = c.construct_params(params)
+ eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ eq_(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
+
+ if strict_types:
+ assert type(reflected_c.type) is type(c.type), \
+ "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ else:
+ self.assert_types_base(reflected_c, c)
+
+ if isinstance(c.type, sqltypes.String):
+ eq_(c.type.length, reflected_c.type.length)
+
+ eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ if c.server_default:
+ assert isinstance(reflected_c.server_default,
+ schema.FetchedValue)
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
+
+ def assert_types_base(self, c1, c2):
+ assert c1.type._compare_type_affinity(c2.type),\
+ "On column %r, type '%s' doesn't correspond to type '%s'" % \
+ (c1.name, c1.type, c2.type)
+
+class AssertsExecutionResults(object):
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print repr(result)
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list):
+ self.assert_(len(result) == len(list),
+ "result list is not the same size as test list, " +
+ "for class " + class_.__name__)
+ for i in range(0, len(list)):
+ self.assert_row(class_, result[i], list[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(rowobj.__class__ is class_,
+ "item class is not " + repr(class_))
+ for key, value in desc.iteritems():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s" % (
+ key, getattr(rowobj, key), value))
+
+ def assert_unordered_result(self, result, cls, *expected):
+ """As assert_result, but the order of objects is not considered.
+
+ The algorithm is very expensive but not a big deal for the small
+ numbers of rows that the test suite manipulates.
+ """
+
+ class immutabledict(dict):
+ def __hash__(self):
+ return id(self)
+
+ found = util.IdentitySet(result)
+ expected = set([immutabledict(e) for e in expected])
+
+ for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
+ fail('Unexpected type "%s", expected "%s"' % (
+ type(wrong).__name__, cls.__name__))
+
+ if len(found) != len(expected):
+ fail('Unexpected object count "%s", expected "%s"' % (
+ len(found), len(expected)))
+
+ NOVALUE = object()
+ def _compare_item(obj, spec):
+ for key, value in spec.iteritems():
+ if isinstance(value, tuple):
+ try:
+ self.assert_unordered_result(
+ getattr(obj, key), value[0], *value[1])
+ except AssertionError:
+ return False
+ else:
+ if getattr(obj, key, NOVALUE) != value:
+ return False
+ return True
+
+ for expected_item in expected:
+ for found_item in found:
+ if _compare_item(found_item, expected_item):
+ found.remove(found_item)
+ break
+ else:
+ fail(
+ "Expected %s instance with attributes %s not found." % (
+ cls.__name__, repr(expected_item)))
+ return True
+
+ def assert_sql_execution(self, db, callable_, *rules):
+ assertsql.asserter.add_rules(rules)
+ try:
+ callable_()
+ assertsql.asserter.statement_complete()
+ finally:
+ assertsql.asserter.clear_rules()
+
+ def assert_sql(self, db, callable_, list_, with_sequences=None):
+ if with_sequences is not None and config.db.dialect.supports_sequences:
+ rules = with_sequences
+ else:
+ rules = list_
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(*[
+ assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
+ ])
+ else:
+ newrule = assertsql.ExactSQL(*rule)
+ newrules.append(newrule)
+
+ self.assert_sql_execution(db, callable_, *newrules)
+
+ def assert_sql_count(self, db, callable_, count):
+ self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+
+