summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-12-07 18:54:52 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2014-12-07 18:54:52 -0500
commit60e6ac8856e5f7f257e1797280d1510682ae8fb7 (patch)
tree41ada63c12c08415cc91dfdb812d85fd4430e1ca
parent1b260c7959c9b89e6a3993d5d96bc6b0918a8fb0 (diff)
downloadsqlalchemy-60e6ac8856e5f7f257e1797280d1510682ae8fb7.tar.gz
- rework the assert_sql system so that we have a context manager to work with,
use events that are local to the engine and to the run and are removed afterwards.
-rw-r--r--lib/sqlalchemy/testing/assertions.py13
-rw-r--r--lib/sqlalchemy/testing/assertsql.py92
-rw-r--r--lib/sqlalchemy/testing/engines.py3
3 files changed, 75 insertions, 33 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index bf7c27a89..66d1f3cb0 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -405,13 +405,16 @@ class AssertsExecutionResults(object):
cls.__name__, repr(expected_item)))
return True
+ def sql_execution_asserter(self, db=None):
+ if db is None:
+ from . import db as db
+
+ return assertsql.assert_engine(db)
+
def assert_sql_execution(self, db, callable_, *rules):
- assertsql.asserter.add_rules(rules)
- try:
+ with self.sql_execution_asserter(db) as asserter:
callable_()
- assertsql.asserter.statement_complete()
- finally:
- assertsql.asserter.clear_rules()
+ asserter.assert_(*rules)
def assert_sql(self, db, callable_, list_, with_sequences=None):
if (with_sequences is not None and
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index bcc999fe3..2ac0605a2 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -8,6 +8,9 @@
from ..engine.default import DefaultDialect
from .. import util
import re
+import collections
+import contextlib
+from .. import event
class AssertRule(object):
@@ -321,39 +324,78 @@ def _process_assertion_statement(query, context):
return query
-class SQLAssert(object):
+class SQLExecuteObserved(
+ collections.namedtuple(
+ "SQLExecuteObserved", ["clauseelement", "multiparams", "params"])
+):
+ def process(self, rules):
+ if rules is not None:
+ if not rules:
+ assert False, \
+ 'All rules have been exhausted, but further '\
+ 'statements remain'
+ rule = rules[0]
+ rule.process_execute(
+ self.clauseelement, *self.multiparams, **self.params)
+ if rule.is_consumed():
+ rules.pop(0)
- rules = None
- def add_rules(self, rules):
- self.rules = list(rules)
+class SQLCursorExecuteObserved(
+ collections.namedtuple(
+ "SQLCursorExecuteObserved",
+ ["statement", "parameters", "context", "executemany"])
+):
+ def process(self, rules):
+ if rules:
+ rule = rules[0]
+ rule.process_cursor_execute(
+ self.statement, self.parameters,
+ self.context, self.executemany)
- def statement_complete(self):
- for rule in self.rules:
+
+class SQLAsserter(object):
+ def __init__(self):
+ self.accumulated = []
+
+ def _close(self):
+ # safety feature in case event.remove
+ # goes haywire
+ self._final = self.accumulated
+ del self.accumulated
+
+ def assert_(self, *rules):
+ rules = list(rules)
+ for observed in self._final:
+ observed.process(rules)
+
+ for rule in rules:
if not rule.consume_final():
assert False, \
'All statements are complete, but pending '\
'assertion rules remain'
- def clear_rules(self):
- del self.rules
- def execute(self, conn, clauseelement, multiparams, params, result):
- if self.rules is not None:
- if not self.rules:
- assert False, \
- 'All rules have been exhausted, but further '\
- 'statements remain'
- rule = self.rules[0]
- rule.process_execute(clauseelement, *multiparams, **params)
- if rule.is_consumed():
- self.rules.pop(0)
+@contextlib.contextmanager
+def assert_engine(engine):
+ asserter = SQLAsserter()
- def cursor_execute(self, conn, cursor, statement, parameters,
- context, executemany):
- if self.rules:
- rule = self.rules[0]
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
+ @event.listens_for(engine, "after_execute")
+ def execute(conn, clauseelement, multiparams, params, result):
+ asserter.accumulated.append(
+ SQLExecuteObserved(
+ clauseelement, multiparams, params))
-asserter = SQLAssert()
+ @event.listens_for(engine, "after_cursor_execute")
+ def cursor_execute(conn, cursor, statement, parameters,
+ context, executemany):
+ asserter.accumulated.append(
+ SQLCursorExecuteObserved(
+ statement, parameters, context, executemany))
+
+ try:
+ yield asserter
+ finally:
+ asserter._close()
+ event.remove(engine, "after_cursor_execute", cursor_execute)
+ event.remove(engine, "after_execute", execute)
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index 0f6f59401..7d73e7423 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -204,7 +204,6 @@ def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
- from .assertsql import asserter
if not options:
use_reaper = True
@@ -219,8 +218,6 @@ def testing_engine(url=None, options=None):
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
- event.listen(engine, 'after_execute', asserter.execute)
- event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
if use_reaper:
event.listen(engine.pool, 'connect', testing_reaper.connect)
event.listen(engine.pool, 'checkout', testing_reaper.checkout)