summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-02-08 11:58:15 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-02-08 13:14:52 -0500
commitc0521e6f0688b794048e44ff3df429249a093b72 (patch)
treed509123bb03c6e938952c4eb6444fd00d5edbc67 /lib/sqlalchemy
parent80010c63149be411e89c7434a9d52096f9de56b8 (diff)
downloadsqlalchemy-c0521e6f0688b794048e44ff3df429249a093b72.tar.gz
Add identifier_preparer per-execution context for schema translates
Fixed bug where the "schema_translate_map" feature failed to be taken into account for the use case of direct execution of :class:`_schema.DefaultGenerator` objects such as sequences, which included the case where they were "pre-executed" in order to generate primary key values when implicit_returning was disabled. Fixes: #5929 Change-Id: I3fed1d0af28be5ce9c9bb572524dcc8411633f60
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py2
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py10
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py7
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py6
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2
-rw-r--r--lib/sqlalchemy/engine/default.py19
-rw-r--r--lib/sqlalchemy/testing/provision.py11
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py68
8 files changed, 112 insertions, 13 deletions
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
index 82861e30f..7fc914f1b 100644
--- a/lib/sqlalchemy/dialects/firebird/base.py
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -614,7 +614,7 @@ class FBExecutionContext(default.DefaultExecutionContext):
return self._execute_scalar(
"SELECT gen_id(%s, 1) FROM rdb$database"
- % self.dialect.identifier_preparer.format_sequence(seq),
+ % self.identifier_preparer.format_sequence(seq),
type_,
)
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 9d0e5d322..674d54179 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -17,7 +17,7 @@ External Dialects
In addition to the above DBAPI layers with native SQLAlchemy support, there
are third-party dialects for other DBAPI layers that are compatible
with SQL Server. See the "External Dialects" list on the
-:ref:`dialect_toplevel` page.
+:ref:`dialect_toplevel` page.
.. _mssql_identity:
@@ -1560,7 +1560,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
self.cursor,
self._opt_encode(
"SET IDENTITY_INSERT %s ON"
- % self.dialect.identifier_preparer.format_table(tbl)
+ % self.identifier_preparer.format_table(tbl)
),
(),
self,
@@ -1606,7 +1606,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
self.cursor,
self._opt_encode(
"SET IDENTITY_INSERT %s OFF"
- % self.dialect.identifier_preparer.format_table(
+ % self.identifier_preparer.format_table(
self.compiled.statement.table
)
),
@@ -1630,7 +1630,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
self.cursor.execute(
self._opt_encode(
"SET IDENTITY_INSERT %s OFF"
- % self.dialect.identifier_preparer.format_table(
+ % self.identifier_preparer.format_table(
self.compiled.statement.table
)
)
@@ -1650,7 +1650,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
return self._execute_scalar(
(
"SELECT NEXT VALUE FOR %s"
- % self.dialect.identifier_preparer.format_sequence(seq)
+ % self.identifier_preparer.format_sequence(seq)
),
type_,
)
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 063f750fa..c80ff3f19 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -988,9 +988,13 @@ from ...types import BLOB
from ...types import BOOLEAN
from ...types import DATE
from ...types import VARBINARY
+from ...util import compat
from ...util import topological
+if compat.TYPE_CHECKING:
+ from typing import Any
+
RESERVED_WORDS = set(
[
"accessible",
@@ -1394,7 +1398,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
return self._execute_scalar(
(
"select nextval(%s)"
- % self.dialect.identifier_preparer.format_sequence(seq)
+ % self.identifier_preparer.format_sequence(seq)
),
type_,
)
@@ -3263,6 +3267,7 @@ class MySQLDialect(default.DefaultDialect):
return parser.parse(sql, charset)
def _detect_charset(self, connection):
+ # type: (Any) -> str
raise NotImplementedError()
def _detect_casing(self, connection):
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 9344abeee..f9805abeb 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -1012,9 +1012,7 @@ class OracleCompiler(compiler.SQLCompiler):
return self.process(vc.column, **kw) + "(+)"
def visit_sequence(self, seq, **kw):
- return (
- self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
- )
+ return self.preparer.format_sequence(seq) + ".nextval"
def get_render_as_alias_suffix(self, alias_name_text):
"""Oracle doesn't like ``FROM table AS alias``"""
@@ -1441,7 +1439,7 @@ class OracleExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
return self._execute_scalar(
"SELECT "
- + self.dialect.identifier_preparer.format_sequence(seq)
+ + self.identifier_preparer.format_sequence(seq)
+ ".nextval FROM DUAL",
type_,
)
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index f067e6537..7e821acde 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2936,7 +2936,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
return self._execute_scalar(
(
"select nextval('%s')"
- % self.dialect.identifier_preparer.format_sequence(seq)
+ % self.identifier_preparer.format_sequence(seq)
),
type_,
)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 7fddf2814..0c48fcba3 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1141,6 +1141,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return "unknown"
@util.memoized_property
+ def identifier_preparer(self):
+ if self.compiled:
+ return self.compiled.preparer
+ elif "schema_translate_map" in self.execution_options:
+ return self.dialect.identifier_preparer._with_schema_translate(
+ self.execution_options["schema_translate_map"]
+ )
+ else:
+ return self.dialect.identifier_preparer
+
+ @util.memoized_property
def engine(self):
return self.root_connection.engine
@@ -1197,6 +1208,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
):
stmt = self.dialect._encoder(stmt)[0]
+ if "schema_translate_map" in self.execution_options:
+ schema_translate_map = self.execution_options.get(
+ "schema_translate_map", {}
+ )
+
+ rst = self.identifier_preparer._render_schema_translates
+ stmt = rst(stmt, schema_translate_map)
+
if not parameters:
if self.dialect.positional:
parameters = self.dialect.execute_sequence_format()
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index 2fade1c32..a976abee0 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -262,7 +262,6 @@ def drop_all_schema_objects(cfg, eng):
)
util.drop_all_tables(eng, inspector)
-
if config.requirements.schemas.enabled_for_config(cfg):
util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
@@ -273,6 +272,16 @@ def drop_all_schema_objects(cfg, eng):
with eng.begin() as conn:
for seq in inspector.get_sequence_names():
conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+ if config.requirements.schemas.enabled_for_config(cfg):
+ for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+ for seq in inspector.get_sequence_names(
+ schema=schema_name
+ ):
+ conn.execute(
+ ddl.DropSequence(
+ schema.Sequence(seq, schema=schema_name)
+ )
+ )
@register.init
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index 7445ade00..d6747d253 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -45,6 +45,34 @@ class SequenceTest(fixtures.TablesTest):
Column("data", String(50)),
)
+ Table(
+ "seq_no_returning",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ )
+
+ if testing.requires.schemas.enabled:
+ Table(
+ "seq_no_returning_sch",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema=config.test_schema),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ schema=config.test_schema,
+ )
+
def test_insert_roundtrip(self, connection):
connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
self._assert_round_trip(self.tables.seq_pk, connection)
@@ -72,6 +100,46 @@ class SequenceTest(fixtures.TablesTest):
row = conn.execute(table.select()).first()
eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
+ def test_insert_roundtrip_no_implicit_returning(self, connection):
+ connection.execute(
+ self.tables.seq_no_returning.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.seq_no_returning, connection)
+
+ @testing.combinations((True,), (False,), argnames="implicit_returning")
+ @testing.requires.schemas
+ def test_insert_roundtrip_translate(self, connection, implicit_returning):
+
+ seq_no_returning = Table(
+ "seq_no_returning_sch",
+ MetaData(),
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema="alt_schema"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=implicit_returning,
+ schema="alt_schema",
+ )
+
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+ connection.execute(seq_no_returning.insert(), dict(data="some data"))
+ self._assert_round_trip(seq_no_returning, connection)
+
+ @testing.requires.schemas
+ def test_nextval_direct_schema_translate(self, connection):
+ seq = Sequence("noret_sch_id_seq", schema="alt_schema")
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+
+ r = connection.execute(seq)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
__requires__ = ("sequences",)