diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2023-04-25 16:44:00 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-04-25 16:44:00 +0000 |
commit | e8173073ffd4b08a6efd36b7d290642ae96f5de3 (patch) | |
tree | f9f87aacccec61a65cd9715b7919670e536c4d2b /lib/sqlalchemy/dialects/postgresql | |
parent | 32a17e60ba63f0278a754e1ab7e9ebf9460e07c5 (diff) | |
parent | f3bc7e5e2b0f8242661c8d89797bfcb3503d9948 (diff) | |
download | sqlalchemy-e8173073ffd4b08a6efd36b7d290642ae96f5de3.tar.gz |
Merge "Adding typing to Postgres dialect file." into main
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/types.py | 67 |
1 files changed, 49 insertions, 18 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index edab23935..0db2721c8 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -3,24 +3,46 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations import datetime as dt +from typing import Any +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from uuid import UUID as _python_UUID from ...sql import sqltypes - +from ...sql import type_api +from ...util.typing import Literal _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) -class PGUuid(sqltypes.UUID): +class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]): render_bind_cast = True render_literal_cast = True + if TYPE_CHECKING: + + @overload + def __init__( + self: PGUuid[_python_UUID], as_uuid: Literal[True] = ... + ) -> None: + ... + + @overload + def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None: + ... + + def __init__(self, as_uuid: bool = True) -> None: + ... + -class BYTEA(sqltypes.LargeBinary[bytes]): +class BYTEA(sqltypes.LargeBinary): __visit_name__ = "BYTEA" @@ -53,7 +75,6 @@ PGMacAddr8 = MACADDR8 class MONEY(sqltypes.TypeEngine[str]): - r"""Provide the PostgreSQL MONEY type. Depending on driver, result rows using this type may return a @@ -146,7 +167,9 @@ class TIMESTAMP(sqltypes.TIMESTAMP): __visit_name__ = "TIMESTAMP" - def __init__(self, timezone=False, precision=None): + def __init__( + self, timezone: bool = False, precision: Optional[int] = None + ) -> None: """Construct a TIMESTAMP. :param timezone: boolean value if timezone present, default False @@ -165,7 +188,9 @@ class TIME(sqltypes.TIME): __visit_name__ = "TIME" - def __init__(self, timezone=False, precision=None): + def __init__( + self, timezone: bool = False, precision: Optional[int] = None + ) -> None: """Construct a TIME. :param timezone: boolean value if timezone present, default False @@ -178,14 +203,16 @@ class TIME(sqltypes.TIME): self.precision = precision -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): +class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): """PostgreSQL INTERVAL type.""" __visit_name__ = "INTERVAL" native = True - def __init__(self, precision=None, fields=None): + def __init__( + self, precision: Optional[int] = None, fields: Optional[str] = None + ) -> None: """Construct an INTERVAL. :param precision: optional integer precision value @@ -200,18 +227,20 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): self.fields = fields @classmethod - def adapt_emulated_to_native(cls, interval, **kw): + def adapt_emulated_to_native( + cls, interval: sqltypes.Interval, **kw: Any # type: ignore[override] + ) -> INTERVAL: return INTERVAL(precision=interval.second_precision) @property - def _type_affinity(self): + def _type_affinity(self) -> Type[sqltypes.Interval]: return sqltypes.Interval - def as_generic(self, allow_nulltype=False): + def as_generic(self, allow_nulltype: bool = False) -> sqltypes.Interval: return sqltypes.Interval(native=True, second_precision=self.precision) @property - def python_type(self): + def python_type(self) -> Type[dt.timedelta]: return dt.timedelta @@ -221,13 +250,15 @@ PGInterval = INTERVAL class BIT(sqltypes.TypeEngine[int]): __visit_name__ = "BIT" - def __init__(self, length=None, varying=False): - if not varying: + def __init__( + self, length: Optional[int] = None, varying: bool = False + ) -> None: + if varying: + # BIT VARYING can be unlimited-length, so no default + self.length = length + else: # BIT without VARYING defaults to length 1 self.length = length or 1 - else: - # but BIT VARYING can be unlimited-length, so no default - self.length = length self.varying = varying |