summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
authorjonathan vanasco <jonathan@2xlp.com>2015-04-02 13:30:26 -0400
committerjonathan vanasco <jonathan@2xlp.com>2015-04-02 13:30:26 -0400
commit6de3d490a2adb0fff43f98e15a53407b46668b61 (patch)
treed5e0e2077dfe7dc69ce30e9d0a8c89ceff78e3fe /lib/sqlalchemy/testing/assertsql.py
parentefca4af93603faa7abfeacbab264cad85ee4105c (diff)
parent5e04995a82c00e801a99765cde7726f5e73e18c2 (diff)
downloadsqlalchemy-6de3d490a2adb0fff43f98e15a53407b46668b61.tar.gz
Merge branch 'master' of bitbucket.org:zzzeek/sqlalchemy
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py532
1 files changed, 264 insertions, 268 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index bcc999fe3..a596d9743 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -1,5 +1,5 @@
# testing/assertsql.py
-# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# Copyright (C) 2005-2015 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -8,84 +8,141 @@
from ..engine.default import DefaultDialect
from .. import util
import re
+import collections
+import contextlib
+from .. import event
+from sqlalchemy.schema import _DDLCompiles
+from sqlalchemy.engine.util import _distill_params
class AssertRule(object):
- def process_execute(self, clauseelement, *multiparams, **params):
- pass
+ is_consumed = False
+ errormessage = None
+ consume_statement = True
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
+ def process_statement(self, execute_observed):
pass
- def is_consumed(self):
- """Return True if this rule has been consumed, False if not.
-
- Should raise an AssertionError if this rule's condition has
- definitely failed.
-
- """
-
- raise NotImplementedError()
+ def no_more_statements(self):
+ assert False, 'All statements are complete, but pending '\
+ 'assertion rules remain'
- def rule_passed(self):
- """Return True if the last test of this rule passed, False if
- failed, None if no test was applied."""
- raise NotImplementedError()
-
- def consume_final(self):
- """Return True if this rule has been consumed.
-
- Should raise an AssertionError if this rule's condition has not
- been consumed or has failed.
+class SQLMatchRule(AssertRule):
+ pass
- """
- if self._result is None:
- assert False, 'Rule has not been consumed'
- return self.is_consumed()
+class CursorSQL(SQLMatchRule):
+ consume_statement = False
+ def __init__(self, statement, params=None):
+ self.statement = statement
+ self.params = params
-class SQLMatchRule(AssertRule):
- def __init__(self):
- self._result = None
- self._errmsg = ""
+ def process_statement(self, execute_observed):
+ stmt = execute_observed.statements[0]
+ if self.statement != stmt.statement or (
+ self.params is not None and self.params != stmt.parameters):
+ self.errormessage = \
+ "Testing for exact SQL %s parameters %s received %s %s" % (
+ self.statement, self.params,
+ stmt.statement, stmt.parameters
+ )
+ else:
+ execute_observed.statements.pop(0)
+ self.is_consumed = True
+ if not execute_observed.statements:
+ self.consume_statement = True
- def rule_passed(self):
- return self._result
- def is_consumed(self):
- if self._result is None:
- return False
+class CompiledSQL(SQLMatchRule):
- assert self._result, self._errmsg
+ def __init__(self, statement, params=None):
+ self.statement = statement
+ self.params = params
- return True
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r'[\n\t]', '', self.statement)
+ return received_statement == stmt
+ def _compile_dialect(self, execute_observed):
+ return DefaultDialect()
-class ExactSQL(SQLMatchRule):
+ def _received_statement(self, execute_observed):
+ """reconstruct the statement and params in terms
+ of a target dialect, which for CompiledSQL is just DefaultDialect."""
- def __init__(self, sql, params=None):
- SQLMatchRule.__init__(self)
- self.sql = sql
- self.params = params
+ context = execute_observed.context
+ compare_dialect = self._compile_dialect(execute_observed)
+ if isinstance(context.compiled.statement, _DDLCompiles):
+ compiled = \
+ context.compiled.statement.compile(dialect=compare_dialect)
+ else:
+ compiled = (
+ context.compiled.statement.compile(
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ inline=context.compiled.inline)
+ )
+ _received_statement = re.sub(r'[\n\t]', '', str(compiled))
+ parameters = execute_observed.parameters
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- _received_statement = \
- _process_engine_statement(context.unicode_statement,
- context)
- _received_parameters = context.compiled_parameters
+ if not parameters:
+ _received_parameters = [compiled.construct_params()]
+ else:
+ _received_parameters = [
+ compiled.construct_params(m) for m in parameters]
+
+ return _received_statement, _received_parameters
+
+ def process_statement(self, execute_observed):
+ context = execute_observed.context
+
+ _received_statement, _received_parameters = \
+ self._received_statement(execute_observed)
+ params = self._all_params(context)
+
+ equivalent = self._compare_sql(execute_observed, _received_statement)
+
+ if equivalent:
+ if params is not None:
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while all_params and all_received:
+ param = dict(all_params.pop(0))
+
+ for idx, received in enumerate(list(all_received)):
+ # do a positive compare only
+ for param_key in param:
+ # a key in param did not match current
+ # 'received'
+ if param_key not in received or \
+ received[param_key] != param[param_key]:
+ break
+ else:
+ # all keys in param matched 'received';
+ # onto next param
+ del all_received[idx]
+ break
+ else:
+ # param did not match any entry
+ # in all_received
+ equivalent = False
+ break
+ if all_params or all_received:
+ equivalent = False
- # TODO: remove this step once all unit tests are migrated, as
- # ExactSQL should really be *exact* SQL
+ if equivalent:
+ self.is_consumed = True
+ self.errormessage = None
+ else:
+ self.errormessage = self._failure_message(params) % {
+ 'received_statement': _received_statement,
+ 'received_parameters': _received_parameters
+ }
- sql = _process_assertion_statement(self.sql, context)
- equivalent = _received_statement == sql
+ def _all_params(self, context):
if self.params:
if util.callable(self.params):
params = self.params(context)
@@ -93,127 +150,77 @@ class ExactSQL(SQLMatchRule):
params = self.params
if not isinstance(params, list):
params = [params]
- equivalent = equivalent and params \
- == context.compiled_parameters
+ return params
else:
- params = {}
- self._result = equivalent
- if not self._result:
- self._errmsg = (
- 'Testing for exact statement %r exact params %r, '
- 'received %r with params %r' %
- (sql, params, _received_statement, _received_parameters))
-
+ return None
+
+ def _failure_message(self, expected_params):
+ return (
+ 'Testing for compiled statement %r partial params %r, '
+ 'received %%(received_statement)r with params '
+ '%%(received_parameters)r' % (
+ self.statement, expected_params
+ )
+ )
-class RegexSQL(SQLMatchRule):
+class RegexSQL(CompiledSQL):
def __init__(self, regex, params=None):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- _received_statement = \
- _process_engine_statement(context.unicode_statement,
- context)
- _received_parameters = context.compiled_parameters
- equivalent = bool(self.regex.match(_received_statement))
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
-
- # do a positive compare only
-
- for param, received in zip(params, _received_parameters):
- for k, v in param.items():
- if k not in received or received[k] != v:
- equivalent = False
- break
- else:
- params = {}
- self._result = equivalent
- if not self._result:
- self._errmsg = \
- 'Testing for regex %r partial params %r, received %r '\
- 'with params %r' % (self.orig_regex, params,
- _received_statement,
- _received_parameters)
-
+ def _failure_message(self, expected_params):
+ return (
+ 'Testing for compiled statement ~%r partial params %r, '
+ 'received %%(received_statement)r with params '
+ '%%(received_parameters)r' % (
+ self.orig_regex, expected_params
+ )
+ )
-class CompiledSQL(SQLMatchRule):
+ def _compare_sql(self, execute_observed, received_statement):
+ return bool(self.regex.match(received_statement))
- def __init__(self, statement, params=None):
- SQLMatchRule.__init__(self)
- self.statement = statement
- self.params = params
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- from sqlalchemy.schema import _DDLCompiles
- _received_parameters = list(context.compiled_parameters)
-
- # recompile from the context, using the default dialect
+class DialectSQL(CompiledSQL):
+ def _compile_dialect(self, execute_observed):
+ return execute_observed.context.dialect
- if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(dialect=DefaultDialect())
+ def _received_statement(self, execute_observed):
+ received_stmt, received_params = super(DialectSQL, self).\
+ _received_statement(execute_observed)
+ for real_stmt in execute_observed.statements:
+ if real_stmt.statement == received_stmt:
+ break
else:
- compiled = (
- context.compiled.statement.compile(
- dialect=DefaultDialect(),
- column_keys=context.compiled.column_keys)
- )
- _received_statement = re.sub(r'[\n\t]', '', str(compiled))
- equivalent = self.statement == _received_statement
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
- else:
- params = list(params)
- all_params = list(params)
- all_received = list(_received_parameters)
- while params:
- param = dict(params.pop(0))
- for k, v in context.compiled.params.items():
- param.setdefault(k, v)
- if param not in _received_parameters:
- equivalent = False
- break
- else:
- _received_parameters.remove(param)
- if _received_parameters:
- equivalent = False
+ raise AssertionError(
+ "Can't locate compiled statement %r in list of "
+ "statements actually invoked" % received_stmt)
+ return received_stmt, execute_observed.context.compiled_parameters
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r'[\n\t]', '', self.statement)
+
+ # convert our comparison statement to have the
+ # paramstyle of the received
+ paramstyle = execute_observed.context.dialect.paramstyle
+ if paramstyle == 'pyformat':
+ stmt = re.sub(
+ r':([\w_]+)', r"%(\1)s", stmt)
else:
- params = {}
- all_params = {}
- all_received = []
- self._result = equivalent
- if not self._result:
- print('Testing for compiled statement %r partial params '
- '%r, received %r with params %r' %
- (self.statement, all_params,
- _received_statement, all_received))
- self._errmsg = (
- 'Testing for compiled statement %r partial params %r, '
- 'received %r with params %r' %
- (self.statement, all_params,
- _received_statement, all_received))
-
- # print self._errmsg
+ # positional params
+ repl = None
+ if paramstyle == 'qmark':
+ repl = "?"
+ elif paramstyle == 'format':
+ repl = r"%s"
+ elif paramstyle == 'numeric':
+ repl = None
+ stmt = re.sub(r':([\w_]+)', repl, stmt)
+
+ return received_statement == stmt
class CountStatements(AssertRule):
@@ -222,21 +229,13 @@ class CountStatements(AssertRule):
self.count = count
self._statement_count = 0
- def process_execute(self, clauseelement, *multiparams, **params):
+ def process_statement(self, execute_observed):
self._statement_count += 1
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- pass
-
- def is_consumed(self):
- return False
-
- def consume_final(self):
- assert self.count == self._statement_count, \
- 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
- return True
+ def no_more_statements(self):
+ if self.count != self._statement_count:
+ assert False, 'desired statement count %d does not match %d' \
+ % (self.count, self._statement_count)
class AllOf(AssertRule):
@@ -244,116 +243,113 @@ class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
- def process_execute(self, clauseelement, *multiparams, **params):
- for rule in self.rules:
- rule.process_execute(clauseelement, *multiparams, **params)
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- for rule in self.rules:
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
-
- def is_consumed(self):
- if not self.rules:
- return True
+ def process_statement(self, execute_observed):
for rule in list(self.rules):
- if rule.rule_passed(): # a rule passed, move on
- self.rules.remove(rule)
- return len(self.rules) == 0
- return False
+ rule.errormessage = None
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.discard(rule)
+ if not self.rules:
+ self.is_consumed = True
+ break
+ elif not rule.errormessage:
+ # rule is not done yet
+ self.errormessage = None
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
- def rule_passed(self):
- return self.is_consumed()
- def consume_final(self):
- return len(self.rules) == 0
+class Or(AllOf):
+ def process_statement(self, execute_observed):
+ for rule in self.rules:
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.is_consumed = True
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
-class Or(AllOf):
- def __init__(self, *rules):
- self.rules = set(rules)
- self._consume_final = False
- def is_consumed(self):
- if not self.rules:
- return True
- for rule in list(self.rules):
- if rule.rule_passed(): # a rule passed
- self._consume_final = True
- return True
- return False
+class SQLExecuteObserved(object):
+ def __init__(self, context, clauseelement, multiparams, params):
+ self.context = context
+ self.clauseelement = clauseelement
+ self.parameters = _distill_params(multiparams, params)
+ self.statements = []
- def consume_final(self):
- assert self._consume_final, "Unsatisified rules remain"
+class SQLCursorExecuteObserved(
+ collections.namedtuple(
+ "SQLCursorExecuteObserved",
+ ["statement", "parameters", "context", "executemany"])
+):
+ pass
-def _process_engine_statement(query, context):
- if util.jython:
- # oracle+zxjdbc passes a PyStatement when returning into
+class SQLAsserter(object):
+ def __init__(self):
+ self.accumulated = []
- query = str(query)
- if context.engine.name == 'mssql' \
- and query.endswith('; select scope_identity()'):
- query = query[:-25]
- query = re.sub(r'\n', '', query)
- return query
+ def _close(self):
+ self._final = self.accumulated
+ del self.accumulated
+ def assert_(self, *rules):
+ rules = list(rules)
+ observed = list(self._final)
-def _process_assertion_statement(query, context):
- paramstyle = context.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle == 'pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle == 'qmark':
- repl = "?"
- elif paramstyle == 'format':
- repl = r"%s"
- elif paramstyle == 'numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
+ while observed and rules:
+ rule = rules[0]
+ rule.process_statement(observed[0])
+ if rule.is_consumed:
+ rules.pop(0)
+ elif rule.errormessage:
+ assert False, rule.errormessage
- return query
+ if rule.consume_statement:
+ observed.pop(0)
+ if not observed and rules:
+ rules[0].no_more_statements()
+ elif not rules and observed:
+ assert False, "Additional SQL statements remain"
-class SQLAssert(object):
- rules = None
+@contextlib.contextmanager
+def assert_engine(engine):
+ asserter = SQLAsserter()
- def add_rules(self, rules):
- self.rules = list(rules)
+ orig = []
- def statement_complete(self):
- for rule in self.rules:
- if not rule.consume_final():
- assert False, \
- 'All statements are complete, but pending '\
- 'assertion rules remain'
-
- def clear_rules(self):
- del self.rules
-
- def execute(self, conn, clauseelement, multiparams, params, result):
- if self.rules is not None:
- if not self.rules:
- assert False, \
- 'All rules have been exhausted, but further '\
- 'statements remain'
- rule = self.rules[0]
- rule.process_execute(clauseelement, *multiparams, **params)
- if rule.is_consumed():
- self.rules.pop(0)
-
- def cursor_execute(self, conn, cursor, statement, parameters,
- context, executemany):
- if self.rules:
- rule = self.rules[0]
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
+ @event.listens_for(engine, "before_execute")
+ def connection_execute(conn, clauseelement, multiparams, params):
+ # grab the original statement + params before any cursor
+ # execution
+ orig[:] = clauseelement, multiparams, params
-asserter = SQLAssert()
+ @event.listens_for(engine, "after_cursor_execute")
+ def cursor_execute(conn, cursor, statement, parameters,
+ context, executemany):
+ if not context:
+ return
+ # then grab real cursor statements and associate them all
+ # around a single context
+ if asserter.accumulated and \
+ asserter.accumulated[-1].context is context:
+ obs = asserter.accumulated[-1]
+ else:
+ obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
+ asserter.accumulated.append(obs)
+ obs.statements.append(
+ SQLCursorExecuteObserved(
+ statement, parameters, context, executemany)
+ )
+
+ try:
+ yield asserter
+ finally:
+ event.remove(engine, "after_cursor_execute", cursor_execute)
+ event.remove(engine, "before_execute", connection_execute)
+ asserter._close()