diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index d955d1554..864ce5b4d 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -144,7 +144,7 @@ class RegexSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params): + def __init__(self, statement, params=None): SQLMatchRule.__init__(self) self.statement = statement self.params = params @@ -153,14 +153,19 @@ class CompiledSQL(SQLMatchRule): executemany): if not context: return + from sqlalchemy.schema import _DDLCompiles _received_parameters = list(context.compiled_parameters) # recompile from the context, using the default dialect - compiled = \ - context.compiled.statement.compile(dialect=DefaultDialect(), + if isinstance(context.compiled.statement, _DDLCompiles): + compiled = \ + context.compiled.statement.compile(dialect=DefaultDialect()) + else: + compiled = \ + context.compiled.statement.compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys) - _received_statement = re.sub(r'\n', '', str(compiled)) + _received_statement = re.sub(r'[\n\t]', '', str(compiled)) equivalent = self.statement == _received_statement if self.params: if util.callable(self.params): |