diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2020-03-24 19:53:14 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-03-24 19:53:14 +0000 |
commit | fa1f67a01e80367d73cf5d1d93b6f7f51dc1746b (patch) | |
tree | 7c0c5e37a7d1708289a68fcdf30df9dad9013088 /lib/sqlalchemy/sql/compiler.py | |
parent | e817d1415d825bcf8c8f33927baaf56cd5d07b95 (diff) | |
parent | cadfc608d63f4e0df46c0daaa28902423fd88d71 (diff) | |
download | sqlalchemy-fa1f67a01e80367d73cf5d1d93b6f7f51dc1746b.tar.gz |
Merge "Convert schema_translate to a post compile"
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 100 |
1 files changed, 75 insertions, 25 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1f183b5c1..ae9c3c73a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -26,6 +26,7 @@ To generate user-defined SQL strings, see import collections import contextlib import itertools +import operator import re from . import base @@ -39,6 +40,7 @@ from . import schema from . import selectable from . import sqltypes from .base import NO_ARG +from .elements import quoted_name from .. import exc from .. import util @@ -369,6 +371,8 @@ class Compiled(object): _cached_metadata = None + schema_translate_map = None + execution_options = util.immutabledict() """ Execution options propagated from the statement. In some cases, @@ -381,6 +385,7 @@ class Compiled(object): statement, bind=None, schema_translate_map=None, + render_schema_translate=False, compile_kwargs=util.immutabledict(), ): """Construct a new :class:`.Compiled` object. @@ -411,6 +416,7 @@ class Compiled(object): self.bind = bind self.preparer = self.dialect.identifier_preparer if schema_translate_map: + self.schema_translate_map = schema_translate_map self.preparer = self.preparer._with_schema_translate( schema_translate_map ) @@ -422,6 +428,11 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) + if render_schema_translate: + self.string = self.preparer._render_schema_translates( + self.string, schema_translate_map + ) + @util.deprecated( "0.7", "The :meth:`.Compiled.compile` method is deprecated and will be " @@ -3365,18 +3376,18 @@ class DDLCompiler(Compiled): return self.sql_compiler.post_process_text(ddl.statement % context) - def visit_create_schema(self, create): + def visit_create_schema(self, create, **kw): schema = self.preparer.format_schema(create.element) return "CREATE SCHEMA " + schema - def visit_drop_schema(self, drop): + def visit_drop_schema(self, drop, **kw): schema = self.preparer.format_schema(drop.element) text = "DROP SCHEMA " + schema if drop.cascade: text += " CASCADE" return text - def visit_create_table(self, create): + def visit_create_table(self, create, **kw): table = create.element preparer = self.preparer @@ -3426,7 +3437,7 @@ class DDLCompiler(Compiled): text += "\n)%s\n\n" % self.post_create_table(table) return text - def visit_create_column(self, create, first_pk=False): + def visit_create_column(self, create, first_pk=False, **kw): column = create.element if column.system: @@ -3442,7 +3453,7 @@ class DDLCompiler(Compiled): return text def create_table_constraints( - self, table, _include_foreign_key_constraints=None + self, table, _include_foreign_key_constraints=None, **kw ): # On some DB order is significant: visit PK first, then the @@ -3482,10 +3493,10 @@ class DDLCompiler(Compiled): if p is not None ) - def visit_drop_table(self, drop): + def visit_drop_table(self, drop, **kw): return "\nDROP TABLE " + self.preparer.format_table(drop.element) - def visit_drop_view(self, drop): + def visit_drop_view(self, drop, **kw): return "\nDROP VIEW " + self.preparer.format_table(drop.element) def _verify_index_table(self, index): @@ -3495,7 +3506,7 @@ class DDLCompiler(Compiled): ) def visit_create_index( - self, create, include_schema=False, include_table_schema=True + self, create, include_schema=False, include_table_schema=True, **kw ): index = create.element self._verify_index_table(index) @@ -3521,7 +3532,7 @@ class DDLCompiler(Compiled): ) return text - def visit_drop_index(self, drop): + def visit_drop_index(self, drop, **kw): index = drop.element if index.name is None: @@ -3548,13 +3559,13 @@ class DDLCompiler(Compiled): index_name = schema_name + "." + index_name return index_name - def visit_add_constraint(self, create): + def visit_add_constraint(self, create, **kw): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), self.process(create.element), ) - def visit_set_table_comment(self, create): + def visit_set_table_comment(self, create, **kw): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( @@ -3562,12 +3573,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_table_comment(self, drop): + def visit_drop_table_comment(self, drop, **kw): return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( drop.element ) - def visit_set_column_comment(self, create): + def visit_set_column_comment(self, create, **kw): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( create.element, use_table=True, use_schema=True @@ -3577,12 +3588,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_column_comment(self, drop): + def visit_drop_column_comment(self, drop, **kw): return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( drop.element, use_table=True ) - def visit_create_sequence(self, create): + def visit_create_sequence(self, create, **kw): text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( create.element ) @@ -3606,10 +3617,10 @@ class DDLCompiler(Compiled): text += " CYCLE" return text - def visit_drop_sequence(self, drop): + def visit_drop_sequence(self, drop, **kw): return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) - def visit_drop_constraint(self, drop): + def visit_drop_constraint(self, drop, **kw): constraint = drop.element if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3671,7 +3682,7 @@ class DDLCompiler(Compiled): else: return self.visit_check_constraint(constraint) - def visit_check_constraint(self, constraint): + def visit_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3683,7 +3694,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_column_check_constraint(self, constraint): + def visit_column_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3695,7 +3706,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3715,7 +3726,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_foreign_key_constraint(self, constraint): + def visit_foreign_key_constraint(self, constraint, **kw): preparer = self.preparer text = "" if constraint.name is not None: @@ -3744,7 +3755,7 @@ class DDLCompiler(Compiled): return preparer.format_table(table) - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3789,7 +3800,7 @@ class DDLCompiler(Compiled): text += " MATCH %s" % constraint.match return text - def visit_computed_column(self, generated): + def visit_computed_column(self, generated, **kw): text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( generated.sqltext, include_table=False, literal_binds=True ) @@ -3975,7 +3986,16 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = schema._schema_getter(None) + schema_for_object = operator.attrgetter("schema") + """Return the .schema attribute for an object. + + For the default IdentifierPreparer, the schema for an object is always + the value of the ".schema" attribute. if the preparer is replaced + with one that has a non-empty schema_translate_map, the value of the + ".schema" attribute is rendered a symbol that will be converted to a + real schema name from the mapping post-compile. + + """ def __init__( self, @@ -4016,9 +4036,39 @@ class IdentifierPreparer(object): def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) prep.__dict__.update(self.__dict__) - prep.schema_for_object = schema._schema_getter(schema_translate_map) + + def symbol_getter(obj): + name = obj.schema + if name in schema_translate_map and obj._use_schema_map: + return quoted_name( + "[SCHEMA_%s]" % (name or "_none"), quote=False + ) + else: + return obj.schema + + prep.schema_for_object = symbol_getter return prep + def _render_schema_translates(self, statement, schema_translate_map): + d = schema_translate_map + if None in d: + d["_none"] = d[None] + + def replace(m): + name = m.group(2) + effective_schema = d[name] + if not effective_schema: + effective_schema = self.dialect.default_schema_name + if not effective_schema: + # TODO: no coverage here + raise exc.CompileError( + "Dialect has no default schema name; can't " + "use None as dynamic schema target." + ) + return self.quote(effective_schema) + + return re.sub(r"(\[SCHEMA_([\w\d_]+)\])", replace, statement) + def _escape_identifier(self, value): """Escape an identifier. |