summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index d955d1554..864ce5b4d 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -144,7 +144,7 @@ class RegexSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params):
+ def __init__(self, statement, params=None):
SQLMatchRule.__init__(self)
self.statement = statement
self.params = params
@@ -153,14 +153,19 @@ class CompiledSQL(SQLMatchRule):
executemany):
if not context:
return
+ from sqlalchemy.schema import _DDLCompiles
_received_parameters = list(context.compiled_parameters)
# recompile from the context, using the default dialect
- compiled = \
- context.compiled.statement.compile(dialect=DefaultDialect(),
+ if isinstance(context.compiled.statement, _DDLCompiles):
+ compiled = \
+ context.compiled.statement.compile(dialect=DefaultDialect())
+ else:
+ compiled = \
+ context.compiled.statement.compile(dialect=DefaultDialect(),
column_keys=context.compiled.column_keys)
- _received_statement = re.sub(r'\n', '', str(compiled))
+ _received_statement = re.sub(r'[\n\t]', '', str(compiled))
equivalent = self.statement == _received_statement
if self.params:
if util.callable(self.params):