diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-09-27 02:37:33 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-09-27 02:37:33 -0400 |
commit | 20cdc64588b0f6ae52f8380c11d0ed848005377b (patch) | |
tree | 08f6cc8f82263f1e402c1c05c83b66a1f4b016ac /lib/sqlalchemy/testing | |
parent | 21cac5b598a83ef0e24423dc523629b475aa3af0 (diff) | |
download | sqlalchemy-20cdc64588b0f6ae52f8380c11d0ed848005377b.tar.gz |
trying different approaches to test layout. in this one, the testing modules
become an externally usable package but still remains within the main sqlalchemy parent package.
in this system, we use kind of an ugly hack to get the noseplugin imported outside of the
"sqlalchemy" package, while still making it available within sqlalchemy for usage by
third party libraries.
Diffstat (limited to 'lib/sqlalchemy/testing')
22 files changed, 3058 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py new file mode 100644 index 000000000..415705e93 --- /dev/null +++ b/lib/sqlalchemy/testing/__init__.py @@ -0,0 +1,21 @@ +from __future__ import absolute_import + +from .warnings import testing_warn, assert_warnings, resetwarnings + +from . import config + +from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ + fails_on, fails_on_everything_except, skip, only_on, exclude, against,\ + _server_version + +from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ + eq_, ne_, is_, is_not_, startswith_, assert_raises, \ + assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults + +from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict + +crashes = skip + +from .config import db, requirements as requires + + diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py new file mode 100644 index 000000000..1e8559c1a --- /dev/null +++ b/lib/sqlalchemy/testing/assertions.py @@ -0,0 +1,349 @@ +from __future__ import absolute_import + +from . import util as testutil +from sqlalchemy import pool, orm, util +from sqlalchemy.engine import default +from sqlalchemy import exc as sa_exc +from sqlalchemy.util import decorator +from sqlalchemy import types as sqltypes, schema +import warnings +import re +from .warnings import resetwarnings +from .exclusions import db_spec, _is_excluded +from . import assertsql +from . import config +import itertools +from .util import fail + +def emits_warning(*messages): + """Mark a test as emitting a warning. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + # TODO: it would be nice to assert that a named warning was + # emitted. should work with some monkeypatching of warnings, + # and may work on non-CPython if they keep to the spirit of + # warnings.showwarning's docstring. + # - update: jython looks ok, it uses cpython's module + + @decorator + def decorate(fn, *args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SAWarning)) + else: + filters.extend(dict(action='ignore', + message=message, + category=sa_exc.SAWarning) + for message in messages) + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return decorate + +def emits_warning_on(db, *warnings): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + spec = db_spec(db) + + @decorator + def decorate(fn, *args, **kw): + if isinstance(db, basestring): + if not spec(config.db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + else: + if not _is_excluded(*db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + return decorate + + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + """ + + @decorator + def decorate(fn, *args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SADeprecationWarning)) + else: + filters.extend( + [dict(action='ignore', + message=message, + category=sa_exc.SADeprecationWarning) + for message in + [(m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages]]) + + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return decorate + + + +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + + testutil.lazy_gc() + assert not pool._refs, str(pool._refs) + + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + +def is_not_(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, fragment) + +def assert_raises(except_cls, callable_, *args, **kw): + try: + callable_(*args, **kw) + success = False + except except_cls: + success = True + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + try: + callable_(*args, **kwargs) + assert False, "Callable did not raise an exception" + except except_cls, e: + assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e) + print unicode(e).encode('utf-8') + + +class AssertsCompiledSQL(object): + def assert_compile(self, clause, result, params=None, + checkparams=None, dialect=None, + checkpositional=None, + use_default_dialect=False, + allow_dialect_select=False): + if use_default_dialect: + dialect = default.DefaultDialect() + elif dialect == None and not allow_dialect_select: + dialect = getattr(self, '__dialect__', None) + if dialect == 'default': + dialect = default.DefaultDialect() + elif dialect is None: + dialect = config.db.dialect + + kw = {} + if params is not None: + kw['column_keys'] = params.keys() + + if isinstance(clause, orm.Query): + context = clause._compile_context() + context.statement.use_labels = True + clause = context.statement + + c = clause.compile(dialect=dialect, **kw) + + param_str = repr(getattr(c, 'params', {})) + # Py3K + #param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + + print "\nSQL String:\n" + str(c) + param_str + + cc = re.sub(r'[\n\t]', '', str(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + eq_(c.construct_params(params), checkparams) + if checkpositional is not None: + p = c.construct_params(params) + eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table, strict_types=False): + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + + if strict_types: + assert type(reflected_c.type) is type(c.type), \ + "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + else: + self.assert_types_base(reflected_c, c) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + if c.server_default: + assert isinstance(reflected_c.server_default, + schema.FetchedValue) + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None + + def assert_types_base(self, c1, c2): + assert c1.type._compare_type_affinity(c2.type),\ + "On column %r, type '%s' doesn't correspond to type '%s'" % \ + (c1.name, c1.type, c2.type) + +class AssertsExecutionResults(object): + def assert_result(self, result, class_, *objects): + result = list(result) + print repr(result) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list): + self.assert_(len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + class_.__name__) + for i in range(0, len(list)): + self.assert_row(class_, result[i], list[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_(rowobj.__class__ is class_, + "item class is not " + repr(class_)) + for key, value in desc.iteritems(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_(getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" % ( + key, getattr(rowobj, key), value)) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class immutabledict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = set([immutabledict(e) for e in expected]) + + for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): + fail('Unexpected type "%s", expected "%s"' % ( + type(wrong).__name__, cls.__name__)) + + if len(found) != len(expected): + fail('Unexpected object count "%s", expected "%s"' % ( + len(found), len(expected))) + + NOVALUE = object() + def _compare_item(obj, spec): + for key, value in spec.iteritems(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1]) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." % ( + cls.__name__, repr(expected_item))) + return True + + def assert_sql_execution(self, db, callable_, *rules): + assertsql.asserter.add_rules(rules) + try: + callable_() + assertsql.asserter.statement_complete() + finally: + assertsql.asserter.clear_rules() + + def assert_sql(self, db, callable_, list_, with_sequences=None): + if with_sequences is not None and config.db.dialect.supports_sequences: + rules = with_sequences + else: + rules = list_ + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf(*[ + assertsql.ExactSQL(k, v) for k, v in rule.iteritems() + ]) + else: + newrule = assertsql.ExactSQL(*rule) + newrules.append(newrule) + + self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + + diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py new file mode 100644 index 000000000..897f4b3b1 --- /dev/null +++ b/lib/sqlalchemy/testing/assertsql.py @@ -0,0 +1,316 @@ + +from sqlalchemy.interfaces import ConnectionProxy +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.base import Connection +from sqlalchemy import util +import re + +class AssertRule(object): + + def process_execute(self, clauseelement, *multiparams, **params): + pass + + def process_cursor_execute(self, statement, parameters, context, + executemany): + pass + + def is_consumed(self): + """Return True if this rule has been consumed, False if not. + + Should raise an AssertionError if this rule's condition has + definitely failed. + + """ + + raise NotImplementedError() + + def rule_passed(self): + """Return True if the last test of this rule passed, False if + failed, None if no test was applied.""" + + raise NotImplementedError() + + def consume_final(self): + """Return True if this rule has been consumed. + + Should raise an AssertionError if this rule's condition has not + been consumed or has failed. + + """ + + if self._result is None: + assert False, 'Rule has not been consumed' + return self.is_consumed() + +class SQLMatchRule(AssertRule): + def __init__(self): + self._result = None + self._errmsg = "" + + def rule_passed(self): + return self._result + + def is_consumed(self): + if self._result is None: + return False + + assert self._result, self._errmsg + + return True + +class ExactSQL(SQLMatchRule): + + def __init__(self, sql, params=None): + SQLMatchRule.__init__(self) + self.sql = sql + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_statement = \ + _process_engine_statement(context.unicode_statement, + context) + _received_parameters = context.compiled_parameters + + # TODO: remove this step once all unit tests are migrated, as + # ExactSQL should really be *exact* SQL + + sql = _process_assertion_statement(self.sql, context) + equivalent = _received_statement == sql + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + equivalent = equivalent and params \ + == context.compiled_parameters + else: + params = {} + self._result = equivalent + if not self._result: + self._errmsg = \ + 'Testing for exact statement %r exact params %r, '\ + 'received %r with params %r' % (sql, params, + _received_statement, _received_parameters) + + +class RegexSQL(SQLMatchRule): + + def __init__(self, regex, params=None): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_statement = \ + _process_engine_statement(context.unicode_statement, + context) + _received_parameters = context.compiled_parameters + equivalent = bool(self.regex.match(_received_statement)) + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + + # do a positive compare only + + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + self._result = equivalent + if not self._result: + self._errmsg = \ + 'Testing for regex %r partial params %r, received %r '\ + 'with params %r' % (self.orig_regex, params, + _received_statement, + _received_parameters) + +class CompiledSQL(SQLMatchRule): + + def __init__(self, statement, params): + SQLMatchRule.__init__(self) + self.statement = statement + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_parameters = list(context.compiled_parameters) + + # recompile from the context, using the default dialect + + compiled = \ + context.compiled.statement.compile(dialect=DefaultDialect(), + column_keys=context.compiled.column_keys) + _received_statement = re.sub(r'\n', '', str(compiled)) + equivalent = self.statement == _received_statement + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + all_params = list(params) + all_received = list(_received_parameters) + while params: + param = dict(params.pop(0)) + for k, v in context.compiled.params.iteritems(): + param.setdefault(k, v) + if param not in _received_parameters: + equivalent = False + break + else: + _received_parameters.remove(param) + if _received_parameters: + equivalent = False + else: + params = {} + all_params = {} + all_received = [] + self._result = equivalent + if not self._result: + print 'Testing for compiled statement %r partial params '\ + '%r, received %r with params %r' % (self.statement, + all_params, _received_statement, all_received) + self._errmsg = \ + 'Testing for compiled statement %r partial params %r, '\ + 'received %r with params %r' % (self.statement, + all_params, _received_statement, all_received) + + + # print self._errmsg + +class CountStatements(AssertRule): + + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_execute(self, clauseelement, *multiparams, **params): + self._statement_count += 1 + + def process_cursor_execute(self, statement, parameters, context, + executemany): + pass + + def is_consumed(self): + return False + + def consume_final(self): + assert self.count == self._statement_count, \ + 'desired statement count %d does not match %d' \ + % (self.count, self._statement_count) + return True + +class AllOf(AssertRule): + + def __init__(self, *rules): + self.rules = set(rules) + + def process_execute(self, clauseelement, *multiparams, **params): + for rule in self.rules: + rule.process_execute(clauseelement, *multiparams, **params) + + def process_cursor_execute(self, statement, parameters, context, + executemany): + for rule in self.rules: + rule.process_cursor_execute(statement, parameters, context, + executemany) + + def is_consumed(self): + if not self.rules: + return True + for rule in list(self.rules): + if rule.rule_passed(): # a rule passed, move on + self.rules.remove(rule) + return len(self.rules) == 0 + assert False, 'No assertion rules were satisfied for statement' + + def consume_final(self): + return len(self.rules) == 0 + +def _process_engine_statement(query, context): + if util.jython: + + # oracle+zxjdbc passes a PyStatement when returning into + + query = unicode(query) + if context.engine.name == 'mssql' \ + and query.endswith('; select scope_identity()'): + query = query[:-25] + query = re.sub(r'\n', '', query) + return query + +def _process_assertion_statement(query, context): + paramstyle = context.dialect.paramstyle + if paramstyle == 'named': + pass + elif paramstyle =='pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle=='qmark': + repl = "?" + elif paramstyle=='format': + repl = r"%s" + elif paramstyle=='numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + + return query + +class SQLAssert(object): + + rules = None + + def add_rules(self, rules): + self.rules = list(rules) + + def statement_complete(self): + for rule in self.rules: + if not rule.consume_final(): + assert False, \ + 'All statements are complete, but pending '\ + 'assertion rules remain' + + def clear_rules(self): + del self.rules + + def execute(self, conn, clauseelement, multiparams, params, result): + if self.rules is not None: + if not self.rules: + assert False, \ + 'All rules have been exhausted, but further '\ + 'statements remain' + rule = self.rules[0] + rule.process_execute(clauseelement, *multiparams, **params) + if rule.is_consumed(): + self.rules.pop(0) + + def cursor_execute(self, conn, cursor, statement, parameters, + context, executemany): + if self.rules: + rule = self.rules[0] + rule.process_cursor_execute(statement, parameters, context, + executemany) + +asserter = SQLAssert() + diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py new file mode 100644 index 000000000..2945bd456 --- /dev/null +++ b/lib/sqlalchemy/testing/config.py @@ -0,0 +1,3 @@ +requirements = None +db = None + diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py new file mode 100644 index 000000000..f7401550e --- /dev/null +++ b/lib/sqlalchemy/testing/engines.py @@ -0,0 +1,429 @@ +from __future__ import absolute_import + +import types +import weakref +from collections import deque +from . import config +from .util import decorator +from sqlalchemy import event, pool +import re +import warnings + +class ConnectionKiller(object): + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + self.testing_engines = weakref.WeakKeyDictionary() + self.conns = set() + + def add_engine(self, engine): + self.testing_engines[engine] = True + + def connect(self, dbapi_conn, con_record): + self.conns.add(dbapi_conn) + + def checkout(self, dbapi_con, con_record, con_proxy): + self.proxy_refs[con_proxy] = True + + def _safe(self, fn): + try: + fn() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + warnings.warn( + "testing_reaper couldn't " + "rollback/close connection: %s" % e) + + def rollback_all(self): + for rec in self.proxy_refs.keys(): + if rec is not None and rec.is_valid: + self._safe(rec.rollback) + + def close_all(self): + for rec in self.proxy_refs.keys(): + if rec is not None: + self._safe(rec._close) + + def _after_test_ctx(self): + pass + # this can cause a deadlock with pg8000 - pg8000 acquires + # prepared statment lock inside of rollback() - if async gc + # is collecting in finalize_fairy, deadlock. + # not sure if this should be if pypy/jython only + #for conn in self.conns: + # self._safe(conn.rollback) + + def _stop_test_ctx(self): + if config.options.low_connections: + self._stop_test_ctx_minimal() + else: + self._stop_test_ctx_aggressive() + + def _stop_test_ctx_minimal(self): + self.close_all() + + self.conns = set() + + for rec in self.testing_engines.keys(): + if rec is not config.db: + rec.dispose() + + def _stop_test_ctx_aggressive(self): + self.close_all() + for conn in self.conns: + self._safe(conn.close) + self.conns = set() + for rec in self.testing_engines.keys(): + rec.dispose() + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + +testing_reaper = ConnectionKiller() + +def drop_all_tables(metadata, bind): + testing_reaper.close_all() + if hasattr(bind, 'close'): + bind.close() + metadata.drop_all(bind) + +@decorator +def assert_conns_closed(fn, *args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + +@decorator +def rollback_open_connections(fn, *args, **kw): + """Decorator that rolls back all open connections after fn execution.""" + + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + +@decorator +def close_first(fn, *args, **kw): + """Decorator that closes all connections before fn execution.""" + + testing_reaper.close_all() + fn(*args, **kw) + + +@decorator +def close_open_connections(fn, *args, **kw): + """Decorator that closes all connections after fn execution.""" + try: + fn(*args, **kw) + finally: + testing_reaper.close_all() + +def all_dialects(exclude=None): + import sqlalchemy.databases as d + for name in d.__all__: + # TEMPORARY + if exclude and name in exclude: + continue + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + yield mod.dialect() + +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def _safe(self, fn): + try: + fn() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + warnings.warn( + "ReconnectFixture couldn't " + "close connection: %s" % e) + + def shutdown(self): + # TODO: this doesn't cover all cases + # as nicely as we'd like, namely MySQLdb. + # would need to implement R. Brewer's + # proxy server idea to get better + # coverage. + for c in list(self.connections): + self._safe(c.close) + self.connections = [] + +def reconnecting_engine(url=None, options=None): + url = url or config.db_url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options['module'] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + _dispose = engine.dispose + def dispose(): + engine.dialect.dbapi.shutdown() + _dispose() + engine.test_shutdown = engine.dialect.dbapi.shutdown + engine.dispose = dispose + return engine + +def testing_engine(url=None, options=None): + """Produce an engine configured by --options with optional overrides.""" + + from sqlalchemy import create_engine + from .assertsql import asserter + + if not options: + use_reaper = True + else: + use_reaper = options.pop('use_reaper', True) + + url = url or config.db_url + if options is None: + options = config.db_opts + + engine = create_engine(url, **options) + if isinstance(engine.pool, pool.QueuePool): + engine.pool._timeout = 0 + engine.pool._max_overflow = 0 + event.listen(engine, 'after_execute', asserter.execute) + event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) + if use_reaper: + event.listen(engine.pool, 'connect', testing_reaper.connect) + event.listen(engine.pool, 'checkout', testing_reaper.checkout) + testing_reaper.add_engine(engine) + + return engine + +def utf8_engine(url=None, options=None): + """Hook for dialects or drivers that don't handle utf8 by default.""" + + from sqlalchemy.engine import url as engine_url + + if config.db.dialect.name == 'mysql' and \ + config.db.driver in ['mysqldb', 'pymysql']: + # note 1.2.1.gamma.6 or greater of MySQLdb + # needed here + url = url or config.db_url + url = engine_url.make_url(url) + url.query['charset'] = 'utf8' + url.query['use_unicode'] = '0' + url = str(url) + + return testing_engine(url, options) + +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ + + from sqlalchemy import create_engine + + if not dialect_name: + dialect_name = config.db.name + + buffer = [] + def executor(sql, *a, **kw): + buffer.append(sql) + def assert_sql(stmts): + recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + assert recv == stmts, recv + def print_sql(): + d = engine.dialect + return "\n".join( + str(s.compile(dialect=d)) + for s in engine.mock + ) + engine = create_engine(dialect_name + '://', + strategy='mock', executor=executor) + assert not hasattr(engine, 'mock') + engine.mock = buffer + engine.assert_sql = assert_sql + engine.print_sql = print_sql + return engine + +class DBAPIProxyCursor(object): + """Proxy a DBAPI cursor. + + Tests can provide subclasses of this to intercept + DBAPI-level cursor operations. + + """ + def __init__(self, engine, conn): + self.engine = engine + self.connection = conn + self.cursor = conn.cursor() + + def execute(self, stmt, parameters=None, **kw): + if parameters: + return self.cursor.execute(stmt, parameters, **kw) + else: + return self.cursor.execute(stmt, **kw) + + def executemany(self, stmt, params, **kw): + return self.cursor.executemany(stmt, params, **kw) + + def __getattr__(self, key): + return getattr(self.cursor, key) + +class DBAPIProxyConnection(object): + """Proxy a DBAPI connection. + + Tests can provide subclasses of this to intercept + DBAPI-level connection operations. + + """ + def __init__(self, engine, cursor_cls): + self.conn = self._sqla_unwrap = engine.pool._creator() + self.engine = engine + self.cursor_cls = cursor_cls + + def cursor(self): + return self.cursor_cls(self.engine, self.conn) + + def close(self): + self.conn.close() + + def __getattr__(self, key): + return getattr(self.conn, key) + +def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor): + """Produce an engine that provides proxy hooks for + common methods. + + """ + def mock_conn(): + return conn_cls(config.db, cursor_cls) + return testing_engine(options={'creator':mock_conn}) + +class ReplayableSession(object): + """A simple record/playback tool. + + This is *not* a mock testing class. It only records a session for later + playback and makes no assertions on call consistency whatsoever. It's + unlikely to be suitable for anything other than DB-API recording. + + """ + + Callable = object() + NoAttribute = object() + + # Py3K + #Natives = set([getattr(types, t) + # for t in dir(types) if not t.startswith('_')]). \ + # union([type(t) if not isinstance(t, type) + # else t for t in __builtins__.values()]).\ + # difference([getattr(types, t) + # for t in ('FunctionType', 'BuiltinFunctionType', + # 'MethodType', 'BuiltinMethodType', + # 'LambdaType', )]) + # Py2K + Natives = set([getattr(types, t) + for t in dir(types) if not t.startswith('_')]). \ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', 'UnboundMethodType',)]) + # end Py2K + + def __init__(self): + self.buffer = deque() + + def recorder(self, base): + return self.Recorder(self.buffer, base) + + def player(self): + return self.Player(self.buffer) + + class Recorder(object): + def __init__(self, buffer, subject): + self._buffer = buffer + self._subject = subject + + def __call__(self, *args, **kw): + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + + result = subject(*args, **kw) + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + @property + def _sqla_unwrap(self): + return self._subject + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + try: + result = type(subject).__getattribute__(subject, key) + except AttributeError: + buffer.append(ReplayableSession.NoAttribute) + raise + else: + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + class Player(object): + def __init__(self, buffer): + self._buffer = buffer + + def __call__(self, *args, **kw): + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + else: + return result + + @property + def _sqla_unwrap(self): + return None + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + elif result is ReplayableSession.NoAttribute: + raise AttributeError(key) + else: + return result + diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py new file mode 100644 index 000000000..1b24e73b7 --- /dev/null +++ b/lib/sqlalchemy/testing/entities.py @@ -0,0 +1,83 @@ +import sqlalchemy as sa +from sqlalchemy import exc as sa_exc + +_repr_stack = set() +class BasicEntity(object): + def __init__(self, **kw): + for key, value in kw.iteritems(): + setattr(self, key, value) + + def __repr__(self): + if id(self) in _repr_stack: + return object.__repr__(self) + _repr_stack.add(id(self)) + try: + return "%s(%s)" % ( + (self.__class__.__name__), + ', '.join(["%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith('_')])) + finally: + _repr_stack.remove(id(self)) + +_recursion_stack = set() +class ComparableEntity(BasicEntity): + def __hash__(self): + return hash(self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'Deep, sparse compare. + + Deeply compare two entities, following the non-None attributes of the + non-persisted object, if possible. + + """ + if other is self: + return True + elif not self.__class__ == other.__class__: + return False + + if id(self) in _recursion_stack: + return True + _recursion_stack.add(id(self)) + + try: + # pick the entity thats not SA persisted as the source + try: + self_key = sa.orm.attributes.instance_state(self).key + except sa.orm.exc.NO_STATE: + self_key = None + + if other is None: + a = self + b = other + elif self_key is not None: + a = other + b = self + else: + a = self + b = other + + for attr in a.__dict__.keys(): + if attr.startswith('_'): + continue + value = getattr(a, attr) + + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, '__iter__'): + if list(value) != list(battr): + return False + else: + if value is not None and value != battr: + return False + return True + finally: + _recursion_stack.remove(id(self)) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py new file mode 100644 index 000000000..ba2eebe4f --- /dev/null +++ b/lib/sqlalchemy/testing/exclusions.py @@ -0,0 +1,269 @@ +import operator +from nose import SkipTest +from sqlalchemy.util import decorator +from . import config +from sqlalchemy import util + + +def fails_if(predicate, reason=None): + predicate = _as_predicate(predicate) + + @decorator + def decorate(fn, *args, **kw): + if not predicate(): + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected (%s): %s " % ( + fn.__name__, predicate, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' (%s)" % + (fn.__name__, predicate)) + return decorate + +def skip_if(predicate, reason=None): + predicate = _as_predicate(predicate) + + @decorator + def decorate(fn, *args, **kw): + if predicate(): + if reason: + msg = "'%s' : %s" % ( + fn.__name__, + reason + ) + else: + msg = "'%s': %s" % ( + fn.__name__, predicate + ) + raise SkipTest(msg) + else: + return fn(*args, **kw) + return decorate + +def only_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return skip_if(NotPredicate(predicate), reason) + +def succeeds_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return fails_if(NotPredicate(predicate), reason) + +class Predicate(object): + @classmethod + def as_predicate(cls, predicate): + if isinstance(predicate, Predicate): + return predicate + elif isinstance(predicate, list): + return OrPredicate([cls.as_predicate(pred) for pred in predicate]) + elif isinstance(predicate, tuple): + return SpecPredicate(*predicate) + elif isinstance(predicate, basestring): + return SpecPredicate(predicate, None, None) + elif util.callable(predicate): + return LambdaPredicate(predicate) + else: + assert False, "unknown predicate type: %s" % predicate + +class SpecPredicate(Predicate): + def __init__(self, db, op=None, spec=None, description=None): + self.db = db + self.op = op + self.spec = spec + self.description = description + + _ops = { + '<': operator.lt, + '>': operator.gt, + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '>=': operator.ge, + 'in': operator.contains, + 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + } + + def __call__(self, engine=None): + if engine is None: + engine = config.db + + if "+" in self.db: + dialect, driver = self.db.split('+') + else: + dialect, driver = self.db, None + + if dialect and engine.name != dialect: + return False + if driver is not None and engine.driver != driver: + return False + + if self.op is not None: + assert driver is None, "DBAPI version specs not supported yet" + + version = _server_version(engine) + oper = hasattr(self.op, '__call__') and self.op \ + or self._ops[self.op] + return oper(version, self.spec) + else: + return True + + def _as_string(self, negate=False): + if self.description is not None: + return self.description + elif self.op is None: + if negate: + return "not %s" % self.db + else: + return "%s" % self.db + else: + if negate: + return "not %s %s %s" % ( + self.db, + self.op, + self.spec + ) + else: + return "%s %s %s" % ( + self.db, + self.op, + self.spec + ) + + def __str__(self): + return self._as_string() + +class LambdaPredicate(Predicate): + def __init__(self, lambda_, description=None, args=None, kw=None): + self.lambda_ = lambda_ + self.args = args or () + self.kw = kw or {} + if description: + self.description = description + elif lambda_.__doc__: + self.description = lambda_.__doc__ + else: + self.description = "custom function" + + def __call__(self): + return self.lambda_(*self.args, **self.kw) + + def _as_string(self, negate=False): + if negate: + return "not " + self.description + else: + return self.description + + def __str__(self): + return self._as_string() + +class NotPredicate(Predicate): + def __init__(self, predicate): + self.predicate = predicate + + def __call__(self, *arg, **kw): + return not self.predicate(*arg, **kw) + + def __str__(self): + return self.predicate._as_string(True) + +class OrPredicate(Predicate): + def __init__(self, predicates, description=None): + self.predicates = predicates + self.description = description + + def __call__(self, *arg, **kw): + for pred in self.predicates: + if pred(*arg, **kw): + self._str = pred + return True + return False + + _str = None + + def _eval_str(self, negate=False): + if self._str is None: + if negate: + conjunction = " and " + else: + conjunction = " or " + return conjunction.join(p._as_string(negate=negate) + for p in self.predicates) + else: + return self._str._as_string(negate=negate) + + def _negation_str(self): + if self.description is not None: + return "Not " + (self.description % {"spec": self._str}) + else: + return self._eval_str(negate=True) + + def _as_string(self, negate=False): + if negate: + return self._negation_str() + else: + if self.description is not None: + return self.description % {"spec": self._str} + else: + return self._eval_str() + + def __str__(self): + return self._as_string() + +_as_predicate = Predicate.as_predicate + +def _is_excluded(db, op, spec): + return SpecPredicate(db, op, spec)() + +def _server_version(engine): + """Return a server_version_info tuple.""" + + # force metadata to be retrieved + conn = engine.connect() + version = getattr(engine.dialect, 'server_version_info', ()) + conn.close() + return version + +def db_spec(*dbs): + return OrPredicate( + Predicate.as_predicate(db) for db in dbs + ) + +def open(fn): + return fn + +@decorator +def future(fn, *args, **kw): + return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature") + +def fails_on(db, reason): + return fails_if(SpecPredicate(db), reason) + +def fails_on_everything_except(*dbs): + return succeeds_if( + OrPredicate([ + SpecPredicate(db) for db in dbs + ]) + ) + +def skip(db, reason): + return skip_if(SpecPredicate(db), reason) + +def only_on(dbs, reason): + return only_if( + OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) + ) + + +def exclude(db, op, spec, reason): + return skip_if(SpecPredicate(db, op, spec), reason) + + +def against(*queries): + return OrPredicate([ + Predicate.as_predicate(query) + for query in queries + ])() diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py new file mode 100644 index 000000000..018276d4d --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures.py @@ -0,0 +1,334 @@ +from . import config +from . import assertions, schema +from .util import adict +from .engines import drop_all_tables +from .entities import BasicEntity, ComparableEntity +import sys +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta + +class TestBase(object): + # A sequence of database names to always run, regardless of the + # constraints below. + __whitelist__ = () + + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + def assert_(self, val, msg=None): + assert val, msg + +class TablesTest(TestBase): + + # 'once', None + run_setup_bind = 'once' + + # 'once', 'each', None + run_define_tables = 'once' + + # 'once', 'each', None + run_create_tables = 'once' + + # 'once', 'each', None + run_inserts = 'each' + + # 'each', None + run_deletes = 'each' + + # 'once', None + run_dispose_bind = None + + bind = None + metadata = None + tables = None + other = None + + @classmethod + def setup_class(cls): + cls._init_class() + + cls._setup_once_tables() + + cls._setup_once_inserts() + + @classmethod + def _init_class(cls): + if cls.run_define_tables == 'each': + if cls.run_create_tables == 'once': + cls.run_create_tables = 'each' + assert cls.run_inserts in ('each', None) + + if cls.other is None: + cls.other = adict() + + if cls.tables is None: + cls.tables = adict() + + if cls.bind is None: + setattr(cls, 'bind', cls.setup_bind()) + + if cls.metadata is None: + setattr(cls, 'metadata', sa.MetaData()) + + if cls.metadata.bind is None: + cls.metadata.bind = cls.bind + + @classmethod + def _setup_once_inserts(cls): + if cls.run_inserts == 'once': + cls._load_fixtures() + cls.insert_data() + + @classmethod + def _setup_once_tables(cls): + if cls.run_define_tables == 'once': + cls.define_tables(cls.metadata) + if cls.run_create_tables == 'once': + cls.metadata.create_all(cls.bind) + cls.tables.update(cls.metadata.tables) + + def _setup_each_tables(self): + if self.run_define_tables == 'each': + self.tables.clear() + if self.run_create_tables == 'each': + drop_all_tables(self.metadata, self.bind) + self.metadata.clear() + self.define_tables(self.metadata) + if self.run_create_tables == 'each': + self.metadata.create_all(self.bind) + self.tables.update(self.metadata.tables) + elif self.run_create_tables == 'each': + drop_all_tables(self.metadata, self.bind) + self.metadata.create_all(self.bind) + + def _setup_each_inserts(self): + if self.run_inserts == 'each': + self._load_fixtures() + self.insert_data() + + def _teardown_each_tables(self): + # no need to run deletes if tables are recreated on setup + if self.run_define_tables != 'each' and self.run_deletes == 'each': + for table in reversed(self.metadata.sorted_tables): + try: + table.delete().execute().close() + except sa.exc.DBAPIError, ex: + print >> sys.stderr, "Error emptying table %s: %r" % ( + table, ex) + + def setup(self): + self._setup_each_tables() + self._setup_each_inserts() + + def teardown(self): + self._teardown_each_tables() + + @classmethod + def _teardown_once_metadata_bind(cls): + if cls.run_create_tables: + drop_all_tables(cls.metadata, cls.bind) + + if cls.run_dispose_bind == 'once': + cls.dispose_bind(cls.bind) + + cls.metadata.bind = None + + if cls.run_setup_bind is not None: + cls.bind = None + + @classmethod + def teardown_class(cls): + cls._teardown_once_metadata_bind() + + @classmethod + def setup_bind(cls): + return config.db + + @classmethod + def dispose_bind(cls, bind): + if hasattr(bind, 'dispose'): + bind.dispose() + elif hasattr(bind, 'close'): + bind.close() + + @classmethod + def define_tables(cls, metadata): + pass + + @classmethod + def fixtures(cls): + return {} + + @classmethod + def insert_data(cls): + pass + + def sql_count_(self, count, fn): + self.assert_sql_count(self.bind, fn, count) + + def sql_eq_(self, callable_, statements, with_sequences=None): + self.assert_sql(self.bind, + callable_, statements, with_sequences) + + @classmethod + def _load_fixtures(cls): + """Insert rows as represented by the fixtures() method.""" + headers, rows = {}, {} + for table, data in cls.fixtures().iteritems(): + if len(data) < 2: + continue + if isinstance(table, basestring): + table = cls.tables[table] + headers[table] = data[0] + rows[table] = data[1:] + for table in cls.metadata.sorted_tables: + if table not in headers: + continue + cls.bind.execute( + table.insert(), + [dict(zip(headers[table], column_values)) + for column_values in rows[table]]) + + +class _ORMTest(object): + __requires__ = ('subqueries',) + + @classmethod + def teardown_class(cls): + sa.orm.session.Session.close_all() + sa.orm.clear_mappers() + +class ORMTest(_ORMTest, TestBase): + pass + +class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): + # 'once', 'each', None + run_setup_classes = 'once' + + # 'once', 'each', None + run_setup_mappers = 'each' + + classes = None + + @classmethod + def setup_class(cls): + cls._init_class() + + if cls.classes is None: + cls.classes = adict() + + cls._setup_once_tables() + cls._setup_once_classes() + cls._setup_once_mappers() + cls._setup_once_inserts() + + @classmethod + def teardown_class(cls): + cls._teardown_once_class() + cls._teardown_once_metadata_bind() + + def setup(self): + self._setup_each_tables() + self._setup_each_mappers() + self._setup_each_inserts() + + def teardown(self): + sa.orm.session.Session.close_all() + self._teardown_each_mappers() + self._teardown_each_tables() + + @classmethod + def _teardown_once_class(cls): + cls.classes.clear() + _ORMTest.teardown_class() + + + @classmethod + def _setup_once_classes(cls): + if cls.run_setup_classes == 'once': + cls._with_register_classes(cls.setup_classes) + + @classmethod + def _setup_once_mappers(cls): + if cls.run_setup_mappers == 'once': + cls._with_register_classes(cls.setup_mappers) + + def _setup_each_mappers(self): + if self.run_setup_mappers == 'each': + self._with_register_classes(self.setup_mappers) + + @classmethod + def _with_register_classes(cls, fn): + """Run a setup method, framing the operation with a Base class + that will catch new subclasses to be established within + the "classes" registry. + + """ + cls_registry = cls.classes + class FindFixture(type): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + return type.__init__(cls, classname, bases, dict_) + + + class _Base(object): + __metaclass__ = FindFixture + class Basic(BasicEntity, _Base): + pass + class Comparable(ComparableEntity, _Base): + pass + cls.Basic = Basic + cls.Comparable = Comparable + fn() + + def _teardown_each_mappers(self): + # some tests create mappers in the test bodies + # and will define setup_mappers as None - + # clear mappers in any case + if self.run_setup_mappers != 'once': + sa.orm.clear_mappers() + + @classmethod + def setup_classes(cls): + pass + + @classmethod + def setup_mappers(cls): + pass + +class DeclarativeMappedTest(MappedTest): + run_setup_classes = 'once' + run_setup_mappers = 'once' + + @classmethod + def _setup_once_tables(cls): + pass + + @classmethod + def _with_register_classes(cls, fn): + cls_registry = cls.classes + class FindFixtureDeclarative(DeclarativeMeta): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + return DeclarativeMeta.__init__( + cls, classname, bases, dict_) + class DeclarativeBasic(object): + __table_cls__ = schema.Table + _DeclBase = declarative_base(metadata=cls.metadata, + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic) + cls.DeclarativeBasic = _DeclBase + fn() + if cls.metadata.tables: + cls.metadata.create_all(config.db) diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py new file mode 100644 index 000000000..f5b8b827c --- /dev/null +++ b/lib/sqlalchemy/testing/pickleable.py @@ -0,0 +1,107 @@ +"""Classes used in pickling tests, need to be at the module level for unpickling.""" + +from . import fixtures + +class User(fixtures.ComparableEntity): + pass + +class Order(fixtures.ComparableEntity): + pass + +class Dingaling(fixtures.ComparableEntity): + pass + +class EmailUser(User): + pass + +class Address(fixtures.ComparableEntity): + pass + +# TODO: these are kind of arbitrary.... +class Child1(fixtures.ComparableEntity): + pass + +class Child2(fixtures.ComparableEntity): + pass + +class Parent(fixtures.ComparableEntity): + pass + +class Screen(object): + def __init__(self, obj, parent=None): + self.obj = obj + self.parent = parent + +class Foo(object): + def __init__(self, moredata): + self.data = 'im data' + self.stuff = 'im stuff' + self.moredata = moredata + __hash__ = object.__hash__ + def __eq__(self, other): + return other.data == self.data and \ + other.stuff == self.stuff and \ + other.moredata == self.moredata + + +class Bar(object): + def __init__(self, x, y): + self.x = x + self.y = y + __hash__ = object.__hash__ + def __eq__(self, other): + return other.__class__ is self.__class__ and \ + other.x == self.x and \ + other.y == self.y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + +class OldSchool: + def __init__(self, x, y): + self.x = x + self.y = y + def __eq__(self, other): + return other.__class__ is self.__class__ and \ + other.x == self.x and \ + other.y == self.y + +class OldSchoolWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + +class BarWithoutCompare(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class NotComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + +class BrokenComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + raise NotImplementedError + diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/__init__.py diff --git a/lib/sqlalchemy/testing/plugin/config.py b/lib/sqlalchemy/testing/plugin/config.py new file mode 100644 index 000000000..08b9753dc --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/config.py @@ -0,0 +1,186 @@ +"""Option and configuration implementations, run by the nose plugin +on test suite startup.""" + +import time +import warnings +import sys +import re + +logging = None +db = None +db_label = None +db_url = None +db_opts = {} +options = None +file_config = None + +def _log(option, opt_str, value, parser): + global logging + if not logging: + import logging + logging.basicConfig() + + if opt_str.endswith('-info'): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith('-debug'): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + print "Available --db options (use --dburi to override)" + for macro in sorted(file_config.options('db')): + print "%20s\t%s" % (macro, file_config.get('db', macro)) + sys.exit(0) + +def _server_side_cursors(options, opt_str, value, parser): + db_opts['server_side_cursors'] = True + +def _zero_timeout(options, opt_str, value, parser): + warnings.warn("--zero-timeout testing option is now on in all cases") + +def _engine_strategy(options, opt_str, value, parser): + if value: + db_opts['strategy'] = value + +pre_configure = [] +post_configure = [] + +def _setup_options(opt, file_config): + global options + from sqlalchemy.testing import config + config.options = options = opt +pre_configure.append(_setup_options) + +def _monkeypatch_cdecimal(options, file_config): + if options.cdecimal: + import sys + import cdecimal + sys.modules['decimal'] = cdecimal +pre_configure.append(_monkeypatch_cdecimal) + +def _engine_uri(options, file_config): + global db_label, db_url + db_label = 'sqlite' + if options.dburi: + db_url = options.dburi + db_label = db_url[:db_url.index(':')] + elif options.db: + db_label = options.db + db_url = None + + if db_url is None: + if db_label not in file_config.options('db'): + raise RuntimeError( + "Unknown engine. Specify --dbs for known engines.") + db_url = file_config.get('db', db_label) +post_configure.append(_engine_uri) + +def _require(options, file_config): + if not(options.require or + (file_config.has_section('require') and + file_config.items('require'))): + return + + try: + import pkg_resources + except ImportError: + raise RuntimeError("setuptools is required for version requirements") + + cmdline = [] + for requirement in options.require: + pkg_resources.require(requirement) + cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0]) + + if file_config.has_section('require'): + for label, requirement in file_config.items('require'): + if not label == db_label or label.startswith('%s.' % db_label): + continue + seen = [c for c in cmdline if requirement.startswith(c)] + if seen: + continue + pkg_resources.require(requirement) +post_configure.append(_require) + +def _engine_pool(options, file_config): + if options.mockpool: + from sqlalchemy import pool + db_opts['poolclass'] = pool.AssertionPool +post_configure.append(_engine_pool) + +def _create_testing_engine(options, file_config): + from sqlalchemy.testing import engines, config + from sqlalchemy import testing + global db + config.db = testing.db = db = engines.testing_engine(db_url, db_opts) + config.db_opts = db_opts + config.db_url = db_url + +post_configure.append(_create_testing_engine) + +def _prep_testing_database(options, file_config): + from sqlalchemy.testing import engines + from sqlalchemy import schema + + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + +post_configure.append(_prep_testing_database) + +def _set_table_options(options, file_config): + from sqlalchemy.testing import schema + + table_options = schema.table_options + for spec in options.tableopts: + key, value = spec.split('=') + table_options[key] = value + + if options.mysql_engine: + table_options['mysql_engine'] = options.mysql_engine +post_configure.append(_set_table_options) + +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm import unitofwork, session, mapper, dependency + from sqlalchemy.util import topological + from sqlalchemy.testing.util import RandomSet + topological.set = unitofwork.set = session.set = mapper.set = \ + dependency.set = RandomSet +post_configure.append(_reverse_topological) + +def _requirements(options, file_config): + from sqlalchemy.testing import config + from sqlalchemy import testing + requirement_cls = file_config.get('sqla_testing', "requirement_cls") + + modname, clsname = requirement_cls.split(":") + + # importlib.import_module() only introduced in 2.7, a little + # late + mod = __import__(modname) + for component in modname.split(".")[1:]: + mod = getattr(mod, component) + req_cls = getattr(mod, clsname) + config.requirements = testing.requires = req_cls(db, config) + +post_configure.append(_requirements) + +def _setup_profiling(options, file_config): + from sqlalchemy.testing import profiling + profiling._profile_stats = profiling.ProfileStatsFile( + file_config.get('sqla_testing', 'profile_file')) + +post_configure.append(_setup_profiling) + diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py new file mode 100644 index 000000000..c9e12c305 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -0,0 +1,199 @@ +"""Enhance nose with extra options and behaviors for running SQLAlchemy tests. + +This module is imported relative to the "plugins" package as a top level +package by the sqla_nose.py runner, so that the plugin can be loaded with +the rest of nose including the coverage plugin before any of SQLAlchemy itself +is imported, so that coverage works. + +When third party libraries use this library, it can be imported +normally as "from sqlalchemy.testing.plugin import noseplugin". + +""" +import os +import ConfigParser + +from nose.plugins import Plugin +from nose import SkipTest +from . import config + +from .config import _log, _list_dbs, _zero_timeout, \ + _engine_strategy, _server_side_cursors, pre_configure,\ + post_configure + +# late imports +fixtures = None +engines = None +exclusions = None +warnings = None +profiling = None +assertions = None +requirements = None +util = None +file_config = None + +class NoseSQLAlchemy(Plugin): + """ + Handles the setup and extra properties required for testing SQLAlchemy + """ + enabled = True + + name = 'sqla_testing' + score = 100 + + def options(self, parser, env=os.environ): + Plugin.options(self, parser, env) + opt = parser.add_option + opt("--log-info", action="callback", type="string", callback=_log, + help="turn on info logging for <LOG> (multiple OK)") + opt("--log-debug", action="callback", type="string", callback=_log, + help="turn on debug logging for <LOG> (multiple OK)") + opt("--require", action="append", dest="require", default=[], + help="require a particular driver or module version (multiple OK)") + opt("--db", action="store", dest="db", default="sqlite", + help="Use prefab database uri") + opt('--dbs', action='callback', callback=_list_dbs, + help="List available prefab dbs") + opt("--dburi", action="store", dest="dburi", + help="Database uri (overrides --db)") + opt("--dropfirst", action="store_true", dest="dropfirst", + help="Drop all tables in the target database first") + opt("--mockpool", action="store_true", dest="mockpool", + help="Use mock pool (asserts only one connection used)") + opt("--zero-timeout", action="callback", callback=_zero_timeout, + help="Set pool_timeout to zero, applies to QueuePool only") + opt("--low-connections", action="store_true", dest="low_connections", + help="Use a low number of distinct connections - i.e. for Oracle TNS" + ) + opt("--enginestrategy", action="callback", type="string", + callback=_engine_strategy, + help="Engine strategy (plain or threadlocal, defaults to plain)") + opt("--reversetop", action="store_true", dest="reversetop", default=False, + help="Use a random-ordering set implementation in the ORM (helps " + "reveal dependency issues)") + opt("--with-cdecimal", action="store_true", dest="cdecimal", default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' for all tests") + opt("--unhashable", action="store_true", dest="unhashable", default=False, + help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") + opt("--noncomparable", action="store_true", dest="noncomparable", default=False, + help="Disallow SQLAlchemy from performing == on mapped test objects.") + opt("--truthless", action="store_true", dest="truthless", default=False, + help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") + opt("--serverside", action="callback", callback=_server_side_cursors, + help="Turn on server side cursors for PG") + opt("--mysql-engine", action="store", dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, default is " + "a db-default/InnoDB combo.") + opt("--table-option", action="append", dest="tableopts", default=[], + help="Add a dialect-specific table option, key=value") + opt("--write-profiles", action="store_true", dest="write_profiles", default=False, + help="Write/update profiling data.") + global file_config + file_config = ConfigParser.ConfigParser() + file_config.read(['setup.cfg', 'test.cfg', os.path.expanduser('~/.satest.cfg')]) + config.file_config = file_config + + def configure(self, options, conf): + Plugin.configure(self, options, conf) + self.options = options + for fn in pre_configure: + fn(self.options, file_config) + + def begin(self): + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(self.options, file_config) + + # late imports, has to happen after config as well + # as nose plugins like coverage + global util, fixtures, engines, exclusions, \ + assertions, warnings, profiling + from sqlalchemy.testing import fixtures, engines, exclusions, \ + assertions, warnings, profiling + from sqlalchemy import util + + def describeTest(self, test): + return "" + + def wantFunction(self, fn): + if fn.__module__.startswith('test.lib') or \ + fn.__module__.startswith('test.bootstrap'): + return False + + def wantClass(self, cls): + """Return true if you want the main test selector to collect + tests from this class, false if you don't, and None if you don't + care. + + :Parameters: + cls : class + The class being examined by the selector + + """ + + if not issubclass(cls, fixtures.TestBase): + return False + elif cls.__name__.startswith('_'): + return False + else: + return True + + def _do_skips(self, cls): + from sqlalchemy.testing import config + if hasattr(cls, '__requires__'): + def test_suite(): + return 'ok' + test_suite.__name__ = cls.__name__ + for requirement in cls.__requires__: + check = getattr(config.requirements, requirement) + check(test_suite)() + + if cls.__unsupported_on__: + spec = exclusions.db_spec(*cls.__unsupported_on__) + if spec(config.db): + raise SkipTest( + "'%s' unsupported on DB implementation '%s'" % ( + cls.__name__, config.db.name) + ) + + if getattr(cls, '__only_on__', None): + spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) + if not spec(config.db): + raise SkipTest( + "'%s' unsupported on DB implementation '%s'" % ( + cls.__name__, config.db.name) + ) + + if getattr(cls, '__skip_if__', False): + for c in getattr(cls, '__skip_if__'): + if c(): + raise SkipTest("'%s' skipped by %s" % ( + cls.__name__, c.__name__) + ) + + for db, op, spec in getattr(cls, '__excluded_on__', ()): + exclusions.exclude(db, op, spec, + "'%s' unsupported on DB %s version %s" % ( + cls.__name__, config.db.name, + exclusions._server_version(config.db))) + + def beforeTest(self, test): + warnings.resetwarnings() + profiling._current_test = test.id() + + def afterTest(self, test): + engines.testing_reaper._after_test_ctx() + warnings.resetwarnings() + + def startContext(self, ctx): + if not isinstance(ctx, type) \ + or not issubclass(ctx, fixtures.TestBase): + return + self._do_skips(ctx) + + def stopContext(self, ctx): + if not isinstance(ctx, type) \ + or not issubclass(ctx, fixtures.TestBase): + return + engines.testing_reaper._stop_test_ctx() + if not config.options.low_connections: + assertions.global_cleanup_assertions() diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py new file mode 100644 index 000000000..be32b1d1d --- /dev/null +++ b/lib/sqlalchemy/testing/profiling.py @@ -0,0 +1,292 @@ +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" + +import os +import sys +from .util import gc_collect, decorator +from . import config +from nose import SkipTest +import pstats +import time +import collections +from sqlalchemy import util +try: + import cProfile +except ImportError: + cProfile = None +from sqlalchemy.util.compat import jython, pypy, win32 + +_current_test = None + +def profiled(target=None, **target_opts): + """Function profiling. + + @profiled('label') + or + @profiled('label', report=True, sort=('calls',), limit=20) + + Enables profiling for a function when 'label' is targetted for + profiling. Report options can be supplied, and override the global + configuration and command-line options. + """ + + profile_config = {'targets': set(), + 'report': True, + 'print_callers': False, + 'print_callees': False, + 'graphic': False, + 'sort': ('time', 'calls'), + 'limit': None} + if target is None: + target = 'anonymous_target' + + filename = "%s.prof" % target + + @decorator + def decorate(fn, *args, **kw): + elapsed, load_stats, result = _profile( + filename, fn, *args, **kw) + + graphic = target_opts.get('graphic', profile_config['graphic']) + if graphic: + os.system("runsnake %s" % filename) + else: + report = target_opts.get('report', profile_config['report']) + if report: + sort_ = target_opts.get('sort', profile_config['sort']) + limit = target_opts.get('limit', profile_config['limit']) + print ("Profile report for target '%s' (%s)" % ( + target, filename) + ) + + stats = load_stats() + stats.sort_stats(*sort_) + if limit: + stats.print_stats(limit) + else: + stats.print_stats() + + print_callers = target_opts.get('print_callers', + profile_config['print_callers']) + if print_callers: + stats.print_callers() + + print_callees = target_opts.get('print_callees', + profile_config['print_callees']) + if print_callees: + stats.print_callees() + + os.unlink(filename) + return result + return decorate + + +class ProfileStatsFile(object): + """"Store per-platform/fn profiling results in a file. + + We're still targeting Py2.5, 2.4 on 0.7 with no dependencies, + so no json lib :( need to roll something silly + + """ + def __init__(self, filename): + self.write = config.options is not None and config.options.write_profiles + self.fname = os.path.abspath(filename) + self.short_fname = os.path.split(self.fname)[-1] + self.data = collections.defaultdict(lambda: collections.defaultdict(dict)) + self._read() + if self.write: + # rewrite for the case where features changed, + # etc. + self._write() + + @util.memoized_property + def platform_key(self): + + dbapi_key = config.db.name + "_" + config.db.driver + + # keep it at 2.7, 3.1, 3.2, etc. for now. + py_version = '.'.join([str(v) for v in sys.version_info[0:2]]) + + platform_tokens = [py_version] + platform_tokens.append(dbapi_key) + if jython: + platform_tokens.append("jython") + if pypy: + platform_tokens.append("pypy") + if win32: + platform_tokens.append("win") + _has_cext = config.requirements._has_cextensions() + platform_tokens.append(_has_cext and "cextensions" or "nocextensions") + return "_".join(platform_tokens) + + def has_stats(self): + test_key = _current_test + return test_key in self.data and self.platform_key in self.data[test_key] + + def result(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + + if 'counts' not in per_platform: + per_platform['counts'] = counts = [] + else: + counts = per_platform['counts'] + + if 'current_count' not in per_platform: + per_platform['current_count'] = current_count = 0 + else: + current_count = per_platform['current_count'] + + has_count = len(counts) > current_count + + if not has_count: + counts.append(callcount) + if self.write: + self._write() + result = None + else: + result = per_platform['lineno'], counts[current_count] + per_platform['current_count'] += 1 + return result + + + def _header(self): + return \ + "# %s\n"\ + "# This file is written out on a per-environment basis.\n"\ + "# For each test in aaa_profiling, the corresponding function and \n"\ + "# environment is located within this file. If it doesn't exist,\n"\ + "# the test is skipped.\n"\ + "# If a callcount does exist, it is compared to what we received. \n"\ + "# assertions are raised if the counts do not match.\n"\ + "# \n"\ + "# To add a new callcount test, apply the function_call_count \n"\ + "# decorator and re-run the tests using the --write-profiles option - \n"\ + "# this file will be rewritten including the new count.\n"\ + "# \n"\ + "" % (self.fname) + + def _read(self): + try: + profile_f = open(self.fname) + except IOError: + return + for lineno, line in enumerate(profile_f): + line = line.strip() + if not line or line.startswith("#"): + continue + + test_key, platform_key, counts = line.split() + per_fn = self.data[test_key] + per_platform = per_fn[platform_key] + per_platform['counts'] = [int(count) for count in counts.split(",")] + per_platform['lineno'] = lineno + 1 + per_platform['current_count'] = 0 + profile_f.close() + + def _write(self): + print("Writing profile file %s" % self.fname) + profile_f = open(self.fname, "w") + profile_f.write(self._header()) + for test_key in sorted(self.data): + + per_fn = self.data[test_key] + profile_f.write("\n# TEST: %s\n\n" % test_key) + for platform_key in sorted(per_fn): + per_platform = per_fn[platform_key] + profile_f.write( + "%s %s %s\n" % ( + test_key, + platform_key, ",".join(str(count) for count in per_platform['counts']) + ) + ) + profile_f.close() + +from sqlalchemy.util.compat import update_wrapper + +def function_call_count(variance=0.05): + """Assert a target for a test case's function call count. + + The main purpose of this assertion is to detect changes in + callcounts for various functions - the actual number is not as important. + Callcounts are stored in a file keyed to Python version and OS platform + information. This file is generated automatically for new tests, + and versioned so that unexpected changes in callcounts will be detected. + + """ + + def decorate(fn): + def wrap(*args, **kw): + + + if cProfile is None: + raise SkipTest("cProfile is not installed") + + if not _profile_stats.has_stats() and not _profile_stats.write: + # run the function anyway, to support dependent tests + # (not a great idea but we have these in test_zoomark) + fn(*args, **kw) + raise SkipTest("No profiling stats available on this " + "platform for this function. Run tests with " + "--write-profiles to add statistics to %s for " + "this platform." % _profile_stats.short_fname) + + gc_collect() + + + timespent, load_stats, fn_result = _profile( + fn, *args, **kw + ) + stats = load_stats() + callcount = stats.total_calls + + expected = _profile_stats.result(callcount) + if expected is None: + expected_count = None + else: + line_no, expected_count = expected + + print("Pstats calls: %d Expected %s" % ( + callcount, + expected_count + ) + ) + stats.print_stats() + #stats.print_callers() + + if expected_count: + deviance = int(callcount * variance) + if abs(callcount - expected_count) > deviance: + raise AssertionError( + "Adjusted function call count %s not within %s%% " + "of expected %s. (Delete line %d of file %s to regenerate " + "this callcount, when tests are run with --write-profiles.)" + % ( + callcount, (variance * 100), + expected_count, line_no, + _profile_stats.fname)) + return fn_result + return update_wrapper(wrap, fn) + return decorate + + +def _profile(fn, *args, **kw): + filename = "%s.prof" % fn.__name__ + + def load_stats(): + st = pstats.Stats(filename) + os.unlink(filename) + return st + + began = time.time() + cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), + filename=filename) + ended = time.time() + + return ended - began, load_stats, locals()['result'] + diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py new file mode 100644 index 000000000..eca883d4e --- /dev/null +++ b/lib/sqlalchemy/testing/requirements.py @@ -0,0 +1,38 @@ +"""Global database feature support policy. + +Provides decorators to mark tests requiring specific feature support from the +target database. + +""" + +from .exclusions import \ + skip, \ + skip_if,\ + only_if,\ + only_on,\ + fails_on,\ + fails_on_everything_except,\ + fails_if,\ + SpecPredicate,\ + against + +def no_support(db, reason): + return SpecPredicate(db, description=reason) + +def exclude(db, op, spec, description=None): + return SpecPredicate(db, op, spec, description=description) + + +def _chain_decorators_on(*decorators): + def decorate(fn): + for decorator in reversed(decorators): + fn = decorator(fn) + return fn + return decorate + +class Requirements(object): + def __init__(self, db, config): + self.db = db + self.config = config + + diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py new file mode 100644 index 000000000..03da78c64 --- /dev/null +++ b/lib/sqlalchemy/testing/schema.py @@ -0,0 +1,85 @@ +"""Enhanced versions of schema.Table and schema.Column which establish +desired state for different backends. +""" + +from . import exclusions +from sqlalchemy import schema, event +from . import config + +__all__ = 'Table', 'Column', + +table_options = {} + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k, kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + kw.update(table_options) + + if exclusions.against('mysql'): + if 'mysql_engine' not in kw and 'mysql_type' not in kw: + if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: + kw['mysql_engine'] = 'InnoDB' + else: + kw['mysql_engine'] = 'MyISAM' + + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around seleting self-refs too. + if exclusions.against('firebird'): + table_name = args[0] + unpack = (config.db.dialect. + identifier_preparer.unformat_identifiers) + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [fk + for col in args if isinstance(col, schema.Column) + for fk in col.foreign_keys] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + if name == table_name: + if fk.ondelete is None: + fk.ondelete = 'CASCADE' + if fk.onupdate is None: + fk.onupdate = 'CASCADE' + + return schema.Table(*args, **kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k, kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + col = schema.Column(*args, **kw) + if 'test_needs_autoincrement' in test_opts and \ + kw.get('primary_key', False) and \ + exclusions.against('firebird', 'oracle'): + def add_seq(c, tbl): + c._init_items( + schema.Sequence(_truncate_name( + config.db.dialect, tbl.name + '_' + c.name + '_seq'), + optional=True) + ) + event.listen(col, 'after_parent_attach', add_seq, propagate=True) + return col + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return name[0:max(dialect.max_identifier_length - 6, 0)] + \ + "_" + hex(hash(name) % 64)[2:] + else: + return name + diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/testing/suite/__init__.py diff --git a/lib/sqlalchemy/testing/suite/requirements.py b/lib/sqlalchemy/testing/suite/requirements.py new file mode 100644 index 000000000..5eda39b2b --- /dev/null +++ b/lib/sqlalchemy/testing/suite/requirements.py @@ -0,0 +1,24 @@ +from ..requirements import Requirements +from .. import exclusions + + +class SuiteRequirements(Requirements): + + @property + def create_table(self): + """target platform can emit basic CreateTable DDL.""" + + return exclusions.open + + @property + def drop_table(self): + """target platform can emit basic DropTable DDL.""" + + return exclusions.open + + @property + def autoincrement_insert(self): + """target platform generates new surrogate integer primary key values + when insert() is executed, excluding the pk column.""" + + return exclusions.open diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py new file mode 100644 index 000000000..1285c4196 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -0,0 +1,48 @@ +from .. import fixtures, config, util +from ..config import requirements +from ..assertions import eq_ + +from sqlalchemy import Table, Column, Integer, String + + +class TableDDLTest(fixtures.TestBase): + + def _simple_fixture(self): + return Table('test_table', self.metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + def _simple_roundtrip(self): + with config.db.begin() as conn: + conn.execute("insert into test_table(id, data) values " + "(1, 'some data')") + result = conn.execute("select id, data from test_table") + eq_( + result.first(), + (1, 'some data') + ) + + + @requirements.create_table + @util.provide_metadata + def test_create_table(self): + table = self._simple_fixture() + table.create( + config.db, checkfirst=False + ) + self._simple_roundtrip() + + + @requirements.drop_table + @util.provide_metadata + def test_drop_table(self): + table = self._simple_fixture() + table.create( + config.db, checkfirst=False + ) + table.drop( + config.db, checkfirst=False + ) + +__all__ = ('TableDDLTest',)
\ No newline at end of file diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_reflection.py diff --git a/lib/sqlalchemy/testing/suite/test_sequencing.py b/lib/sqlalchemy/testing/suite/test_sequencing.py new file mode 100644 index 000000000..7b09ecb76 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_sequencing.py @@ -0,0 +1,36 @@ +from .. import fixtures, config, util +from ..config import requirements +from ..assertions import eq_ + +from sqlalchemy import Table, Column, Integer, String + + +class InsertSequencingTest(fixtures.TablesTest): + run_deletes = 'each' + + @classmethod + def define_tables(cls, metadata): + Table('plain_pk', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + def _assert_round_trip(self, table): + row = config.db.execute(table.select()).first() + eq_( + row, + (1, "some data") + ) + + @requirements.autoincrement_insert + def test_autoincrement_on_insert(self): + + config.db.execute( + self.tables.plain_pk.insert(), + data="some data" + ) + self._assert_round_trip(self.tables.plain_pk) + + + +__all__ = ('InsertSequencingTest',)
\ No newline at end of file diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py new file mode 100644 index 000000000..625b9e6a5 --- /dev/null +++ b/lib/sqlalchemy/testing/util.py @@ -0,0 +1,196 @@ +from sqlalchemy.util import jython, pypy, defaultdict, decorator +from sqlalchemy.util.compat import decimal + +import gc +import time +import random +import sys +import types + +if jython: + def jython_gc_collect(*args): + """aggressive gc.collect for tests.""" + gc.collect() + time.sleep(0.1) + gc.collect() + gc.collect() + return 0 + + # "lazy" gc, for VM's that don't GC on refcount == 0 + lazy_gc = jython_gc_collect +elif pypy: + def pypy_gc_collect(*args): + gc.collect() + gc.collect() + lazy_gc = pypy_gc_collect +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + def lazy_gc(): + pass + +def picklers(): + picklers = set() + # Py2K + try: + import cPickle + picklers.add(cPickle) + except ImportError: + pass + # end Py2K + import pickle + picklers.add(pickle) + + # yes, this thing needs this much testing + for pickle_ in picklers: + for protocol in -1, 0, 1, 2: + yield pickle_.loads, lambda d: pickle_.dumps(d, protocol) + + +def round_decimal(value, prec): + if isinstance(value, float): + return round(value, prec) + + # can also use shift() here but that is 2.6 only + return (value * decimal.Decimal("1" + "0" * prec) + ).to_integral(decimal.ROUND_FLOOR) / \ + pow(10, prec) + +class RandomSet(set): + def __iter__(self): + l = list(set.__iter__(self)) + random.shuffle(l) + return iter(l) + + def pop(self): + index = random.randint(0, len(self) - 1) + item = list(set.__iter__(self))[index] + self.remove(item) + return item + + def union(self, other): + return RandomSet(set.union(self, other)) + + def difference(self, other): + return RandomSet(set.difference(self, other)) + + def intersection(self, other): + return RandomSet(set.intersection(self, other)) + + def copy(self): + return RandomSet(self) + +def conforms_partial_ordering(tuples, sorted_elements): + """True if the given sorting conforms to the given partial ordering.""" + + deps = defaultdict(set) + for parent, child in tuples: + deps[parent].add(child) + for i, node in enumerate(sorted_elements): + for n in sorted_elements[i:]: + if node in deps[n]: + return False + else: + return True + +def all_partial_orderings(tuples, elements): + edges = defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + def _all_orderings(elements): + + if len(elements) == 1: + yield list(elements) + else: + for elem in elements: + subset = set(elements).difference([elem]) + if not subset.intersection(edges[elem]): + for sub_ordering in _all_orderings(subset): + yield [elem] + sub_ordering + + return iter(_all_orderings(elements)) + + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + This function should be phased out as much as possible + in favor of @decorator. Tests that "generate" many named tests + should be modernized. + + """ + try: + fn.__name__ = name + except TypeError: + fn = types.FunctionType(fn.func_code, fn.func_globals, name, + fn.func_defaults, fn.func_closure) + return fn + + + +def run_as_contextmanager(ctx, fn, *arg, **kw): + """Run the given function under the given contextmanager, + simulating the behavior of 'with' to support older + Python versions. + + """ + + obj = ctx.__enter__() + try: + result = fn(obj, *arg, **kw) + ctx.__exit__(None, None, None) + return result + except: + exc_info = sys.exc_info() + raise_ = ctx.__exit__(*exc_info) + if raise_ is None: + raise + else: + return raise_ + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return set([tuple(row) for row in results]) + + +def fail(msg): + assert False, msg + + +@decorator +def provide_metadata(fn, *args, **kw): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from . import config + from sqlalchemy import schema + + metadata = schema.MetaData(config.db) + self = args[0] + prev_meta = getattr(self, 'metadata', None) + self.metadata = metadata + try: + return fn(*args, **kw) + finally: + metadata.drop_all() + self.metadata = prev_meta + +class adict(dict): + """Dict keys available as attributes. Shadows.""" + def __getattribute__(self, key): + try: + return self[key] + except KeyError: + return dict.__getattribute__(self, key) + + def get_all(self, *keys): + return tuple([self[key] for key in keys]) + + diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py new file mode 100644 index 000000000..799fca128 --- /dev/null +++ b/lib/sqlalchemy/testing/warnings.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +import warnings +from sqlalchemy import exc as sa_exc +from sqlalchemy import util + +def testing_warn(msg, stacklevel=3): + """Replaces sqlalchemy.util.warn during tests.""" + + filename = "sqlalchemy.testing.warnings" + lineno = 1 + if isinstance(msg, basestring): + warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno) + else: + warnings.warn_explicit(msg, filename, lineno) + +def resetwarnings(): + """Reset warning behavior to testing defaults.""" + + util.warn = util.langhelpers.warn = testing_warn + + warnings.filterwarnings('ignore', + category=sa_exc.SAPendingDeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SAWarning) + +def assert_warnings(fn, warnings): + """Assert that each of the given warnings are emitted by fn.""" + + from .assertions import eq_, emits_warning + + canary = [] + orig_warn = util.warn + def capture_warnings(*args, **kw): + orig_warn(*args, **kw) + popwarn = warnings.pop(0) + canary.append(popwarn) + eq_(args[0], popwarn) + util.warn = util.langhelpers.warn = capture_warnings + + result = emits_warning()(fn)() + assert canary, "No warning was emitted" + return result |