summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-03-24 19:53:14 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-03-24 19:53:14 +0000
commitfa1f67a01e80367d73cf5d1d93b6f7f51dc1746b (patch)
tree7c0c5e37a7d1708289a68fcdf30df9dad9013088 /lib/sqlalchemy/sql/compiler.py
parente817d1415d825bcf8c8f33927baaf56cd5d07b95 (diff)
parentcadfc608d63f4e0df46c0daaa28902423fd88d71 (diff)
downloadsqlalchemy-fa1f67a01e80367d73cf5d1d93b6f7f51dc1746b.tar.gz
Merge "Convert schema_translate to a post compile"
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py100
1 files changed, 75 insertions, 25 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 1f183b5c1..ae9c3c73a 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -26,6 +26,7 @@ To generate user-defined SQL strings, see
import collections
import contextlib
import itertools
+import operator
import re
from . import base
@@ -39,6 +40,7 @@ from . import schema
from . import selectable
from . import sqltypes
from .base import NO_ARG
+from .elements import quoted_name
from .. import exc
from .. import util
@@ -369,6 +371,8 @@ class Compiled(object):
_cached_metadata = None
+ schema_translate_map = None
+
execution_options = util.immutabledict()
"""
Execution options propagated from the statement. In some cases,
@@ -381,6 +385,7 @@ class Compiled(object):
statement,
bind=None,
schema_translate_map=None,
+ render_schema_translate=False,
compile_kwargs=util.immutabledict(),
):
"""Construct a new :class:`.Compiled` object.
@@ -411,6 +416,7 @@ class Compiled(object):
self.bind = bind
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
+ self.schema_translate_map = schema_translate_map
self.preparer = self.preparer._with_schema_translate(
schema_translate_map
)
@@ -422,6 +428,11 @@ class Compiled(object):
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
+ if render_schema_translate:
+ self.string = self.preparer._render_schema_translates(
+ self.string, schema_translate_map
+ )
+
@util.deprecated(
"0.7",
"The :meth:`.Compiled.compile` method is deprecated and will be "
@@ -3365,18 +3376,18 @@ class DDLCompiler(Compiled):
return self.sql_compiler.post_process_text(ddl.statement % context)
- def visit_create_schema(self, create):
+ def visit_create_schema(self, create, **kw):
schema = self.preparer.format_schema(create.element)
return "CREATE SCHEMA " + schema
- def visit_drop_schema(self, drop):
+ def visit_drop_schema(self, drop, **kw):
schema = self.preparer.format_schema(drop.element)
text = "DROP SCHEMA " + schema
if drop.cascade:
text += " CASCADE"
return text
- def visit_create_table(self, create):
+ def visit_create_table(self, create, **kw):
table = create.element
preparer = self.preparer
@@ -3426,7 +3437,7 @@ class DDLCompiler(Compiled):
text += "\n)%s\n\n" % self.post_create_table(table)
return text
- def visit_create_column(self, create, first_pk=False):
+ def visit_create_column(self, create, first_pk=False, **kw):
column = create.element
if column.system:
@@ -3442,7 +3453,7 @@ class DDLCompiler(Compiled):
return text
def create_table_constraints(
- self, table, _include_foreign_key_constraints=None
+ self, table, _include_foreign_key_constraints=None, **kw
):
# On some DB order is significant: visit PK first, then the
@@ -3482,10 +3493,10 @@ class DDLCompiler(Compiled):
if p is not None
)
- def visit_drop_table(self, drop):
+ def visit_drop_table(self, drop, **kw):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
- def visit_drop_view(self, drop):
+ def visit_drop_view(self, drop, **kw):
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
def _verify_index_table(self, index):
@@ -3495,7 +3506,7 @@ class DDLCompiler(Compiled):
)
def visit_create_index(
- self, create, include_schema=False, include_table_schema=True
+ self, create, include_schema=False, include_table_schema=True, **kw
):
index = create.element
self._verify_index_table(index)
@@ -3521,7 +3532,7 @@ class DDLCompiler(Compiled):
)
return text
- def visit_drop_index(self, drop):
+ def visit_drop_index(self, drop, **kw):
index = drop.element
if index.name is None:
@@ -3548,13 +3559,13 @@ class DDLCompiler(Compiled):
index_name = schema_name + "." + index_name
return index_name
- def visit_add_constraint(self, create):
+ def visit_add_constraint(self, create, **kw):
return "ALTER TABLE %s ADD %s" % (
self.preparer.format_table(create.element.table),
self.process(create.element),
)
- def visit_set_table_comment(self, create):
+ def visit_set_table_comment(self, create, **kw):
return "COMMENT ON TABLE %s IS %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
@@ -3562,12 +3573,12 @@ class DDLCompiler(Compiled):
),
)
- def visit_drop_table_comment(self, drop):
+ def visit_drop_table_comment(self, drop, **kw):
return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
drop.element
)
- def visit_set_column_comment(self, create):
+ def visit_set_column_comment(self, create, **kw):
return "COMMENT ON COLUMN %s IS %s" % (
self.preparer.format_column(
create.element, use_table=True, use_schema=True
@@ -3577,12 +3588,12 @@ class DDLCompiler(Compiled):
),
)
- def visit_drop_column_comment(self, drop):
+ def visit_drop_column_comment(self, drop, **kw):
return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
drop.element, use_table=True
)
- def visit_create_sequence(self, create):
+ def visit_create_sequence(self, create, **kw):
text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
create.element
)
@@ -3606,10 +3617,10 @@ class DDLCompiler(Compiled):
text += " CYCLE"
return text
- def visit_drop_sequence(self, drop):
+ def visit_drop_sequence(self, drop, **kw):
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
- def visit_drop_constraint(self, drop):
+ def visit_drop_constraint(self, drop, **kw):
constraint = drop.element
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -3671,7 +3682,7 @@ class DDLCompiler(Compiled):
else:
return self.visit_check_constraint(constraint)
- def visit_check_constraint(self, constraint):
+ def visit_check_constraint(self, constraint, **kw):
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -3683,7 +3694,7 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_column_check_constraint(self, constraint):
+ def visit_column_check_constraint(self, constraint, **kw):
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -3695,7 +3706,7 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_primary_key_constraint(self, constraint):
+ def visit_primary_key_constraint(self, constraint, **kw):
if len(constraint) == 0:
return ""
text = ""
@@ -3715,7 +3726,7 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_foreign_key_constraint(self, constraint):
+ def visit_foreign_key_constraint(self, constraint, **kw):
preparer = self.preparer
text = ""
if constraint.name is not None:
@@ -3744,7 +3755,7 @@ class DDLCompiler(Compiled):
return preparer.format_table(table)
- def visit_unique_constraint(self, constraint):
+ def visit_unique_constraint(self, constraint, **kw):
if len(constraint) == 0:
return ""
text = ""
@@ -3789,7 +3800,7 @@ class DDLCompiler(Compiled):
text += " MATCH %s" % constraint.match
return text
- def visit_computed_column(self, generated):
+ def visit_computed_column(self, generated, **kw):
text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
generated.sqltext, include_table=False, literal_binds=True
)
@@ -3975,7 +3986,16 @@ class IdentifierPreparer(object):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- schema_for_object = schema._schema_getter(None)
+ schema_for_object = operator.attrgetter("schema")
+ """Return the .schema attribute for an object.
+
+ For the default IdentifierPreparer, the schema for an object is always
+ the value of the ".schema" attribute. if the preparer is replaced
+ with one that has a non-empty schema_translate_map, the value of the
+ ".schema" attribute is rendered a symbol that will be converted to a
+ real schema name from the mapping post-compile.
+
+ """
def __init__(
self,
@@ -4016,9 +4036,39 @@ class IdentifierPreparer(object):
def _with_schema_translate(self, schema_translate_map):
prep = self.__class__.__new__(self.__class__)
prep.__dict__.update(self.__dict__)
- prep.schema_for_object = schema._schema_getter(schema_translate_map)
+
+ def symbol_getter(obj):
+ name = obj.schema
+ if name in schema_translate_map and obj._use_schema_map:
+ return quoted_name(
+ "[SCHEMA_%s]" % (name or "_none"), quote=False
+ )
+ else:
+ return obj.schema
+
+ prep.schema_for_object = symbol_getter
return prep
+ def _render_schema_translates(self, statement, schema_translate_map):
+ d = schema_translate_map
+ if None in d:
+ d["_none"] = d[None]
+
+ def replace(m):
+ name = m.group(2)
+ effective_schema = d[name]
+ if not effective_schema:
+ effective_schema = self.dialect.default_schema_name
+ if not effective_schema:
+ # TODO: no coverage here
+ raise exc.CompileError(
+ "Dialect has no default schema name; can't "
+ "use None as dynamic schema target."
+ )
+ return self.quote(effective_schema)
+
+ return re.sub(r"(\[SCHEMA_([\w\d_]+)\])", replace, statement)
+
def _escape_identifier(self, value):
"""Escape an identifier.