summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/__init__.py21
-rw-r--r--lib/sqlalchemy/testing/assertions.py349
-rw-r--r--lib/sqlalchemy/testing/assertsql.py316
-rw-r--r--lib/sqlalchemy/testing/config.py3
-rw-r--r--lib/sqlalchemy/testing/engines.py429
-rw-r--r--lib/sqlalchemy/testing/entities.py83
-rw-r--r--lib/sqlalchemy/testing/exclusions.py269
-rw-r--r--lib/sqlalchemy/testing/fixtures.py334
-rw-r--r--lib/sqlalchemy/testing/pickleable.py107
-rw-r--r--lib/sqlalchemy/testing/plugin/__init__.py0
-rw-r--r--lib/sqlalchemy/testing/plugin/config.py186
-rw-r--r--lib/sqlalchemy/testing/plugin/noseplugin.py199
-rw-r--r--lib/sqlalchemy/testing/profiling.py292
-rw-r--r--lib/sqlalchemy/testing/requirements.py38
-rw-r--r--lib/sqlalchemy/testing/schema.py85
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py0
-rw-r--r--lib/sqlalchemy/testing/suite/requirements.py24
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py48
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py0
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequencing.py36
-rw-r--r--lib/sqlalchemy/testing/util.py196
-rw-r--r--lib/sqlalchemy/testing/warnings.py43
22 files changed, 3058 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
new file mode 100644
index 000000000..415705e93
--- /dev/null
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -0,0 +1,21 @@
+from __future__ import absolute_import
+
+from .warnings import testing_warn, assert_warnings, resetwarnings
+
+from . import config
+
+from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
+ fails_on, fails_on_everything_except, skip, only_on, exclude, against,\
+ _server_version
+
+from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
+ eq_, ne_, is_, is_not_, startswith_, assert_raises, \
+ assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults
+
+from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
+
+crashes = skip
+
+from .config import db, requirements as requires
+
+
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
new file mode 100644
index 000000000..1e8559c1a
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -0,0 +1,349 @@
+from __future__ import absolute_import
+
+from . import util as testutil
+from sqlalchemy import pool, orm, util
+from sqlalchemy.engine import default
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.util import decorator
+from sqlalchemy import types as sqltypes, schema
+import warnings
+import re
+from .warnings import resetwarnings
+from .exclusions import db_spec, _is_excluded
+from . import assertsql
+from . import config
+import itertools
+from .util import fail
+
+def emits_warning(*messages):
+ """Mark a test as emitting a warning.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+ # TODO: it would be nice to assert that a named warning was
+ # emitted. should work with some monkeypatching of warnings,
+ # and may work on non-CPython if they keep to the spirit of
+ # warnings.showwarning's docstring.
+ # - update: jython looks ok, it uses cpython's module
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SAWarning))
+ else:
+ filters.extend(dict(action='ignore',
+ message=message,
+ category=sa_exc.SAWarning)
+ for message in messages)
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return decorate
+
+def emits_warning_on(db, *warnings):
+ """Mark a test as emitting a warning on a specific dialect.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+ spec = db_spec(db)
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ if isinstance(db, basestring):
+ if not spec(config.db):
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ else:
+ if not _is_excluded(*db):
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ return decorate
+
+
+def uses_deprecated(*messages):
+ """Mark a test as immune from fatal deprecation warnings.
+
+ With no arguments, squelches all SADeprecationWarning failures.
+ Or pass one or more strings; these will be matched to the root
+ of the warning description by warnings.filterwarnings().
+
+ As a special case, you may pass a function name prefixed with //
+ and it will be re-written as needed to match the standard warning
+ verbiage emitted by the sqlalchemy.util.deprecated decorator.
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SADeprecationWarning))
+ else:
+ filters.extend(
+ [dict(action='ignore',
+ message=message,
+ category=sa_exc.SADeprecationWarning)
+ for message in
+ [(m.startswith('//') and
+ ('Call to deprecated function ' + m[2:]) or m)
+ for m in messages]])
+
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return decorate
+
+
+
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+
+ testutil.lazy_gc()
+ assert not pool._refs, str(pool._refs)
+
+
+
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+def startswith_(a, fragment, msg=None):
+ """Assert a.startswith(fragment), with repr messaging on failure."""
+ assert a.startswith(fragment), msg or "%r does not start with %r" % (
+ a, fragment)
+
+def assert_raises(except_cls, callable_, *args, **kw):
+ try:
+ callable_(*args, **kw)
+ success = False
+ except except_cls:
+ success = True
+
+ # assert outside the block so it works for AssertionError too !
+ assert success, "Callable did not raise an exception"
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ try:
+ callable_(*args, **kwargs)
+ assert False, "Callable did not raise an exception"
+ except except_cls, e:
+ assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
+ print unicode(e).encode('utf-8')
+
+
+class AssertsCompiledSQL(object):
+ def assert_compile(self, clause, result, params=None,
+ checkparams=None, dialect=None,
+ checkpositional=None,
+ use_default_dialect=False,
+ allow_dialect_select=False):
+ if use_default_dialect:
+ dialect = default.DefaultDialect()
+ elif dialect == None and not allow_dialect_select:
+ dialect = getattr(self, '__dialect__', None)
+ if dialect == 'default':
+ dialect = default.DefaultDialect()
+ elif dialect is None:
+ dialect = config.db.dialect
+
+ kw = {}
+ if params is not None:
+ kw['column_keys'] = params.keys()
+
+ if isinstance(clause, orm.Query):
+ context = clause._compile_context()
+ context.statement.use_labels = True
+ clause = context.statement
+
+ c = clause.compile(dialect=dialect, **kw)
+
+ param_str = repr(getattr(c, 'params', {}))
+ # Py3K
+ #param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+
+ print "\nSQL String:\n" + str(c) + param_str
+
+ cc = re.sub(r'[\n\t]', '', str(c))
+
+ eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+ if checkparams is not None:
+ eq_(c.construct_params(params), checkparams)
+ if checkpositional is not None:
+ p = c.construct_params(params)
+ eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ eq_(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
+
+ if strict_types:
+ assert type(reflected_c.type) is type(c.type), \
+ "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ else:
+ self.assert_types_base(reflected_c, c)
+
+ if isinstance(c.type, sqltypes.String):
+ eq_(c.type.length, reflected_c.type.length)
+
+ eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ if c.server_default:
+ assert isinstance(reflected_c.server_default,
+ schema.FetchedValue)
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
+
+ def assert_types_base(self, c1, c2):
+ assert c1.type._compare_type_affinity(c2.type),\
+ "On column %r, type '%s' doesn't correspond to type '%s'" % \
+ (c1.name, c1.type, c2.type)
+
+class AssertsExecutionResults(object):
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print repr(result)
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list):
+ self.assert_(len(result) == len(list),
+ "result list is not the same size as test list, " +
+ "for class " + class_.__name__)
+ for i in range(0, len(list)):
+ self.assert_row(class_, result[i], list[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(rowobj.__class__ is class_,
+ "item class is not " + repr(class_))
+ for key, value in desc.iteritems():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s" % (
+ key, getattr(rowobj, key), value))
+
+ def assert_unordered_result(self, result, cls, *expected):
+ """As assert_result, but the order of objects is not considered.
+
+ The algorithm is very expensive but not a big deal for the small
+ numbers of rows that the test suite manipulates.
+ """
+
+ class immutabledict(dict):
+ def __hash__(self):
+ return id(self)
+
+ found = util.IdentitySet(result)
+ expected = set([immutabledict(e) for e in expected])
+
+ for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
+ fail('Unexpected type "%s", expected "%s"' % (
+ type(wrong).__name__, cls.__name__))
+
+ if len(found) != len(expected):
+ fail('Unexpected object count "%s", expected "%s"' % (
+ len(found), len(expected)))
+
+ NOVALUE = object()
+ def _compare_item(obj, spec):
+ for key, value in spec.iteritems():
+ if isinstance(value, tuple):
+ try:
+ self.assert_unordered_result(
+ getattr(obj, key), value[0], *value[1])
+ except AssertionError:
+ return False
+ else:
+ if getattr(obj, key, NOVALUE) != value:
+ return False
+ return True
+
+ for expected_item in expected:
+ for found_item in found:
+ if _compare_item(found_item, expected_item):
+ found.remove(found_item)
+ break
+ else:
+ fail(
+ "Expected %s instance with attributes %s not found." % (
+ cls.__name__, repr(expected_item)))
+ return True
+
+ def assert_sql_execution(self, db, callable_, *rules):
+ assertsql.asserter.add_rules(rules)
+ try:
+ callable_()
+ assertsql.asserter.statement_complete()
+ finally:
+ assertsql.asserter.clear_rules()
+
+ def assert_sql(self, db, callable_, list_, with_sequences=None):
+ if with_sequences is not None and config.db.dialect.supports_sequences:
+ rules = with_sequences
+ else:
+ rules = list_
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(*[
+ assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
+ ])
+ else:
+ newrule = assertsql.ExactSQL(*rule)
+ newrules.append(newrule)
+
+ self.assert_sql_execution(db, callable_, *newrules)
+
+ def assert_sql_count(self, db, callable_, count):
+ self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+
+
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
new file mode 100644
index 000000000..897f4b3b1
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -0,0 +1,316 @@
+
+from sqlalchemy.interfaces import ConnectionProxy
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.base import Connection
+from sqlalchemy import util
+import re
+
+class AssertRule(object):
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ pass
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ pass
+
+ def is_consumed(self):
+ """Return True if this rule has been consumed, False if not.
+
+ Should raise an AssertionError if this rule's condition has
+ definitely failed.
+
+ """
+
+ raise NotImplementedError()
+
+ def rule_passed(self):
+ """Return True if the last test of this rule passed, False if
+ failed, None if no test was applied."""
+
+ raise NotImplementedError()
+
+ def consume_final(self):
+ """Return True if this rule has been consumed.
+
+ Should raise an AssertionError if this rule's condition has not
+ been consumed or has failed.
+
+ """
+
+ if self._result is None:
+ assert False, 'Rule has not been consumed'
+ return self.is_consumed()
+
+class SQLMatchRule(AssertRule):
+ def __init__(self):
+ self._result = None
+ self._errmsg = ""
+
+ def rule_passed(self):
+ return self._result
+
+ def is_consumed(self):
+ if self._result is None:
+ return False
+
+ assert self._result, self._errmsg
+
+ return True
+
+class ExactSQL(SQLMatchRule):
+
+ def __init__(self, sql, params=None):
+ SQLMatchRule.__init__(self)
+ self.sql = sql
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ if not context:
+ return
+ _received_statement = \
+ _process_engine_statement(context.unicode_statement,
+ context)
+ _received_parameters = context.compiled_parameters
+
+ # TODO: remove this step once all unit tests are migrated, as
+ # ExactSQL should really be *exact* SQL
+
+ sql = _process_assertion_statement(self.sql, context)
+ equivalent = _received_statement == sql
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+ equivalent = equivalent and params \
+ == context.compiled_parameters
+ else:
+ params = {}
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = \
+ 'Testing for exact statement %r exact params %r, '\
+ 'received %r with params %r' % (sql, params,
+ _received_statement, _received_parameters)
+
+
+class RegexSQL(SQLMatchRule):
+
+ def __init__(self, regex, params=None):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ if not context:
+ return
+ _received_statement = \
+ _process_engine_statement(context.unicode_statement,
+ context)
+ _received_parameters = context.compiled_parameters
+ equivalent = bool(self.regex.match(_received_statement))
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+
+ # do a positive compare only
+
+ for param, received in zip(params, _received_parameters):
+ for k, v in param.iteritems():
+ if k not in received or received[k] != v:
+ equivalent = False
+ break
+ else:
+ params = {}
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = \
+ 'Testing for regex %r partial params %r, received %r '\
+ 'with params %r' % (self.orig_regex, params,
+ _received_statement,
+ _received_parameters)
+
+class CompiledSQL(SQLMatchRule):
+
+ def __init__(self, statement, params):
+ SQLMatchRule.__init__(self)
+ self.statement = statement
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ if not context:
+ return
+ _received_parameters = list(context.compiled_parameters)
+
+ # recompile from the context, using the default dialect
+
+ compiled = \
+ context.compiled.statement.compile(dialect=DefaultDialect(),
+ column_keys=context.compiled.column_keys)
+ _received_statement = re.sub(r'\n', '', str(compiled))
+ equivalent = self.statement == _received_statement
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while params:
+ param = dict(params.pop(0))
+ for k, v in context.compiled.params.iteritems():
+ param.setdefault(k, v)
+ if param not in _received_parameters:
+ equivalent = False
+ break
+ else:
+ _received_parameters.remove(param)
+ if _received_parameters:
+ equivalent = False
+ else:
+ params = {}
+ all_params = {}
+ all_received = []
+ self._result = equivalent
+ if not self._result:
+ print 'Testing for compiled statement %r partial params '\
+ '%r, received %r with params %r' % (self.statement,
+ all_params, _received_statement, all_received)
+ self._errmsg = \
+ 'Testing for compiled statement %r partial params %r, '\
+ 'received %r with params %r' % (self.statement,
+ all_params, _received_statement, all_received)
+
+
+ # print self._errmsg
+
+class CountStatements(AssertRule):
+
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ self._statement_count += 1
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ pass
+
+ def is_consumed(self):
+ return False
+
+ def consume_final(self):
+ assert self.count == self._statement_count, \
+ 'desired statement count %d does not match %d' \
+ % (self.count, self._statement_count)
+ return True
+
+class AllOf(AssertRule):
+
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ for rule in self.rules:
+ rule.process_execute(clauseelement, *multiparams, **params)
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ for rule in self.rules:
+ rule.process_cursor_execute(statement, parameters, context,
+ executemany)
+
+ def is_consumed(self):
+ if not self.rules:
+ return True
+ for rule in list(self.rules):
+ if rule.rule_passed(): # a rule passed, move on
+ self.rules.remove(rule)
+ return len(self.rules) == 0
+ assert False, 'No assertion rules were satisfied for statement'
+
+ def consume_final(self):
+ return len(self.rules) == 0
+
+def _process_engine_statement(query, context):
+ if util.jython:
+
+ # oracle+zxjdbc passes a PyStatement when returning into
+
+ query = unicode(query)
+ if context.engine.name == 'mssql' \
+ and query.endswith('; select scope_identity()'):
+ query = query[:-25]
+ query = re.sub(r'\n', '', query)
+ return query
+
+def _process_assertion_statement(query, context):
+ paramstyle = context.dialect.paramstyle
+ if paramstyle == 'named':
+ pass
+ elif paramstyle =='pyformat':
+ query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+ else:
+ # positional params
+ repl = None
+ if paramstyle=='qmark':
+ repl = "?"
+ elif paramstyle=='format':
+ repl = r"%s"
+ elif paramstyle=='numeric':
+ repl = None
+ query = re.sub(r':([\w_]+)', repl, query)
+
+ return query
+
+class SQLAssert(object):
+
+ rules = None
+
+ def add_rules(self, rules):
+ self.rules = list(rules)
+
+ def statement_complete(self):
+ for rule in self.rules:
+ if not rule.consume_final():
+ assert False, \
+ 'All statements are complete, but pending '\
+ 'assertion rules remain'
+
+ def clear_rules(self):
+ del self.rules
+
+ def execute(self, conn, clauseelement, multiparams, params, result):
+ if self.rules is not None:
+ if not self.rules:
+ assert False, \
+ 'All rules have been exhausted, but further '\
+ 'statements remain'
+ rule = self.rules[0]
+ rule.process_execute(clauseelement, *multiparams, **params)
+ if rule.is_consumed():
+ self.rules.pop(0)
+
+ def cursor_execute(self, conn, cursor, statement, parameters,
+ context, executemany):
+ if self.rules:
+ rule = self.rules[0]
+ rule.process_cursor_execute(statement, parameters, context,
+ executemany)
+
+asserter = SQLAssert()
+
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
new file mode 100644
index 000000000..2945bd456
--- /dev/null
+++ b/lib/sqlalchemy/testing/config.py
@@ -0,0 +1,3 @@
+requirements = None
+db = None
+
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
new file mode 100644
index 000000000..f7401550e
--- /dev/null
+++ b/lib/sqlalchemy/testing/engines.py
@@ -0,0 +1,429 @@
+from __future__ import absolute_import
+
+import types
+import weakref
+from collections import deque
+from . import config
+from .util import decorator
+from sqlalchemy import event, pool
+import re
+import warnings
+
+class ConnectionKiller(object):
+ def __init__(self):
+ self.proxy_refs = weakref.WeakKeyDictionary()
+ self.testing_engines = weakref.WeakKeyDictionary()
+ self.conns = set()
+
+ def add_engine(self, engine):
+ self.testing_engines[engine] = True
+
+ def connect(self, dbapi_conn, con_record):
+ self.conns.add(dbapi_conn)
+
+ def checkout(self, dbapi_con, con_record, con_proxy):
+ self.proxy_refs[con_proxy] = True
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except Exception, e:
+ warnings.warn(
+ "testing_reaper couldn't "
+ "rollback/close connection: %s" % e)
+
+ def rollback_all(self):
+ for rec in self.proxy_refs.keys():
+ if rec is not None and rec.is_valid:
+ self._safe(rec.rollback)
+
+ def close_all(self):
+ for rec in self.proxy_refs.keys():
+ if rec is not None:
+ self._safe(rec._close)
+
+ def _after_test_ctx(self):
+ pass
+ # this can cause a deadlock with pg8000 - pg8000 acquires
+ # prepared statment lock inside of rollback() - if async gc
+ # is collecting in finalize_fairy, deadlock.
+ # not sure if this should be if pypy/jython only
+ #for conn in self.conns:
+ # self._safe(conn.rollback)
+
+ def _stop_test_ctx(self):
+ if config.options.low_connections:
+ self._stop_test_ctx_minimal()
+ else:
+ self._stop_test_ctx_aggressive()
+
+ def _stop_test_ctx_minimal(self):
+ self.close_all()
+
+ self.conns = set()
+
+ for rec in self.testing_engines.keys():
+ if rec is not config.db:
+ rec.dispose()
+
+ def _stop_test_ctx_aggressive(self):
+ self.close_all()
+ for conn in self.conns:
+ self._safe(conn.close)
+ self.conns = set()
+ for rec in self.testing_engines.keys():
+ rec.dispose()
+
+ def assert_all_closed(self):
+ for rec in self.proxy_refs:
+ if rec.is_valid:
+ assert False
+
+testing_reaper = ConnectionKiller()
+
+def drop_all_tables(metadata, bind):
+ testing_reaper.close_all()
+ if hasattr(bind, 'close'):
+ bind.close()
+ metadata.drop_all(bind)
+
+@decorator
+def assert_conns_closed(fn, *args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.assert_all_closed()
+
+@decorator
+def rollback_open_connections(fn, *args, **kw):
+ """Decorator that rolls back all open connections after fn execution."""
+
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.rollback_all()
+
+@decorator
+def close_first(fn, *args, **kw):
+ """Decorator that closes all connections before fn execution."""
+
+ testing_reaper.close_all()
+ fn(*args, **kw)
+
+
+@decorator
+def close_open_connections(fn, *args, **kw):
+ """Decorator that closes all connections after fn execution."""
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.close_all()
+
+def all_dialects(exclude=None):
+ import sqlalchemy.databases as d
+ for name in d.__all__:
+ # TEMPORARY
+ if exclude and name in exclude:
+ continue
+ mod = getattr(d, name, None)
+ if not mod:
+ mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+ yield mod.dialect()
+
+class ReconnectFixture(object):
+ def __init__(self, dbapi):
+ self.dbapi = dbapi
+ self.connections = []
+
+ def __getattr__(self, key):
+ return getattr(self.dbapi, key)
+
+ def connect(self, *args, **kwargs):
+ conn = self.dbapi.connect(*args, **kwargs)
+ self.connections.append(conn)
+ return conn
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except Exception, e:
+ warnings.warn(
+ "ReconnectFixture couldn't "
+ "close connection: %s" % e)
+
+ def shutdown(self):
+ # TODO: this doesn't cover all cases
+ # as nicely as we'd like, namely MySQLdb.
+ # would need to implement R. Brewer's
+ # proxy server idea to get better
+ # coverage.
+ for c in list(self.connections):
+ self._safe(c.close)
+ self.connections = []
+
+def reconnecting_engine(url=None, options=None):
+ url = url or config.db_url
+ dbapi = config.db.dialect.dbapi
+ if not options:
+ options = {}
+ options['module'] = ReconnectFixture(dbapi)
+ engine = testing_engine(url, options)
+ _dispose = engine.dispose
+ def dispose():
+ engine.dialect.dbapi.shutdown()
+ _dispose()
+ engine.test_shutdown = engine.dialect.dbapi.shutdown
+ engine.dispose = dispose
+ return engine
+
+def testing_engine(url=None, options=None):
+ """Produce an engine configured by --options with optional overrides."""
+
+ from sqlalchemy import create_engine
+ from .assertsql import asserter
+
+ if not options:
+ use_reaper = True
+ else:
+ use_reaper = options.pop('use_reaper', True)
+
+ url = url or config.db_url
+ if options is None:
+ options = config.db_opts
+
+ engine = create_engine(url, **options)
+ if isinstance(engine.pool, pool.QueuePool):
+ engine.pool._timeout = 0
+ engine.pool._max_overflow = 0
+ event.listen(engine, 'after_execute', asserter.execute)
+ event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
+ if use_reaper:
+ event.listen(engine.pool, 'connect', testing_reaper.connect)
+ event.listen(engine.pool, 'checkout', testing_reaper.checkout)
+ testing_reaper.add_engine(engine)
+
+ return engine
+
+def utf8_engine(url=None, options=None):
+ """Hook for dialects or drivers that don't handle utf8 by default."""
+
+ from sqlalchemy.engine import url as engine_url
+
+ if config.db.dialect.name == 'mysql' and \
+ config.db.driver in ['mysqldb', 'pymysql']:
+ # note 1.2.1.gamma.6 or greater of MySQLdb
+ # needed here
+ url = url or config.db_url
+ url = engine_url.make_url(url)
+ url.query['charset'] = 'utf8'
+ url.query['use_unicode'] = '0'
+ url = str(url)
+
+ return testing_engine(url, options)
+
+def mock_engine(dialect_name=None):
+ """Provides a mocking engine based on the current testing.db.
+
+ This is normally used to test DDL generation flow as emitted
+ by an Engine.
+
+ It should not be used in other cases, as assert_compile() and
+ assert_sql_execution() are much better choices with fewer
+ moving parts.
+
+ """
+
+ from sqlalchemy import create_engine
+
+ if not dialect_name:
+ dialect_name = config.db.name
+
+ buffer = []
+ def executor(sql, *a, **kw):
+ buffer.append(sql)
+ def assert_sql(stmts):
+ recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ assert recv == stmts, recv
+ def print_sql():
+ d = engine.dialect
+ return "\n".join(
+ str(s.compile(dialect=d))
+ for s in engine.mock
+ )
+ engine = create_engine(dialect_name + '://',
+ strategy='mock', executor=executor)
+ assert not hasattr(engine, 'mock')
+ engine.mock = buffer
+ engine.assert_sql = assert_sql
+ engine.print_sql = print_sql
+ return engine
+
+class DBAPIProxyCursor(object):
+ """Proxy a DBAPI cursor.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level cursor operations.
+
+ """
+ def __init__(self, engine, conn):
+ self.engine = engine
+ self.connection = conn
+ self.cursor = conn.cursor()
+
+ def execute(self, stmt, parameters=None, **kw):
+ if parameters:
+ return self.cursor.execute(stmt, parameters, **kw)
+ else:
+ return self.cursor.execute(stmt, **kw)
+
+ def executemany(self, stmt, params, **kw):
+ return self.cursor.executemany(stmt, params, **kw)
+
+ def __getattr__(self, key):
+ return getattr(self.cursor, key)
+
+class DBAPIProxyConnection(object):
+ """Proxy a DBAPI connection.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level connection operations.
+
+ """
+ def __init__(self, engine, cursor_cls):
+ self.conn = self._sqla_unwrap = engine.pool._creator()
+ self.engine = engine
+ self.cursor_cls = cursor_cls
+
+ def cursor(self):
+ return self.cursor_cls(self.engine, self.conn)
+
+ def close(self):
+ self.conn.close()
+
+ def __getattr__(self, key):
+ return getattr(self.conn, key)
+
+def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor):
+ """Produce an engine that provides proxy hooks for
+ common methods.
+
+ """
+ def mock_conn():
+ return conn_cls(config.db, cursor_cls)
+ return testing_engine(options={'creator':mock_conn})
+
+class ReplayableSession(object):
+ """A simple record/playback tool.
+
+ This is *not* a mock testing class. It only records a session for later
+ playback and makes no assertions on call consistency whatsoever. It's
+ unlikely to be suitable for anything other than DB-API recording.
+
+ """
+
+ Callable = object()
+ NoAttribute = object()
+
+ # Py3K
+ #Natives = set([getattr(types, t)
+ # for t in dir(types) if not t.startswith('_')]). \
+ # union([type(t) if not isinstance(t, type)
+ # else t for t in __builtins__.values()]).\
+ # difference([getattr(types, t)
+ # for t in ('FunctionType', 'BuiltinFunctionType',
+ # 'MethodType', 'BuiltinMethodType',
+ # 'LambdaType', )])
+ # Py2K
+ Natives = set([getattr(types, t)
+ for t in dir(types) if not t.startswith('_')]). \
+ difference([getattr(types, t)
+ for t in ('FunctionType', 'BuiltinFunctionType',
+ 'MethodType', 'BuiltinMethodType',
+ 'LambdaType', 'UnboundMethodType',)])
+ # end Py2K
+
+ def __init__(self):
+ self.buffer = deque()
+
+ def recorder(self, base):
+ return self.Recorder(self.buffer, base)
+
+ def player(self):
+ return self.Player(self.buffer)
+
+ class Recorder(object):
+ def __init__(self, buffer, subject):
+ self._buffer = buffer
+ self._subject = subject
+
+ def __call__(self, *args, **kw):
+ subject, buffer = [object.__getattribute__(self, x)
+ for x in ('_subject', '_buffer')]
+
+ result = subject(*args, **kw)
+ if type(result) not in ReplayableSession.Natives:
+ buffer.append(ReplayableSession.Callable)
+ return type(self)(buffer, result)
+ else:
+ buffer.append(result)
+ return result
+
+ @property
+ def _sqla_unwrap(self):
+ return self._subject
+
+ def __getattribute__(self, key):
+ try:
+ return object.__getattribute__(self, key)
+ except AttributeError:
+ pass
+
+ subject, buffer = [object.__getattribute__(self, x)
+ for x in ('_subject', '_buffer')]
+ try:
+ result = type(subject).__getattribute__(subject, key)
+ except AttributeError:
+ buffer.append(ReplayableSession.NoAttribute)
+ raise
+ else:
+ if type(result) not in ReplayableSession.Natives:
+ buffer.append(ReplayableSession.Callable)
+ return type(self)(buffer, result)
+ else:
+ buffer.append(result)
+ return result
+
+ class Player(object):
+ def __init__(self, buffer):
+ self._buffer = buffer
+
+ def __call__(self, *args, **kw):
+ buffer = object.__getattribute__(self, '_buffer')
+ result = buffer.popleft()
+ if result is ReplayableSession.Callable:
+ return self
+ else:
+ return result
+
+ @property
+ def _sqla_unwrap(self):
+ return None
+
+ def __getattribute__(self, key):
+ try:
+ return object.__getattribute__(self, key)
+ except AttributeError:
+ pass
+ buffer = object.__getattribute__(self, '_buffer')
+ result = buffer.popleft()
+ if result is ReplayableSession.Callable:
+ return self
+ elif result is ReplayableSession.NoAttribute:
+ raise AttributeError(key)
+ else:
+ return result
+
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
new file mode 100644
index 000000000..1b24e73b7
--- /dev/null
+++ b/lib/sqlalchemy/testing/entities.py
@@ -0,0 +1,83 @@
+import sqlalchemy as sa
+from sqlalchemy import exc as sa_exc
+
+_repr_stack = set()
+class BasicEntity(object):
+ def __init__(self, **kw):
+ for key, value in kw.iteritems():
+ setattr(self, key, value)
+
+ def __repr__(self):
+ if id(self) in _repr_stack:
+ return object.__repr__(self)
+ _repr_stack.add(id(self))
+ try:
+ return "%s(%s)" % (
+ (self.__class__.__name__),
+ ', '.join(["%s=%r" % (key, getattr(self, key))
+ for key in sorted(self.__dict__.keys())
+ if not key.startswith('_')]))
+ finally:
+ _repr_stack.remove(id(self))
+
+_recursion_stack = set()
+class ComparableEntity(BasicEntity):
+ def __hash__(self):
+ return hash(self.__class__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __eq__(self, other):
+ """'Deep, sparse compare.
+
+ Deeply compare two entities, following the non-None attributes of the
+ non-persisted object, if possible.
+
+ """
+ if other is self:
+ return True
+ elif not self.__class__ == other.__class__:
+ return False
+
+ if id(self) in _recursion_stack:
+ return True
+ _recursion_stack.add(id(self))
+
+ try:
+ # pick the entity thats not SA persisted as the source
+ try:
+ self_key = sa.orm.attributes.instance_state(self).key
+ except sa.orm.exc.NO_STATE:
+ self_key = None
+
+ if other is None:
+ a = self
+ b = other
+ elif self_key is not None:
+ a = other
+ b = self
+ else:
+ a = self
+ b = other
+
+ for attr in a.__dict__.keys():
+ if attr.startswith('_'):
+ continue
+ value = getattr(a, attr)
+
+ try:
+ # handle lazy loader errors
+ battr = getattr(b, attr)
+ except (AttributeError, sa_exc.UnboundExecutionError):
+ return False
+
+ if hasattr(value, '__iter__'):
+ if list(value) != list(battr):
+ return False
+ else:
+ if value is not None and value != battr:
+ return False
+ return True
+ finally:
+ _recursion_stack.remove(id(self))
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
new file mode 100644
index 000000000..ba2eebe4f
--- /dev/null
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -0,0 +1,269 @@
+import operator
+from nose import SkipTest
+from sqlalchemy.util import decorator
+from . import config
+from sqlalchemy import util
+
+
+def fails_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ if not predicate():
+ return fn(*args, **kw)
+ else:
+ try:
+ fn(*args, **kw)
+ except Exception, ex:
+ print ("'%s' failed as expected (%s): %s " % (
+ fn.__name__, predicate, str(ex)))
+ return True
+ else:
+ raise AssertionError(
+ "Unexpected success for '%s' (%s)" %
+ (fn.__name__, predicate))
+ return decorate
+
+def skip_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ if predicate():
+ if reason:
+ msg = "'%s' : %s" % (
+ fn.__name__,
+ reason
+ )
+ else:
+ msg = "'%s': %s" % (
+ fn.__name__, predicate
+ )
+ raise SkipTest(msg)
+ else:
+ return fn(*args, **kw)
+ return decorate
+
+def only_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return skip_if(NotPredicate(predicate), reason)
+
+def succeeds_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return fails_if(NotPredicate(predicate), reason)
+
+class Predicate(object):
+ @classmethod
+ def as_predicate(cls, predicate):
+ if isinstance(predicate, Predicate):
+ return predicate
+ elif isinstance(predicate, list):
+ return OrPredicate([cls.as_predicate(pred) for pred in predicate])
+ elif isinstance(predicate, tuple):
+ return SpecPredicate(*predicate)
+ elif isinstance(predicate, basestring):
+ return SpecPredicate(predicate, None, None)
+ elif util.callable(predicate):
+ return LambdaPredicate(predicate)
+ else:
+ assert False, "unknown predicate type: %s" % predicate
+
+class SpecPredicate(Predicate):
+ def __init__(self, db, op=None, spec=None, description=None):
+ self.db = db
+ self.op = op
+ self.spec = spec
+ self.description = description
+
+ _ops = {
+ '<': operator.lt,
+ '>': operator.gt,
+ '==': operator.eq,
+ '!=': operator.ne,
+ '<=': operator.le,
+ '>=': operator.ge,
+ 'in': operator.contains,
+ 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+ }
+
+ def __call__(self, engine=None):
+ if engine is None:
+ engine = config.db
+
+ if "+" in self.db:
+ dialect, driver = self.db.split('+')
+ else:
+ dialect, driver = self.db, None
+
+ if dialect and engine.name != dialect:
+ return False
+ if driver is not None and engine.driver != driver:
+ return False
+
+ if self.op is not None:
+ assert driver is None, "DBAPI version specs not supported yet"
+
+ version = _server_version(engine)
+ oper = hasattr(self.op, '__call__') and self.op \
+ or self._ops[self.op]
+ return oper(version, self.spec)
+ else:
+ return True
+
+ def _as_string(self, negate=False):
+ if self.description is not None:
+ return self.description
+ elif self.op is None:
+ if negate:
+ return "not %s" % self.db
+ else:
+ return "%s" % self.db
+ else:
+ if negate:
+ return "not %s %s %s" % (
+ self.db,
+ self.op,
+ self.spec
+ )
+ else:
+ return "%s %s %s" % (
+ self.db,
+ self.op,
+ self.spec
+ )
+
+ def __str__(self):
+ return self._as_string()
+
+class LambdaPredicate(Predicate):
+ def __init__(self, lambda_, description=None, args=None, kw=None):
+ self.lambda_ = lambda_
+ self.args = args or ()
+ self.kw = kw or {}
+ if description:
+ self.description = description
+ elif lambda_.__doc__:
+ self.description = lambda_.__doc__
+ else:
+ self.description = "custom function"
+
+ def __call__(self):
+ return self.lambda_(*self.args, **self.kw)
+
+ def _as_string(self, negate=False):
+ if negate:
+ return "not " + self.description
+ else:
+ return self.description
+
+ def __str__(self):
+ return self._as_string()
+
+class NotPredicate(Predicate):
+ def __init__(self, predicate):
+ self.predicate = predicate
+
+ def __call__(self, *arg, **kw):
+ return not self.predicate(*arg, **kw)
+
+ def __str__(self):
+ return self.predicate._as_string(True)
+
+class OrPredicate(Predicate):
+ def __init__(self, predicates, description=None):
+ self.predicates = predicates
+ self.description = description
+
+ def __call__(self, *arg, **kw):
+ for pred in self.predicates:
+ if pred(*arg, **kw):
+ self._str = pred
+ return True
+ return False
+
+ _str = None
+
+ def _eval_str(self, negate=False):
+ if self._str is None:
+ if negate:
+ conjunction = " and "
+ else:
+ conjunction = " or "
+ return conjunction.join(p._as_string(negate=negate)
+ for p in self.predicates)
+ else:
+ return self._str._as_string(negate=negate)
+
+ def _negation_str(self):
+ if self.description is not None:
+ return "Not " + (self.description % {"spec": self._str})
+ else:
+ return self._eval_str(negate=True)
+
+ def _as_string(self, negate=False):
+ if negate:
+ return self._negation_str()
+ else:
+ if self.description is not None:
+ return self.description % {"spec": self._str}
+ else:
+ return self._eval_str()
+
+ def __str__(self):
+ return self._as_string()
+
+_as_predicate = Predicate.as_predicate
+
+def _is_excluded(db, op, spec):
+ return SpecPredicate(db, op, spec)()
+
+def _server_version(engine):
+ """Return a server_version_info tuple."""
+
+ # force metadata to be retrieved
+ conn = engine.connect()
+ version = getattr(engine.dialect, 'server_version_info', ())
+ conn.close()
+ return version
+
+def db_spec(*dbs):
+ return OrPredicate(
+ Predicate.as_predicate(db) for db in dbs
+ )
+
+def open(fn):
+ return fn
+
+@decorator
+def future(fn, *args, **kw):
+ return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature")
+
+def fails_on(db, reason):
+ return fails_if(SpecPredicate(db), reason)
+
+def fails_on_everything_except(*dbs):
+ return succeeds_if(
+ OrPredicate([
+ SpecPredicate(db) for db in dbs
+ ])
+ )
+
+def skip(db, reason):
+ return skip_if(SpecPredicate(db), reason)
+
+def only_on(dbs, reason):
+ return only_if(
+ OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
+ )
+
+
+def exclude(db, op, spec, reason):
+ return skip_if(SpecPredicate(db, op, spec), reason)
+
+
+def against(*queries):
+ return OrPredicate([
+ Predicate.as_predicate(query)
+ for query in queries
+ ])()
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
new file mode 100644
index 000000000..018276d4d
--- /dev/null
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -0,0 +1,334 @@
+from . import config
+from . import assertions, schema
+from .util import adict
+from .engines import drop_all_tables
+from .entities import BasicEntity, ComparableEntity
+import sys
+import sqlalchemy as sa
+from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
+
+class TestBase(object):
+ # A sequence of database names to always run, regardless of the
+ # constraints below.
+ __whitelist__ = ()
+
+ # A sequence of requirement names matching testing.requires decorators
+ __requires__ = ()
+
+ # A sequence of dialect names to exclude from the test class.
+ __unsupported_on__ = ()
+
+ # If present, test class is only runnable for the *single* specified
+ # dialect. If you need multiple, use __unsupported_on__ and invert.
+ __only_on__ = None
+
+ # A sequence of no-arg callables. If any are True, the entire testcase is
+ # skipped.
+ __skip_if__ = None
+
+ def assert_(self, val, msg=None):
+ assert val, msg
+
+class TablesTest(TestBase):
+
+ # 'once', None
+ run_setup_bind = 'once'
+
+ # 'once', 'each', None
+ run_define_tables = 'once'
+
+ # 'once', 'each', None
+ run_create_tables = 'once'
+
+ # 'once', 'each', None
+ run_inserts = 'each'
+
+ # 'each', None
+ run_deletes = 'each'
+
+ # 'once', None
+ run_dispose_bind = None
+
+ bind = None
+ metadata = None
+ tables = None
+ other = None
+
+ @classmethod
+ def setup_class(cls):
+ cls._init_class()
+
+ cls._setup_once_tables()
+
+ cls._setup_once_inserts()
+
+ @classmethod
+ def _init_class(cls):
+ if cls.run_define_tables == 'each':
+ if cls.run_create_tables == 'once':
+ cls.run_create_tables = 'each'
+ assert cls.run_inserts in ('each', None)
+
+ if cls.other is None:
+ cls.other = adict()
+
+ if cls.tables is None:
+ cls.tables = adict()
+
+ if cls.bind is None:
+ setattr(cls, 'bind', cls.setup_bind())
+
+ if cls.metadata is None:
+ setattr(cls, 'metadata', sa.MetaData())
+
+ if cls.metadata.bind is None:
+ cls.metadata.bind = cls.bind
+
+ @classmethod
+ def _setup_once_inserts(cls):
+ if cls.run_inserts == 'once':
+ cls._load_fixtures()
+ cls.insert_data()
+
+ @classmethod
+ def _setup_once_tables(cls):
+ if cls.run_define_tables == 'once':
+ cls.define_tables(cls.metadata)
+ if cls.run_create_tables == 'once':
+ cls.metadata.create_all(cls.bind)
+ cls.tables.update(cls.metadata.tables)
+
+ def _setup_each_tables(self):
+ if self.run_define_tables == 'each':
+ self.tables.clear()
+ if self.run_create_tables == 'each':
+ drop_all_tables(self.metadata, self.bind)
+ self.metadata.clear()
+ self.define_tables(self.metadata)
+ if self.run_create_tables == 'each':
+ self.metadata.create_all(self.bind)
+ self.tables.update(self.metadata.tables)
+ elif self.run_create_tables == 'each':
+ drop_all_tables(self.metadata, self.bind)
+ self.metadata.create_all(self.bind)
+
+ def _setup_each_inserts(self):
+ if self.run_inserts == 'each':
+ self._load_fixtures()
+ self.insert_data()
+
+ def _teardown_each_tables(self):
+ # no need to run deletes if tables are recreated on setup
+ if self.run_define_tables != 'each' and self.run_deletes == 'each':
+ for table in reversed(self.metadata.sorted_tables):
+ try:
+ table.delete().execute().close()
+ except sa.exc.DBAPIError, ex:
+ print >> sys.stderr, "Error emptying table %s: %r" % (
+ table, ex)
+
+ def setup(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ def teardown(self):
+ self._teardown_each_tables()
+
+ @classmethod
+ def _teardown_once_metadata_bind(cls):
+ if cls.run_create_tables:
+ drop_all_tables(cls.metadata, cls.bind)
+
+ if cls.run_dispose_bind == 'once':
+ cls.dispose_bind(cls.bind)
+
+ cls.metadata.bind = None
+
+ if cls.run_setup_bind is not None:
+ cls.bind = None
+
+ @classmethod
+ def teardown_class(cls):
+ cls._teardown_once_metadata_bind()
+
+ @classmethod
+ def setup_bind(cls):
+ return config.db
+
+ @classmethod
+ def dispose_bind(cls, bind):
+ if hasattr(bind, 'dispose'):
+ bind.dispose()
+ elif hasattr(bind, 'close'):
+ bind.close()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ pass
+
+ @classmethod
+ def fixtures(cls):
+ return {}
+
+ @classmethod
+ def insert_data(cls):
+ pass
+
+ def sql_count_(self, count, fn):
+ self.assert_sql_count(self.bind, fn, count)
+
+ def sql_eq_(self, callable_, statements, with_sequences=None):
+ self.assert_sql(self.bind,
+ callable_, statements, with_sequences)
+
+ @classmethod
+ def _load_fixtures(cls):
+ """Insert rows as represented by the fixtures() method."""
+ headers, rows = {}, {}
+ for table, data in cls.fixtures().iteritems():
+ if len(data) < 2:
+ continue
+ if isinstance(table, basestring):
+ table = cls.tables[table]
+ headers[table] = data[0]
+ rows[table] = data[1:]
+ for table in cls.metadata.sorted_tables:
+ if table not in headers:
+ continue
+ cls.bind.execute(
+ table.insert(),
+ [dict(zip(headers[table], column_values))
+ for column_values in rows[table]])
+
+
+class _ORMTest(object):
+ __requires__ = ('subqueries',)
+
+ @classmethod
+ def teardown_class(cls):
+ sa.orm.session.Session.close_all()
+ sa.orm.clear_mappers()
+
+class ORMTest(_ORMTest, TestBase):
+ pass
+
+class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
+ # 'once', 'each', None
+ run_setup_classes = 'once'
+
+ # 'once', 'each', None
+ run_setup_mappers = 'each'
+
+ classes = None
+
+ @classmethod
+ def setup_class(cls):
+ cls._init_class()
+
+ if cls.classes is None:
+ cls.classes = adict()
+
+ cls._setup_once_tables()
+ cls._setup_once_classes()
+ cls._setup_once_mappers()
+ cls._setup_once_inserts()
+
+ @classmethod
+ def teardown_class(cls):
+ cls._teardown_once_class()
+ cls._teardown_once_metadata_bind()
+
+ def setup(self):
+ self._setup_each_tables()
+ self._setup_each_mappers()
+ self._setup_each_inserts()
+
+ def teardown(self):
+ sa.orm.session.Session.close_all()
+ self._teardown_each_mappers()
+ self._teardown_each_tables()
+
+ @classmethod
+ def _teardown_once_class(cls):
+ cls.classes.clear()
+ _ORMTest.teardown_class()
+
+
+ @classmethod
+ def _setup_once_classes(cls):
+ if cls.run_setup_classes == 'once':
+ cls._with_register_classes(cls.setup_classes)
+
+ @classmethod
+ def _setup_once_mappers(cls):
+ if cls.run_setup_mappers == 'once':
+ cls._with_register_classes(cls.setup_mappers)
+
+ def _setup_each_mappers(self):
+ if self.run_setup_mappers == 'each':
+ self._with_register_classes(self.setup_mappers)
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ """Run a setup method, framing the operation with a Base class
+ that will catch new subclasses to be established within
+ the "classes" registry.
+
+ """
+ cls_registry = cls.classes
+ class FindFixture(type):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ return type.__init__(cls, classname, bases, dict_)
+
+
+ class _Base(object):
+ __metaclass__ = FindFixture
+ class Basic(BasicEntity, _Base):
+ pass
+ class Comparable(ComparableEntity, _Base):
+ pass
+ cls.Basic = Basic
+ cls.Comparable = Comparable
+ fn()
+
+ def _teardown_each_mappers(self):
+ # some tests create mappers in the test bodies
+ # and will define setup_mappers as None -
+ # clear mappers in any case
+ if self.run_setup_mappers != 'once':
+ sa.orm.clear_mappers()
+
+ @classmethod
+ def setup_classes(cls):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ pass
+
+class DeclarativeMappedTest(MappedTest):
+ run_setup_classes = 'once'
+ run_setup_mappers = 'once'
+
+ @classmethod
+ def _setup_once_tables(cls):
+ pass
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ cls_registry = cls.classes
+ class FindFixtureDeclarative(DeclarativeMeta):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ return DeclarativeMeta.__init__(
+ cls, classname, bases, dict_)
+ class DeclarativeBasic(object):
+ __table_cls__ = schema.Table
+ _DeclBase = declarative_base(metadata=cls.metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic)
+ cls.DeclarativeBasic = _DeclBase
+ fn()
+ if cls.metadata.tables:
+ cls.metadata.create_all(config.db)
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
new file mode 100644
index 000000000..f5b8b827c
--- /dev/null
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -0,0 +1,107 @@
+"""Classes used in pickling tests, need to be at the module level for unpickling."""
+
+from . import fixtures
+
+class User(fixtures.ComparableEntity):
+ pass
+
+class Order(fixtures.ComparableEntity):
+ pass
+
+class Dingaling(fixtures.ComparableEntity):
+ pass
+
+class EmailUser(User):
+ pass
+
+class Address(fixtures.ComparableEntity):
+ pass
+
+# TODO: these are kind of arbitrary....
+class Child1(fixtures.ComparableEntity):
+ pass
+
+class Child2(fixtures.ComparableEntity):
+ pass
+
+class Parent(fixtures.ComparableEntity):
+ pass
+
+class Screen(object):
+ def __init__(self, obj, parent=None):
+ self.obj = obj
+ self.parent = parent
+
+class Foo(object):
+ def __init__(self, moredata):
+ self.data = 'im data'
+ self.stuff = 'im stuff'
+ self.moredata = moredata
+ __hash__ = object.__hash__
+ def __eq__(self, other):
+ return other.data == self.data and \
+ other.stuff == self.stuff and \
+ other.moredata == self.moredata
+
+
+class Bar(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+ __hash__ = object.__hash__
+ def __eq__(self, other):
+ return other.__class__ is self.__class__ and \
+ other.x == self.x and \
+ other.y == self.y
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+class OldSchool:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+ def __eq__(self, other):
+ return other.__class__ is self.__class__ and \
+ other.x == self.x and \
+ other.y == self.y
+
+class OldSchoolWithoutCompare:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+class BarWithoutCompare(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class NotComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return NotImplemented
+
+ def __ne__(self, other):
+ return NotImplemented
+
+
+class BrokenComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ raise NotImplementedError
+
+ def __ne__(self, other):
+ raise NotImplementedError
+
diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/__init__.py
diff --git a/lib/sqlalchemy/testing/plugin/config.py b/lib/sqlalchemy/testing/plugin/config.py
new file mode 100644
index 000000000..08b9753dc
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/config.py
@@ -0,0 +1,186 @@
+"""Option and configuration implementations, run by the nose plugin
+on test suite startup."""
+
+import time
+import warnings
+import sys
+import re
+
+logging = None
+db = None
+db_label = None
+db_url = None
+db_opts = {}
+options = None
+file_config = None
+
+def _log(option, opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+ logging.basicConfig()
+
+ if opt_str.endswith('-info'):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith('-debug'):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+ print "Available --db options (use --dburi to override)"
+ for macro in sorted(file_config.options('db')):
+ print "%20s\t%s" % (macro, file_config.get('db', macro))
+ sys.exit(0)
+
+def _server_side_cursors(options, opt_str, value, parser):
+ db_opts['server_side_cursors'] = True
+
+def _zero_timeout(options, opt_str, value, parser):
+ warnings.warn("--zero-timeout testing option is now on in all cases")
+
+def _engine_strategy(options, opt_str, value, parser):
+ if value:
+ db_opts['strategy'] = value
+
+pre_configure = []
+post_configure = []
+
+def _setup_options(opt, file_config):
+ global options
+ from sqlalchemy.testing import config
+ config.options = options = opt
+pre_configure.append(_setup_options)
+
+def _monkeypatch_cdecimal(options, file_config):
+ if options.cdecimal:
+ import sys
+ import cdecimal
+ sys.modules['decimal'] = cdecimal
+pre_configure.append(_monkeypatch_cdecimal)
+
+def _engine_uri(options, file_config):
+ global db_label, db_url
+ db_label = 'sqlite'
+ if options.dburi:
+ db_url = options.dburi
+ db_label = db_url[:db_url.index(':')]
+ elif options.db:
+ db_label = options.db
+ db_url = None
+
+ if db_url is None:
+ if db_label not in file_config.options('db'):
+ raise RuntimeError(
+ "Unknown engine. Specify --dbs for known engines.")
+ db_url = file_config.get('db', db_label)
+post_configure.append(_engine_uri)
+
+def _require(options, file_config):
+ if not(options.require or
+ (file_config.has_section('require') and
+ file_config.items('require'))):
+ return
+
+ try:
+ import pkg_resources
+ except ImportError:
+ raise RuntimeError("setuptools is required for version requirements")
+
+ cmdline = []
+ for requirement in options.require:
+ pkg_resources.require(requirement)
+ cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+ if file_config.has_section('require'):
+ for label, requirement in file_config.items('require'):
+ if not label == db_label or label.startswith('%s.' % db_label):
+ continue
+ seen = [c for c in cmdline if requirement.startswith(c)]
+ if seen:
+ continue
+ pkg_resources.require(requirement)
+post_configure.append(_require)
+
+def _engine_pool(options, file_config):
+ if options.mockpool:
+ from sqlalchemy import pool
+ db_opts['poolclass'] = pool.AssertionPool
+post_configure.append(_engine_pool)
+
+def _create_testing_engine(options, file_config):
+ from sqlalchemy.testing import engines, config
+ from sqlalchemy import testing
+ global db
+ config.db = testing.db = db = engines.testing_engine(db_url, db_opts)
+ config.db_opts = db_opts
+ config.db_url = db_url
+
+post_configure.append(_create_testing_engine)
+
+def _prep_testing_database(options, file_config):
+ from sqlalchemy.testing import engines
+ from sqlalchemy import schema
+
+ # also create alt schemas etc. here?
+ if options.dropfirst:
+ e = engines.utf8_engine()
+ existing = e.table_names()
+ if existing:
+ print "Dropping existing tables in database: " + db_url
+ try:
+ print "Tables: %s" % ', '.join(existing)
+ except:
+ pass
+ print "Abort within 5 seconds..."
+ time.sleep(5)
+ md = schema.MetaData(e, reflect=True)
+ md.drop_all()
+ e.dispose()
+
+post_configure.append(_prep_testing_database)
+
+def _set_table_options(options, file_config):
+ from sqlalchemy.testing import schema
+
+ table_options = schema.table_options
+ for spec in options.tableopts:
+ key, value = spec.split('=')
+ table_options[key] = value
+
+ if options.mysql_engine:
+ table_options['mysql_engine'] = options.mysql_engine
+post_configure.append(_set_table_options)
+
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm import unitofwork, session, mapper, dependency
+ from sqlalchemy.util import topological
+ from sqlalchemy.testing.util import RandomSet
+ topological.set = unitofwork.set = session.set = mapper.set = \
+ dependency.set = RandomSet
+post_configure.append(_reverse_topological)
+
+def _requirements(options, file_config):
+ from sqlalchemy.testing import config
+ from sqlalchemy import testing
+ requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+
+ modname, clsname = requirement_cls.split(":")
+
+ # importlib.import_module() only introduced in 2.7, a little
+ # late
+ mod = __import__(modname)
+ for component in modname.split(".")[1:]:
+ mod = getattr(mod, component)
+ req_cls = getattr(mod, clsname)
+ config.requirements = testing.requires = req_cls(db, config)
+
+post_configure.append(_requirements)
+
+def _setup_profiling(options, file_config):
+ from sqlalchemy.testing import profiling
+ profiling._profile_stats = profiling.ProfileStatsFile(
+ file_config.get('sqla_testing', 'profile_file'))
+
+post_configure.append(_setup_profiling)
+
diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py
new file mode 100644
index 000000000..c9e12c305
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/noseplugin.py
@@ -0,0 +1,199 @@
+"""Enhance nose with extra options and behaviors for running SQLAlchemy tests.
+
+This module is imported relative to the "plugins" package as a top level
+package by the sqla_nose.py runner, so that the plugin can be loaded with
+the rest of nose including the coverage plugin before any of SQLAlchemy itself
+is imported, so that coverage works.
+
+When third party libraries use this library, it can be imported
+normally as "from sqlalchemy.testing.plugin import noseplugin".
+
+"""
+import os
+import ConfigParser
+
+from nose.plugins import Plugin
+from nose import SkipTest
+from . import config
+
+from .config import _log, _list_dbs, _zero_timeout, \
+ _engine_strategy, _server_side_cursors, pre_configure,\
+ post_configure
+
+# late imports
+fixtures = None
+engines = None
+exclusions = None
+warnings = None
+profiling = None
+assertions = None
+requirements = None
+util = None
+file_config = None
+
+class NoseSQLAlchemy(Plugin):
+ """
+ Handles the setup and extra properties required for testing SQLAlchemy
+ """
+ enabled = True
+
+ name = 'sqla_testing'
+ score = 100
+
+ def options(self, parser, env=os.environ):
+ Plugin.options(self, parser, env)
+ opt = parser.add_option
+ opt("--log-info", action="callback", type="string", callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)")
+ opt("--log-debug", action="callback", type="string", callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)")
+ opt("--require", action="append", dest="require", default=[],
+ help="require a particular driver or module version (multiple OK)")
+ opt("--db", action="store", dest="db", default="sqlite",
+ help="Use prefab database uri")
+ opt('--dbs', action='callback', callback=_list_dbs,
+ help="List available prefab dbs")
+ opt("--dburi", action="store", dest="dburi",
+ help="Database uri (overrides --db)")
+ opt("--dropfirst", action="store_true", dest="dropfirst",
+ help="Drop all tables in the target database first")
+ opt("--mockpool", action="store_true", dest="mockpool",
+ help="Use mock pool (asserts only one connection used)")
+ opt("--zero-timeout", action="callback", callback=_zero_timeout,
+ help="Set pool_timeout to zero, applies to QueuePool only")
+ opt("--low-connections", action="store_true", dest="low_connections",
+ help="Use a low number of distinct connections - i.e. for Oracle TNS"
+ )
+ opt("--enginestrategy", action="callback", type="string",
+ callback=_engine_strategy,
+ help="Engine strategy (plain or threadlocal, defaults to plain)")
+ opt("--reversetop", action="store_true", dest="reversetop", default=False,
+ help="Use a random-ordering set implementation in the ORM (helps "
+ "reveal dependency issues)")
+ opt("--with-cdecimal", action="store_true", dest="cdecimal", default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' for all tests")
+ opt("--unhashable", action="store_true", dest="unhashable", default=False,
+ help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
+ opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
+ help="Disallow SQLAlchemy from performing == on mapped test objects.")
+ opt("--truthless", action="store_true", dest="truthless", default=False,
+ help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
+ opt("--serverside", action="callback", callback=_server_side_cursors,
+ help="Turn on server side cursors for PG")
+ opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+ help="Use the specified MySQL storage engine for all tables, default is "
+ "a db-default/InnoDB combo.")
+ opt("--table-option", action="append", dest="tableopts", default=[],
+ help="Add a dialect-specific table option, key=value")
+ opt("--write-profiles", action="store_true", dest="write_profiles", default=False,
+ help="Write/update profiling data.")
+ global file_config
+ file_config = ConfigParser.ConfigParser()
+ file_config.read(['setup.cfg', 'test.cfg', os.path.expanduser('~/.satest.cfg')])
+ config.file_config = file_config
+
+ def configure(self, options, conf):
+ Plugin.configure(self, options, conf)
+ self.options = options
+ for fn in pre_configure:
+ fn(self.options, file_config)
+
+ def begin(self):
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(self.options, file_config)
+
+ # late imports, has to happen after config as well
+ # as nose plugins like coverage
+ global util, fixtures, engines, exclusions, \
+ assertions, warnings, profiling
+ from sqlalchemy.testing import fixtures, engines, exclusions, \
+ assertions, warnings, profiling
+ from sqlalchemy import util
+
+ def describeTest(self, test):
+ return ""
+
+ def wantFunction(self, fn):
+ if fn.__module__.startswith('test.lib') or \
+ fn.__module__.startswith('test.bootstrap'):
+ return False
+
+ def wantClass(self, cls):
+ """Return true if you want the main test selector to collect
+ tests from this class, false if you don't, and None if you don't
+ care.
+
+ :Parameters:
+ cls : class
+ The class being examined by the selector
+
+ """
+
+ if not issubclass(cls, fixtures.TestBase):
+ return False
+ elif cls.__name__.startswith('_'):
+ return False
+ else:
+ return True
+
+ def _do_skips(self, cls):
+ from sqlalchemy.testing import config
+ if hasattr(cls, '__requires__'):
+ def test_suite():
+ return 'ok'
+ test_suite.__name__ = cls.__name__
+ for requirement in cls.__requires__:
+ check = getattr(config.requirements, requirement)
+ check(test_suite)()
+
+ if cls.__unsupported_on__:
+ spec = exclusions.db_spec(*cls.__unsupported_on__)
+ if spec(config.db):
+ raise SkipTest(
+ "'%s' unsupported on DB implementation '%s'" % (
+ cls.__name__, config.db.name)
+ )
+
+ if getattr(cls, '__only_on__', None):
+ spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+ if not spec(config.db):
+ raise SkipTest(
+ "'%s' unsupported on DB implementation '%s'" % (
+ cls.__name__, config.db.name)
+ )
+
+ if getattr(cls, '__skip_if__', False):
+ for c in getattr(cls, '__skip_if__'):
+ if c():
+ raise SkipTest("'%s' skipped by %s" % (
+ cls.__name__, c.__name__)
+ )
+
+ for db, op, spec in getattr(cls, '__excluded_on__', ()):
+ exclusions.exclude(db, op, spec,
+ "'%s' unsupported on DB %s version %s" % (
+ cls.__name__, config.db.name,
+ exclusions._server_version(config.db)))
+
+ def beforeTest(self, test):
+ warnings.resetwarnings()
+ profiling._current_test = test.id()
+
+ def afterTest(self, test):
+ engines.testing_reaper._after_test_ctx()
+ warnings.resetwarnings()
+
+ def startContext(self, ctx):
+ if not isinstance(ctx, type) \
+ or not issubclass(ctx, fixtures.TestBase):
+ return
+ self._do_skips(ctx)
+
+ def stopContext(self, ctx):
+ if not isinstance(ctx, type) \
+ or not issubclass(ctx, fixtures.TestBase):
+ return
+ engines.testing_reaper._stop_test_ctx()
+ if not config.options.low_connections:
+ assertions.global_cleanup_assertions()
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
new file mode 100644
index 000000000..be32b1d1d
--- /dev/null
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -0,0 +1,292 @@
+"""Profiling support for unit and performance tests.
+
+These are special purpose profiling methods which operate
+in a more fine-grained way than nose's profiling plugin.
+
+"""
+
+import os
+import sys
+from .util import gc_collect, decorator
+from . import config
+from nose import SkipTest
+import pstats
+import time
+import collections
+from sqlalchemy import util
+try:
+ import cProfile
+except ImportError:
+ cProfile = None
+from sqlalchemy.util.compat import jython, pypy, win32
+
+_current_test = None
+
+def profiled(target=None, **target_opts):
+ """Function profiling.
+
+ @profiled('label')
+ or
+ @profiled('label', report=True, sort=('calls',), limit=20)
+
+ Enables profiling for a function when 'label' is targetted for
+ profiling. Report options can be supplied, and override the global
+ configuration and command-line options.
+ """
+
+ profile_config = {'targets': set(),
+ 'report': True,
+ 'print_callers': False,
+ 'print_callees': False,
+ 'graphic': False,
+ 'sort': ('time', 'calls'),
+ 'limit': None}
+ if target is None:
+ target = 'anonymous_target'
+
+ filename = "%s.prof" % target
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ elapsed, load_stats, result = _profile(
+ filename, fn, *args, **kw)
+
+ graphic = target_opts.get('graphic', profile_config['graphic'])
+ if graphic:
+ os.system("runsnake %s" % filename)
+ else:
+ report = target_opts.get('report', profile_config['report'])
+ if report:
+ sort_ = target_opts.get('sort', profile_config['sort'])
+ limit = target_opts.get('limit', profile_config['limit'])
+ print ("Profile report for target '%s' (%s)" % (
+ target, filename)
+ )
+
+ stats = load_stats()
+ stats.sort_stats(*sort_)
+ if limit:
+ stats.print_stats(limit)
+ else:
+ stats.print_stats()
+
+ print_callers = target_opts.get('print_callers',
+ profile_config['print_callers'])
+ if print_callers:
+ stats.print_callers()
+
+ print_callees = target_opts.get('print_callees',
+ profile_config['print_callees'])
+ if print_callees:
+ stats.print_callees()
+
+ os.unlink(filename)
+ return result
+ return decorate
+
+
+class ProfileStatsFile(object):
+ """"Store per-platform/fn profiling results in a file.
+
+ We're still targeting Py2.5, 2.4 on 0.7 with no dependencies,
+ so no json lib :( need to roll something silly
+
+ """
+ def __init__(self, filename):
+ self.write = config.options is not None and config.options.write_profiles
+ self.fname = os.path.abspath(filename)
+ self.short_fname = os.path.split(self.fname)[-1]
+ self.data = collections.defaultdict(lambda: collections.defaultdict(dict))
+ self._read()
+ if self.write:
+ # rewrite for the case where features changed,
+ # etc.
+ self._write()
+
+ @util.memoized_property
+ def platform_key(self):
+
+ dbapi_key = config.db.name + "_" + config.db.driver
+
+ # keep it at 2.7, 3.1, 3.2, etc. for now.
+ py_version = '.'.join([str(v) for v in sys.version_info[0:2]])
+
+ platform_tokens = [py_version]
+ platform_tokens.append(dbapi_key)
+ if jython:
+ platform_tokens.append("jython")
+ if pypy:
+ platform_tokens.append("pypy")
+ if win32:
+ platform_tokens.append("win")
+ _has_cext = config.requirements._has_cextensions()
+ platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
+ return "_".join(platform_tokens)
+
+ def has_stats(self):
+ test_key = _current_test
+ return test_key in self.data and self.platform_key in self.data[test_key]
+
+ def result(self, callcount):
+ test_key = _current_test
+ per_fn = self.data[test_key]
+ per_platform = per_fn[self.platform_key]
+
+ if 'counts' not in per_platform:
+ per_platform['counts'] = counts = []
+ else:
+ counts = per_platform['counts']
+
+ if 'current_count' not in per_platform:
+ per_platform['current_count'] = current_count = 0
+ else:
+ current_count = per_platform['current_count']
+
+ has_count = len(counts) > current_count
+
+ if not has_count:
+ counts.append(callcount)
+ if self.write:
+ self._write()
+ result = None
+ else:
+ result = per_platform['lineno'], counts[current_count]
+ per_platform['current_count'] += 1
+ return result
+
+
+ def _header(self):
+ return \
+ "# %s\n"\
+ "# This file is written out on a per-environment basis.\n"\
+ "# For each test in aaa_profiling, the corresponding function and \n"\
+ "# environment is located within this file. If it doesn't exist,\n"\
+ "# the test is skipped.\n"\
+ "# If a callcount does exist, it is compared to what we received. \n"\
+ "# assertions are raised if the counts do not match.\n"\
+ "# \n"\
+ "# To add a new callcount test, apply the function_call_count \n"\
+ "# decorator and re-run the tests using the --write-profiles option - \n"\
+ "# this file will be rewritten including the new count.\n"\
+ "# \n"\
+ "" % (self.fname)
+
+ def _read(self):
+ try:
+ profile_f = open(self.fname)
+ except IOError:
+ return
+ for lineno, line in enumerate(profile_f):
+ line = line.strip()
+ if not line or line.startswith("#"):
+ continue
+
+ test_key, platform_key, counts = line.split()
+ per_fn = self.data[test_key]
+ per_platform = per_fn[platform_key]
+ per_platform['counts'] = [int(count) for count in counts.split(",")]
+ per_platform['lineno'] = lineno + 1
+ per_platform['current_count'] = 0
+ profile_f.close()
+
+ def _write(self):
+ print("Writing profile file %s" % self.fname)
+ profile_f = open(self.fname, "w")
+ profile_f.write(self._header())
+ for test_key in sorted(self.data):
+
+ per_fn = self.data[test_key]
+ profile_f.write("\n# TEST: %s\n\n" % test_key)
+ for platform_key in sorted(per_fn):
+ per_platform = per_fn[platform_key]
+ profile_f.write(
+ "%s %s %s\n" % (
+ test_key,
+ platform_key, ",".join(str(count) for count in per_platform['counts'])
+ )
+ )
+ profile_f.close()
+
+from sqlalchemy.util.compat import update_wrapper
+
+def function_call_count(variance=0.05):
+ """Assert a target for a test case's function call count.
+
+ The main purpose of this assertion is to detect changes in
+ callcounts for various functions - the actual number is not as important.
+ Callcounts are stored in a file keyed to Python version and OS platform
+ information. This file is generated automatically for new tests,
+ and versioned so that unexpected changes in callcounts will be detected.
+
+ """
+
+ def decorate(fn):
+ def wrap(*args, **kw):
+
+
+ if cProfile is None:
+ raise SkipTest("cProfile is not installed")
+
+ if not _profile_stats.has_stats() and not _profile_stats.write:
+ # run the function anyway, to support dependent tests
+ # (not a great idea but we have these in test_zoomark)
+ fn(*args, **kw)
+ raise SkipTest("No profiling stats available on this "
+ "platform for this function. Run tests with "
+ "--write-profiles to add statistics to %s for "
+ "this platform." % _profile_stats.short_fname)
+
+ gc_collect()
+
+
+ timespent, load_stats, fn_result = _profile(
+ fn, *args, **kw
+ )
+ stats = load_stats()
+ callcount = stats.total_calls
+
+ expected = _profile_stats.result(callcount)
+ if expected is None:
+ expected_count = None
+ else:
+ line_no, expected_count = expected
+
+ print("Pstats calls: %d Expected %s" % (
+ callcount,
+ expected_count
+ )
+ )
+ stats.print_stats()
+ #stats.print_callers()
+
+ if expected_count:
+ deviance = int(callcount * variance)
+ if abs(callcount - expected_count) > deviance:
+ raise AssertionError(
+ "Adjusted function call count %s not within %s%% "
+ "of expected %s. (Delete line %d of file %s to regenerate "
+ "this callcount, when tests are run with --write-profiles.)"
+ % (
+ callcount, (variance * 100),
+ expected_count, line_no,
+ _profile_stats.fname))
+ return fn_result
+ return update_wrapper(wrap, fn)
+ return decorate
+
+
+def _profile(fn, *args, **kw):
+ filename = "%s.prof" % fn.__name__
+
+ def load_stats():
+ st = pstats.Stats(filename)
+ os.unlink(filename)
+ return st
+
+ began = time.time()
+ cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
+ filename=filename)
+ ended = time.time()
+
+ return ended - began, load_stats, locals()['result']
+
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
new file mode 100644
index 000000000..eca883d4e
--- /dev/null
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -0,0 +1,38 @@
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+"""
+
+from .exclusions import \
+ skip, \
+ skip_if,\
+ only_if,\
+ only_on,\
+ fails_on,\
+ fails_on_everything_except,\
+ fails_if,\
+ SpecPredicate,\
+ against
+
+def no_support(db, reason):
+ return SpecPredicate(db, description=reason)
+
+def exclude(db, op, spec, description=None):
+ return SpecPredicate(db, op, spec, description=description)
+
+
+def _chain_decorators_on(*decorators):
+ def decorate(fn):
+ for decorator in reversed(decorators):
+ fn = decorator(fn)
+ return fn
+ return decorate
+
+class Requirements(object):
+ def __init__(self, db, config):
+ self.db = db
+ self.config = config
+
+
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
new file mode 100644
index 000000000..03da78c64
--- /dev/null
+++ b/lib/sqlalchemy/testing/schema.py
@@ -0,0 +1,85 @@
+"""Enhanced versions of schema.Table and schema.Column which establish
+desired state for different backends.
+"""
+
+from . import exclusions
+from sqlalchemy import schema, event
+from . import config
+
+__all__ = 'Table', 'Column',
+
+table_options = {}
+
+def Table(*args, **kw):
+ """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = dict([(k, kw.pop(k)) for k in kw.keys()
+ if k.startswith('test_')])
+
+ kw.update(table_options)
+
+ if exclusions.against('mysql'):
+ if 'mysql_engine' not in kw and 'mysql_type' not in kw:
+ if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
+ kw['mysql_engine'] = 'InnoDB'
+ else:
+ kw['mysql_engine'] = 'MyISAM'
+
+ # Apply some default cascading rules for self-referential foreign keys.
+ # MySQL InnoDB has some issues around seleting self-refs too.
+ if exclusions.against('firebird'):
+ table_name = args[0]
+ unpack = (config.db.dialect.
+ identifier_preparer.unformat_identifiers)
+
+ # Only going after ForeignKeys in Columns. May need to
+ # expand to ForeignKeyConstraint too.
+ fks = [fk
+ for col in args if isinstance(col, schema.Column)
+ for fk in col.foreign_keys]
+
+ for fk in fks:
+ # root around in raw spec
+ ref = fk._colspec
+ if isinstance(ref, schema.Column):
+ name = ref.table.name
+ else:
+ # take just the table name: on FB there cannot be
+ # a schema, so the first element is always the
+ # table name, possibly followed by the field name
+ name = unpack(ref)[0]
+ if name == table_name:
+ if fk.ondelete is None:
+ fk.ondelete = 'CASCADE'
+ if fk.onupdate is None:
+ fk.onupdate = 'CASCADE'
+
+ return schema.Table(*args, **kw)
+
+
+def Column(*args, **kw):
+ """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = dict([(k, kw.pop(k)) for k in kw.keys()
+ if k.startswith('test_')])
+
+ col = schema.Column(*args, **kw)
+ if 'test_needs_autoincrement' in test_opts and \
+ kw.get('primary_key', False) and \
+ exclusions.against('firebird', 'oracle'):
+ def add_seq(c, tbl):
+ c._init_items(
+ schema.Sequence(_truncate_name(
+ config.db.dialect, tbl.name + '_' + c.name + '_seq'),
+ optional=True)
+ )
+ event.listen(col, 'after_parent_attach', add_seq, propagate=True)
+ return col
+
+def _truncate_name(dialect, name):
+ if len(name) > dialect.max_identifier_length:
+ return name[0:max(dialect.max_identifier_length - 6, 0)] + \
+ "_" + hex(hash(name) % 64)[2:]
+ else:
+ return name
+
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/__init__.py
diff --git a/lib/sqlalchemy/testing/suite/requirements.py b/lib/sqlalchemy/testing/suite/requirements.py
new file mode 100644
index 000000000..5eda39b2b
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/requirements.py
@@ -0,0 +1,24 @@
+from ..requirements import Requirements
+from .. import exclusions
+
+
+class SuiteRequirements(Requirements):
+
+ @property
+ def create_table(self):
+ """target platform can emit basic CreateTable DDL."""
+
+ return exclusions.open
+
+ @property
+ def drop_table(self):
+ """target platform can emit basic DropTable DDL."""
+
+ return exclusions.open
+
+ @property
+ def autoincrement_insert(self):
+ """target platform generates new surrogate integer primary key values
+ when insert() is executed, excluding the pk column."""
+
+ return exclusions.open
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
new file mode 100644
index 000000000..1285c4196
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -0,0 +1,48 @@
+from .. import fixtures, config, util
+from ..config import requirements
+from ..assertions import eq_
+
+from sqlalchemy import Table, Column, Integer, String
+
+
+class TableDDLTest(fixtures.TestBase):
+
+ def _simple_fixture(self):
+ return Table('test_table', self.metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50))
+ )
+
+ def _simple_roundtrip(self):
+ with config.db.begin() as conn:
+ conn.execute("insert into test_table(id, data) values "
+ "(1, 'some data')")
+ result = conn.execute("select id, data from test_table")
+ eq_(
+ result.first(),
+ (1, 'some data')
+ )
+
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_create_table(self):
+ table = self._simple_fixture()
+ table.create(
+ config.db, checkfirst=False
+ )
+ self._simple_roundtrip()
+
+
+ @requirements.drop_table
+ @util.provide_metadata
+ def test_drop_table(self):
+ table = self._simple_fixture()
+ table.create(
+ config.db, checkfirst=False
+ )
+ table.drop(
+ config.db, checkfirst=False
+ )
+
+__all__ = ('TableDDLTest',) \ No newline at end of file
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
diff --git a/lib/sqlalchemy/testing/suite/test_sequencing.py b/lib/sqlalchemy/testing/suite/test_sequencing.py
new file mode 100644
index 000000000..7b09ecb76
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_sequencing.py
@@ -0,0 +1,36 @@
+from .. import fixtures, config, util
+from ..config import requirements
+from ..assertions import eq_
+
+from sqlalchemy import Table, Column, Integer, String
+
+
+class InsertSequencingTest(fixtures.TablesTest):
+ run_deletes = 'each'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('plain_pk', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50))
+ )
+
+ def _assert_round_trip(self, table):
+ row = config.db.execute(table.select()).first()
+ eq_(
+ row,
+ (1, "some data")
+ )
+
+ @requirements.autoincrement_insert
+ def test_autoincrement_on_insert(self):
+
+ config.db.execute(
+ self.tables.plain_pk.insert(),
+ data="some data"
+ )
+ self._assert_round_trip(self.tables.plain_pk)
+
+
+
+__all__ = ('InsertSequencingTest',) \ No newline at end of file
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
new file mode 100644
index 000000000..625b9e6a5
--- /dev/null
+++ b/lib/sqlalchemy/testing/util.py
@@ -0,0 +1,196 @@
+from sqlalchemy.util import jython, pypy, defaultdict, decorator
+from sqlalchemy.util.compat import decimal
+
+import gc
+import time
+import random
+import sys
+import types
+
+if jython:
+ def jython_gc_collect(*args):
+ """aggressive gc.collect for tests."""
+ gc.collect()
+ time.sleep(0.1)
+ gc.collect()
+ gc.collect()
+ return 0
+
+ # "lazy" gc, for VM's that don't GC on refcount == 0
+ lazy_gc = jython_gc_collect
+elif pypy:
+ def pypy_gc_collect(*args):
+ gc.collect()
+ gc.collect()
+ lazy_gc = pypy_gc_collect
+else:
+ # assume CPython - straight gc.collect, lazy_gc() is a pass
+ gc_collect = gc.collect
+ def lazy_gc():
+ pass
+
+def picklers():
+ picklers = set()
+ # Py2K
+ try:
+ import cPickle
+ picklers.add(cPickle)
+ except ImportError:
+ pass
+ # end Py2K
+ import pickle
+ picklers.add(pickle)
+
+ # yes, this thing needs this much testing
+ for pickle_ in picklers:
+ for protocol in -1, 0, 1, 2:
+ yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
+
+
+def round_decimal(value, prec):
+ if isinstance(value, float):
+ return round(value, prec)
+
+ # can also use shift() here but that is 2.6 only
+ return (value * decimal.Decimal("1" + "0" * prec)
+ ).to_integral(decimal.ROUND_FLOOR) / \
+ pow(10, prec)
+
+class RandomSet(set):
+ def __iter__(self):
+ l = list(set.__iter__(self))
+ random.shuffle(l)
+ return iter(l)
+
+ def pop(self):
+ index = random.randint(0, len(self) - 1)
+ item = list(set.__iter__(self))[index]
+ self.remove(item)
+ return item
+
+ def union(self, other):
+ return RandomSet(set.union(self, other))
+
+ def difference(self, other):
+ return RandomSet(set.difference(self, other))
+
+ def intersection(self, other):
+ return RandomSet(set.intersection(self, other))
+
+ def copy(self):
+ return RandomSet(self)
+
+def conforms_partial_ordering(tuples, sorted_elements):
+ """True if the given sorting conforms to the given partial ordering."""
+
+ deps = defaultdict(set)
+ for parent, child in tuples:
+ deps[parent].add(child)
+ for i, node in enumerate(sorted_elements):
+ for n in sorted_elements[i:]:
+ if node in deps[n]:
+ return False
+ else:
+ return True
+
+def all_partial_orderings(tuples, elements):
+ edges = defaultdict(set)
+ for parent, child in tuples:
+ edges[child].add(parent)
+
+ def _all_orderings(elements):
+
+ if len(elements) == 1:
+ yield list(elements)
+ else:
+ for elem in elements:
+ subset = set(elements).difference([elem])
+ if not subset.intersection(edges[elem]):
+ for sub_ordering in _all_orderings(subset):
+ yield [elem] + sub_ordering
+
+ return iter(_all_orderings(elements))
+
+
+def function_named(fn, name):
+ """Return a function with a given __name__.
+
+ Will assign to __name__ and return the original function if possible on
+ the Python implementation, otherwise a new function will be constructed.
+
+ This function should be phased out as much as possible
+ in favor of @decorator. Tests that "generate" many named tests
+ should be modernized.
+
+ """
+ try:
+ fn.__name__ = name
+ except TypeError:
+ fn = types.FunctionType(fn.func_code, fn.func_globals, name,
+ fn.func_defaults, fn.func_closure)
+ return fn
+
+
+
+def run_as_contextmanager(ctx, fn, *arg, **kw):
+ """Run the given function under the given contextmanager,
+ simulating the behavior of 'with' to support older
+ Python versions.
+
+ """
+
+ obj = ctx.__enter__()
+ try:
+ result = fn(obj, *arg, **kw)
+ ctx.__exit__(None, None, None)
+ return result
+ except:
+ exc_info = sys.exc_info()
+ raise_ = ctx.__exit__(*exc_info)
+ if raise_ is None:
+ raise
+ else:
+ return raise_
+
+def rowset(results):
+ """Converts the results of sql execution into a plain set of column tuples.
+
+ Useful for asserting the results of an unordered query.
+ """
+
+ return set([tuple(row) for row in results])
+
+
+def fail(msg):
+ assert False, msg
+
+
+@decorator
+def provide_metadata(fn, *args, **kw):
+ """Provide bound MetaData for a single test, dropping afterwards."""
+
+ from . import config
+ from sqlalchemy import schema
+
+ metadata = schema.MetaData(config.db)
+ self = args[0]
+ prev_meta = getattr(self, 'metadata', None)
+ self.metadata = metadata
+ try:
+ return fn(*args, **kw)
+ finally:
+ metadata.drop_all()
+ self.metadata = prev_meta
+
+class adict(dict):
+ """Dict keys available as attributes. Shadows."""
+ def __getattribute__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ return dict.__getattribute__(self, key)
+
+ def get_all(self, *keys):
+ return tuple([self[key] for key in keys])
+
+
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
new file mode 100644
index 000000000..799fca128
--- /dev/null
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -0,0 +1,43 @@
+from __future__ import absolute_import
+
+import warnings
+from sqlalchemy import exc as sa_exc
+from sqlalchemy import util
+
+def testing_warn(msg, stacklevel=3):
+ """Replaces sqlalchemy.util.warn during tests."""
+
+ filename = "sqlalchemy.testing.warnings"
+ lineno = 1
+ if isinstance(msg, basestring):
+ warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno)
+ else:
+ warnings.warn_explicit(msg, filename, lineno)
+
+def resetwarnings():
+ """Reset warning behavior to testing defaults."""
+
+ util.warn = util.langhelpers.warn = testing_warn
+
+ warnings.filterwarnings('ignore',
+ category=sa_exc.SAPendingDeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SAWarning)
+
+def assert_warnings(fn, warnings):
+ """Assert that each of the given warnings are emitted by fn."""
+
+ from .assertions import eq_, emits_warning
+
+ canary = []
+ orig_warn = util.warn
+ def capture_warnings(*args, **kw):
+ orig_warn(*args, **kw)
+ popwarn = warnings.pop(0)
+ canary.append(popwarn)
+ eq_(args[0], popwarn)
+ util.warn = util.langhelpers.warn = capture_warnings
+
+ result = emits_warning()(fn)()
+ assert canary, "No warning was emitted"
+ return result