diff options
Diffstat (limited to 'lib/sqlalchemy/dialects')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/named_types.py | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index e2a683e18..b0427b569 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -20,6 +20,7 @@ from ...sql import elements from ...sql import roles from ...sql import sqltypes from ...sql import type_api +from ...sql.base import _NoArg from ...sql.ddl import InvokeCreateDDLBase from ...sql.ddl import InvokeDropDDLBase @@ -244,7 +245,13 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum): DDLGenerator = EnumGenerator DDLDropper = EnumDropper - def __init__(self, *enums, name: str, create_type: bool = True, **kw): + def __init__( + self, + *enums, + name: Union[str, _NoArg, None] = _NoArg.NO_ARG, + create_type: bool = True, + **kw, + ): """Construct an :class:`_postgresql.ENUM`. Arguments are the same as that of @@ -280,7 +287,19 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum): "non-native enum." ) self.create_type = create_type - super().__init__(*enums, name=name, **kw) + if name is not _NoArg.NO_ARG: + kw["name"] = name + super().__init__(*enums, **kw) + + def coerce_compared_value(self, op, value): + super_coerced_type = super().coerce_compared_value(op, value) + if ( + super_coerced_type._type_affinity + is type_api.STRINGTYPE._type_affinity + ): + return self + else: + return super_coerced_type @classmethod def __test_init__(cls): |