summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-04-25 16:44:00 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-04-25 16:44:00 +0000
commite8173073ffd4b08a6efd36b7d290642ae96f5de3 (patch)
treef9f87aacccec61a65cd9715b7919670e536c4d2b /lib/sqlalchemy/dialects
parent32a17e60ba63f0278a754e1ab7e9ebf9460e07c5 (diff)
parentf3bc7e5e2b0f8242661c8d89797bfcb3503d9948 (diff)
downloadsqlalchemy-e8173073ffd4b08a6efd36b7d290642ae96f5de3.tar.gz
Merge "Adding typing to Postgres dialect file." into main
Diffstat (limited to 'lib/sqlalchemy/dialects')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/types.py67
1 files changed, 49 insertions, 18 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py
index edab23935..0db2721c8 100644
--- a/lib/sqlalchemy/dialects/postgresql/types.py
+++ b/lib/sqlalchemy/dialects/postgresql/types.py
@@ -3,24 +3,46 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
import datetime as dt
+from typing import Any
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TYPE_CHECKING
+from uuid import UUID as _python_UUID
from ...sql import sqltypes
-
+from ...sql import type_api
+from ...util.typing import Literal
_DECIMAL_TYPES = (1231, 1700)
_FLOAT_TYPES = (700, 701, 1021, 1022)
_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
-class PGUuid(sqltypes.UUID):
+class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]):
render_bind_cast = True
render_literal_cast = True
+ if TYPE_CHECKING:
+
+ @overload
+ def __init__(
+ self: PGUuid[_python_UUID], as_uuid: Literal[True] = ...
+ ) -> None:
+ ...
+
+ @overload
+ def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None:
+ ...
+
+ def __init__(self, as_uuid: bool = True) -> None:
+ ...
+
-class BYTEA(sqltypes.LargeBinary[bytes]):
+class BYTEA(sqltypes.LargeBinary):
__visit_name__ = "BYTEA"
@@ -53,7 +75,6 @@ PGMacAddr8 = MACADDR8
class MONEY(sqltypes.TypeEngine[str]):
-
r"""Provide the PostgreSQL MONEY type.
Depending on driver, result rows using this type may return a
@@ -146,7 +167,9 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"
- def __init__(self, timezone=False, precision=None):
+ def __init__(
+ self, timezone: bool = False, precision: Optional[int] = None
+ ) -> None:
"""Construct a TIMESTAMP.
:param timezone: boolean value if timezone present, default False
@@ -165,7 +188,9 @@ class TIME(sqltypes.TIME):
__visit_name__ = "TIME"
- def __init__(self, timezone=False, precision=None):
+ def __init__(
+ self, timezone: bool = False, precision: Optional[int] = None
+ ) -> None:
"""Construct a TIME.
:param timezone: boolean value if timezone present, default False
@@ -178,14 +203,16 @@ class TIME(sqltypes.TIME):
self.precision = precision
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
"""PostgreSQL INTERVAL type."""
__visit_name__ = "INTERVAL"
native = True
- def __init__(self, precision=None, fields=None):
+ def __init__(
+ self, precision: Optional[int] = None, fields: Optional[str] = None
+ ) -> None:
"""Construct an INTERVAL.
:param precision: optional integer precision value
@@ -200,18 +227,20 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
self.fields = fields
@classmethod
- def adapt_emulated_to_native(cls, interval, **kw):
+ def adapt_emulated_to_native(
+ cls, interval: sqltypes.Interval, **kw: Any # type: ignore[override]
+ ) -> INTERVAL:
return INTERVAL(precision=interval.second_precision)
@property
- def _type_affinity(self):
+ def _type_affinity(self) -> Type[sqltypes.Interval]:
return sqltypes.Interval
- def as_generic(self, allow_nulltype=False):
+ def as_generic(self, allow_nulltype: bool = False) -> sqltypes.Interval:
return sqltypes.Interval(native=True, second_precision=self.precision)
@property
- def python_type(self):
+ def python_type(self) -> Type[dt.timedelta]:
return dt.timedelta
@@ -221,13 +250,15 @@ PGInterval = INTERVAL
class BIT(sqltypes.TypeEngine[int]):
__visit_name__ = "BIT"
- def __init__(self, length=None, varying=False):
- if not varying:
+ def __init__(
+ self, length: Optional[int] = None, varying: bool = False
+ ) -> None:
+ if varying:
+ # BIT VARYING can be unlimited-length, so no default
+ self.length = length
+ else:
# BIT without VARYING defaults to length 1
self.length = length or 1
- else:
- # but BIT VARYING can be unlimited-length, so no default
- self.length = length
self.varying = varying