diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
22 files changed, 3058 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py new file mode 100644 index 000000000..415705e93 --- /dev/null +++ b/lib/sqlalchemy/testing/__init__.py @@ -0,0 +1,21 @@ +from __future__ import absolute_import + +from .warnings import testing_warn, assert_warnings, resetwarnings + +from . import config + +from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ + fails_on, fails_on_everything_except, skip, only_on, exclude, against,\ + _server_version + +from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ + eq_, ne_, is_, is_not_, startswith_, assert_raises, \ + assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults + +from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict + +crashes = skip + +from .config import db, requirements as requires + + diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py new file mode 100644 index 000000000..1e8559c1a --- /dev/null +++ b/lib/sqlalchemy/testing/assertions.py @@ -0,0 +1,349 @@ +from __future__ import absolute_import + +from . import util as testutil +from sqlalchemy import pool, orm, util +from sqlalchemy.engine import default +from sqlalchemy import exc as sa_exc +from sqlalchemy.util import decorator +from sqlalchemy import types as sqltypes, schema +import warnings +import re +from .warnings import resetwarnings +from .exclusions import db_spec, _is_excluded +from . import assertsql +from . import config +import itertools +from .util import fail + +def emits_warning(*messages): + """Mark a test as emitting a warning. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + # TODO: it would be nice to assert that a named warning was + # emitted. should work with some monkeypatching of warnings, + # and may work on non-CPython if they keep to the spirit of + # warnings.showwarning's docstring. + # - update: jython looks ok, it uses cpython's module + + @decorator + def decorate(fn, *args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SAWarning)) + else: + filters.extend(dict(action='ignore', + message=message, + category=sa_exc.SAWarning) + for message in messages) + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return decorate + +def emits_warning_on(db, *warnings): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + spec = db_spec(db) + + @decorator + def decorate(fn, *args, **kw): + if isinstance(db, basestring): + if not spec(config.db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + else: + if not _is_excluded(*db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + return decorate + + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + """ + + @decorator + def decorate(fn, *args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SADeprecationWarning)) + else: + filters.extend( + [dict(action='ignore', + message=message, + category=sa_exc.SADeprecationWarning) + for message in + [(m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages]]) + + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return decorate + + + +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + + testutil.lazy_gc() + assert not pool._refs, str(pool._refs) + + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + +def is_not_(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, fragment) + +def assert_raises(except_cls, callable_, *args, **kw): + try: + callable_(*args, **kw) + success = False + except except_cls: + success = True + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + try: + callable_(*args, **kwargs) + assert False, "Callable did not raise an exception" + except except_cls, e: + assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e) + print unicode(e).encode('utf-8') + + +class AssertsCompiledSQL(object): + def assert_compile(self, clause, result, params=None, + checkparams=None, dialect=None, + checkpositional=None, + use_default_dialect=False, + allow_dialect_select=False): + if use_default_dialect: + dialect = default.DefaultDialect() + elif dialect == None and not allow_dialect_select: + dialect = getattr(self, '__dialect__', None) + if dialect == 'default': + dialect = default.DefaultDialect() + elif dialect is None: + dialect = config.db.dialect + + kw = {} + if params is not None: + kw['column_keys'] = params.keys() + + if isinstance(clause, orm.Query): + context = clause._compile_context() + context.statement.use_labels = True + clause = context.statement + + c = clause.compile(dialect=dialect, **kw) + + param_str = repr(getattr(c, 'params', {})) + # Py3K + #param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + + print "\nSQL String:\n" + str(c) + param_str + + cc = re.sub(r'[\n\t]', '', str(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + eq_(c.construct_params(params), checkparams) + if checkpositional is not None: + p = c.construct_params(params) + eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table, strict_types=False): + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + + if strict_types: + assert type(reflected_c.type) is type(c.type), \ + "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + else: + self.assert_types_base(reflected_c, c) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + if c.server_default: + assert isinstance(reflected_c.server_default, + schema.FetchedValue) + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None + + def assert_types_base(self, c1, c2): + assert c1.type._compare_type_affinity(c2.type),\ + "On column %r, type '%s' doesn't correspond to type '%s'" % \ + (c1.name, c1.type, c2.type) + +class AssertsExecutionResults(object): + def assert_result(self, result, class_, *objects): + result = list(result) + print repr(result) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list): + self.assert_(len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + class_.__name__) + for i in range(0, len(list)): + self.assert_row(class_, result[i], list[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_(rowobj.__class__ is class_, + "item class is not " + repr(class_)) + for key, value in desc.iteritems(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_(getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" % ( + key, getattr(rowobj, key), value)) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class immutabledict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = set([immutabledict(e) for e in expected]) + + for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): + fail('Unexpected type "%s", expected "%s"' % ( + type(wrong).__name__, cls.__name__)) + + if len(found) != len(expected): + fail('Unexpected object count "%s", expected "%s"' % ( + len(found), len(expected))) + + NOVALUE = object() + def _compare_item(obj, spec): + for key, value in spec.iteritems(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1]) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." % ( + cls.__name__, repr(expected_item))) + return True + + def assert_sql_execution(self, db, callable_, *rules): + assertsql.asserter.add_rules(rules) + try: + callable_() + assertsql.asserter.statement_complete() + finally: + assertsql.asserter.clear_rules() + + def assert_sql(self, db, callable_, list_, with_sequences=None): + if with_sequences is not None and config.db.dialect.supports_sequences: + rules = with_sequences + else: + rules = list_ + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf(*[ + assertsql.ExactSQL(k, v) for k, v in rule.iteritems() + ]) + else: + newrule = assertsql.ExactSQL(*rule) + newrules.append(newrule) + + self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + + diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py new file mode 100644 index 000000000..897f4b3b1 --- /dev/null +++ b/lib/sqlalchemy/testing/assertsql.py @@ -0,0 +1,316 @@ + +from sqlalchemy.interfaces import ConnectionProxy +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.base import Connection +from sqlalchemy import util +import re + +class AssertRule(object): + + def process_execute(self, clauseelement, *multiparams, **params): + pass + + def process_cursor_execute(self, statement, parameters, context, + executemany): + pass + + def is_consumed(self): + """Return True if this rule has been consumed, False if not. + + Should raise an AssertionError if this rule's condition has + definitely failed. + + """ + + raise NotImplementedError() + + def rule_passed(self): + """Return True if the last test of this rule passed, False if + failed, None if no test was applied.""" + + raise NotImplementedError() + + def consume_final(self): + """Return True if this rule has been consumed. + + Should raise an AssertionError if this rule's condition has not + been consumed or has failed. + + """ + + if self._result is None: + assert False, 'Rule has not been consumed' + return self.is_consumed() + +class SQLMatchRule(AssertRule): + def __init__(self): + self._result = None + self._errmsg = "" + + def rule_passed(self): + return self._result + + def is_consumed(self): + if self._result is None: + return False + + assert self._result, self._errmsg + + return True + +class ExactSQL(SQLMatchRule): + + def __init__(self, sql, params=None): + SQLMatchRule.__init__(self) + self.sql = sql + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_statement = \ + _process_engine_statement(context.unicode_statement, + context) + _received_parameters = context.compiled_parameters + + # TODO: remove this step once all unit tests are migrated, as + # ExactSQL should really be *exact* SQL + + sql = _process_assertion_statement(self.sql, context) + equivalent = _received_statement == sql + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + equivalent = equivalent and params \ + == context.compiled_parameters + else: + params = {} + self._result = equivalent + if not self._result: + self._errmsg = \ + 'Testing for exact statement %r exact params %r, '\ + 'received %r with params %r' % (sql, params, + _received_statement, _received_parameters) + + +class RegexSQL(SQLMatchRule): + + def __init__(self, regex, params=None): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_statement = \ + _process_engine_statement(context.unicode_statement, + context) + _received_parameters = context.compiled_parameters + equivalent = bool(self.regex.match(_received_statement)) + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + + # do a positive compare only + + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + self._result = equivalent + if not self._result: + self._errmsg = \ + 'Testing for regex %r partial params %r, received %r '\ + 'with params %r' % (self.orig_regex, params, + _received_statement, + _received_parameters) + +class CompiledSQL(SQLMatchRule): + + def __init__(self, statement, params): + SQLMatchRule.__init__(self) + self.statement = statement + self.params = params + + def process_cursor_execute(self, statement, parameters, context, + executemany): + if not context: + return + _received_parameters = list(context.compiled_parameters) + + # recompile from the context, using the default dialect + + compiled = \ + context.compiled.statement.compile(dialect=DefaultDialect(), + column_keys=context.compiled.column_keys) + _received_statement = re.sub(r'\n', '', str(compiled)) + equivalent = self.statement == _received_statement + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + all_params = list(params) + all_received = list(_received_parameters) + while params: + param = dict(params.pop(0)) + for k, v in context.compiled.params.iteritems(): + param.setdefault(k, v) + if param not in _received_parameters: + equivalent = False + break + else: + _received_parameters.remove(param) + if _received_parameters: + equivalent = False + else: + params = {} + all_params = {} + all_received = [] + self._result = equivalent + if not self._result: + print 'Testing for compiled statement %r partial params '\ + '%r, received %r with params %r' % (self.statement, + all_params, _received_statement, all_received) + self._errmsg = \ + 'Testing for compiled statement %r partial params %r, '\ + 'received %r with params %r' % (self.statement, + all_params, _received_statement, all_received) + + + # print self._errmsg + +class CountStatements(AssertRule): + + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_execute(self, clauseelement, *multiparams, **params): + self._statement_count += 1 + + def process_cursor_execute(self, statement, parameters, context, + executemany): + pass + + def is_consumed(self): + return False + + def consume_final(self): + assert self.count == self._statement_count, \ + 'desired statement count %d does not match %d' \ + % (self.count, self._statement_count) + return True + +class AllOf(AssertRule): + + def __init__(self, *rules): + self.rules = set(rules) + + def process_execute(self, clauseelement, *multiparams, **params): + for rule in self.rules: + rule.process_execute(clauseelement, *multiparams, **params) + + def process_cursor_execute(self, statement, parameters, context, + executemany): + for rule in self.rules: + rule.process_cursor_execute(statement, parameters, context, + executemany) + + def is_consumed(self): + if not self.rules: + return True + for rule in list(self.rules): + if rule.rule_passed(): # a rule passed, move on + self.rules.remove(rule) + return len(self.rules) == 0 + assert False, 'No assertion rules were satisfied for statement' + + def consume_final(self): + return len(self.rules) == 0 + +def _process_engine_statement(query, context): + if util.jython: + + # oracle+zxjdbc passes a PyStatement when returning into + + query = unicode(query) + if context.engine.name == 'mssql' \ + and query.endswith('; select scope_identity()'): + query = query[:-25] + query = re.sub(r'\n', '', query) + return query + +def _process_assertion_statement(query, context): + paramstyle = context.dialect.paramstyle + if paramstyle == 'named': + pass + elif paramstyle =='pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle=='qmark': + repl = "?" + elif paramstyle=='format': + repl = r"%s" + elif paramstyle=='numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + + return query + +class SQLAssert(object): + + rules = None + + def add_rules(self, rules): + self.rules = list(rules) + + def statement_complete(self): + for rule in self.rules: + if not rule.consume_final(): + assert False, \ + 'All statements are complete, but pending '\ + 'assertion rules remain' + + def clear_rules(self): + del self.rules + + def execute(self, conn, clauseelement, multiparams, params, result): + if self.rules is not None: + if not self.rules: + assert False, \ + 'All rules have been exhausted, but further '\ + 'statements remain' + rule = self.rules[0] + rule.process_execute(clauseelement, *multiparams, **params) + if rule.is_consumed(): + self.rules.pop(0) + + def cursor_execute(self, conn, cursor, statement, parameters, + context, executemany): + if self.rules: + rule = self.rules[0] + rule.process_cursor_execute(statement, parameters, context, + executemany) + +asserter = SQLAssert() + diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py new file mode 100644 index 000000000..2945bd456 --- /dev/null +++ b/lib/sqlalchemy/testing/config.py @@ -0,0 +1,3 @@ +requirements = None +db = None + diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py new file mode 100644 index 000000000..f7401550e --- /dev/null +++ b/lib/sqlalchemy/testing/engines.py @@ -0,0 +1,429 @@ +from __future__ import absolute_import + +import types +import weakref +from collections import deque +from . import config +from .util import decorator +from sqlalchemy import event, pool +import re +import warnings + +class ConnectionKiller(object): + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + self.testing_engines = weakref.WeakKeyDictionary() + self.conns = set() + + def add_engine(self, engine): + self.testing_engines[engine] = True + + def connect(self, dbapi_conn, con_record): + self.conns.add(dbapi_conn) + + def checkout(self, dbapi_con, con_record, con_proxy): + self.proxy_refs[con_proxy] = True + + def _safe(self, fn): + try: + fn() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + warnings.warn( + "testing_reaper couldn't " + "rollback/close connection: %s" % e) + + def rollback_all(self): + for rec in self.proxy_refs.keys(): + if rec is not None and rec.is_valid: + self._safe(rec.rollback) + + def close_all(self): + for rec in self.proxy_refs.keys(): + if rec is not None: + self._safe(rec._close) + + def _after_test_ctx(self): + pass + # this can cause a deadlock with pg8000 - pg8000 acquires + # prepared statment lock inside of rollback() - if async gc + # is collecting in finalize_fairy, deadlock. + # not sure if this should be if pypy/jython only + #for conn in self.conns: + # self._safe(conn.rollback) + + def _stop_test_ctx(self): + if config.options.low_connections: + self._stop_test_ctx_minimal() + else: + self._stop_test_ctx_aggressive() + + def _stop_test_ctx_minimal(self): + self.close_all() + + self.conns = set() + + for rec in self.testing_engines.keys(): + if rec is not config.db: + rec.dispose() + + def _stop_test_ctx_aggressive(self): + self.close_all() + for conn in self.conns: + self._safe(conn.close) + self.conns = set() + for rec in self.testing_engines.keys(): + rec.dispose() + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + +testing_reaper = ConnectionKiller() + +def drop_all_tables(metadata, bind): + testing_reaper.close_all() + if hasattr(bind, 'close'): + bind.close() + metadata.drop_all(bind) + +@decorator +def assert_conns_closed(fn, *args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + +@decorator +def rollback_open_connections(fn, *args, **kw): + """Decorator that rolls back all open connections after fn execution.""" + + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + +@decorator +def close_first(fn, *args, **kw): + """Decorator that closes all connections before fn execution.""" + + testing_reaper.close_all() + fn(*args, **kw) + + +@decorator +def close_open_connections(fn, *args, **kw): + """Decorator that closes all connections after fn execution.""" + try: + fn(*args, **kw) + finally: + testing_reaper.close_all() + +def all_dialects(exclude=None): + import sqlalchemy.databases as d + for name in d.__all__: + # TEMPORARY + if exclude and name in exclude: + continue + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + yield mod.dialect() + +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def _safe(self, fn): + try: + fn() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + warnings.warn( + "ReconnectFixture couldn't " + "close connection: %s" % e) + + def shutdown(self): + # TODO: this doesn't cover all cases + # as nicely as we'd like, namely MySQLdb. + # would need to implement R. Brewer's + # proxy server idea to get better + # coverage. + for c in list(self.connections): + self._safe(c.close) + self.connections = [] + +def reconnecting_engine(url=None, options=None): + url = url or config.db_url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options['module'] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + _dispose = engine.dispose + def dispose(): + engine.dialect.dbapi.shutdown() + _dispose() + engine.test_shutdown = engine.dialect.dbapi.shutdown + engine.dispose = dispose + return engine + +def testing_engine(url=None, options=None): + """Produce an engine configured by --options with optional overrides.""" + + from sqlalchemy import create_engine + from .assertsql import asserter + + if not options: + use_reaper = True + else: + use_reaper = options.pop('use_reaper', True) + + url = url or config.db_url + if options is None: + options = config.db_opts + + engine = create_engine(url, **options) + if isinstance(engine.pool, pool.QueuePool): + engine.pool._timeout = 0 + engine.pool._max_overflow = 0 + event.listen(engine, 'after_execute', asserter.execute) + event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) + if use_reaper: + event.listen(engine.pool, 'connect', testing_reaper.connect) + event.listen(engine.pool, 'checkout', testing_reaper.checkout) + testing_reaper.add_engine(engine) + + return engine + +def utf8_engine(url=None, options=None): + """Hook for dialects or drivers that don't handle utf8 by default.""" + + from sqlalchemy.engine import url as engine_url + + if config.db.dialect.name == 'mysql' and \ + config.db.driver in ['mysqldb', 'pymysql']: + # note 1.2.1.gamma.6 or greater of MySQLdb + # needed here + url = url or config.db_url + url = engine_url.make_url(url) + url.query['charset'] = 'utf8' + url.query['use_unicode'] = '0' + url = str(url) + + return testing_engine(url, options) + +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ + + from sqlalchemy import create_engine + + if not dialect_name: + dialect_name = config.db.name + + buffer = [] + def executor(sql, *a, **kw): + buffer.append(sql) + def assert_sql(stmts): + recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + assert recv == stmts, recv + def print_sql(): + d = engine.dialect + return "\n".join( + str(s.compile(dialect=d)) + for s in engine.mock + ) + engine = create_engine(dialect_name + '://', + strategy='mock', executor=executor) + assert not hasattr(engine, 'mock') + engine.mock = buffer + engine.assert_sql = assert_sql + engine.print_sql = print_sql + return engine + +class DBAPIProxyCursor(object): + """Proxy a DBAPI cursor. + + Tests can provide subclasses of this to intercept + DBAPI-level cursor operations. + + """ + def __init__(self, engine, conn): + self.engine = engine + self.connection = conn + self.cursor = conn.cursor() + + def execute(self, stmt, parameters=None, **kw): + if parameters: + return self.cursor.execute(stmt, parameters, **kw) + else: + return self.cursor.execute(stmt, **kw) + + def executemany(self, stmt, params, **kw): + return self.cursor.executemany(stmt, params, **kw) + + def __getattr__(self, key): + return getattr(self.cursor, key) + +class DBAPIProxyConnection(object): + """Proxy a DBAPI connection. + + Tests can provide subclasses of this to intercept + DBAPI-level connection operations. + + """ + def __init__(self, engine, cursor_cls): + self.conn = self._sqla_unwrap = engine.pool._creator() + self.engine = engine + self.cursor_cls = cursor_cls + + def cursor(self): + return self.cursor_cls(self.engine, self.conn) + + def close(self): + self.conn.close() + + def __getattr__(self, key): + return getattr(self.conn, key) + +def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor): + """Produce an engine that provides proxy hooks for + common methods. + + """ + def mock_conn(): + return conn_cls(config.db, cursor_cls) + return testing_engine(options={'creator':mock_conn}) + +class ReplayableSession(object): + """A simple record/playback tool. + + This is *not* a mock testing class. It only records a session for later + playback and makes no assertions on call consistency whatsoever. It's + unlikely to be suitable for anything other than DB-API recording. + + """ + + Callable = object() + NoAttribute = object() + + # Py3K + #Natives = set([getattr(types, t) + # for t in dir(types) if not t.startswith('_')]). \ + # union([type(t) if not isinstance(t, type) + # else t for t in __builtins__.values()]).\ + # difference([getattr(types, t) + # for t in ('FunctionType', 'BuiltinFunctionType', + # 'MethodType', 'BuiltinMethodType', + # 'LambdaType', )]) + # Py2K + Natives = set([getattr(types, t) + for t in dir(types) if not t.startswith('_')]). \ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', 'UnboundMethodType',)]) + # end Py2K + + def __init__(self): + self.buffer = deque() + + def recorder(self, base): + return self.Recorder(self.buffer, base) + + def player(self): + return self.Player(self.buffer) + + class Recorder(object): + def __init__(self, buffer, subject): + self._buffer = buffer + self._subject = subject + + def __call__(self, *args, **kw): + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + + result = subject(*args, **kw) + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + @property + def _sqla_unwrap(self): + return self._subject + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + try: + result = type(subject).__getattribute__(subject, key) + except AttributeError: + buffer.append(ReplayableSession.NoAttribute) + raise + else: + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + class Player(object): + def __init__(self, buffer): + self._buffer = buffer + + def __call__(self, *args, **kw): + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + else: + return result + + @property + def _sqla_unwrap(self): + return None + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + elif result is ReplayableSession.NoAttribute: + raise AttributeError(key) + else: + return result + diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py new file mode 100644 index 000000000..1b24e73b7 --- /dev/null +++ b/lib/sqlalchemy/testing/entities.py @@ -0,0 +1,83 @@ +import sqlalchemy as sa +from sqlalchemy import exc as sa_exc + +_repr_stack = set() +class BasicEntity(object): + def __init__(self, **kw): + for key, value in kw.iteritems(): + setattr(self, key, value) + + def __repr__(self): + if id(self) in _repr_stack: + return object.__repr__(self) + _repr_stack.add(id(self)) + try: + return "%s(%s)" % ( + (self.__class__.__name__), + ', '.join(["%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith('_')])) + finally: + _repr_stack.remove(id(self)) + +_recursion_stack = set() +class ComparableEntity(BasicEntity): + def __hash__(self): + return hash(self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'Deep, sparse compare. + + Deeply compare two entities, following the non-None attributes of the + non-persisted object, if possible. + + """ + if other is self: + return True + elif not self.__class__ == other.__class__: + return False + + if id(self) in _recursion_stack: + return True + _recursion_stack.add(id(self)) + + try: + # pick the entity thats not SA persisted as the source + try: + self_key = sa.orm.attributes.instance_state(self).key + except sa.orm.exc.NO_STATE: + self_key = None + + if other is None: + a = self + b = other + elif self_key is not None: + a = other + b = self + else: + a = self + b = other + + for attr in a.__dict__.keys(): + if attr.startswith('_'): + continue + value = getattr(a, attr) + + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, '__iter__'): + if list(value) != list(battr): + return False + else: + if value is not None and value != battr: + return False + return True + finally: + _recursion_stack.remove(id(self)) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py new file mode 100644 index 000000000..ba2eebe4f --- /dev/null +++ b/lib/sqlalchemy/testing/exclusions.py @@ -0,0 +1,269 @@ +import operator +from nose import SkipTest +from sqlalchemy.util import decorator +from . import config +from sqlalchemy import util + + +def fails_if(predicate, reason=None): + predicate = _as_predicate(predicate) + + @decorator + def decorate(fn, *args, **kw): + if not predicate(): + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected (%s): %s " % ( + fn.__name__, predicate, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' (%s)" % + (fn.__name__, predicate)) + return decorate + +def skip_if(predicate, reason=None): + predicate = _as_predicate(predicate) + + @decorator + def decorate(fn, *args, **kw): + if predicate(): + if reason: + msg = "'%s' : %s" % ( + fn.__name__, + reason + ) + else: + msg = "'%s': %s" % ( + fn.__name__, predicate + ) + raise SkipTest(msg) + else: + return fn(*args, **kw) + return decorate + +def only_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return skip_if(NotPredicate(predicate), reason) + +def succeeds_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return fails_if(NotPredicate(predicate), reason) + +class Predicate(object): + @classmethod + def as_predicate(cls, predicate): + if isinstance(predicate, Predicate): + return predicate + elif isinstance(predicate, list): + return OrPredicate([cls.as_predicate(pred) for pred in predicate]) + elif isinstance(predicate, tuple): + return SpecPredicate(*predicate) + elif isinstance(predicate, basestring): + return SpecPredicate(predicate, None, None) + elif util.callable(predicate): + return LambdaPredicate(predicate) + else: + assert False, "unknown predicate type: %s" % predicate + +class SpecPredicate(Predicate): + def __init__(self, db, op=None, spec=None, description=None): + self.db = db + self.op = op + self.spec = spec + self.description = description + + _ops = { + '<': operator.lt, + '>': operator.gt, + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '>=': operator.ge, + 'in': operator.contains, + 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + } + + def __call__(self, engine=None): + if engine is None: + engine = config.db + + if "+" in self.db: + dialect, driver = self.db.split('+') + else: + dialect, driver = self.db, None + + if dialect and engine.name != dialect: + return False + if driver is not None and engine.driver != driver: + return False + + if self.op is not None: + assert driver is None, "DBAPI version specs not supported yet" + + version = _server_version(engine) + oper = hasattr(self.op, '__call__') and self.op \ + or self._ops[self.op] + return oper(version, self.spec) + else: + return True + + def _as_string(self, negate=False): + if self.description is not None: + return self.description + elif self.op is None: + if negate: + return "not %s" % self.db + else: + return "%s" % self.db + else: + if negate: + return "not %s %s %s" % ( + self.db, + self.op, + self.spec + ) + else: + return "%s %s %s" % ( + self.db, + self.op, + self.spec + ) + + def __str__(self): + return self._as_string() + +class LambdaPredicate(Predicate): + def __init__(self, lambda_, description=None, args=None, kw=None): + self.lambda_ = lambda_ + self.args = args or () + self.kw = kw or {} + if description: + self.description = description + elif lambda_.__doc__: + self.description = lambda_.__doc__ + else: + self.description = "custom function" + + def __call__(self): + return self.lambda_(*self.args, **self.kw) + + def _as_string(self, negate=False): + if negate: + return "not " + self.description + else: + return self.description + + def __str__(self): + return self._as_string() + +class NotPredicate(Predicate): + def __init__(self, predicate): + self.predicate = predicate + + def __call__(self, *arg, **kw): + return not self.predicate(*arg, **kw) + + def __str__(self): + return self.predicate._as_string(True) + +class OrPredicate(Predicate): + def __init__(self, predicates, description=None): + self.predicates = predicates + self.description = description + + def __call__(self, *arg, **kw): + for pred in self.predicates: + if pred(*arg, **kw): + self._str = pred + return True + return False + + _str = None + + def _eval_str(self, negate=False): + if self._str is None: + if negate: + conjunction = " and " + else: + conjunction = " or " + return conjunction.join(p._as_string(negate=negate) + for p in self.predicates) + else: + return self._str._as_string(negate=negate) + + def _negation_str(self): + if self.description is not None: + return "Not " + (self.description % {"spec": self._str}) + else: + return self._eval_str(negate=True) + + def _as_string(self, negate=False): + if negate: + return self._negation_str() + else: + if self.description is not None: + return self.description % {"spec": self._str} + else: + return self._eval_str() + + def __str__(self): + return self._as_string() + +_as_predicate = Predicate.as_predicate + +def _is_excluded(db, op, spec): + return SpecPredicate(db, op, spec)() + +def _server_version(engine): + """Return a server_version_info tuple.""" + + # force metadata to be retrieved + conn = engine.connect() + version = getattr(engine.dialect, 'server_version_info', ()) + conn.close() + return version + +def db_spec(*dbs): + return OrPredicate( + Predicate.as_predicate(db) for db in dbs + ) + +def open(fn): + return fn + +@decorator +def future(fn, *args, **kw): + return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature") + +def fails_on(db, reason): + return fails_if(SpecPredicate(db), reason) + +def fails_on_everything_except(*dbs): + return succeeds_if( + OrPredicate([ + SpecPredicate(db) for db in dbs + ]) + ) + +def skip(db, reason): + return skip_if(SpecPredicate(db), reason) + +def only_on(dbs, reason): + return only_if( + OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) + ) + + +def exclude(db, op, spec, reason): + return skip_if(SpecPredicate(db, op, spec), reason) + + +def against(*queries): + return OrPredicate([ + Predicate.as_predicate(query) + for query in queries + ])() diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py new file mode 100644 index 000000000..018276d4d --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures.py @@ -0,0 +1,334 @@ +from . import config +from . import assertions, schema +from .util import adict +from .engines import drop_all_tables +from .entities import BasicEntity, ComparableEntity +import sys +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta + +class TestBase(object): + # A sequence of database names to always run, regardless of the + # constraints below. + __whitelist__ = () + + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + def assert_(self, val, msg=None): + assert val, msg + +class TablesTest(TestBase): + + # 'once', None + run_setup_bind = 'once' + + # 'once', 'each', None + run_define_tables = 'once' + + # 'once', 'each', None + run_create_tables = 'once' + + # 'once', 'each', None + run_inserts = 'each' + + # 'each', None + run_deletes = 'each' + + # 'once', None + run_dispose_bind = None + + bind = None + metadata = None + tables = None + other = None + + @classmethod + def setup_class(cls): + cls._init_class() + + cls._setup_once_tables() + + cls._setup_once_inserts() + + @classmethod + def _init_class(cls): + if cls.run_define_tables == 'each': + if cls.run_create_tables == 'once': + cls.run_create_tables = 'each' + assert cls.run_inserts in ('each', None) + + if cls.other is None: + cls.other = adict() + + if cls.tables is None: + cls.tables = adict() + + if cls.bind is None: + setattr(cls, 'bind', cls.setup_bind()) + + if cls.metadata is None: + setattr(cls, 'metadata', sa.MetaData()) + + if cls.metadata.bind is None: + cls.metadata.bind = cls.bind + + @classmethod + def _setup_once_inserts(cls): + if cls.run_inserts == 'once': + cls._load_fixtures() + cls.insert_data() + + @classmethod + def _setup_once_tables(cls): + if cls.run_define_tables == 'once': + cls.define_tables(cls.metadata) + if cls.run_create_tables == 'once': + cls.metadata.create_all(cls.bind) + cls.tables.update(cls.metadata.tables) + + def _setup_each_tables(self): + if self.run_define_tables == 'each': + self.tables.clear() + if self.run_create_tables == 'each': + drop_all_tables(self.metadata, self.bind) + self.metadata.clear() + self.define_tables(self.metadata) + if self.run_create_tables == 'each': + self.metadata.create_all(self.bind) + self.tables.update(self.metadata.tables) + elif self.run_create_tables == 'each': + drop_all_tables(self.metadata, self.bind) + self.metadata.create_all(self.bind) + + def _setup_each_inserts(self): + if self.run_inserts == 'each': + self._load_fixtures() + self.insert_data() + + def _teardown_each_tables(self): + # no need to run deletes if tables are recreated on setup + if self.run_define_tables != 'each' and self.run_deletes == 'each': + for table in reversed(self.metadata.sorted_tables): + try: + table.delete().execute().close() + except sa.exc.DBAPIError, ex: + print >> sys.stderr, "Error emptying table %s: %r" % ( + table, ex) + + def setup(self): + self._setup_each_tables() + self._setup_each_inserts() + + def teardown(self): + self._teardown_each_tables() + + @classmethod + def _teardown_once_metadata_bind(cls): + if cls.run_create_tables: + drop_all_tables(cls.metadata, cls.bind) + + if cls.run_dispose_bind == 'once': + cls.dispose_bind(cls.bind) + + cls.metadata.bind = None + + if cls.run_setup_bind is not None: + cls.bind = None + + @classmethod + def teardown_class(cls): + cls._teardown_once_metadata_bind() + + @classmethod + def setup_bind(cls): + return config.db + + @classmethod + def dispose_bind(cls, bind): + if hasattr(bind, 'dispose'): + bind.dispose() + elif hasattr(bind, 'close'): + bind.close() + + @classmethod + def define_tables(cls, metadata): + pass + + @classmethod + def fixtures(cls): + return {} + + @classmethod + def insert_data(cls): + pass + + def sql_count_(self, count, fn): + self.assert_sql_count(self.bind, fn, count) + + def sql_eq_(self, callable_, statements, with_sequences=None): + self.assert_sql(self.bind, + callable_, statements, with_sequences) + + @classmethod + def _load_fixtures(cls): + """Insert rows as represented by the fixtures() method.""" + headers, rows = {}, {} + for table, data in cls.fixtures().iteritems(): + if len(data) < 2: + continue + if isinstance(table, basestring): + table = cls.tables[table] + headers[table] = data[0] + rows[table] = data[1:] + for table in cls.metadata.sorted_tables: + if table not in headers: + continue + cls.bind.execute( + table.insert(), + [dict(zip(headers[table], column_values)) + for column_values in rows[table]]) + + +class _ORMTest(object): + __requires__ = ('subqueries',) + + @classmethod + def teardown_class(cls): + sa.orm.session.Session.close_all() + sa.orm.clear_mappers() + +class ORMTest(_ORMTest, TestBase): + pass + +class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): + # 'once', 'each', None + run_setup_classes = 'once' + + # 'once', 'each', None + run_setup_mappers = 'each' + + classes = None + + @classmethod + def setup_class(cls): + cls._init_class() + + if cls.classes is None: + cls.classes = adict() + + cls._setup_once_tables() + cls._setup_once_classes() + cls._setup_once_mappers() + cls._setup_once_inserts() + + @classmethod + def teardown_class(cls): + cls._teardown_once_class() + cls._teardown_once_metadata_bind() + + def setup(self): + self._setup_each_tables() + self._setup_each_mappers() + self._setup_each_inserts() + + def teardown(self): + sa.orm.session.Session.close_all() + self._teardown_each_mappers() + self._teardown_each_tables() + + @classmethod + def _teardown_once_class(cls): + cls.classes.clear() + _ORMTest.teardown_class() + + + @classmethod + def _setup_once_classes(cls): + if cls.run_setup_classes == 'once': + cls._with_register_classes(cls.setup_classes) + + @classmethod + def _setup_once_mappers(cls): + if cls.run_setup_mappers == 'once': + cls._with_register_classes(cls.setup_mappers) + + def _setup_each_mappers(self): + if self.run_setup_mappers == 'each': + self._with_register_classes(self.setup_mappers) + + @classmethod + def _with_register_classes(cls, fn): + """Run a setup method, framing the operation with a Base class + that will catch new subclasses to be established within + the "classes" registry. + + """ + cls_registry = cls.classes + class FindFixture(type): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + return type.__init__(cls, classname, bases, dict_) + + + class _Base(object): + __metaclass__ = FindFixture + class Basic(BasicEntity, _Base): + pass + class Comparable(ComparableEntity, _Base): + pass + cls.Basic = Basic + cls.Comparable = Comparable + fn() + + def _teardown_each_mappers(self): + # some tests create mappers in the test bodies + # and will define setup_mappers as None - + # clear mappers in any case + if self.run_setup_mappers != 'once': + sa.orm.clear_mappers() + + @classmethod + def setup_classes(cls): + pass + + @classmethod + def setup_mappers(cls): + pass + +class DeclarativeMappedTest(MappedTest): + run_setup_classes = 'once' + run_setup_mappers = 'once' + + @classmethod + def _setup_once_tables(cls): + pass + + @classmethod + def _with_register_classes(cls, fn): + cls_registry = cls.classes + class FindFixtureDeclarative(DeclarativeMeta): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + return DeclarativeMeta.__init__( + cls, classname, bases, dict_) + class DeclarativeBasic(object): + __table_cls__ = schema.Table + _DeclBase = declarative_base(metadata=cls.metadata, + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic) + cls.DeclarativeBasic = _DeclBase + fn() + if cls.metadata.tables: + cls.metadata.create_all(config.db) diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py new file mode 100644 index 000000000..f5b8b827c --- /dev/null +++ b/lib/sqlalchemy/testing/pickleable.py @@ -0,0 +1,107 @@ +"""Classes used in pickling tests, need to be at the module level for unpickling.""" + +from . import fixtures + +class User(fixtures.ComparableEntity): + pass + +class Order(fixtures.ComparableEntity): + pass + +class Dingaling(fixtures.ComparableEntity): + pass + +class EmailUser(User): + pass + +class Address(fixtures.ComparableEntity): + pass + +# TODO: these are kind of arbitrary.... +class Child1(fixtures.ComparableEntity): + pass + +class Child2(fixtures.ComparableEntity): + pass + +class Parent(fixtures.ComparableEntity): + pass + +class Screen(object): + def __init__(self, obj, parent=None): + self.obj = obj + self.parent = parent + +class Foo(object): + def __init__(self, moredata): + self.data = 'im data' + self.stuff = 'im stuff' + self.moredata = moredata + __hash__ = object.__hash__ + def __eq__(self, other): + return other.data == self.data and \ + other.stuff == self.stuff and \ + other.moredata == self.moredata + + +class Bar(object): + def __init__(self, x, y): + self.x = x + self.y = y + __hash__ = object.__hash__ + def __eq__(self, other): + return other.__class__ is self.__class__ and \ + other.x == self.x and \ + other.y == self.y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + +class OldSchool: + def __init__(self, x, y): + self.x = x + self.y = y + def __eq__(self, other): + return other.__class__ is self.__class__ and \ + other.x == self.x and \ + other.y == self.y + +class OldSchoolWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + +class BarWithoutCompare(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class NotComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + +class BrokenComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + raise NotImplementedError + diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/__init__.py diff --git a/lib/sqlalchemy/testing/plugin/config.py b/lib/sqlalchemy/testing/plugin/config.py new file mode 100644 index 000000000..08b9753dc --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/config.py @@ -0,0 +1,186 @@ +"""Option and configuration implementations, run by the nose plugin +on test suite startup.""" + +import time +import warnings +import sys +import re + +logging = None +db = None +db_label = None +db_url = None +db_opts = {} +options = None +file_config = None + +def _log(option, opt_str, value, parser): + global logging + if not logging: + import logging + logging.basicConfig() + + if opt_str.endswith('-info'): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith('-debug'): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + print "Available --db options (use --dburi to override)" + for macro in sorted(file_config.options('db')): + print "%20s\t%s" % (macro, file_config.get('db', macro)) + sys.exit(0) + +def _server_side_cursors(options, opt_str, value, parser): + db_opts['server_side_cursors'] = True + +def _zero_timeout(options, opt_str, value, parser): + warnings.warn("--zero-timeout testing option is now on in all cases") + +def _engine_strategy(options, opt_str, value, parser): + if value: + db_opts['strategy'] = value + +pre_configure = [] +post_configure = [] + +def _setup_options(opt, file_config): + global options + from sqlalchemy.testing import config + config.options = options = opt +pre_configure.append(_setup_options) + +def _monkeypatch_cdecimal(options, file_config): + if options.cdecimal: + import sys + import cdecimal + sys.modules['decimal'] = cdecimal +pre_configure.append(_monkeypatch_cdecimal) + +def _engine_uri(options, file_config): + global db_label, db_url + db_label = 'sqlite' + if options.dburi: + db_url = options.dburi + db_label = db_url[:db_url.index(':')] + elif options.db: + db_label = options.db + db_url = None + + if db_url is None: + if db_label not in file_config.options('db'): + raise RuntimeError( + "Unknown engine. Specify --dbs for known engines.") + db_url = file_config.get('db', db_label) +post_configure.append(_engine_uri) + +def _require(options, file_config): + if not(options.require or + (file_config.has_section('require') and + file_config.items('require'))): + return + + try: + import pkg_resources + except ImportError: + raise RuntimeError("setuptools is required for version requirements") + + cmdline = [] + for requirement in options.require: + pkg_resources.require(requirement) + cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0]) + + if file_config.has_section('require'): + for label, requirement in file_config.items('require'): + if not label == db_label or label.startswith('%s.' % db_label): + continue + seen = [c for c in cmdline if requirement.startswith(c)] + if seen: + continue + pkg_resources.require(requirement) +post_configure.append(_require) + +def _engine_pool(options, file_config): + if options.mockpool: + from sqlalchemy import pool + db_opts['poolclass'] = pool.AssertionPool +post_configure.append(_engine_pool) + +def _create_testing_engine(options, file_config): + from sqlalchemy.testing import engines, config + from sqlalchemy import testing + global db + config.db = testing.db = db = engines.testing_engine(db_url, db_opts) + config.db_opts = db_opts + config.db_url = db_url + +post_configure.append(_create_testing_engine) + +def _prep_testing_database(options, file_config): + from sqlalchemy.testing import engines + from sqlalchemy import schema + + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + +post_configure.append(_prep_testing_database) + +def _set_table_options(options, file_config): + from sqlalchemy.testing import schema + + table_options = schema.table_options + for spec in options.tableopts: + key, value = spec.split('=') + table_options[key] = value + + if options.mysql_engine: + table_options['mysql_engine'] = options.mysql_engine +post_configure.append(_set_table_options) + +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm import unitofwork, session, mapper, dependency + from sqlalchemy.util import topological + from sqlalchemy.testing.util import RandomSet + topological.set = unitofwork.set = session.set = mapper.set = \ + dependency.set = RandomSet +post_configure.append(_reverse_topological) + +def _requirements(options, file_config): + from sqlalchemy.testing import config + from sqlalchemy import testing + requirement_cls = file_config.get('sqla_testing', "requirement_cls") + + modname, clsname = requirement_cls.split(":") + + # importlib.import_module() only introduced in 2.7, a little + # late + mod = __import__(modname) + for component in modname.split(".")[1:]: + mod = getattr(mod, component) + req_cls = getattr(mod, clsname) + config.requirements = testing.requires = req_cls(db, config) + +post_configure.append(_requirements) + +def _setup_profiling(options, file_config): + from sqlalchemy.testing import profiling + profiling._profile_stats = profiling.ProfileStatsFile( + file_config.get('sqla_testing', 'profile_file')) + +post_configure.append(_setup_profiling) + diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py new file mode 100644 index 000000000..c9e12c305 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -0,0 +1,199 @@ +"""Enhance nose with extra options and behaviors for running SQLAlchemy tests. + +This module is imported relative to the "plugins" package as a top level +package by the sqla_nose.py runner, so that the plugin can be loaded with +the rest of nose including the coverage plugin before any of SQLAlchemy itself +is imported, so that coverage works. + +When third party libraries use this library, it can be imported +normally as "from sqlalchemy.testing.plugin import noseplugin". + +""" +import os +import ConfigParser + +from nose.plugins import Plugin +from nose import SkipTest +from . import config + +from .config import _log, _list_dbs, _zero_timeout, \ + _engine_strategy, _server_side_cursors, pre_configure,\ + post_configure + +# late imports +fixtures = None +engines = None +exclusions = None +warnings = None +profiling = None +assertions = None +requirements = None +util = None +file_config = None + +class NoseSQLAlchemy(Plugin): + """ + Handles the setup and extra properties required for testing SQLAlchemy + """ + enabled = True + + name = 'sqla_testing' + score = 100 + + def options(self, parser, env=os.environ): + Plugin.options(self, parser, env) + opt = parser.add_option + opt("--log-info", action="callback", type="string", callback=_log, + help="turn on info logging for <LOG> (multiple OK)") + opt("--log-debug", action="callback", type="string", callback=_log, + help="turn on debug logging for <LOG> (multiple OK)") + opt("--require", action="append", dest="require", default=[], + help="require a particular driver or module version (multiple OK)") + opt("--db", action="store", dest="db", default="sqlite", + help="Use prefab database uri") + opt('--dbs', action='callback', callback=_list_dbs, + help="List available prefab dbs") + opt("--dburi", action="store", dest="dburi", + help="Database uri (overrides --db)") + opt("--dropfirst", action="store_true", dest="dropfirst", + help="Drop all tables in the target database first") + opt("--mockpool", action="store_true", dest="mockpool", + help="Use mock pool (asserts only one connection used)") + opt("--zero-timeout", action="callback", callback=_zero_timeout, + help="Set pool_timeout to zero, applies to QueuePool only") + opt("--low-connections", action="store_true", dest="low_connections", + help="Use a low number of distinct connections - i.e. for Oracle TNS" + ) + opt("--enginestrategy", action="callback", type="string", + callback=_engine_strategy, + help="Engine strategy (plain or threadlocal, defaults to plain)") + opt("--reversetop", action="store_true", dest="reversetop", default=False, + help="Use a random-ordering set implementation in the ORM (helps " + "reveal dependency issues)") + opt("--with-cdecimal", action="store_true", dest="cdecimal", default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' for all tests") + opt("--unhashable", action="store_true", dest="unhashable", default=False, + help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") + opt("--noncomparable", action="store_true", dest="noncomparable", default=False, + help="Disallow SQLAlchemy from performing == on mapped test objects.") + opt("--truthless", action="store_true", dest="truthless", default=False, + help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") + opt("--serverside", action="callback", callback=_server_side_cursors, + help="Turn on server side cursors for PG") + opt("--mysql-engine", action="store", dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, default is " + "a db-default/InnoDB combo.") + opt("--table-option", action="append", dest="tableopts", default=[], + help="Add a dialect-specific table option, key=value") + opt("--write-profiles", action="store_true", dest="write_profiles", default=False, + help="Write/update profiling data.") + global file_config + file_config = ConfigParser.ConfigParser() + file_config.read(['setup.cfg', 'test.cfg', os.path.expanduser('~/.satest.cfg')]) + config.file_config = file_config + + def configure(self, options, conf): + Plugin.configure(self, options, conf) + self.options = options + for fn in pre_configure: + fn(self.options, file_config) + + def begin(self): + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(self.options, file_config) + + # late imports, has to happen after config as well + # as nose plugins like coverage + global util, fixtures, engines, exclusions, \ + assertions, warnings, profiling + from sqlalchemy.testing import fixtures, engines, exclusions, \ + assertions, warnings, profiling + from sqlalchemy import util + + def describeTest(self, test): + return "" + + def wantFunction(self, fn): + if fn.__module__.startswith('test.lib') or \ + fn.__module__.startswith('test.bootstrap'): + return False + + def wantClass(self, cls): + """Return true if you want the main test selector to collect + tests from this class, false if you don't, and None if you don't + care. + + :Parameters: + cls : class + The class being examined by the selector + + """ + + if not issubclass(cls, fixtures.TestBase): + return False + elif cls.__name__.startswith('_'): + return False + else: + return True + + def _do_skips(self, cls): + from sqlalchemy.testing import config + if hasattr(cls, '__requires__'): + def test_suite(): + return 'ok' + test_suite.__name__ = cls.__name__ + for requirement in cls.__requires__: + check = getattr(config.requirements, requirement) + check(test_suite)() + + if cls.__unsupported_on__: + spec = exclusions.db_spec(*cls.__unsupported_on__) + if spec(config.db): + raise SkipTest( + "'%s' unsupported on DB implementation '%s'" % ( + cls.__name__, config.db.name) + ) + + if getattr(cls, '__only_on__', None): + spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) + if not spec(config.db): + raise SkipTest( + "'%s' unsupported on DB implementation '%s'" % ( + cls.__name__, config.db.name) + ) + + if getattr(cls, '__skip_if__', False): + for c in getattr(cls, '__skip_if__'): + if c(): + raise SkipTest("'%s' skipped by %s" % ( + cls.__name__, c.__name__) + ) + + for db, op, spec in getattr(cls, '__excluded_on__', ()): + exclusions.exclude(db, op, spec, + "'%s' unsupported on DB %s version %s" % ( + cls.__name__, config.db.name, + exclusions._server_version(config.db))) + + def beforeTest(self, test): + warnings.resetwarnings() + profiling._current_test = test.id() + + def afterTest(self, test): + engines.testing_reaper._after_test_ctx() + warnings.resetwarnings() + + def startContext(self, ctx): + if not isinstance(ctx, type) \ + or not issubclass(ctx, fixtures.TestBase): + return + self._do_skips(ctx) + + def stopContext(self, ctx): + if not isinstance(ctx, type) \ + or not issubclass(ctx, fixtures.TestBase): + return + engines.testing_reaper._stop_test_ctx() + if not config.options.low_connections: + assertions.global_cleanup_assertions() diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py new file mode 100644 index 000000000..be32b1d1d --- /dev/null +++ b/lib/sqlalchemy/testing/profiling.py @@ -0,0 +1,292 @@ +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" + +import os +import sys +from .util import gc_collect, decorator +from . import config +from nose import SkipTest +import pstats +import time +import collections +from sqlalchemy import util +try: + import cProfile +except ImportError: + cProfile = None +from sqlalchemy.util.compat import jython, pypy, win32 + +_current_test = None + +def profiled(target=None, **target_opts): + """Function profiling. + + @profiled('label') + or + @profiled('label', report=True, sort=('calls',), limit=20) + + Enables profiling for a function when 'label' is targetted for + profiling. Report options can be supplied, and override the global + configuration and command-line options. + """ + + profile_config = {'targets': set(), + 'report': True, + 'print_callers': False, + 'print_callees': False, + 'graphic': False, + 'sort': ('time', 'calls'), + 'limit': None} + if target is None: + target = 'anonymous_target' + + filename = "%s.prof" % target + + @decorator + def decorate(fn, *args, **kw): + elapsed, load_stats, result = _profile( + filename, fn, *args, **kw) + + graphic = target_opts.get('graphic', profile_config['graphic']) + if graphic: + os.system("runsnake %s" % filename) + else: + report = target_opts.get('report', profile_config['report']) + if report: + sort_ = target_opts.get('sort', profile_config['sort']) + limit = target_opts.get('limit', profile_config['limit']) + print ("Profile report for target '%s' (%s)" % ( + target, filename) + ) + + stats = load_stats() + stats.sort_stats(*sort_) + if limit: + stats.print_stats(limit) + else: + stats.print_stats() + + print_callers = target_opts.get('print_callers', + profile_config['print_callers']) + if print_callers: + stats.print_callers() + + print_callees = target_opts.get('print_callees', + profile_config['print_callees']) + if print_callees: + stats.print_callees() + + os.unlink(filename) + return result + return decorate + + +class ProfileStatsFile(object): + """"Store per-platform/fn profiling results in a file. + + We're still targeting Py2.5, 2.4 on 0.7 with no dependencies, + so no json lib :( need to roll something silly + + """ + def __init__(self, filename): + self.write = config.options is not None and config.options.write_profiles + self.fname = os.path.abspath(filename) + self.short_fname = os.path.split(self.fname)[-1] + self.data = collections.defaultdict(lambda: collections.defaultdict(dict)) + self._read() + if self.write: + # rewrite for the case where features changed, + # etc. + self._write() + + @util.memoized_property + def platform_key(self): + + dbapi_key = config.db.name + "_" + config.db.driver + + # keep it at 2.7, 3.1, 3.2, etc. for now. + py_version = '.'.join([str(v) for v in sys.version_info[0:2]]) + + platform_tokens = [py_version] + platform_tokens.append(dbapi_key) + if jython: + platform_tokens.append("jython") + if pypy: + platform_tokens.append("pypy") + if win32: + platform_tokens.append("win") + _has_cext = config.requirements._has_cextensions() + platform_tokens.append(_has_cext and "cextensions" or "nocextensions") + return "_".join(platform_tokens) + + def has_stats(self): + test_key = _current_test + return test_key in self.data and self.platform_key in self.data[test_key] + + def result(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + + if 'counts' not in per_platform: + per_platform['counts'] = counts = [] + else: + counts = per_platform['counts'] + + if 'current_count' not in per_platform: + per_platform['current_count'] = current_count = 0 + else: + current_count = per_platform['current_count'] + + has_count = len(counts) > current_count + + if not has_count: + counts.append(callcount) + if self.write: + self._write() + result = None + else: + result = per_platform['lineno'], counts[current_count] + per_platform['current_count'] += 1 + return result + + + def _header(self): + return \ + "# %s\n"\ + "# This file is written out on a per-environment basis.\n"\ + "# For each test in aaa_profiling, the corresponding function and \n"\ + "# environment is located within this file. If it doesn't exist,\n"\ + "# the test is skipped.\n"\ + "# If a callcount does exist, it is compared to what we received. \n"\ + "# assertions are raised if the counts do not match.\n"\ + "# \n"\ + "# To add a new callcount test, apply the function_call_count \n"\ + "# decorator and re-run the tests using the --write-profiles option - \n"\ + "# this file will be rewritten including the new count.\n"\ + "# \n"\ + "" % (self.fname) + + def _read(self): + try: + profile_f = open(self.fname) + except IOError: + return + for lineno, line in enumerate(profile_f): + line = line.strip() + if not line or line.startswith("#"): + continue + + test_key, platform_key, counts = line.split() + per_fn = self.data[test_key] + per_platform = per_fn[platform_key] + per_platform['counts'] = [int(count) for count in counts.split(",")] + per_platform['lineno'] = lineno + 1 + per_platform['current_count'] = 0 + profile_f.close() + + def _write(self): + print("Writing profile file %s" % self.fname) + profile_f = open(self.fname, "w") + profile_f.write(self._header()) + for test_key in sorted(self.data): + + per_fn = self.data[test_key] + profile_f.write("\n# TEST: %s\n\n" % test_key) + for platform_key in sorted(per_fn): + per_platform = per_fn[platform_key] + profile_f.write( + "%s %s %s\n" % ( + test_key, + platform_key, ",".join(str(count) for count in per_platform['counts']) + ) + ) + profile_f.close() + +from sqlalchemy.util.compat import update_wrapper + +def function_call_count(variance=0.05): + """Assert a target for a test case's function call count. + + The main purpose of this assertion is to detect changes in + callcounts for various functions - the actual number is not as important. + Callcounts are stored in a file keyed to Python version and OS platform + information. This file is generated automatically for new tests, + and versioned so that unexpected changes in callcounts will be detected. + + """ + + def decorate(fn): + def wrap(*args, **kw): + + + if cProfile is None: + raise SkipTest("cProfile is not installed") + + if not _profile_stats.has_stats() and not _profile_stats.write: + # run the function anyway, to support dependent tests + # (not a great idea but we have these in test_zoomark) + fn(*args, **kw) + raise SkipTest("No profiling stats available on this " + "platform for this function. Run tests with " + "--write-profiles to add statistics to %s for " + "this platform." % _profile_stats.short_fname) + + gc_collect() + + + timespent, load_stats, fn_result = _profile( + fn, *args, **kw + ) + stats = load_stats() + callcount = stats.total_calls + + expected = _profile_stats.result(callcount) + if expected is None: + expected_count = None + else: + line_no, expected_count = expected + + print("Pstats calls: %d Expected %s" % ( + callcount, + expected_count + ) + ) + stats.print_stats() + #stats.print_callers() + + if expected_count: + deviance = int(callcount * variance) + if abs(callcount - expected_count) > deviance: + raise AssertionError( + "Adjusted function call count %s not within %s%% " + "of expected %s. (Delete line %d of file %s to regenerate " + "this callcount, when tests are run with --write-profiles.)" + % ( + callcount, (variance * 100), + expected_count, line_no, + _profile_stats.fname)) + return fn_result + return update_wrapper(wrap, fn) + return decorate + + +def _profile(fn, *args, **kw): + filename = "%s.prof" % fn.__name__ + + def load_stats(): + st = pstats.Stats(filename) + os.unlink(filename) + return st + + began = time.time() + cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), + filename=filename) + ended = time.time() + + return ended - began, load_stats, locals()['result'] + diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py new file mode 100644 index 000000000..eca883d4e --- /dev/null +++ b/lib/sqlalchemy/testing/requirements.py @@ -0,0 +1,38 @@ +"""Global database feature support policy. + +Provides decorators to mark tests requiring specific feature support from the +target database. + +""" + +from .exclusions import \ + skip, \ + skip_if,\ + only_if,\ + only_on,\ + fails_on,\ + fails_on_everything_except,\ + fails_if,\ + SpecPredicate,\ + against + +def no_support(db, reason): + return SpecPredicate(db, description=reason) + +def exclude(db, op, spec, description=None): + return SpecPredicate(db, op, spec, description=description) + + +def _chain_decorators_on(*decorators): + def decorate(fn): + for decorator in reversed(decorators): + fn = decorator(fn) + return fn + return decorate + +class Requirements(object): + def __init__(self, db, config): + self.db = db + self.config = config + + diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py new file mode 100644 index 000000000..03da78c64 --- /dev/null +++ b/lib/sqlalchemy/testing/schema.py @@ -0,0 +1,85 @@ +"""Enhanced versions of schema.Table and schema.Column which establish +desired state for different backends. +""" + +from . import exclusions +from sqlalchemy import schema, event +from . import config + +__all__ = 'Table', 'Column', + +table_options = {} + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k, kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + kw.update(table_options) + + if exclusions.against('mysql'): + if 'mysql_engine' not in kw and 'mysql_type' not in kw: + if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: + kw['mysql_engine'] = 'InnoDB' + else: + kw['mysql_engine'] = 'MyISAM' + + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around seleting self-refs too. + if exclusions.against('firebird'): + table_name = args[0] + unpack = (config.db.dialect. + identifier_preparer.unformat_identifiers) + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [fk + for col in args if isinstance(col, schema.Column) + for fk in col.foreign_keys] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + if name == table_name: + if fk.ondelete is None: + fk.ondelete = 'CASCADE' + if fk.onupdate is None: + fk.onupdate = 'CASCADE' + + return schema.Table(*args, **kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k, kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + col = schema.Column(*args, **kw) + if 'test_needs_autoincrement' in test_opts and \ + kw.get('primary_key', False) and \ + exclusions.against('firebird', 'oracle'): + def add_seq(c, tbl): + c._init_items( + schema.Sequence(_truncate_name( + config.db.dialect, tbl.name + '_' + c.name + '_seq'), + optional=True) + ) + event.listen(col, 'after_parent_attach', add_seq, propagate=True) + return col + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return name[0:max(dialect.max_identifier_length - 6, 0)] + \ + "_" + hex(hash(name) % 64)[2:] + else: + return name + diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/testing/suite/__init__.py diff --git a/lib/sqlalchemy/testing/suite/requirements.py b/lib/sqlalchemy/testing/suite/requirements.py new file mode 100644 index 000000000..5eda39b2b --- /dev/null +++ b/lib/sqlalchemy/testing/suite/requirements.py @@ -0,0 +1,24 @@ +from ..requirements import Requirements +from .. import exclusions + + +class SuiteRequirements(Requirements): + + @property + def create_table(self): + """target platform can emit basic CreateTable DDL.""" + + return exclusions.open + + @property + def drop_table(self): + """target platform can emit basic DropTable DDL.""" + + return exclusions.open + + @property + def autoincrement_insert(self): + """target platform generates new surrogate integer primary key values + when insert() is executed, excluding the pk column.""" + + return exclusions.open diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py new file mode 100644 index 000000000..1285c4196 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -0,0 +1,48 @@ +from .. import fixtures, config, util +from ..config import requirements +from ..assertions import eq_ + +from sqlalchemy import Table, Column, Integer, String + + +class TableDDLTest(fixtures.TestBase): + + def _simple_fixture(self): + return Table('test_table', self.metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + def _simple_roundtrip(self): + with config.db.begin() as conn: + conn.execute("insert into test_table(id, data) values " + "(1, 'some data')") + result = conn.execute("select id, data from test_table") + eq_( + result.first(), + (1, 'some data') + ) + + + @requirements.create_table + @util.provide_metadata + def test_create_table(self): + table = self._simple_fixture() + table.create( + config.db, checkfirst=False + ) + self._simple_roundtrip() + + + @requirements.drop_table + @util.provide_metadata + def test_drop_table(self): + table = self._simple_fixture() + table.create( + config.db, checkfirst=False + ) + table.drop( + config.db, checkfirst=False + ) + +__all__ = ('TableDDLTest',)
\ No newline at end of file diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_reflection.py diff --git a/lib/sqlalchemy/testing/suite/test_sequencing.py b/lib/sqlalchemy/testing/suite/test_sequencing.py new file mode 100644 index 000000000..7b09ecb76 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_sequencing.py @@ -0,0 +1,36 @@ +from .. import fixtures, config, util +from ..config import requirements +from ..assertions import eq_ + +from sqlalchemy import Table, Column, Integer, String + + +class InsertSequencingTest(fixtures.TablesTest): + run_deletes = 'each' + + @classmethod + def define_tables(cls, metadata): + Table('plain_pk', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)) + ) + + def _assert_round_trip(self, table): + row = config.db.execute(table.select()).first() + eq_( + row, + (1, "some data") + ) + + @requirements.autoincrement_insert + def test_autoincrement_on_insert(self): + + config.db.execute( + self.tables.plain_pk.insert(), + data="some data" + ) + self._assert_round_trip(self.tables.plain_pk) + + + +__all__ = ('InsertSequencingTest',)
\ No newline at end of file diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py new file mode 100644 index 000000000..625b9e6a5 --- /dev/null +++ b/lib/sqlalchemy/testing/util.py @@ -0,0 +1,196 @@ +from sqlalchemy.util import jython, pypy, defaultdict, decorator +from sqlalchemy.util.compat import decimal + +import gc +import time +import random +import sys +import types + +if jython: + def jython_gc_collect(*args): + """aggressive gc.collect for tests.""" + gc.collect() + time.sleep(0.1) + gc.collect() + gc.collect() + return 0 + + # "lazy" gc, for VM's that don't GC on refcount == 0 + lazy_gc = jython_gc_collect +elif pypy: + def pypy_gc_collect(*args): + gc.collect() + gc.collect() + lazy_gc = pypy_gc_collect +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + def lazy_gc(): + pass + +def picklers(): + picklers = set() + # Py2K + try: + import cPickle + picklers.add(cPickle) + except ImportError: + pass + # end Py2K + import pickle + picklers.add(pickle) + + # yes, this thing needs this much testing + for pickle_ in picklers: + for protocol in -1, 0, 1, 2: + yield pickle_.loads, lambda d: pickle_.dumps(d, protocol) + + +def round_decimal(value, prec): + if isinstance(value, float): + return round(value, prec) + + # can also use shift() here but that is 2.6 only + return (value * decimal.Decimal("1" + "0" * prec) + ).to_integral(decimal.ROUND_FLOOR) / \ + pow(10, prec) + +class RandomSet(set): + def __iter__(self): + l = list(set.__iter__(self)) + random.shuffle(l) + return iter(l) + + def pop(self): + index = random.randint(0, len(self) - 1) + item = list(set.__iter__(self))[index] + self.remove(item) + return item + + def union(self, other): + return RandomSet(set.union(self, other)) + + def difference(self, other): + return RandomSet(set.difference(self, other)) + + def intersection(self, other): + return RandomSet(set.intersection(self, other)) + + def copy(self): + return RandomSet(self) + +def conforms_partial_ordering(tuples, sorted_elements): + """True if the given sorting conforms to the given partial ordering.""" + + deps = defaultdict(set) + for parent, child in tuples: + deps[parent].add(child) + for i, node in enumerate(sorted_elements): + for n in sorted_elements[i:]: + if node in deps[n]: + return False + else: + return True + +def all_partial_orderings(tuples, elements): + edges = defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + def _all_orderings(elements): + + if len(elements) == 1: + yield list(elements) + else: + for elem in elements: + subset = set(elements).difference([elem]) + if not subset.intersection(edges[elem]): + for sub_ordering in _all_orderings(subset): + yield [elem] + sub_ordering + + return iter(_all_orderings(elements)) + + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + This function should be phased out as much as possible + in favor of @decorator. Tests that "generate" many named tests + should be modernized. + + """ + try: + fn.__name__ = name + except TypeError: + fn = types.FunctionType(fn.func_code, fn.func_globals, name, + fn.func_defaults, fn.func_closure) + return fn + + + +def run_as_contextmanager(ctx, fn, *arg, **kw): + """Run the given function under the given contextmanager, + simulating the behavior of 'with' to support older + Python versions. + + """ + + obj = ctx.__enter__() + try: + result = fn(obj, *arg, **kw) + ctx.__exit__(None, None, None) + return result + except: + exc_info = sys.exc_info() + raise_ = ctx.__exit__(*exc_info) + if raise_ is None: + raise + else: + return raise_ + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return set([tuple(row) for row in results]) + + +def fail(msg): + assert False, msg + + +@decorator +def provide_metadata(fn, *args, **kw): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from . import config + from sqlalchemy import schema + + metadata = schema.MetaData(config.db) + self = args[0] + prev_meta = getattr(self, 'metadata', None) + self.metadata = metadata + try: + return fn(*args, **kw) + finally: + metadata.drop_all() + self.metadata = prev_meta + +class adict(dict): + """Dict keys available as attributes. Shadows.""" + def __getattribute__(self, key): + try: + return self[key] + except KeyError: + return dict.__getattribute__(self, key) + + def get_all(self, *keys): + return tuple([self[key] for key in keys]) + + diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py new file mode 100644 index 000000000..799fca128 --- /dev/null +++ b/lib/sqlalchemy/testing/warnings.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +import warnings +from sqlalchemy import exc as sa_exc +from sqlalchemy import util + +def testing_warn(msg, stacklevel=3): + """Replaces sqlalchemy.util.warn during tests.""" + + filename = "sqlalchemy.testing.warnings" + lineno = 1 + if isinstance(msg, basestring): + warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno) + else: + warnings.warn_explicit(msg, filename, lineno) + +def resetwarnings(): + """Reset warning behavior to testing defaults.""" + + util.warn = util.langhelpers.warn = testing_warn + + warnings.filterwarnings('ignore', + category=sa_exc.SAPendingDeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SAWarning) + +def assert_warnings(fn, warnings): + """Assert that each of the given warnings are emitted by fn.""" + + from .assertions import eq_, emits_warning + + canary = [] + orig_warn = util.warn + def capture_warnings(*args, **kw): + orig_warn(*args, **kw) + popwarn = warnings.pop(0) + canary.append(popwarn) + eq_(args[0], popwarn) + util.warn = util.langhelpers.warn = capture_warnings + + result = emits_warning()(fn)() + assert canary, "No warning was emitted" + return result |
