diff options
| -rw-r--r-- | doc/build/changelog/unreleased_13/5158.rst | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 11 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_compiler.py | 16 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 52 |
4 files changed, 87 insertions, 3 deletions
diff --git a/doc/build/changelog/unreleased_13/5158.rst b/doc/build/changelog/unreleased_13/5158.rst new file mode 100644 index 000000000..adab86d40 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5158.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, postgresql + :tickets: 5158 + + Fixed issue where the "schema_translate_map" feature would not work with a + PostgreSQL native enumeration type (i.e. :class:`.Enum`, + :class:`.postgresql.ENUM`) in that while the "CREATE TYPE" statement would + be emitted with the correct schema, the schema would not be rendered in + the CREATE TABLE statement at the point at which the enumeration was + referenced. + diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index ceefc20b0..45911d4c0 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1887,7 +1887,9 @@ class PGDDLCompiler(compiler.DDLCompiler): colspec += " SERIAL" else: colspec += " " + self.dialect.type_compiler.process( - column.type, type_expression=column + column.type, + type_expression=column, + identifier_preparer=self.preparer, ) default = self.get_column_default_string(column) if default is not None: @@ -2149,8 +2151,11 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): else: return self.visit_ENUM(type_, **kw) - def visit_ENUM(self, type_, **kw): - return self.dialect.identifier_preparer.format_type(type_) + def visit_ENUM(self, type_, identifier_preparer=None, **kw): + if identifier_preparer is None: + identifier_preparer = self.dialect.identifier_preparer + + return identifier_preparer.format_type(type_) def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 4c4c43281..aabbc3ac3 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -237,6 +237,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): schema_translate_map=schema_translate_map, ) + def test_create_table_with_schema_type_schema_translate(self): + e1 = Enum("x", "y", "z", name="somename") + e2 = Enum("x", "y", "z", name="somename", schema="someschema") + schema_translate_map = {None: "foo", "someschema": "bar"} + + table = Table( + "some_table", MetaData(), Column("q", e1), Column("p", e2) + ) + from sqlalchemy.schema import CreateTable + + self.assert_compile( + CreateTable(table), + "CREATE TABLE foo.some_table (q foo.somename, p bar.somename)", + schema_translate_map=schema_translate_map, + ) + def test_create_table_with_tablespace(self): m = MetaData() tbl = Table( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 6a61fb33b..be05dec7b 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -175,6 +175,58 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): [(1, "two"), (2, "three"), (3, "three")], ) + @testing.combinations(None, "foo") + def test_create_table_schema_translate_map(self, symbol_name): + # note we can't use the fixture here because it will not drop + # from the correct schema + metadata = MetaData() + + t1 = Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column( + "value", + Enum( + "one", + "two", + "three", + name="schema_enum", + schema=symbol_name, + ), + ), + schema=symbol_name, + ) + with testing.db.connect() as conn: + conn = conn.execution_options( + schema_translate_map={symbol_name: testing.config.test_schema} + ) + t1.create(conn) + assert "schema_enum" in [ + e["name"] + for e in inspect(conn).get_enums( + schema=testing.config.test_schema + ) + ] + t1.create(conn, checkfirst=True) + + conn.execute(t1.insert(), value="two") + conn.execute(t1.insert(), value="three") + conn.execute(t1.insert(), value="three") + eq_( + conn.execute(t1.select().order_by(t1.c.id)).fetchall(), + [(1, "two"), (2, "three"), (3, "three")], + ) + + t1.drop(conn) + assert "schema_enum" not in [ + e["name"] + for e in inspect(conn).get_enums( + schema=testing.config.test_schema + ) + ] + t1.drop(conn, checkfirst=True) + def test_name_required(self): metadata = MetaData(testing.db) etype = Enum("four", "five", "six", metadata=metadata) |
