diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-02-08 11:58:15 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-02-08 13:14:52 -0500 |
| commit | c0521e6f0688b794048e44ff3df429249a093b72 (patch) | |
| tree | d509123bb03c6e938952c4eb6444fd00d5edbc67 /lib/sqlalchemy | |
| parent | 80010c63149be411e89c7434a9d52096f9de56b8 (diff) | |
| download | sqlalchemy-c0521e6f0688b794048e44ff3df429249a093b72.tar.gz | |
Add identifier_preparer per-execution context for schema translates
Fixed bug where the "schema_translate_map" feature failed to be taken into
account for the use case of direct execution of
:class:`_schema.DefaultGenerator` objects such as sequences, which included
the case where they were "pre-executed" in order to generate primary key
values when implicit_returning was disabled.
Fixes: #5929
Change-Id: I3fed1d0af28be5ce9c9bb572524dcc8411633f60
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/firebird/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/provision.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_sequence.py | 68 |
8 files changed, 112 insertions, 13 deletions
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 82861e30f..7fc914f1b 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -614,7 +614,7 @@ class FBExecutionContext(default.DefaultExecutionContext): return self._execute_scalar( "SELECT gen_id(%s, 1) FROM rdb$database" - % self.dialect.identifier_preparer.format_sequence(seq), + % self.identifier_preparer.format_sequence(seq), type_, ) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9d0e5d322..674d54179 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -17,7 +17,7 @@ External Dialects In addition to the above DBAPI layers with native SQLAlchemy support, there are third-party dialects for other DBAPI layers that are compatible with SQL Server. See the "External Dialects" list on the -:ref:`dialect_toplevel` page. +:ref:`dialect_toplevel` page. .. _mssql_identity: @@ -1560,7 +1560,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self.cursor, self._opt_encode( "SET IDENTITY_INSERT %s ON" - % self.dialect.identifier_preparer.format_table(tbl) + % self.identifier_preparer.format_table(tbl) ), (), self, @@ -1606,7 +1606,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self.cursor, self._opt_encode( "SET IDENTITY_INSERT %s OFF" - % self.dialect.identifier_preparer.format_table( + % self.identifier_preparer.format_table( self.compiled.statement.table ) ), @@ -1630,7 +1630,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self.cursor.execute( self._opt_encode( "SET IDENTITY_INSERT %s OFF" - % self.dialect.identifier_preparer.format_table( + % self.identifier_preparer.format_table( self.compiled.statement.table ) ) @@ -1650,7 +1650,7 @@ class MSExecutionContext(default.DefaultExecutionContext): return self._execute_scalar( ( "SELECT NEXT VALUE FOR %s" - % self.dialect.identifier_preparer.format_sequence(seq) + % self.identifier_preparer.format_sequence(seq) ), type_, ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 063f750fa..c80ff3f19 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -988,9 +988,13 @@ from ...types import BLOB from ...types import BOOLEAN from ...types import DATE from ...types import VARBINARY +from ...util import compat from ...util import topological +if compat.TYPE_CHECKING: + from typing import Any + RESERVED_WORDS = set( [ "accessible", @@ -1394,7 +1398,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext): return self._execute_scalar( ( "select nextval(%s)" - % self.dialect.identifier_preparer.format_sequence(seq) + % self.identifier_preparer.format_sequence(seq) ), type_, ) @@ -3263,6 +3267,7 @@ class MySQLDialect(default.DefaultDialect): return parser.parse(sql, charset) def _detect_charset(self, connection): + # type: (Any) -> str raise NotImplementedError() def _detect_casing(self, connection): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 9344abeee..f9805abeb 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1012,9 +1012,7 @@ class OracleCompiler(compiler.SQLCompiler): return self.process(vc.column, **kw) + "(+)" def visit_sequence(self, seq, **kw): - return ( - self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" - ) + return self.preparer.format_sequence(seq) + ".nextval" def get_render_as_alias_suffix(self, alias_name_text): """Oracle doesn't like ``FROM table AS alias``""" @@ -1441,7 +1439,7 @@ class OracleExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( "SELECT " - + self.dialect.identifier_preparer.format_sequence(seq) + + self.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", type_, ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index f067e6537..7e821acde 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2936,7 +2936,7 @@ class PGExecutionContext(default.DefaultExecutionContext): return self._execute_scalar( ( "select nextval('%s')" - % self.dialect.identifier_preparer.format_sequence(seq) + % self.identifier_preparer.format_sequence(seq) ), type_, ) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 7fddf2814..0c48fcba3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1141,6 +1141,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return "unknown" @util.memoized_property + def identifier_preparer(self): + if self.compiled: + return self.compiled.preparer + elif "schema_translate_map" in self.execution_options: + return self.dialect.identifier_preparer._with_schema_translate( + self.execution_options["schema_translate_map"] + ) + else: + return self.dialect.identifier_preparer + + @util.memoized_property def engine(self): return self.root_connection.engine @@ -1197,6 +1208,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ): stmt = self.dialect._encoder(stmt)[0] + if "schema_translate_map" in self.execution_options: + schema_translate_map = self.execution_options.get( + "schema_translate_map", {} + ) + + rst = self.identifier_preparer._render_schema_translates + stmt = rst(stmt, schema_translate_map) + if not parameters: if self.dialect.positional: parameters = self.dialect.execute_sequence_format() diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 2fade1c32..a976abee0 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -262,7 +262,6 @@ def drop_all_schema_objects(cfg, eng): ) util.drop_all_tables(eng, inspector) - if config.requirements.schemas.enabled_for_config(cfg): util.drop_all_tables(eng, inspector, schema=cfg.test_schema) util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) @@ -273,6 +272,16 @@ def drop_all_schema_objects(cfg, eng): with eng.begin() as conn: for seq in inspector.get_sequence_names(): conn.execute(ddl.DropSequence(schema.Sequence(seq))) + if config.requirements.schemas.enabled_for_config(cfg): + for schema_name in [cfg.test_schema, cfg.test_schema_2]: + for seq in inspector.get_sequence_names( + schema=schema_name + ): + conn.execute( + ddl.DropSequence( + schema.Sequence(seq, schema=schema_name) + ) + ) @register.init diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index 7445ade00..d6747d253 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -45,6 +45,34 @@ class SequenceTest(fixtures.TablesTest): Column("data", String(50)), ) + Table( + "seq_no_returning", + metadata, + Column( + "id", + Integer, + Sequence("noret_id_seq"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + ) + + if testing.requires.schemas.enabled: + Table( + "seq_no_returning_sch", + metadata, + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema=config.test_schema), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + schema=config.test_schema, + ) + def test_insert_roundtrip(self, connection): connection.execute(self.tables.seq_pk.insert(), dict(data="some data")) self._assert_round_trip(self.tables.seq_pk, connection) @@ -72,6 +100,46 @@ class SequenceTest(fixtures.TablesTest): row = conn.execute(table.select()).first() eq_(row, (testing.db.dialect.default_sequence_base, "some data")) + def test_insert_roundtrip_no_implicit_returning(self, connection): + connection.execute( + self.tables.seq_no_returning.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.seq_no_returning, connection) + + @testing.combinations((True,), (False,), argnames="implicit_returning") + @testing.requires.schemas + def test_insert_roundtrip_translate(self, connection, implicit_returning): + + seq_no_returning = Table( + "seq_no_returning_sch", + MetaData(), + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema="alt_schema"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=implicit_returning, + schema="alt_schema", + ) + + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + connection.execute(seq_no_returning.insert(), dict(data="some data")) + self._assert_round_trip(seq_no_returning, connection) + + @testing.requires.schemas + def test_nextval_direct_schema_translate(self, connection): + seq = Sequence("noret_sch_id_seq", schema="alt_schema") + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + + r = connection.execute(seq) + eq_(r, testing.db.dialect.default_sequence_base) + class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): __requires__ = ("sequences",) |
