summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/assertions.py8
-rw-r--r--lib/sqlalchemy/testing/assertsql.py16
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py1
3 files changed, 17 insertions, 8 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e0bf4326e..7dada1394 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -352,6 +352,8 @@ class AssertsCompiledSQL(object):
literal_binds=False,
render_postcompile=False,
schema_translate_map=None,
+ render_schema_translate=False,
+ default_schema_name=None,
inline_flag=None,
):
if use_default_dialect:
@@ -371,6 +373,9 @@ class AssertsCompiledSQL(object):
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
+ if default_schema_name:
+ dialect.default_schema_name = default_schema_name
+
kw = {}
compile_kwargs = {}
@@ -386,6 +391,9 @@ class AssertsCompiledSQL(object):
if render_postcompile:
compile_kwargs["render_postcompile"] = True
+ if render_schema_translate:
+ kw["render_schema_translate"] = True
+
from sqlalchemy import orm
if isinstance(clause, orm.Query):
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index e38c7ddd8..f0da69400 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -91,21 +91,23 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
+
+ if "schema_translate_map" in context.execution_options:
+ map_ = context.execution_options["schema_translate_map"]
+ else:
+ map_ = None
+
if isinstance(context.compiled.statement, _DDLCompiles):
+
compiled = context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.execution_options.get(
- "schema_translate_map"
- ),
+ dialect=compare_dialect, schema_translate_map=map_
)
else:
compiled = context.compiled.statement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
inline=context.compiled.inline,
- schema_translate_map=context.execution_options.get(
- "schema_translate_map"
- ),
+ schema_translate_map=map_,
)
_received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
parameters = execute_observed.parameters
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 473c98116..68a43feb7 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -360,7 +360,6 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schema_reflection
def test_dialect_initialize(self):
engine = engines.testing_engine()
- assert not hasattr(engine.dialect, "default_schema_name")
inspect(engine)
assert hasattr(engine.dialect, "default_schema_name")