diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-03-15 16:41:17 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-03-15 16:41:17 -0400 |
| commit | 224b03f9c006b12e3bbae9190ca9d0132e843208 (patch) | |
| tree | 4c54f3404af67962835c3a78849f8e89ebb98da0 /lib/sqlalchemy/testing/assertsql.py | |
| parent | a87b3c2101114d82f999c23d113ad2018629ed48 (diff) | |
| parent | 8bc370ed382a45654101fa34bac4a2886ce089c3 (diff) | |
| download | sqlalchemy-224b03f9c006b12e3bbae9190ca9d0132e843208.tar.gz | |
Merge branch 'master' into pr157
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 40 |
1 files changed, 31 insertions, 9 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 5c746e8f1..0aae12dcc 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -1,5 +1,5 @@ # testing/assertsql.py -# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors # <see AUTHORS file> # # This module is part of SQLAlchemy and is released under @@ -13,6 +13,7 @@ import contextlib from .. import event from sqlalchemy.schema import _DDLCompiles from sqlalchemy.engine.util import _distill_params +from sqlalchemy.engine import url class AssertRule(object): @@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params=None): + def __init__(self, statement, params=None, dialect='default'): self.statement = statement self.params = params + self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - return DefaultDialect() + if self.dialect == 'default': + return DefaultDialect() + else: + # ugh + if self.dialect == 'postgresql': + params = {'implicit_returning': True} + else: + params = {} + return url.URL(self.dialect).get_dialect()(**params) def _received_statement(self, execute_observed): """reconstruct the statement and params in terms @@ -77,15 +87,20 @@ class CompiledSQL(SQLMatchRule): compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): compiled = \ - context.compiled.statement.compile(dialect=compare_dialect) + context.compiled.statement.compile( + dialect=compare_dialect, + schema_translate_map=context. + execution_options.get('schema_translate_map')) else: compiled = ( context.compiled.statement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, - inline=context.compiled.inline) + inline=context.compiled.inline, + schema_translate_map=context. + execution_options.get('schema_translate_map')) ) - _received_statement = re.sub(r'[\n\t]', '', str(compiled)) + _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: @@ -159,7 +174,7 @@ class CompiledSQL(SQLMatchRule): 'Testing for compiled statement %r partial params %r, ' 'received %%(received_statement)r with params ' '%%(received_parameters)r' % ( - self.statement, expected_params + self.statement.replace('%', '%%'), expected_params ) ) @@ -170,6 +185,7 @@ class RegexSQL(CompiledSQL): self.regex = re.compile(regex) self.orig_regex = regex self.params = params + self.dialect = 'default' def _failure_message(self, expected_params): return ( @@ -188,21 +204,27 @@ class DialectSQL(CompiledSQL): def _compile_dialect(self, execute_observed): return execute_observed.context.dialect + def _compare_no_space(self, real_stmt, received_stmt): + stmt = re.sub(r'[\n\t]', '', real_stmt) + return received_stmt == stmt + def _received_statement(self, execute_observed): received_stmt, received_params = super(DialectSQL, self).\ _received_statement(execute_observed) + + # TODO: why do we need this part? for real_stmt in execute_observed.statements: - if real_stmt.statement == received_stmt: + if self._compare_no_space(real_stmt.statement, received_stmt): break else: raise AssertionError( "Can't locate compiled statement %r in list of " "statements actually invoked" % received_stmt) + return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) - # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle |
