diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 48 |
1 files changed, 35 insertions, 13 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e39b6315d..86d850733 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -282,6 +282,32 @@ class AllOf(AssertRule): self.errormessage = list(self.rules)[0].errormessage +class EachOf(AssertRule): + + def __init__(self, *rules): + self.rules = list(rules) + + def process_statement(self, execute_observed): + while self.rules: + rule = self.rules[0] + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.pop(0) + elif rule.errormessage: + self.errormessage = rule.errormessage + if rule.consume_statement: + break + + if not self.rules: + self.is_consumed = True + + def no_more_statements(self): + if self.rules and not self.rules[0].is_consumed: + self.rules[0].no_more_statements() + elif self.rules: + super(EachOf, self).no_more_statements() + + class Or(AllOf): def process_statement(self, execute_observed): @@ -319,24 +345,20 @@ class SQLAsserter(object): del self.accumulated def assert_(self, *rules): - rules = list(rules) - observed = list(self._final) + rule = EachOf(*rules) - while observed and rules: - rule = rules[0] - rule.process_statement(observed[0]) + observed = list(self._final) + while observed: + statement = observed.pop(0) + rule.process_statement(statement) if rule.is_consumed: - rules.pop(0) + break elif rule.errormessage: assert False, rule.errormessage - - if rule.consume_statement: - observed.pop(0) - - if not observed and rules: - rules[0].no_more_statements() - elif not rules and observed: + if observed: assert False, "Additional SQL statements remain" + elif not rule.is_consumed: + rule.no_more_statements() @contextlib.contextmanager |