diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 92 |
1 files changed, 67 insertions, 25 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index bcc999fe3..2ac0605a2 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -8,6 +8,9 @@ from ..engine.default import DefaultDialect from .. import util import re +import collections +import contextlib +from .. import event class AssertRule(object): @@ -321,39 +324,78 @@ def _process_assertion_statement(query, context): return query -class SQLAssert(object): +class SQLExecuteObserved( + collections.namedtuple( + "SQLExecuteObserved", ["clauseelement", "multiparams", "params"]) +): + def process(self, rules): + if rules is not None: + if not rules: + assert False, \ + 'All rules have been exhausted, but further '\ + 'statements remain' + rule = rules[0] + rule.process_execute( + self.clauseelement, *self.multiparams, **self.params) + if rule.is_consumed(): + rules.pop(0) - rules = None - def add_rules(self, rules): - self.rules = list(rules) +class SQLCursorExecuteObserved( + collections.namedtuple( + "SQLCursorExecuteObserved", + ["statement", "parameters", "context", "executemany"]) +): + def process(self, rules): + if rules: + rule = rules[0] + rule.process_cursor_execute( + self.statement, self.parameters, + self.context, self.executemany) - def statement_complete(self): - for rule in self.rules: + +class SQLAsserter(object): + def __init__(self): + self.accumulated = [] + + def _close(self): + # safety feature in case event.remove + # goes haywire + self._final = self.accumulated + del self.accumulated + + def assert_(self, *rules): + rules = list(rules) + for observed in self._final: + observed.process(rules) + + for rule in rules: if not rule.consume_final(): assert False, \ 'All statements are complete, but pending '\ 'assertion rules remain' - def clear_rules(self): - del self.rules - def execute(self, conn, clauseelement, multiparams, params, result): - if self.rules is not None: - if not self.rules: - assert False, \ - 'All rules have been exhausted, but further '\ - 'statements remain' - rule = self.rules[0] - rule.process_execute(clauseelement, *multiparams, **params) - if rule.is_consumed(): - self.rules.pop(0) +@contextlib.contextmanager +def assert_engine(engine): + asserter = SQLAsserter() - def cursor_execute(self, conn, cursor, statement, parameters, - context, executemany): - if self.rules: - rule = self.rules[0] - rule.process_cursor_execute(statement, parameters, context, - executemany) + @event.listens_for(engine, "after_execute") + def execute(conn, clauseelement, multiparams, params, result): + asserter.accumulated.append( + SQLExecuteObserved( + clauseelement, multiparams, params)) -asserter = SQLAssert() + @event.listens_for(engine, "after_cursor_execute") + def cursor_execute(conn, cursor, statement, parameters, + context, executemany): + asserter.accumulated.append( + SQLCursorExecuteObserved( + statement, parameters, context, executemany)) + + try: + yield asserter + finally: + asserter._close() + event.remove(engine, "after_cursor_execute", cursor_execute) + event.remove(engine, "after_execute", execute) |