diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/test/__init__.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/assertsql.py | 283 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/config.py | 177 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/engines.py | 245 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/noseplugin.py | 156 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/orm.py | 111 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/pickleable.py | 75 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/profiling.py | 207 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/requires.py | 127 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/schema.py | 74 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/testing.py | 701 |
11 files changed, 2182 insertions, 0 deletions
diff --git a/lib/sqlalchemy/test/__init__.py b/lib/sqlalchemy/test/__init__.py new file mode 100644 index 000000000..d69cedefd --- /dev/null +++ b/lib/sqlalchemy/test/__init__.py @@ -0,0 +1,26 @@ +"""Testing environment and utilities. + +This package contains base classes and routines used by +the unit tests. Tests are based on Nose and bootstrapped +by noseplugin.NoseSQLAlchemy. + +""" + +from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config +from sqlalchemy.test.schema import Column, Table +from sqlalchemy.test.testing import \ + AssertsCompiledSQL, \ + AssertsExecutionResults, \ + ComparesTables, \ + TestBase, \ + rowset + + +__all__ = ('testing', + 'Column', 'Table', + 'rowset', + 'TestBase', 'AssertsExecutionResults', + 'AssertsCompiledSQL', 'ComparesTables', + 'engines', 'profiling', 'pickleable') + + diff --git a/lib/sqlalchemy/test/assertsql.py b/lib/sqlalchemy/test/assertsql.py new file mode 100644 index 000000000..dc2c6d40f --- /dev/null +++ b/lib/sqlalchemy/test/assertsql.py @@ -0,0 +1,283 @@ + +from sqlalchemy.interfaces import ConnectionProxy +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.base import Connection +from sqlalchemy import util +import testing +import re + +class AssertRule(object): + def process_execute(self, clauseelement, *multiparams, **params): + pass + + def process_cursor_execute(self, statement, parameters, context, executemany): + pass + + def is_consumed(self): + """Return True if this rule has been consumed, False if not. + + Should raise an AssertionError if this rule's condition has definitely failed. + + """ + raise NotImplementedError() + + def rule_passed(self): + """Return True if the last test of this rule passed, False if failed, None if no test was applied.""" + + raise NotImplementedError() + + def consume_final(self): + """Return True if this rule has been consumed. + + Should raise an AssertionError if this rule's condition has not been consumed or has failed. + + """ + + if self._result is None: + assert False, "Rule has not been consumed" + + return self.is_consumed() + +class SQLMatchRule(AssertRule): + def __init__(self): + self._result = None + self._errmsg = "" + + def rule_passed(self): + return self._result + + def is_consumed(self): + if self._result is None: + return False + + assert self._result, self._errmsg + + return True + +class ExactSQL(SQLMatchRule): + def __init__(self, sql, params=None): + SQLMatchRule.__init__(self) + self.sql = sql + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_statement = _process_engine_statement(statement, context) + _received_parameters = context.compiled_parameters + + # TODO: remove this step once all unit tests + # are migrated, as ExactSQL should really be *exact* SQL + sql = _process_assertion_statement(self.sql, context) + + equivalent = _received_statement == sql + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + equivalent = equivalent and params == context.compiled_parameters + else: + params = {} + + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for exact statement %r exact params %r, " \ + "received %r with params %r" % (sql, params, _received_statement, _received_parameters) + + +class RegexSQL(SQLMatchRule): + def __init__(self, regex, params=None): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_statement = _process_engine_statement(statement, context) + _received_parameters = context.compiled_parameters + + equivalent = bool(self.regex.match(_received_statement)) + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + + # do a positive compare only + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for regex %r partial params %r, "\ + "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters) + +class CompiledSQL(SQLMatchRule): + def __init__(self, statement, params): + SQLMatchRule.__init__(self) + self.statement = statement + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_parameters = context.compiled_parameters + + # recompile from the context, using the default dialect + compiled = context.compiled.statement.\ + compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys) + + _received_statement = re.sub(r'\n', '', str(compiled)) + + equivalent = self.statement == _received_statement + if self.params: + if util.callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + + # do a positive compare only + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for compiled statement %r partial params %r, " \ + "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters) + + +class CountStatements(AssertRule): + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_execute(self, clauseelement, *multiparams, **params): + self._statement_count += 1 + + def process_cursor_execute(self, statement, parameters, context, executemany): + pass + + def is_consumed(self): + return False + + def consume_final(self): + assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count) + return True + +class AllOf(AssertRule): + def __init__(self, *rules): + self.rules = set(rules) + + def process_execute(self, clauseelement, *multiparams, **params): + for rule in self.rules: + rule.process_execute(clauseelement, *multiparams, **params) + + def process_cursor_execute(self, statement, parameters, context, executemany): + for rule in self.rules: + rule.process_cursor_execute(statement, parameters, context, executemany) + + def is_consumed(self): + if not self.rules: + return True + + for rule in list(self.rules): + if rule.rule_passed(): # a rule passed, move on + self.rules.remove(rule) + return len(self.rules) == 0 + + assert False, "No assertion rules were satisfied for statement" + + def consume_final(self): + return len(self.rules) == 0 + +def _process_engine_statement(query, context): + if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'): + query = query[:-25] + + query = re.sub(r'\n', '', query) + + return query + +def _process_assertion_statement(query, context): + paramstyle = context.dialect.paramstyle + if paramstyle == 'named': + pass + elif paramstyle =='pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle=='qmark': + repl = "?" + elif paramstyle=='format': + repl = r"%s" + elif paramstyle=='numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + + return query + +class SQLAssert(ConnectionProxy): + rules = None + + def add_rules(self, rules): + self.rules = list(rules) + + def statement_complete(self): + for rule in self.rules: + if not rule.consume_final(): + assert False, "All statements are complete, but pending assertion rules remain" + + def clear_rules(self): + del self.rules + + def execute(self, conn, execute, clauseelement, *multiparams, **params): + result = execute(clauseelement, *multiparams, **params) + + if self.rules is not None: + if not self.rules: + assert False, "All rules have been exhausted, but further statements remain" + rule = self.rules[0] + rule.process_execute(clauseelement, *multiparams, **params) + if rule.is_consumed(): + self.rules.pop(0) + + return result + + def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + result = execute(cursor, statement, parameters, context) + + if self.rules: + rule = self.rules[0] + rule.process_cursor_execute(statement, parameters, context, executemany) + + return result + +asserter = SQLAssert() + diff --git a/lib/sqlalchemy/test/config.py b/lib/sqlalchemy/test/config.py new file mode 100644 index 000000000..6ea5667cc --- /dev/null +++ b/lib/sqlalchemy/test/config.py @@ -0,0 +1,177 @@ +import optparse, os, sys, re, ConfigParser, StringIO, time, warnings +logging = None + +__all__ = 'parser', 'configure', 'options', + +db = None +db_label, db_url, db_opts = None, None, {} + +options = None +file_config = None + +base_config = """ +[db] +sqlite=sqlite:///:memory: +sqlite_file=sqlite:///querytest.db +postgres=postgres://scott:tiger@127.0.0.1:5432/test +mysql=mysql://scott:tiger@127.0.0.1:3306/test +oracle=oracle://scott:tiger@127.0.0.1:1521 +oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 +mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test +firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb +maxdb=maxdb://MONA:RED@/maxdb1 +""" + +def _log(option, opt_str, value, parser): + global logging + if not logging: + import logging + logging.basicConfig() + + if opt_str.endswith('-info'): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith('-debug'): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + print "Available --db options (use --dburi to override)" + for macro in sorted(file_config.options('db')): + print "%20s\t%s" % (macro, file_config.get('db', macro)) + sys.exit(0) + +def _server_side_cursors(options, opt_str, value, parser): + db_opts['server_side_cursors'] = True + +def _engine_strategy(options, opt_str, value, parser): + if value: + db_opts['strategy'] = value + +class _ordered_map(object): + def __init__(self): + self._keys = list() + self._data = dict() + + def __setitem__(self, key, value): + if key not in self._keys: + self._keys.append(key) + self._data[key] = value + + def __iter__(self): + for key in self._keys: + yield self._data[key] + +# at one point in refactoring, modules were injecting into the config +# process. this could probably just become a list now. +post_configure = _ordered_map() + +def _engine_uri(options, file_config): + global db_label, db_url + db_label = 'sqlite' + if options.dburi: + db_url = options.dburi + db_label = db_url[:db_url.index(':')] + elif options.db: + db_label = options.db + db_url = None + + if db_url is None: + if db_label not in file_config.options('db'): + raise RuntimeError( + "Unknown engine. Specify --dbs for known engines.") + db_url = file_config.get('db', db_label) +post_configure['engine_uri'] = _engine_uri + +def _require(options, file_config): + if not(options.require or + (file_config.has_section('require') and + file_config.items('require'))): + return + + try: + import pkg_resources + except ImportError: + raise RuntimeError("setuptools is required for version requirements") + + cmdline = [] + for requirement in options.require: + pkg_resources.require(requirement) + cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0]) + + if file_config.has_section('require'): + for label, requirement in file_config.items('require'): + if not label == db_label or label.startswith('%s.' % db_label): + continue + seen = [c for c in cmdline if requirement.startswith(c)] + if seen: + continue + pkg_resources.require(requirement) +post_configure['require'] = _require + +def _engine_pool(options, file_config): + if options.mockpool: + from sqlalchemy import pool + db_opts['poolclass'] = pool.AssertionPool +post_configure['engine_pool'] = _engine_pool + +def _create_testing_engine(options, file_config): + from sqlalchemy.test import engines, testing + global db + db = engines.testing_engine(db_url, db_opts) + testing.db = db +post_configure['create_engine'] = _create_testing_engine + +def _prep_testing_database(options, file_config): + from sqlalchemy.test import engines + from sqlalchemy import schema + + try: + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + except (KeyboardInterrupt, SystemExit): + raise + except Exception, e: + warnings.warn(RuntimeWarning( + "Error checking for existing tables in testing " + "database: %s" % e)) +post_configure['prep_db'] = _prep_testing_database + +def _set_table_options(options, file_config): + from sqlalchemy.test import schema + + table_options = schema.table_options + for spec in options.tableopts: + key, value = spec.split('=') + table_options[key] = value + + if options.mysql_engine: + table_options['mysql_engine'] = options.mysql_engine +post_configure['table_options'] = _set_table_options + +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm import unitofwork + from sqlalchemy import topological + class RevQueueDepSort(topological.QueueDependencySorter): + def __init__(self, tuples, allitems): + self.tuples = list(tuples) + self.allitems = list(allitems) + self.tuples.reverse() + self.allitems.reverse() + topological.QueueDependencySorter = RevQueueDepSort + unitofwork.DependencySorter = RevQueueDepSort +post_configure['topological'] = _reverse_topological + diff --git a/lib/sqlalchemy/test/engines.py b/lib/sqlalchemy/test/engines.py new file mode 100644 index 000000000..f0001978b --- /dev/null +++ b/lib/sqlalchemy/test/engines.py @@ -0,0 +1,245 @@ +import sys, types, weakref +from collections import deque +import config +from sqlalchemy.util import function_named, callable + +class ConnectionKiller(object): + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + + def checkout(self, dbapi_con, con_record, con_proxy): + self.proxy_refs[con_proxy] = True + + def _apply_all(self, methods): + for rec in self.proxy_refs: + if rec is not None and rec.is_valid: + try: + for name in methods: + if callable(name): + name(rec) + else: + getattr(rec, name)() + except (SystemExit, KeyboardInterrupt): + raise + except Exception, e: + # fixme + sys.stderr.write("\n" + str(e) + "\n") + + def rollback_all(self): + self._apply_all(('rollback',)) + + def close_all(self): + self._apply_all(('rollback', 'close')) + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + +testing_reaper = ConnectionKiller() + +def assert_conns_closed(fn): + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + return function_named(decorated, fn.__name__) + +def rollback_open_connections(fn): + """Decorator that rolls back all open connections after fn execution.""" + + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + return function_named(decorated, fn.__name__) + +def close_open_connections(fn): + """Decorator that closes all connections after fn execution.""" + + def decorated(*args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.close_all() + return function_named(decorated, fn.__name__) + +def all_dialects(): + import sqlalchemy.databases as d + for name in d.__all__: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + yield mod.dialect() + +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + conn = self.dbapi.connect(*args, **kwargs) + self.connections.append(conn) + return conn + + def shutdown(self): + for c in list(self.connections): + c.close() + self.connections = [] + +def reconnecting_engine(url=None, options=None): + url = url or config.db_url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options['module'] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + engine.test_shutdown = engine.dialect.dbapi.shutdown + return engine + +def testing_engine(url=None, options=None): + """Produce an engine configured by --options with optional overrides.""" + + from sqlalchemy import create_engine + from sqlalchemy.test.assertsql import asserter + + url = url or config.db_url + options = options or config.db_opts + + options.setdefault('proxy', asserter) + + listeners = options.setdefault('listeners', []) + listeners.append(testing_reaper) + + engine = create_engine(url, **options) + + return engine + +def utf8_engine(url=None, options=None): + """Hook for dialects or drivers that don't handle utf8 by default.""" + + from sqlalchemy.engine import url as engine_url + + if config.db.name == 'mysql': + dbapi_ver = config.db.dialect.dbapi.version_info + if (dbapi_ver < (1, 2, 1) or + dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), + (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))): + raise RuntimeError('Character set support unavailable with this ' + 'driver version: %s' % repr(dbapi_ver)) + else: + url = url or config.db_url + url = engine_url.make_url(url) + url.query['charset'] = 'utf8' + url.query['use_unicode'] = '0' + url = str(url) + + return testing_engine(url, options) + +def mock_engine(db=None): + """Provides a mocking engine based on the current testing.db.""" + + from sqlalchemy import create_engine + + dbi = db or config.db + buffer = [] + def executor(sql, *a, **kw): + buffer.append(sql) + engine = create_engine(dbi.name + '://', + strategy='mock', executor=executor) + assert not hasattr(engine, 'mock') + engine.mock = buffer + return engine + +class ReplayableSession(object): + """A simple record/playback tool. + + This is *not* a mock testing class. It only records a session for later + playback and makes no assertions on call consistency whatsoever. It's + unlikely to be suitable for anything other than DB-API recording. + + """ + + Callable = object() + NoAttribute = object() + Natives = set([getattr(types, t) + for t in dir(types) if not t.startswith('_')]). \ + difference([getattr(types, t) + for t in ('FunctionType', 'BuiltinFunctionType', + 'MethodType', 'BuiltinMethodType', + 'LambdaType', 'UnboundMethodType',)]) + def __init__(self): + self.buffer = deque() + + def recorder(self, base): + return self.Recorder(self.buffer, base) + + def player(self): + return self.Player(self.buffer) + + class Recorder(object): + def __init__(self, buffer, subject): + self._buffer = buffer + self._subject = subject + + def __call__(self, *args, **kw): + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + + result = subject(*args, **kw) + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + subject, buffer = [object.__getattribute__(self, x) + for x in ('_subject', '_buffer')] + try: + result = type(subject).__getattribute__(subject, key) + except AttributeError: + buffer.append(ReplayableSession.NoAttribute) + raise + else: + if type(result) not in ReplayableSession.Natives: + buffer.append(ReplayableSession.Callable) + return type(self)(buffer, result) + else: + buffer.append(result) + return result + + class Player(object): + def __init__(self, buffer): + self._buffer = buffer + + def __call__(self, *args, **kw): + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + else: + return result + + def __getattribute__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + buffer = object.__getattribute__(self, '_buffer') + result = buffer.popleft() + if result is ReplayableSession.Callable: + return self + elif result is ReplayableSession.NoAttribute: + raise AttributeError(key) + else: + return result diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py new file mode 100644 index 000000000..263d2d783 --- /dev/null +++ b/lib/sqlalchemy/test/noseplugin.py @@ -0,0 +1,156 @@ +import logging +import os +import re +import sys +import time +import warnings +import ConfigParser +import StringIO +from config import db, db_label, db_url, file_config, base_config, \ + post_configure, \ + _list_dbs, _server_side_cursors, _engine_strategy, \ + _engine_uri, _require, _engine_pool, \ + _create_testing_engine, _prep_testing_database, \ + _set_table_options, _reverse_topological, _log +from sqlalchemy.test import testing, config, requires +from nose.plugins import Plugin +from nose.util import tolist +import nose.case + +log = logging.getLogger('nose.plugins.sqlalchemy') + +class NoseSQLAlchemy(Plugin): + """ + Handles the setup and extra properties required for testing SQLAlchemy + """ + enabled = True + name = 'sqlalchemy' + score = 100 + + def options(self, parser, env=os.environ): + Plugin.options(self, parser, env) + opt = parser.add_option + #opt("--verbose", action="store_true", dest="verbose", + #help="enable stdout echoing/printing") + #opt("--quiet", action="store_true", dest="quiet", help="suppress output") + opt("--log-info", action="callback", type="string", callback=_log, + help="turn on info logging for <LOG> (multiple OK)") + opt("--log-debug", action="callback", type="string", callback=_log, + help="turn on debug logging for <LOG> (multiple OK)") + opt("--require", action="append", dest="require", default=[], + help="require a particular driver or module version (multiple OK)") + opt("--db", action="store", dest="db", default="sqlite", + help="Use prefab database uri") + opt('--dbs', action='callback', callback=_list_dbs, + help="List available prefab dbs") + opt("--dburi", action="store", dest="dburi", + help="Database uri (overrides --db)") + opt("--dropfirst", action="store_true", dest="dropfirst", + help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)") + opt("--mockpool", action="store_true", dest="mockpool", + help="Use mock pool (asserts only one connection used)") + opt("--enginestrategy", action="callback", type="string", + callback=_engine_strategy, + help="Engine strategy (plain or threadlocal, defaults to plain)") + opt("--reversetop", action="store_true", dest="reversetop", default=False, + help="Reverse the collection ordering for topological sorts (helps " + "reveal dependency issues)") + opt("--unhashable", action="store_true", dest="unhashable", default=False, + help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") + opt("--noncomparable", action="store_true", dest="noncomparable", default=False, + help="Disallow SQLAlchemy from performing == on mapped test objects.") + opt("--truthless", action="store_true", dest="truthless", default=False, + help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") + opt("--serverside", action="callback", callback=_server_side_cursors, + help="Turn on server side cursors for PG") + opt("--mysql-engine", action="store", dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, default is " + "a db-default/InnoDB combo.") + opt("--table-option", action="append", dest="tableopts", default=[], + help="Add a dialect-specific table option, key=value") + + global file_config + file_config = ConfigParser.ConfigParser() + file_config.readfp(StringIO.StringIO(base_config)) + file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) + config.file_config = file_config + + def configure(self, options, conf): + Plugin.configure(self, options, conf) + + import testing, requires + testing.db = db + testing.requires = requires + + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(options, file_config) + + def describeTest(self, test): + return "" + + def wantClass(self, cls): + """Return true if you want the main test selector to collect + tests from this class, false if you don't, and None if you don't + care. + + :Parameters: + cls : class + The class being examined by the selector + + """ + + if not issubclass(cls, testing.TestBase): + return False + else: + if (hasattr(cls, '__whitelist__') and + testing.db.name in cls.__whitelist__): + return True + else: + return not self.__should_skip_for(cls) + + def __should_skip_for(self, cls): + if hasattr(cls, '__requires__'): + def test_suite(): return 'ok' + for requirement in cls.__requires__: + check = getattr(requires, requirement) + if check(test_suite)() != 'ok': + # The requirement will perform messaging. + return True + if (hasattr(cls, '__unsupported_on__') and + testing.db.name in cls.__unsupported_on__): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if (getattr(cls, '__skip_if__', False)): + for c in getattr(cls, '__skip_if__'): + if c(): + print "'%s' skipped by %s" % ( + cls.__class__.__name__, c.__name__) + return True + for rule in getattr(cls, '__excluded_on__', ()): + if testing._is_excluded(*rule): + print "'%s' unsupported on DB %s version %s" % ( + cls.__class__.__name__, testing.db.name, + _server_version()) + return True + return False + + #def begin(self): + #pass + + def beforeTest(self, test): + testing.resetwarnings() + + def afterTest(self, test): + testing.resetwarnings() + + #def handleError(self, test, err): + #pass + + #def finalize(self, result=None): + #pass diff --git a/lib/sqlalchemy/test/orm.py b/lib/sqlalchemy/test/orm.py new file mode 100644 index 000000000..7ec13c555 --- /dev/null +++ b/lib/sqlalchemy/test/orm.py @@ -0,0 +1,111 @@ +import inspect, re +import config, testing +from sqlalchemy import orm + +__all__ = 'mapper', + + +_whitespace = re.compile(r'^(\s+)') + +def _find_pragma(lines, current): + m = _whitespace.match(lines[current]) + basis = m and m.group() or '' + + for line in reversed(lines[0:current]): + if 'testlib.pragma' in line: + return line + m = _whitespace.match(line) + indent = m and m.group() or '' + + # simplistic detection: + + # >> # testlib.pragma foo + # >> center_line() + if indent == basis: + break + # >> # testlib.pragma foo + # >> if fleem: + # >> center_line() + if line.endswith(':'): + break + return None + +def _make_blocker(method_name, fallback): + """Creates tripwired variant of a method, raising when called. + + To excempt an invocation from blockage, there are two options. + + 1) add a pragma in a comment:: + + # testlib.pragma exempt:methodname + offending_line() + + 2) add a magic cookie to the function's namespace:: + __sa_baremethodname_exempt__ = True + ... + offending_line() + another_offending_lines() + + The second is useful for testing and development. + """ + + if method_name.startswith('__') and method_name.endswith('__'): + frame_marker = '__sa_%s_exempt__' % method_name[2:-2] + else: + frame_marker = '__sa_%s_exempt__' % method_name + pragma_marker = 'exempt:' + method_name + + def method(self, *args, **kw): + frame_r = None + try: + frame = inspect.stack()[1][0] + frame_r = inspect.getframeinfo(frame, 9) + + module = frame.f_globals.get('__name__', '') + + type_ = type(self) + + pragma = _find_pragma(*frame_r[3:5]) + + exempt = ( + (not module.startswith('sqlalchemy')) or + (pragma and pragma_marker in pragma) or + (frame_marker in frame.f_locals) or + ('self' in frame.f_locals and + getattr(frame.f_locals['self'], frame_marker, False))) + + if exempt: + supermeth = getattr(super(type_, self), method_name, None) + if (supermeth is None or + getattr(supermeth, 'im_func', None) is method): + return fallback(self, *args, **kw) + else: + return supermeth(*args, **kw) + else: + raise AssertionError( + "%s.%s called in %s, line %s in %s" % ( + type_.__name__, method_name, module, frame_r[1], frame_r[2])) + finally: + del frame + method.__name__ = method_name + return method + +def mapper(type_, *args, **kw): + forbidden = [ + ('__hash__', 'unhashable', lambda s: id(s)), + ('__eq__', 'noncomparable', lambda s, o: s is o), + ('__ne__', 'noncomparable', lambda s, o: s is not o), + ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)), + ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)), + ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)), + ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)), + ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)), + ('__nonzero__', 'truthless', lambda s: 1), ] + + if isinstance(type_, type) and type_.__bases__ == (object,): + for method_name, option, fallback in forbidden: + if (getattr(config.options, option, False) and + method_name not in type_.__dict__): + setattr(type_, method_name, _make_blocker(method_name, fallback)) + + return orm.mapper(type_, *args, **kw) diff --git a/lib/sqlalchemy/test/pickleable.py b/lib/sqlalchemy/test/pickleable.py new file mode 100644 index 000000000..9794e424d --- /dev/null +++ b/lib/sqlalchemy/test/pickleable.py @@ -0,0 +1,75 @@ +""" + +some objects used for pickle tests, declared in their own module so that they +are easily pickleable. + +""" + + +class Foo(object): + def __init__(self, moredata): + self.data = 'im data' + self.stuff = 'im stuff' + self.moredata = moredata + __hash__ = object.__hash__ + def __eq__(self, other): + return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata + + +class Bar(object): + def __init__(self, x, y): + self.x = x + self.y = y + __hash__ = object.__hash__ + def __eq__(self, other): + return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + +class OldSchool: + def __init__(self, x, y): + self.x = x + self.y = y + def __eq__(self, other): + return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y + +class OldSchoolWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + +class BarWithoutCompare(object): + def __init__(self, x, y): + self.x = x + self.y = y + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class NotComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + +class BrokenComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + raise NotImplementedError + diff --git a/lib/sqlalchemy/test/profiling.py b/lib/sqlalchemy/test/profiling.py new file mode 100644 index 000000000..ca4b31cbd --- /dev/null +++ b/lib/sqlalchemy/test/profiling.py @@ -0,0 +1,207 @@ +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" + +import os, sys +from sqlalchemy.util import function_named +import config + +__all__ = 'profiled', 'function_call_count', 'conditional_call_count' + +all_targets = set() +profile_config = { 'targets': set(), + 'report': True, + 'sort': ('time', 'calls'), + 'limit': None } +profiler = None + +def profiled(target=None, **target_opts): + """Optional function profiling. + + @profiled('label') + or + @profiled('label', report=True, sort=('calls',), limit=20) + + Enables profiling for a function when 'label' is targetted for + profiling. Report options can be supplied, and override the global + configuration and command-line options. + """ + + # manual or automatic namespacing by module would remove conflict issues + if target is None: + target = 'anonymous_target' + elif target in all_targets: + print "Warning: redefining profile target '%s'" % target + all_targets.add(target) + + filename = "%s.prof" % target + + def decorator(fn): + def profiled(*args, **kw): + if (target not in profile_config['targets'] and + not target_opts.get('always', None)): + return fn(*args, **kw) + + elapsed, load_stats, result = _profile( + filename, fn, *args, **kw) + + report = target_opts.get('report', profile_config['report']) + if report: + sort_ = target_opts.get('sort', profile_config['sort']) + limit = target_opts.get('limit', profile_config['limit']) + print "Profile report for target '%s' (%s)" % ( + target, filename) + + stats = load_stats() + stats.sort_stats(*sort_) + if limit: + stats.print_stats(limit) + else: + stats.print_stats() + #stats.print_callers() + os.unlink(filename) + return result + return function_named(profiled, fn.__name__) + return decorator + +def function_call_count(count=None, versions={}, variance=0.05): + """Assert a target for a test case's function call count. + + count + Optional, general target function call count. + + versions + Optional, a dictionary of Python version strings to counts, + for example:: + + { '2.5.1': 110, + '2.5': 100, + '2.4': 150 } + + The best match for the current running python will be used. + If none match, 'count' will be used as the fallback. + + variance + An +/- deviation percentage, defaults to 5%. + """ + + # this could easily dump the profile report if --verbose is in effect + + version_info = list(sys.version_info) + py_version = '.'.join([str(v) for v in sys.version_info]) + + while version_info: + version = '.'.join([str(v) for v in version_info]) + if version in versions: + count = versions[version] + break + version_info.pop() + + if count is None: + return lambda fn: fn + + def decorator(fn): + def counted(*args, **kw): + try: + filename = "%s.prof" % fn.__name__ + + elapsed, stat_loader, result = _profile( + filename, fn, *args, **kw) + + stats = stat_loader() + calls = stats.total_calls + + stats.sort_stats('calls', 'cumulative') + stats.print_stats() + #stats.print_callers() + deviance = int(count * variance) + if (calls < (count - deviance) or + calls > (count + deviance)): + raise AssertionError( + "Function call count %s not within %s%% " + "of expected %s. (Python version %s)" % ( + calls, (variance * 100), count, py_version)) + + return result + finally: + if os.path.exists(filename): + os.unlink(filename) + return function_named(counted, fn.__name__) + return decorator + +def conditional_call_count(discriminator, categories): + """Apply a function call count conditionally at runtime. + + Takes two arguments, a callable that returns a key value, and a dict + mapping key values to a tuple of arguments to function_call_count. + + The callable is not evaluated until the decorated function is actually + invoked. If the `discriminator` returns a key not present in the + `categories` dictionary, no call count assertion is applied. + + Useful for integration tests, where running a named test in isolation may + have a function count penalty not seen in the full suite, due to lazy + initialization in the DB-API, SA, etc. + """ + + def decorator(fn): + def at_runtime(*args, **kw): + criteria = categories.get(discriminator(), None) + if criteria is None: + return fn(*args, **kw) + + rewrapped = function_call_count(*criteria)(fn) + return rewrapped(*args, **kw) + return function_named(at_runtime, fn.__name__) + return decorator + + +def _profile(filename, fn, *args, **kw): + global profiler + if not profiler: + profiler = 'hotshot' + if sys.version_info > (2, 5): + try: + import cProfile + profiler = 'cProfile' + except ImportError: + pass + + if profiler == 'cProfile': + return _profile_cProfile(filename, fn, *args, **kw) + else: + return _profile_hotshot(filename, fn, *args, **kw) + +def _profile_cProfile(filename, fn, *args, **kw): + import cProfile, gc, pstats, time + + load_stats = lambda: pstats.Stats(filename) + gc.collect() + + began = time.time() + cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), + filename=filename) + ended = time.time() + + return ended - began, load_stats, locals()['result'] + +def _profile_hotshot(filename, fn, *args, **kw): + import gc, hotshot, hotshot.stats, time + load_stats = lambda: hotshot.stats.load(filename) + + gc.collect() + prof = hotshot.Profile(filename) + began = time.time() + prof.start() + try: + result = fn(*args, **kw) + finally: + prof.stop() + ended = time.time() + prof.close() + + return ended - began, load_stats, result + diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py new file mode 100644 index 000000000..b23b8620d --- /dev/null +++ b/lib/sqlalchemy/test/requires.py @@ -0,0 +1,127 @@ +"""Global database feature support policy. + +Provides decorators to mark tests requiring specific feature support from the +target database. + +""" + +from testing import \ + _block_unconditionally as no_support, \ + _chain_decorators_on, \ + exclude, \ + emits_warning_on + + +def deferrable_constraints(fn): + """Target database must support derferable constraints.""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('mysql', 'not supported by database'), + no_support('mssql', 'not supported by database'), + ) + +def foreign_keys(fn): + """Target database must support foreign keys.""" + return _chain_decorators_on( + fn, + no_support('sqlite', 'not supported by database'), + ) + +def identity(fn): + """Target database must support GENERATED AS IDENTITY or a facsimile. + + Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other + column DDL feature that fills in a DB-generated identifier at INSERT-time + without requiring pre-execution of a SEQUENCE or other artifact. + + """ + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('postgres', 'not supported by database'), + no_support('sybase', 'not supported by database'), + ) + +def independent_connections(fn): + """Target must support simultaneous, independent database connections.""" + + # This is also true of some configurations of UnixODBC and probably win32 + # ODBC as well. + return _chain_decorators_on( + fn, + no_support('sqlite', 'no driver support') + ) + +def row_triggers(fn): + """Target must support standard statement-running EACH ROW triggers.""" + return _chain_decorators_on( + fn, + # no access to same table + no_support('mysql', 'requires SUPER priv'), + exclude('mysql', '<', (5, 0, 10), 'not supported by database'), + no_support('postgres', 'not supported by database: no statements'), + ) + +def savepoints(fn): + """Target database must support savepoints.""" + return _chain_decorators_on( + fn, + emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'), + no_support('access', 'not supported by database'), + no_support('sqlite', 'not supported by database'), + no_support('sybase', 'FIXME: guessing, needs confirmation'), + exclude('mysql', '<', (5, 0, 3), 'not supported by database'), + ) + +def sequences(fn): + """Target database must support SEQUENCEs.""" + return _chain_decorators_on( + fn, + no_support('access', 'no SEQUENCE support'), + no_support('mssql', 'no SEQUENCE support'), + no_support('mysql', 'no SEQUENCE support'), + no_support('sqlite', 'no SEQUENCE support'), + no_support('sybase', 'no SEQUENCE support'), + ) + +def subqueries(fn): + """Target database must support subqueries.""" + return _chain_decorators_on( + fn, + exclude('mysql', '<', (4, 1, 1), 'no subquery support'), + ) + +def two_phase_transactions(fn): + """Target database must support two-phase transactions.""" + return _chain_decorators_on( + fn, + no_support('access', 'not supported by database'), + no_support('firebird', 'no SA implementation'), + no_support('maxdb', 'not supported by database'), + no_support('mssql', 'FIXME: guessing, needs confirmation'), + no_support('oracle', 'no SA implementation'), + no_support('sqlite', 'not supported by database'), + no_support('sybase', 'FIXME: guessing, needs confirmation'), + exclude('mysql', '<', (5, 0, 3), 'not supported by database'), + ) + +def unicode_connections(fn): + """Target driver must support some encoding of Unicode across the wire.""" + # TODO: expand to exclude MySQLdb versions w/ broken unicode + return _chain_decorators_on( + fn, + exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), + ) + +def unicode_ddl(fn): + """Target driver must support some encoding of Unicode across the wire.""" + # TODO: expand to exclude MySQLdb versions w/ broken unicode + return _chain_decorators_on( + fn, + no_support('maxdb', 'database support flakey'), + no_support('oracle', 'FIXME: no support in database?'), + no_support('sybase', 'FIXME: guessing, needs confirmation'), + exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), + ) diff --git a/lib/sqlalchemy/test/schema.py b/lib/sqlalchemy/test/schema.py new file mode 100644 index 000000000..f96805fe4 --- /dev/null +++ b/lib/sqlalchemy/test/schema.py @@ -0,0 +1,74 @@ +"""Enhanced versions of schema.Table and schema.Column which establish +desired state for different backends. +""" + +from sqlalchemy.test import testing +from sqlalchemy import schema + +__all__ = 'Table', 'Column', + +table_options = {} + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k,kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + kw.update(table_options) + + if testing.against('mysql'): + if 'mysql_engine' not in kw and 'mysql_type' not in kw: + if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: + kw['mysql_engine'] = 'InnoDB' + + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around seleting self-refs too. + if testing.against('firebird'): + table_name = args[0] + unpack = (testing.config.db.dialect. + identifier_preparer.unformat_identifiers) + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [fk + for col in args if isinstance(col, schema.Column) + for fk in col.args if isinstance(fk, schema.ForeignKey)] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + if name == table_name: + if fk.ondelete is None: + fk.ondelete = 'CASCADE' + if fk.onupdate is None: + fk.onupdate = 'CASCADE' + + if testing.against('firebird', 'oracle'): + pk_seqs = [col for col in args + if (isinstance(col, schema.Column) + and col.primary_key + and getattr(col, '_needs_autoincrement', False))] + for c in pk_seqs: + c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True)) + return schema.Table(*args, **kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k,kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + c = schema.Column(*args, **kw) + if testing.against('firebird', 'oracle'): + if 'test_needs_autoincrement' in test_opts: + c._needs_autoincrement = True + return c diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py new file mode 100644 index 000000000..36c7d340a --- /dev/null +++ b/lib/sqlalchemy/test/testing.py @@ -0,0 +1,701 @@ +"""TestCase and TestSuite artifacts and testing decorators.""" + +import itertools +import operator +import re +import sys +import types +import warnings +from cStringIO import StringIO + +from sqlalchemy.test import config, assertsql +from sqlalchemy.util import function_named + +from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema + +_ops = { '<': operator.lt, + '>': operator.gt, + '==': operator.eq, + '!=': operator.ne, + '<=': operator.le, + '>=': operator.ge, + 'in': operator.contains, + 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + } + +# sugar ('testing.db'); set here by config() at runtime +db = None + +# more sugar, installed by __init__ +requires = None + +def fails_if(callable_): + """Mark a test as expected to fail if callable_ returns True. + + If the callable returns false, the test is run and reported as normal. + However if the callable returns true, the test is expected to fail and the + unit test logic is inverted: if the test fails, a success is reported. If + the test succeeds, a failure is reported. + """ + + docstring = getattr(callable_, '__doc__', None) or callable_.__name__ + description = docstring.split('\n')[0] + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if not callable_(): + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected (condition: %s): %s " % ( + fn_name, description, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' (condition: %s)" % + (fn_name, description)) + return function_named(maybe, fn_name) + return decorate + + +def future(fn): + """Mark a test as expected to unconditionally fail. + + Takes no arguments, omit parens when using as a decorator. + """ + + fn_name = fn.__name__ + def decorated(*args, **kw): + try: + fn(*args, **kw) + except Exception, ex: + print ("Future test '%s' failed as expected: %s " % ( + fn_name, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for future test '%s'" % fn_name) + return function_named(decorated, fn_name) + +def fails_on(dbs, reason): + """Mark a test as expected to fail on the specified database + implementation. + + Unlike ``crashes``, tests marked as ``fails_on`` will be run + for the named databases. The test is expected to fail and the unit test + logic is inverted: if the test fails, a success is reported. If the test + succeeds, a failure is reported. + """ + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name != dbs: + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected on DB implementation " + "'%s': %s" % ( + fn_name, config.db.name, reason)) + return True + else: + raise AssertionError( + "Unexpected success for '%s' on DB implementation '%s'" % + (fn_name, config.db.name)) + return function_named(maybe, fn_name) + return decorate + +def fails_on_everything_except(*dbs): + """Mark a test as expected to fail on most database implementations. + + Like ``fails_on``, except failure is the expected outcome on all + databases except those listed. + """ + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name in dbs: + return fn(*args, **kw) + else: + try: + fn(*args, **kw) + except Exception, ex: + print ("'%s' failed as expected on DB implementation " + "'%s': %s" % ( + fn_name, config.db.name, str(ex))) + return True + else: + raise AssertionError( + "Unexpected success for '%s' on DB implementation '%s'" % + (fn_name, config.db.name)) + return function_named(maybe, fn_name) + return decorate + +def crashes(db, reason): + """Mark a test as unsupported by a database implementation. + + ``crashes`` tests will be skipped unconditionally. Use for feature tests + that cause deadlocks or other fatal problems. + + """ + carp = _should_carp_about_exclusion(reason) + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name == db: + msg = "'%s' unsupported on DB implementation '%s': %s" % ( + fn_name, config.db.name, reason) + print msg + if carp: + print >> sys.stderr, msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + +def _block_unconditionally(db, reason): + """Mark a test as unsupported by a database implementation. + + Will never run the test against any version of the given database, ever, + no matter what. Use when your assumptions are infallible; past, present + and future. + + """ + carp = _should_carp_about_exclusion(reason) + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name == db: + msg = "'%s' unsupported on DB implementation '%s': %s" % ( + fn_name, config.db.name, reason) + print msg + if carp: + print >> sys.stderr, msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + + +def exclude(db, op, spec, reason): + """Mark a test as unsupported by specific database server versions. + + Stackable, both with other excludes and other decorators. Examples:: + + # Not supported by mydb versions less than 1, 0 + @exclude('mydb', '<', (1,0)) + # Other operators work too + @exclude('bigdb', '==', (9,0,9)) + @exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) + + """ + carp = _should_carp_about_exclusion(reason) + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if _is_excluded(db, op, spec): + msg = "'%s' unsupported on DB %s version '%s': %s" % ( + fn_name, config.db.name, _server_version(), reason) + print msg + if carp: + print >> sys.stderr, msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + +def _should_carp_about_exclusion(reason): + """Guard against forgotten exclusions.""" + assert reason + for _ in ('todo', 'fixme', 'xxx'): + if _ in reason.lower(): + return True + else: + if len(reason) < 4: + return True + +def _is_excluded(db, op, spec): + """Return True if the configured db matches an exclusion specification. + + db: + A dialect name + op: + An operator or stringified operator, such as '==' + spec: + A value that will be compared to the dialect's server_version_info + using the supplied operator. + + Examples:: + # Not supported by mydb versions less than 1, 0 + _is_excluded('mydb', '<', (1,0)) + # Other operators work too + _is_excluded('bigdb', '==', (9,0,9)) + _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) + """ + + if config.db.name != db: + return False + + version = _server_version() + + oper = hasattr(op, '__call__') and op or _ops[op] + return oper(version, spec) + +def _server_version(bind=None): + """Return a server_version_info tuple.""" + + if bind is None: + bind = config.db + return bind.dialect.server_version_info(bind.contextual_connect()) + +def skip_if(predicate, reason=None): + """Skip a test if predicate is true.""" + reason = reason or predicate.__name__ + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if predicate(): + msg = "'%s' skipped on DB %s version '%s': %s" % ( + fn_name, config.db.name, _server_version(), reason) + print msg + return True + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + +def emits_warning(*messages): + """Mark a test as emitting a warning. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + + # TODO: it would be nice to assert that a named warning was + # emitted. should work with some monkeypatching of warnings, + # and may work on non-CPython if they keep to the spirit of + # warnings.showwarning's docstring. + # - update: jython looks ok, it uses cpython's module + def decorate(fn): + def safe(*args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SAWarning)) + else: + filters.extend(dict(action='ignore', + message=message, + category=sa_exc.SAWarning) + for message in messages) + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return function_named(safe, fn.__name__) + return decorate + +def emits_warning_on(db, *warnings): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + def decorate(fn): + def maybe(*args, **kw): + if isinstance(db, basestring): + if config.db.name != db: + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + else: + if not _is_excluded(*db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + return function_named(maybe, fn.__name__) + return decorate + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + """ + + def decorate(fn): + def safe(*args, **kw): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SADeprecationWarning)) + else: + filters.extend( + [dict(action='ignore', + message=message, + category=sa_exc.SADeprecationWarning) + for message in + [ (m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages] ]) + + for f in filters: + warnings.filterwarnings(**f) + try: + return fn(*args, **kw) + finally: + resetwarnings() + return function_named(safe, fn.__name__) + return decorate + +def resetwarnings(): + """Reset warning behavior to testing defaults.""" + + warnings.filterwarnings('ignore', + category=sa_exc.SAPendingDeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) + warnings.filterwarnings('error', category=sa_exc.SAWarning) + +# warnings.simplefilter('error') + + if sys.version_info < (2, 4): + warnings.filterwarnings('ignore', category=FutureWarning) + + +def against(*queries): + """Boolean predicate, compares to testing database configuration. + + Given one or more dialect names, returns True if one is the configured + database engine. + + Also supports comparison to database version when provided with one or + more 3-tuples of dialect name, operator, and version specification:: + + testing.against('mysql', 'postgres') + testing.against(('mysql', '>=', (5, 0, 0)) + """ + + for query in queries: + if isinstance(query, basestring): + if config.db.name == query: + return True + else: + name, op, spec = query + if config.db.name != name: + continue + + have = config.db.dialect.server_version_info( + config.db.contextual_connect()) + + oper = hasattr(op, '__call__') and op or _ops[op] + if oper(have, spec): + return True + return False + +def _chain_decorators_on(fn, *decorators): + """Apply a series of decorators to fn, returning a decorated function.""" + for decorator in reversed(decorators): + fn = decorator(fn) + return fn + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return set([tuple(row) for row in results]) + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + +def is_not_(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, fragment) + +def assert_raises(except_cls, callable_, *args, **kw): + try: + callable_(*args, **kw) + assert False, "Callable did not raise an exception" + except except_cls, e: + pass + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + try: + callable_(*args, **kwargs) + assert False, "Callable did not raise an exception" + except except_cls, e: + assert re.search(msg, str(e)), "%r !~ %s" % (msg, e) + +def fail(msg): + assert False, msg + +def fixture(table, columns, *rows): + """Insert data into table after creation.""" + def onload(event, schema_item, connection): + insert = table.insert() + column_names = [col.key for col in columns] + connection.execute(insert, [dict(zip(column_names, column_values)) + for column_values in rows]) + table.append_ddl_listener('after-create', onload) + +def resolve_artifact_names(fn): + """Decorator, augment function globals with tables and classes. + + Swaps out the function's globals at execution time. The 'global' statement + will not work as expected inside a decorated function. + + """ + # This could be automatically applied to framework and test_ methods in + # the MappedTest-derived test suites but... *some* explicitness for this + # magic is probably good. Especially as 'global' won't work- these + # rebound functions aren't regular Python.. + # + # Also: it's lame that CPython accepts a dict-subclass for globals, but + # only calls dict methods. That would allow 'global' to pass through to + # the func_globals. + def resolved(*args, **kwargs): + self = args[0] + context = dict(fn.func_globals) + for source in self._artifact_registries: + context.update(getattr(self, source)) + # jython bug #1034 + rebound = types.FunctionType( + fn.func_code, context, fn.func_name, fn.func_defaults, + fn.func_closure) + return rebound(*args, **kwargs) + return function_named(resolved, fn.func_name) + +class adict(dict): + """Dict keys available as attributes. Shadows.""" + def __getattribute__(self, key): + try: + return self[key] + except KeyError: + return dict.__getattribute__(self, key) + + def get_all(self, *keys): + return tuple([self[key] for key in keys]) + + +class TestBase(object): + # A sequence of database names to always run, regardless of the + # constraints below. + __whitelist__ = () + + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + _artifact_registries = () + + def assert_(self, val, msg=None): + assert val, msg + +class AssertsCompiledSQL(object): + def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None): + if dialect is None: + dialect = getattr(self, '__dialect__', None) + + if params is None: + keys = None + else: + keys = params.keys() + + c = clause.compile(column_keys=keys, dialect=dialect) + + print "\nSQL String:\n" + str(c) + repr(c.params) + + cc = re.sub(r'\n', '', str(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + eq_(c.construct_params(params), checkparams) + +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table): + base_mro = sqltypes.TypeEngine.__mro__ + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + assert len( + set(type(reflected_c.type).__mro__).difference(base_mro).intersection( + set(type(c.type).__mro__).difference(base_mro) + ) + ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + if c.default: + assert isinstance(reflected_c.server_default, + schema.FetchedValue) + elif against(('mysql', '<', (5, 0))): + # ignore reflection of bogus db-generated DefaultClause() + pass + elif not c.primary_key or not against('postgres'): + print repr(c) + assert reflected_c.default is None, reflected_c.default + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] + + +class AssertsExecutionResults(object): + def assert_result(self, result, class_, *objects): + result = list(result) + print repr(result) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list): + self.assert_(len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + class_.__name__) + for i in range(0, len(list)): + self.assert_row(class_, result[i], list[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_(rowobj.__class__ is class_, + "item class is not " + repr(class_)) + for key, value in desc.iteritems(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_(getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" % ( + key, getattr(rowobj, key), value)) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class frozendict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = set([frozendict(e) for e in expected]) + + for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): + fail('Unexpected type "%s", expected "%s"' % ( + type(wrong).__name__, cls.__name__)) + + if len(found) != len(expected): + fail('Unexpected object count "%s", expected "%s"' % ( + len(found), len(expected))) + + NOVALUE = object() + def _compare_item(obj, spec): + for key, value in spec.iteritems(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1]) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." % ( + cls.__name__, repr(expected_item))) + return True + + def assert_sql_execution(self, db, callable_, *rules): + assertsql.asserter.add_rules(rules) + try: + callable_() + assertsql.asserter.statement_complete() + finally: + assertsql.asserter.clear_rules() + + def assert_sql(self, db, callable_, list_, with_sequences=None): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): + rules = with_sequences + else: + rules = list_ + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf(*[ + assertsql.ExactSQL(k, v) for k, v in rule.iteritems() + ]) + else: + newrule = assertsql.ExactSQL(*rule) + newrules.append(newrule) + + self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + + |
