summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/9621.rst18
-rw-r--r--lib/sqlalchemy/dialects/postgresql/named_types.py23
-rw-r--r--test/dialect/postgresql/test_types.py81
3 files changed, 120 insertions, 2 deletions
diff --git a/doc/build/changelog/unreleased_20/9621.rst b/doc/build/changelog/unreleased_20/9621.rst
new file mode 100644
index 000000000..de09479d3
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9621.rst
@@ -0,0 +1,18 @@
+.. change::
+ :tags: bug, postgresql
+ :tickets: 9611
+
+ Restored the :paramref:`_postgresql.ENUM.name` parameter as optional in the
+ signature for :class:`_postgresql.ENUM`, as this is chosen automatically
+ from a given pep-435 ``Enum`` type.
+
+
+.. change::
+ :tags: bug, postgresql
+ :tickets: 9621
+
+ Fixed issue where the comparison for :class:`_postgresql.ENUM` against a
+ plain string would cast that right-hand side type as VARCHAR, which due to
+ more explicit casting added to dialects such as asyncpg would produce a
+ PostgreSQL type mismatch error.
+
diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py
index e2a683e18..b0427b569 100644
--- a/lib/sqlalchemy/dialects/postgresql/named_types.py
+++ b/lib/sqlalchemy/dialects/postgresql/named_types.py
@@ -20,6 +20,7 @@ from ...sql import elements
from ...sql import roles
from ...sql import sqltypes
from ...sql import type_api
+from ...sql.base import _NoArg
from ...sql.ddl import InvokeCreateDDLBase
from ...sql.ddl import InvokeDropDDLBase
@@ -244,7 +245,13 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum):
DDLGenerator = EnumGenerator
DDLDropper = EnumDropper
- def __init__(self, *enums, name: str, create_type: bool = True, **kw):
+ def __init__(
+ self,
+ *enums,
+ name: Union[str, _NoArg, None] = _NoArg.NO_ARG,
+ create_type: bool = True,
+ **kw,
+ ):
"""Construct an :class:`_postgresql.ENUM`.
Arguments are the same as that of
@@ -280,7 +287,19 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum):
"non-native enum."
)
self.create_type = create_type
- super().__init__(*enums, name=name, **kw)
+ if name is not _NoArg.NO_ARG:
+ kw["name"] = name
+ super().__init__(*enums, **kw)
+
+ def coerce_compared_value(self, op, value):
+ super_coerced_type = super().coerce_compared_value(op, value)
+ if (
+ super_coerced_type._type_affinity
+ is type_api.STRINGTYPE._type_affinity
+ ):
+ return self
+ else:
+ return super_coerced_type
@classmethod
def __test_init__(cls):
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 0ee909541..5f5be3c57 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -16,6 +16,7 @@ from sqlalchemy import Enum
from sqlalchemy import exc
from sqlalchemy import Float
from sqlalchemy import func
+from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal
@@ -491,6 +492,86 @@ class NamedTypeTest(
else:
assert False
+ @testing.variation("name", ["noname", "nonename", "explicit_name"])
+ @testing.variation("enum_type", ["pg", "plain"])
+ def test_native_enum_string_from_pep435(self, name, enum_type):
+ """test #9611"""
+
+ class MyEnum(_PY_Enum):
+ one = "one"
+ two = "two"
+
+ if enum_type.plain:
+ cls = Enum
+ elif enum_type.pg:
+ cls = ENUM
+ else:
+ enum_type.fail()
+
+ if name.noname:
+ e1 = cls(MyEnum)
+ eq_(e1.name, "myenum")
+ elif name.nonename:
+ e1 = cls(MyEnum, name=None)
+ eq_(e1.name, None)
+ elif name.explicit_name:
+ e1 = cls(MyEnum, name="abc")
+ eq_(e1.name, "abc")
+
+ @testing.variation("backend_type", ["native", "non_native", "pg_native"])
+ @testing.variation("enum_type", ["pep435", "str"])
+ def test_compare_to_string_round_trip(
+ self, connection, backend_type, enum_type, metadata
+ ):
+ """test #9621"""
+
+ if enum_type.pep435:
+
+ class MyEnum(_PY_Enum):
+ one = "one"
+ two = "two"
+
+ if backend_type.pg_native:
+ typ = ENUM(MyEnum, name="myenum2")
+ else:
+ typ = Enum(
+ MyEnum,
+ native_enum=bool(backend_type.native),
+ name="myenum2",
+ )
+ data = [{"someenum": MyEnum.one}, {"someenum": MyEnum.two}]
+ expected = MyEnum.two
+ elif enum_type.str:
+ if backend_type.pg_native:
+ typ = ENUM("one", "two", name="myenum2")
+ else:
+ typ = Enum(
+ "one",
+ "two",
+ native_enum=bool(backend_type.native),
+ name="myenum2",
+ )
+ data = [{"someenum": "one"}, {"someenum": "two"}]
+ expected = "two"
+ else:
+ enum_type.fail()
+
+ enum_table = Table(
+ "et2",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("someenum", typ),
+ )
+ metadata.create_all(connection)
+
+ connection.execute(insert(enum_table), data)
+ expr = select(enum_table.c.someenum).where(
+ enum_table.c.someenum == "two"
+ )
+
+ row = connection.execute(expr).one()
+ eq_(row, (expected,))
+
@testing.combinations(
(Enum("one", "two", "three")),
(ENUM("one", "two", "three", name=None)),