summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2016-03-15 16:41:17 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2016-03-15 16:41:17 -0400
commit224b03f9c006b12e3bbae9190ca9d0132e843208 (patch)
tree4c54f3404af67962835c3a78849f8e89ebb98da0 /lib/sqlalchemy/testing/assertsql.py
parenta87b3c2101114d82f999c23d113ad2018629ed48 (diff)
parent8bc370ed382a45654101fa34bac4a2886ce089c3 (diff)
downloadsqlalchemy-224b03f9c006b12e3bbae9190ca9d0132e843208.tar.gz
Merge branch 'master' into pr157
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py40
1 files changed, 31 insertions, 9 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 5c746e8f1..0aae12dcc 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -1,5 +1,5 @@
# testing/assertsql.py
-# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -13,6 +13,7 @@ import contextlib
from .. import event
from sqlalchemy.schema import _DDLCompiles
from sqlalchemy.engine.util import _distill_params
+from sqlalchemy.engine import url
class AssertRule(object):
@@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params=None):
+ def __init__(self, statement, params=None, dialect='default'):
self.statement = statement
self.params = params
+ self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- return DefaultDialect()
+ if self.dialect == 'default':
+ return DefaultDialect()
+ else:
+ # ugh
+ if self.dialect == 'postgresql':
+ params = {'implicit_returning': True}
+ else:
+ params = {}
+ return url.URL(self.dialect).get_dialect()(**params)
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
@@ -77,15 +87,20 @@ class CompiledSQL(SQLMatchRule):
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
compiled = \
- context.compiled.statement.compile(dialect=compare_dialect)
+ context.compiled.statement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=context.
+ execution_options.get('schema_translate_map'))
else:
compiled = (
context.compiled.statement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
- inline=context.compiled.inline)
+ inline=context.compiled.inline,
+ schema_translate_map=context.
+ execution_options.get('schema_translate_map'))
)
- _received_statement = re.sub(r'[\n\t]', '', str(compiled))
+ _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
@@ -159,7 +174,7 @@ class CompiledSQL(SQLMatchRule):
'Testing for compiled statement %r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
- self.statement, expected_params
+ self.statement.replace('%', '%%'), expected_params
)
)
@@ -170,6 +185,7 @@ class RegexSQL(CompiledSQL):
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
+ self.dialect = 'default'
def _failure_message(self, expected_params):
return (
@@ -188,21 +204,27 @@ class DialectSQL(CompiledSQL):
def _compile_dialect(self, execute_observed):
return execute_observed.context.dialect
+ def _compare_no_space(self, real_stmt, received_stmt):
+ stmt = re.sub(r'[\n\t]', '', real_stmt)
+ return received_stmt == stmt
+
def _received_statement(self, execute_observed):
received_stmt, received_params = super(DialectSQL, self).\
_received_statement(execute_observed)
+
+ # TODO: why do we need this part?
for real_stmt in execute_observed.statements:
- if real_stmt.statement == received_stmt:
+ if self._compare_no_space(real_stmt.statement, received_stmt):
break
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
"statements actually invoked" % received_stmt)
+
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
-
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle