# testing/assertsql.py # Copyright (C) 2005-2014 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php from ..engine.default import DefaultDialect from .. import util import re import collections import contextlib from .. import event class AssertRule(object): def process_execute(self, clauseelement, *multiparams, **params): pass def process_cursor_execute(self, statement, parameters, context, executemany): pass def is_consumed(self): """Return True if this rule has been consumed, False if not. Should raise an AssertionError if this rule's condition has definitely failed. """ raise NotImplementedError() def rule_passed(self): """Return True if the last test of this rule passed, False if failed, None if no test was applied.""" raise NotImplementedError() def consume_final(self): """Return True if this rule has been consumed. Should raise an AssertionError if this rule's condition has not been consumed or has failed. """ if self._result is None: assert False, 'Rule has not been consumed' return self.is_consumed() class SQLMatchRule(AssertRule): def __init__(self): self._result = None self._errmsg = "" def rule_passed(self): return self._result def is_consumed(self): if self._result is None: return False assert self._result, self._errmsg return True class ExactSQL(SQLMatchRule): def __init__(self, sql, params=None): SQLMatchRule.__init__(self) self.sql = sql self.params = params def process_cursor_execute(self, statement, parameters, context, executemany): if not context: return _received_statement = \ _process_engine_statement(context.unicode_statement, context) _received_parameters = context.compiled_parameters # TODO: remove this step once all unit tests are migrated, as # ExactSQL should really be *exact* SQL sql = _process_assertion_statement(self.sql, context) equivalent = _received_statement == sql if self.params: if util.callable(self.params): params = self.params(context) else: params = self.params if not isinstance(params, list): params = [params] equivalent = equivalent and params \ == context.compiled_parameters else: params = {} self._result = equivalent if not self._result: self._errmsg = ( 'Testing for exact statement %r exact params %r, ' 'received %r with params %r' % (sql, params, _received_statement, _received_parameters)) class RegexSQL(SQLMatchRule): def __init__(self, regex, params=None): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params def process_cursor_execute(self, statement, parameters, context, executemany): if not context: return _received_statement = \ _process_engine_statement(context.unicode_statement, context) _received_parameters = context.compiled_parameters equivalent = bool(self.regex.match(_received_statement)) if self.params: if util.callable(self.params): params = self.params(context) else: params = self.params if not isinstance(params, list): params = [params] # do a positive compare only for param, received in zip(params, _received_parameters): for k, v in param.items(): if k not in received or received[k] != v: equivalent = False break else: params = {} self._result = equivalent if not self._result: self._errmsg = \ 'Testing for regex %r partial params %r, received %r '\ 'with params %r' % (self.orig_regex, params, _received_statement, _received_parameters) class CompiledSQL(SQLMatchRule): def __init__(self, statement, params=None): SQLMatchRule.__init__(self) self.statement = statement self.params = params def process_cursor_execute(self, statement, parameters, context, executemany): if not context: return from sqlalchemy.schema import _DDLCompiles _received_parameters = list(context.compiled_parameters) # recompile from the context, using the default dialect if isinstance(context.compiled.statement, _DDLCompiles): compiled = \ context.compiled.statement.compile(dialect=DefaultDialect()) else: compiled = ( context.compiled.statement.compile( dialect=DefaultDialect(), column_keys=context.compiled.column_keys) ) _received_statement = re.sub(r'[\n\t]', '', str(compiled)) equivalent = self.statement == _received_statement if self.params: if util.callable(self.params): params = self.params(context) else: params = self.params if not isinstance(params, list): params = [params] else: params = list(params) all_params = list(params) all_received = list(_received_parameters) while params: param = dict(params.pop(0)) for k, v in context.compiled.params.items(): param.setdefault(k, v) if param not in _received_parameters: equivalent = False break else: _received_parameters.remove(param) if _received_parameters: equivalent = False else: params = {} all_params = {} all_received = [] self._result = equivalent if not self._result: print('Testing for compiled statement %r partial params ' '%r, received %r with params %r' % (self.statement, all_params, _received_statement, all_received)) self._errmsg = ( 'Testing for compiled statement %r partial params %r, ' 'received %r with params %r' % (self.statement, all_params, _received_statement, all_received)) # print self._errmsg class CountStatements(AssertRule): def __init__(self, count): self.count = count self._statement_count = 0 def process_execute(self, clauseelement, *multiparams, **params): self._statement_count += 1 def process_cursor_execute(self, statement, parameters, context, executemany): pass def is_consumed(self): return False def consume_final(self): assert self.count == self._statement_count, \ 'desired statement count %d does not match %d' \ % (self.count, self._statement_count) return True class AllOf(AssertRule): def __init__(self, *rules): self.rules = set(rules) def process_execute(self, clauseelement, *multiparams, **params): for rule in self.rules: rule.process_execute(clauseelement, *multiparams, **params) def process_cursor_execute(self, statement, parameters, context, executemany): for rule in self.rules: rule.process_cursor_execute(statement, parameters, context, executemany) def is_consumed(self): if not self.rules: return True for rule in list(self.rules): if rule.rule_passed(): # a rule passed, move on self.rules.remove(rule) return len(self.rules) == 0 return False def rule_passed(self): return self.is_consumed() def consume_final(self): return len(self.rules) == 0 class Or(AllOf): def __init__(self, *rules): self.rules = set(rules) self._consume_final = False def is_consumed(self): if not self.rules: return True for rule in list(self.rules): if rule.rule_passed(): # a rule passed self._consume_final = True return True return False def consume_final(self): assert self._consume_final, "Unsatisified rules remain" def _process_engine_statement(query, context): if util.jython: # oracle+zxjdbc passes a PyStatement when returning into query = str(query) if context.engine.name == 'mssql' \ and query.endswith('; select scope_identity()'): query = query[:-25] query = re.sub(r'\n', '', query) return query def _process_assertion_statement(query, context): paramstyle = context.dialect.paramstyle if paramstyle == 'named': pass elif paramstyle == 'pyformat': query = re.sub(r':([\w_]+)', r"%(\1)s", query) else: # positional params repl = None if paramstyle == 'qmark': repl = "?" elif paramstyle == 'format': repl = r"%s" elif paramstyle == 'numeric': repl = None query = re.sub(r':([\w_]+)', repl, query) return query class SQLExecuteObserved( collections.namedtuple( "SQLExecuteObserved", ["clauseelement", "multiparams", "params"]) ): def process(self, rules): if rules is not None: if not rules: assert False, \ 'All rules have been exhausted, but further '\ 'statements remain' rule = rules[0] rule.process_execute( self.clauseelement, *self.multiparams, **self.params) if rule.is_consumed(): rules.pop(0) class SQLCursorExecuteObserved( collections.namedtuple( "SQLCursorExecuteObserved", ["statement", "parameters", "context", "executemany"]) ): def process(self, rules): if rules: rule = rules[0] rule.process_cursor_execute( self.statement, self.parameters, self.context, self.executemany) class SQLAsserter(object): def __init__(self): self.accumulated = [] def _close(self): # safety feature in case event.remove # goes haywire self._final = self.accumulated del self.accumulated def assert_(self, *rules): rules = list(rules) for observed in self._final: observed.process(rules) for rule in rules: if not rule.consume_final(): assert False, \ 'All statements are complete, but pending '\ 'assertion rules remain' @contextlib.contextmanager def assert_engine(engine): asserter = SQLAsserter() @event.listens_for(engine, "after_execute") def execute(conn, clauseelement, multiparams, params, result): asserter.accumulated.append( SQLExecuteObserved( clauseelement, multiparams, params)) @event.listens_for(engine, "after_cursor_execute") def cursor_execute(conn, cursor, statement, parameters, context, executemany): asserter.accumulated.append( SQLCursorExecuteObserved( statement, parameters, context, executemany)) try: yield asserter finally: asserter._close() event.remove(engine, "after_cursor_execute", cursor_execute) event.remove(engine, "after_execute", execute)