summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-12-07 18:54:52 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2014-12-07 18:54:52 -0500
commit60e6ac8856e5f7f257e1797280d1510682ae8fb7 (patch)
tree41ada63c12c08415cc91dfdb812d85fd4430e1ca /lib/sqlalchemy/testing/assertsql.py
parent1b260c7959c9b89e6a3993d5d96bc6b0918a8fb0 (diff)
downloadsqlalchemy-60e6ac8856e5f7f257e1797280d1510682ae8fb7.tar.gz
- rework the assert_sql system so that we have a context manager to work with,
use events that are local to the engine and to the run and are removed afterwards.
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py92
1 files changed, 67 insertions, 25 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index bcc999fe3..2ac0605a2 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -8,6 +8,9 @@
from ..engine.default import DefaultDialect
from .. import util
import re
+import collections
+import contextlib
+from .. import event
class AssertRule(object):
@@ -321,39 +324,78 @@ def _process_assertion_statement(query, context):
return query
-class SQLAssert(object):
+class SQLExecuteObserved(
+ collections.namedtuple(
+ "SQLExecuteObserved", ["clauseelement", "multiparams", "params"])
+):
+ def process(self, rules):
+ if rules is not None:
+ if not rules:
+ assert False, \
+ 'All rules have been exhausted, but further '\
+ 'statements remain'
+ rule = rules[0]
+ rule.process_execute(
+ self.clauseelement, *self.multiparams, **self.params)
+ if rule.is_consumed():
+ rules.pop(0)
- rules = None
- def add_rules(self, rules):
- self.rules = list(rules)
+class SQLCursorExecuteObserved(
+ collections.namedtuple(
+ "SQLCursorExecuteObserved",
+ ["statement", "parameters", "context", "executemany"])
+):
+ def process(self, rules):
+ if rules:
+ rule = rules[0]
+ rule.process_cursor_execute(
+ self.statement, self.parameters,
+ self.context, self.executemany)
- def statement_complete(self):
- for rule in self.rules:
+
+class SQLAsserter(object):
+ def __init__(self):
+ self.accumulated = []
+
+ def _close(self):
+ # safety feature in case event.remove
+ # goes haywire
+ self._final = self.accumulated
+ del self.accumulated
+
+ def assert_(self, *rules):
+ rules = list(rules)
+ for observed in self._final:
+ observed.process(rules)
+
+ for rule in rules:
if not rule.consume_final():
assert False, \
'All statements are complete, but pending '\
'assertion rules remain'
- def clear_rules(self):
- del self.rules
- def execute(self, conn, clauseelement, multiparams, params, result):
- if self.rules is not None:
- if not self.rules:
- assert False, \
- 'All rules have been exhausted, but further '\
- 'statements remain'
- rule = self.rules[0]
- rule.process_execute(clauseelement, *multiparams, **params)
- if rule.is_consumed():
- self.rules.pop(0)
+@contextlib.contextmanager
+def assert_engine(engine):
+ asserter = SQLAsserter()
- def cursor_execute(self, conn, cursor, statement, parameters,
- context, executemany):
- if self.rules:
- rule = self.rules[0]
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
+ @event.listens_for(engine, "after_execute")
+ def execute(conn, clauseelement, multiparams, params, result):
+ asserter.accumulated.append(
+ SQLExecuteObserved(
+ clauseelement, multiparams, params))
-asserter = SQLAssert()
+ @event.listens_for(engine, "after_cursor_execute")
+ def cursor_execute(conn, cursor, statement, parameters,
+ context, executemany):
+ asserter.accumulated.append(
+ SQLCursorExecuteObserved(
+ statement, parameters, context, executemany))
+
+ try:
+ yield asserter
+ finally:
+ asserter._close()
+ event.remove(engine, "after_cursor_execute", cursor_execute)
+ event.remove(engine, "after_execute", execute)