diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 68 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/mock.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 100 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 57 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 1 |
12 files changed, 147 insertions, 148 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index a63ce0033..31425d4c0 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1631,6 +1631,9 @@ class SQLiteDialect(default.DefaultDialect): ) return bool(info) + def _get_default_schema_name(self, connection): + return "main" + @reflection.cache def get_view_names(self, connection, schema=None, **kw): if schema is not None: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index aa21fb13b..4ed3b9af7 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -17,7 +17,6 @@ from .. import inspection from .. import log from .. import util from ..sql import compiler -from ..sql import schema from ..sql import util as sql_util @@ -51,21 +50,7 @@ class Connection(Connectable): """ - schema_for_object = schema._schema_getter(None) - """Return the ".schema" attribute for an object. - - Used for :class:`.Table`, :class:`.Sequence` and similar objects, - and takes into account - the :paramref:`.Connection.execution_options.schema_translate_map` - parameter. - - .. versionadded:: 1.1 - - .. seealso:: - - :ref:`schema_translating` - - """ + _schema_translate_map = None def __init__( self, @@ -92,7 +77,7 @@ class Connection(Connectable): self.should_close_with_result = False self.dispatch = _dispatch self._has_events = _branch_from._has_events - self.schema_for_object = _branch_from.schema_for_object + self._schema_translate_map = _branch_from._schema_translate_map else: self.__connection = ( connection @@ -122,6 +107,24 @@ class Connection(Connectable): if self._has_events or self.engine._has_events: self.dispatch.engine_connect(self, self.__branch) + def schema_for_object(self, obj): + """return the schema name for the given schema item taking into + account current schema translate map. + + """ + + name = obj.schema + schema_translate_map = self._schema_translate_map + + if ( + schema_translate_map + and name in schema_translate_map + and obj._use_schema_map + ): + return schema_translate_map[name] + else: + return name + def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, @@ -1066,10 +1069,7 @@ class Connection(Connectable): dialect = self.dialect compiled = ddl.compile( - dialect=dialect, - schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default - else None, + dialect=dialect, schema_translate_map=self._schema_translate_map ) ret = self._execute_context( dialect, @@ -1103,7 +1103,7 @@ class Connection(Connectable): dialect, elem, tuple(sorted(keys)), - self.schema_for_object.hash_key, + bool(self._schema_translate_map), len(distilled_params) > 1, ) compiled_sql = self._execution_options["compiled_cache"].get(key) @@ -1112,9 +1112,7 @@ class Connection(Connectable): dialect=dialect, column_keys=keys, inline=len(distilled_params) > 1, - schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default - else None, + schema_translate_map=self._schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, ) @@ -1124,9 +1122,7 @@ class Connection(Connectable): dialect=dialect, column_keys=keys, inline=len(distilled_params) > 1, - schema_translate_map=self.schema_for_object - if not self.schema_for_object.is_default - else None, + schema_translate_map=self._schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, ) @@ -1974,21 +1970,7 @@ class Engine(Connectable, log.Identified): _has_events = False _connection_cls = Connection - schema_for_object = schema._schema_getter(None) - """Return the ".schema" attribute for an object. - - Used for :class:`.Table`, :class:`.Sequence` and similar objects, - and takes into account - the :paramref:`.Connection.execution_options.schema_translate_map` - parameter. - - .. versionadded:: 1.1 - - .. seealso:: - - :ref:`schema_translating` - - """ + _schema_translate_map = None def __init__( self, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index b151b6e48..d0940decf 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -28,7 +28,6 @@ from .. import types as sqltypes from .. import util from ..sql import compiler from ..sql import expression -from ..sql import schema from ..sql.elements import quoted_name AUTOCOMMIT_REGEXP = re.compile( @@ -129,6 +128,8 @@ class DefaultDialect(interfaces.Dialect): server_version_info = None + default_schema_name = None + construct_arguments = None """Optional set of argument specifiers for various SQLAlchemy constructs, typically schema items. @@ -495,20 +496,18 @@ class DefaultDialect(interfaces.Dialect): self._set_connection_isolation(connection, isolation_level) if "schema_translate_map" in opts: - getter = schema._schema_getter(opts["schema_translate_map"]) - engine.schema_for_object = getter + engine._schema_translate_map = map_ = opts["schema_translate_map"] @event.listens_for(engine, "engine_connect") def set_schema_translate_map(connection, branch): - connection.schema_for_object = getter + connection._schema_translate_map = map_ def set_connection_execution_options(self, connection, opts): if "isolation_level" in opts: self._set_connection_isolation(connection, opts["isolation_level"]) if "schema_translate_map" in opts: - getter = schema._schema_getter(opts["schema_translate_map"]) - connection.schema_for_object = getter + connection._schema_translate_map = opts["schema_translate_map"] def _set_connection_isolation(self, connection, level): if connection.in_transaction(): @@ -701,11 +700,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.execution_options = dict(self.execution_options) self.execution_options.update(connection._execution_options) + self.unicode_statement = util.text_type(compiled) + if compiled.schema_translate_map: + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, connection._schema_translate_map + ) + if not dialect.supports_unicode_statements: - self.unicode_statement = util.text_type(compiled) self.statement = dialect._encoder(self.unicode_statement)[0] else: - self.statement = self.unicode_statement = util.text_type(compiled) + self.statement = self.unicode_statement self.cursor = self.create_cursor() self.compiled_parameters = [] @@ -807,6 +812,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): elif compiled.positional: positiontup = self.compiled.positiontup + if compiled.schema_translate_map: + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, connection._schema_translate_map + ) + # final self.unicode_statement is now assigned, encode if needed # by dialect if not dialect.supports_unicode_statements: diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index 570ee2d04..bda9e91b5 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -11,7 +11,6 @@ from . import base from . import url as _url from .. import util from ..sql import ddl -from ..sql import schema class MockConnection(base.Connectable): @@ -23,7 +22,8 @@ class MockConnection(base.Connectable): dialect = property(attrgetter("_dialect")) name = property(lambda s: s._dialect.name) - schema_for_object = schema._schema_getter(None) + def schema_for_object(self, obj): + return obj.schema def connect(self, **kwargs): return self diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 203369ed8..8ef0d572f 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -701,7 +701,8 @@ class Inspector(object): dialect = self.bind.dialect - schema = self.bind.schema_for_object(table) + with self._operation_context() as conn: + schema = conn.schema_for_object(table) table_name = table.name diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1f183b5c1..ae9c3c73a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -26,6 +26,7 @@ To generate user-defined SQL strings, see import collections import contextlib import itertools +import operator import re from . import base @@ -39,6 +40,7 @@ from . import schema from . import selectable from . import sqltypes from .base import NO_ARG +from .elements import quoted_name from .. import exc from .. import util @@ -369,6 +371,8 @@ class Compiled(object): _cached_metadata = None + schema_translate_map = None + execution_options = util.immutabledict() """ Execution options propagated from the statement. In some cases, @@ -381,6 +385,7 @@ class Compiled(object): statement, bind=None, schema_translate_map=None, + render_schema_translate=False, compile_kwargs=util.immutabledict(), ): """Construct a new :class:`.Compiled` object. @@ -411,6 +416,7 @@ class Compiled(object): self.bind = bind self.preparer = self.dialect.identifier_preparer if schema_translate_map: + self.schema_translate_map = schema_translate_map self.preparer = self.preparer._with_schema_translate( schema_translate_map ) @@ -422,6 +428,11 @@ class Compiled(object): self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) + if render_schema_translate: + self.string = self.preparer._render_schema_translates( + self.string, schema_translate_map + ) + @util.deprecated( "0.7", "The :meth:`.Compiled.compile` method is deprecated and will be " @@ -3365,18 +3376,18 @@ class DDLCompiler(Compiled): return self.sql_compiler.post_process_text(ddl.statement % context) - def visit_create_schema(self, create): + def visit_create_schema(self, create, **kw): schema = self.preparer.format_schema(create.element) return "CREATE SCHEMA " + schema - def visit_drop_schema(self, drop): + def visit_drop_schema(self, drop, **kw): schema = self.preparer.format_schema(drop.element) text = "DROP SCHEMA " + schema if drop.cascade: text += " CASCADE" return text - def visit_create_table(self, create): + def visit_create_table(self, create, **kw): table = create.element preparer = self.preparer @@ -3426,7 +3437,7 @@ class DDLCompiler(Compiled): text += "\n)%s\n\n" % self.post_create_table(table) return text - def visit_create_column(self, create, first_pk=False): + def visit_create_column(self, create, first_pk=False, **kw): column = create.element if column.system: @@ -3442,7 +3453,7 @@ class DDLCompiler(Compiled): return text def create_table_constraints( - self, table, _include_foreign_key_constraints=None + self, table, _include_foreign_key_constraints=None, **kw ): # On some DB order is significant: visit PK first, then the @@ -3482,10 +3493,10 @@ class DDLCompiler(Compiled): if p is not None ) - def visit_drop_table(self, drop): + def visit_drop_table(self, drop, **kw): return "\nDROP TABLE " + self.preparer.format_table(drop.element) - def visit_drop_view(self, drop): + def visit_drop_view(self, drop, **kw): return "\nDROP VIEW " + self.preparer.format_table(drop.element) def _verify_index_table(self, index): @@ -3495,7 +3506,7 @@ class DDLCompiler(Compiled): ) def visit_create_index( - self, create, include_schema=False, include_table_schema=True + self, create, include_schema=False, include_table_schema=True, **kw ): index = create.element self._verify_index_table(index) @@ -3521,7 +3532,7 @@ class DDLCompiler(Compiled): ) return text - def visit_drop_index(self, drop): + def visit_drop_index(self, drop, **kw): index = drop.element if index.name is None: @@ -3548,13 +3559,13 @@ class DDLCompiler(Compiled): index_name = schema_name + "." + index_name return index_name - def visit_add_constraint(self, create): + def visit_add_constraint(self, create, **kw): return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), self.process(create.element), ) - def visit_set_table_comment(self, create): + def visit_set_table_comment(self, create, **kw): return "COMMENT ON TABLE %s IS %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( @@ -3562,12 +3573,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_table_comment(self, drop): + def visit_drop_table_comment(self, drop, **kw): return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( drop.element ) - def visit_set_column_comment(self, create): + def visit_set_column_comment(self, create, **kw): return "COMMENT ON COLUMN %s IS %s" % ( self.preparer.format_column( create.element, use_table=True, use_schema=True @@ -3577,12 +3588,12 @@ class DDLCompiler(Compiled): ), ) - def visit_drop_column_comment(self, drop): + def visit_drop_column_comment(self, drop, **kw): return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( drop.element, use_table=True ) - def visit_create_sequence(self, create): + def visit_create_sequence(self, create, **kw): text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( create.element ) @@ -3606,10 +3617,10 @@ class DDLCompiler(Compiled): text += " CYCLE" return text - def visit_drop_sequence(self, drop): + def visit_drop_sequence(self, drop, **kw): return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) - def visit_drop_constraint(self, drop): + def visit_drop_constraint(self, drop, **kw): constraint = drop.element if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3671,7 +3682,7 @@ class DDLCompiler(Compiled): else: return self.visit_check_constraint(constraint) - def visit_check_constraint(self, constraint): + def visit_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3683,7 +3694,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_column_check_constraint(self, constraint): + def visit_column_check_constraint(self, constraint, **kw): text = "" if constraint.name is not None: formatted_name = self.preparer.format_constraint(constraint) @@ -3695,7 +3706,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3715,7 +3726,7 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_foreign_key_constraint(self, constraint): + def visit_foreign_key_constraint(self, constraint, **kw): preparer = self.preparer text = "" if constraint.name is not None: @@ -3744,7 +3755,7 @@ class DDLCompiler(Compiled): return preparer.format_table(table) - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -3789,7 +3800,7 @@ class DDLCompiler(Compiled): text += " MATCH %s" % constraint.match return text - def visit_computed_column(self, generated): + def visit_computed_column(self, generated, **kw): text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( generated.sqltext, include_table=False, literal_binds=True ) @@ -3975,7 +3986,16 @@ class IdentifierPreparer(object): illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS - schema_for_object = schema._schema_getter(None) + schema_for_object = operator.attrgetter("schema") + """Return the .schema attribute for an object. + + For the default IdentifierPreparer, the schema for an object is always + the value of the ".schema" attribute. if the preparer is replaced + with one that has a non-empty schema_translate_map, the value of the + ".schema" attribute is rendered a symbol that will be converted to a + real schema name from the mapping post-compile. + + """ def __init__( self, @@ -4016,9 +4036,39 @@ class IdentifierPreparer(object): def _with_schema_translate(self, schema_translate_map): prep = self.__class__.__new__(self.__class__) prep.__dict__.update(self.__dict__) - prep.schema_for_object = schema._schema_getter(schema_translate_map) + + def symbol_getter(obj): + name = obj.schema + if name in schema_translate_map and obj._use_schema_map: + return quoted_name( + "[SCHEMA_%s]" % (name or "_none"), quote=False + ) + else: + return obj.schema + + prep.schema_for_object = symbol_getter return prep + def _render_schema_translates(self, statement, schema_translate_map): + d = schema_translate_map + if None in d: + d["_none"] = d[None] + + def replace(m): + name = m.group(2) + effective_schema = d[name] + if not effective_schema: + effective_schema = self.dialect.default_schema_name + if not effective_schema: + # TODO: no coverage here + raise exc.CompileError( + "Dialect has no default schema name; can't " + "use None as dynamic schema target." + ) + return self.quote(effective_schema) + + return re.sub(r"(\[SCHEMA_([\w\d_]+)\])", replace, statement) + def _escape_identifier(self, value): """Escape an identifier. diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 69f60ba24..02c14d751 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -31,7 +31,6 @@ as components in SQL expressions. from __future__ import absolute_import import collections -import operator import sqlalchemy from . import coercions @@ -143,8 +142,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): schema_item.dispatch._update(self.dispatch) return schema_item - def _translate_schema(self, effective_schema, map_): - return map_.get(effective_schema, effective_schema) + _use_schema_map = True class Table(DialectKWArgs, SchemaItem, TableClause): @@ -4270,59 +4268,6 @@ class ThreadLocalMetaData(MetaData): e.dispose() -class _SchemaTranslateMap(object): - """Provide translation of schema names based on a mapping. - - Also provides helpers for producing cache keys and optimized - access when no mapping is present. - - Used by the :paramref:`.Connection.execution_options.schema_translate_map` - feature. - - .. versionadded:: 1.1 - - - """ - - __slots__ = "map_", "__call__", "hash_key", "is_default" - - _default_schema_getter = operator.attrgetter("schema") - - def __init__(self, map_): - self.map_ = map_ - if map_ is not None: - - def schema_for_object(obj): - effective_schema = self._default_schema_getter(obj) - effective_schema = obj._translate_schema( - effective_schema, map_ - ) - return effective_schema - - self.__call__ = schema_for_object - self.hash_key = ";".join( - "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str) - ) - self.is_default = False - else: - self.hash_key = 0 - self.__call__ = self._default_schema_getter - self.is_default = True - - @classmethod - def _schema_getter(cls, map_): - if map_ is None: - return _default_schema_map - elif isinstance(map_, _SchemaTranslateMap): - return map_ - else: - return _SchemaTranslateMap(map_) - - -_default_schema_map = _SchemaTranslateMap(None) -_schema_getter = _SchemaTranslateMap._schema_getter - - class Computed(FetchedValue, SchemaItem): """Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 45b9e7f9d..ab13b21c4 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -346,8 +346,7 @@ class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable): _is_from_clause = True _is_join = False - def _translate_schema(self, effective_schema, map_): - return effective_schema + _use_schema_map = False _memoized_property = util.group_expirable_memoized_property(["_columns"]) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 3d69d1177..e106684bc 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1000,6 +1000,8 @@ class SchemaType(SchemaEventTarget): """ + _use_schema_map = True + def __init__( self, name=None, @@ -1030,9 +1032,6 @@ class SchemaType(SchemaEventTarget): util.portable_instancemethod(self._on_metadata_drop), ) - def _translate_schema(self, effective_schema, map_): - return map_.get(effective_schema, effective_schema) - def _set_parent(self, column): column._on_table_attach(util.portable_instancemethod(self._set_table)) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e0bf4326e..7dada1394 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -352,6 +352,8 @@ class AssertsCompiledSQL(object): literal_binds=False, render_postcompile=False, schema_translate_map=None, + render_schema_translate=False, + default_schema_name=None, inline_flag=None, ): if use_default_dialect: @@ -371,6 +373,9 @@ class AssertsCompiledSQL(object): elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() + if default_schema_name: + dialect.default_schema_name = default_schema_name + kw = {} compile_kwargs = {} @@ -386,6 +391,9 @@ class AssertsCompiledSQL(object): if render_postcompile: compile_kwargs["render_postcompile"] = True + if render_schema_translate: + kw["render_schema_translate"] = True + from sqlalchemy import orm if isinstance(clause, orm.Query): diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e38c7ddd8..f0da69400 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -91,21 +91,23 @@ class CompiledSQL(SQLMatchRule): context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) + + if "schema_translate_map" in context.execution_options: + map_ = context.execution_options["schema_translate_map"] + else: + map_ = None + if isinstance(context.compiled.statement, _DDLCompiles): + compiled = context.compiled.statement.compile( - dialect=compare_dialect, - schema_translate_map=context.execution_options.get( - "schema_translate_map" - ), + dialect=compare_dialect, schema_translate_map=map_ ) else: compiled = context.compiled.statement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, inline=context.compiled.inline, - schema_translate_map=context.execution_options.get( - "schema_translate_map" - ), + schema_translate_map=map_, ) _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) parameters = execute_observed.parameters diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 473c98116..68a43feb7 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -360,7 +360,6 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_dialect_initialize(self): engine = engines.testing_engine() - assert not hasattr(engine.dialect, "default_schema_name") inspect(engine) assert hasattr(engine.dialect, "default_schema_name") |
