summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-09-27 02:37:33 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-09-27 02:37:33 -0400
commit20cdc64588b0f6ae52f8380c11d0ed848005377b (patch)
tree08f6cc8f82263f1e402c1c05c83b66a1f4b016ac /lib/sqlalchemy/testing/assertsql.py
parent21cac5b598a83ef0e24423dc523629b475aa3af0 (diff)
downloadsqlalchemy-20cdc64588b0f6ae52f8380c11d0ed848005377b.tar.gz
trying different approaches to test layout. in this one, the testing modules
become an externally usable package but still remains within the main sqlalchemy parent package. in this system, we use kind of an ugly hack to get the noseplugin imported outside of the "sqlalchemy" package, while still making it available within sqlalchemy for usage by third party libraries.
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py316
1 files changed, 316 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
new file mode 100644
index 000000000..897f4b3b1
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -0,0 +1,316 @@
+
+from sqlalchemy.interfaces import ConnectionProxy
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.base import Connection
+from sqlalchemy import util
+import re
+
+class AssertRule(object):
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ pass
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ pass
+
+ def is_consumed(self):
+ """Return True if this rule has been consumed, False if not.
+
+ Should raise an AssertionError if this rule's condition has
+ definitely failed.
+
+ """
+
+ raise NotImplementedError()
+
+ def rule_passed(self):
+ """Return True if the last test of this rule passed, False if
+ failed, None if no test was applied."""
+
+ raise NotImplementedError()
+
+ def consume_final(self):
+ """Return True if this rule has been consumed.
+
+ Should raise an AssertionError if this rule's condition has not
+ been consumed or has failed.
+
+ """
+
+ if self._result is None:
+ assert False, 'Rule has not been consumed'
+ return self.is_consumed()
+
+class SQLMatchRule(AssertRule):
+ def __init__(self):
+ self._result = None
+ self._errmsg = ""
+
+ def rule_passed(self):
+ return self._result
+
+ def is_consumed(self):
+ if self._result is None:
+ return False
+
+ assert self._result, self._errmsg
+
+ return True
+
+class ExactSQL(SQLMatchRule):
+
+ def __init__(self, sql, params=None):
+ SQLMatchRule.__init__(self)
+ self.sql = sql
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ if not context:
+ return
+ _received_statement = \
+ _process_engine_statement(context.unicode_statement,
+ context)
+ _received_parameters = context.compiled_parameters
+
+ # TODO: remove this step once all unit tests are migrated, as
+ # ExactSQL should really be *exact* SQL
+
+ sql = _process_assertion_statement(self.sql, context)
+ equivalent = _received_statement == sql
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+ equivalent = equivalent and params \
+ == context.compiled_parameters
+ else:
+ params = {}
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = \
+ 'Testing for exact statement %r exact params %r, '\
+ 'received %r with params %r' % (sql, params,
+ _received_statement, _received_parameters)
+
+
+class RegexSQL(SQLMatchRule):
+
+ def __init__(self, regex, params=None):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ if not context:
+ return
+ _received_statement = \
+ _process_engine_statement(context.unicode_statement,
+ context)
+ _received_parameters = context.compiled_parameters
+ equivalent = bool(self.regex.match(_received_statement))
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+
+ # do a positive compare only
+
+ for param, received in zip(params, _received_parameters):
+ for k, v in param.iteritems():
+ if k not in received or received[k] != v:
+ equivalent = False
+ break
+ else:
+ params = {}
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = \
+ 'Testing for regex %r partial params %r, received %r '\
+ 'with params %r' % (self.orig_regex, params,
+ _received_statement,
+ _received_parameters)
+
+class CompiledSQL(SQLMatchRule):
+
+ def __init__(self, statement, params):
+ SQLMatchRule.__init__(self)
+ self.statement = statement
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ if not context:
+ return
+ _received_parameters = list(context.compiled_parameters)
+
+ # recompile from the context, using the default dialect
+
+ compiled = \
+ context.compiled.statement.compile(dialect=DefaultDialect(),
+ column_keys=context.compiled.column_keys)
+ _received_statement = re.sub(r'\n', '', str(compiled))
+ equivalent = self.statement == _received_statement
+ if self.params:
+ if util.callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while params:
+ param = dict(params.pop(0))
+ for k, v in context.compiled.params.iteritems():
+ param.setdefault(k, v)
+ if param not in _received_parameters:
+ equivalent = False
+ break
+ else:
+ _received_parameters.remove(param)
+ if _received_parameters:
+ equivalent = False
+ else:
+ params = {}
+ all_params = {}
+ all_received = []
+ self._result = equivalent
+ if not self._result:
+ print 'Testing for compiled statement %r partial params '\
+ '%r, received %r with params %r' % (self.statement,
+ all_params, _received_statement, all_received)
+ self._errmsg = \
+ 'Testing for compiled statement %r partial params %r, '\
+ 'received %r with params %r' % (self.statement,
+ all_params, _received_statement, all_received)
+
+
+ # print self._errmsg
+
+class CountStatements(AssertRule):
+
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ self._statement_count += 1
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ pass
+
+ def is_consumed(self):
+ return False
+
+ def consume_final(self):
+ assert self.count == self._statement_count, \
+ 'desired statement count %d does not match %d' \
+ % (self.count, self._statement_count)
+ return True
+
+class AllOf(AssertRule):
+
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ for rule in self.rules:
+ rule.process_execute(clauseelement, *multiparams, **params)
+
+ def process_cursor_execute(self, statement, parameters, context,
+ executemany):
+ for rule in self.rules:
+ rule.process_cursor_execute(statement, parameters, context,
+ executemany)
+
+ def is_consumed(self):
+ if not self.rules:
+ return True
+ for rule in list(self.rules):
+ if rule.rule_passed(): # a rule passed, move on
+ self.rules.remove(rule)
+ return len(self.rules) == 0
+ assert False, 'No assertion rules were satisfied for statement'
+
+ def consume_final(self):
+ return len(self.rules) == 0
+
+def _process_engine_statement(query, context):
+ if util.jython:
+
+ # oracle+zxjdbc passes a PyStatement when returning into
+
+ query = unicode(query)
+ if context.engine.name == 'mssql' \
+ and query.endswith('; select scope_identity()'):
+ query = query[:-25]
+ query = re.sub(r'\n', '', query)
+ return query
+
+def _process_assertion_statement(query, context):
+ paramstyle = context.dialect.paramstyle
+ if paramstyle == 'named':
+ pass
+ elif paramstyle =='pyformat':
+ query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+ else:
+ # positional params
+ repl = None
+ if paramstyle=='qmark':
+ repl = "?"
+ elif paramstyle=='format':
+ repl = r"%s"
+ elif paramstyle=='numeric':
+ repl = None
+ query = re.sub(r':([\w_]+)', repl, query)
+
+ return query
+
+class SQLAssert(object):
+
+ rules = None
+
+ def add_rules(self, rules):
+ self.rules = list(rules)
+
+ def statement_complete(self):
+ for rule in self.rules:
+ if not rule.consume_final():
+ assert False, \
+ 'All statements are complete, but pending '\
+ 'assertion rules remain'
+
+ def clear_rules(self):
+ del self.rules
+
+ def execute(self, conn, clauseelement, multiparams, params, result):
+ if self.rules is not None:
+ if not self.rules:
+ assert False, \
+ 'All rules have been exhausted, but further '\
+ 'statements remain'
+ rule = self.rules[0]
+ rule.process_execute(clauseelement, *multiparams, **params)
+ if rule.is_consumed():
+ self.rules.pop(0)
+
+ def cursor_execute(self, conn, cursor, statement, parameters,
+ context, executemany):
+ if self.rules:
+ rule = self.rules[0]
+ rule.process_cursor_execute(statement, parameters, context,
+ executemany)
+
+asserter = SQLAssert()
+