summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-16 12:07:25 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-17 09:42:29 -0400
commit3b520e758a715cf817075e4a90ae1b5813ffadd3 (patch)
tree260f9517af499e7fb789d188f1631cd823a59929 /lib
parent6acf5d2fca4a988a77481b82662174e8015a6b37 (diff)
downloadsqlalchemy-3b520e758a715cf817075e4a90ae1b5813ffadd3.tar.gz
pep484 for hybrid
Change-Id: I53274b13094d996e11b04acb03f9613edbddf87f References: #6810
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/ext/hybrid.py224
-rw-r--r--lib/sqlalchemy/orm/attributes.py10
-rw-r--r--lib/sqlalchemy/orm/base.py4
-rw-r--r--lib/sqlalchemy/orm/interfaces.py4
-rw-r--r--lib/sqlalchemy/sql/_typing.py10
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/sql/elements.py11
-rw-r--r--lib/sqlalchemy/sql/functions.py39
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py39
9 files changed, 272 insertions, 73 deletions
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index dc34a2ef5..92b3ce54f 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -802,17 +802,41 @@ advanced and/or patient developers, there's probably a whole lot of amazing
things it can be used for.
""" # noqa
+
+from __future__ import annotations
+
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from .. import util
from ..orm import attributes
from ..orm import InspectionAttrExtensionType
from ..orm import interfaces
from ..orm import ORMDescriptor
+from ..sql._typing import is_has_column_element_clause_element
+from ..sql.elements import ColumnElement
+from ..sql.elements import SQLCoreOperations
+from ..util.typing import Literal
+from ..util.typing import Protocol
+if TYPE_CHECKING:
+ from ..orm.util import AliasedInsp
+ from ..sql.operators import OperatorType
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
class HybridExtensionType(InspectionAttrExtensionType):
@@ -844,7 +868,34 @@ class HybridExtensionType(InspectionAttrExtensionType):
"""
-class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
+class _HybridGetterType(Protocol[_T_co]):
+ def __call__(s, self: Any) -> _T_co:
+ ...
+
+
+class _HybridSetterType(Protocol[_T_con]):
+ def __call__(self, instance: Any, value: _T_con) -> None:
+ ...
+
+
+class _HybridUpdaterType(Protocol[_T]):
+ def __call__(
+ self, cls: Type[Any], value: Union[_T, SQLCoreOperations[_T]]
+ ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
+ ...
+
+
+class _HybridDeleterType(Protocol[_T_co]):
+ def __call__(self, instance: Any) -> None:
+ ...
+
+
+class _HybridExprCallableType(Protocol[_T]):
+ def __call__(self, cls: Any) -> SQLCoreOperations[_T]:
+ ...
+
+
+class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
"""A decorator which allows definition of a Python object method with both
instance-level and class-level behavior.
@@ -853,7 +904,11 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
is_attribute = True
extension_type = HybridExtensionType.HYBRID_METHOD
- def __init__(self, func, expr=None):
+ def __init__(
+ self,
+ func: Callable[..., _T],
+ expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None,
+ ):
"""Create a new :class:`.hybrid_method`.
Usage is typically via decorator::
@@ -873,13 +928,29 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
self.func = func
self.expression(expr or func)
- def __get__(self, instance, owner):
+ @overload
+ def __get__(
+ self, instance: Literal[None], owner: Type[object]
+ ) -> Callable[[Any], SQLCoreOperations[_T]]:
+ ...
+
+ @overload
+ def __get__(
+ self, instance: object, owner: Type[object]
+ ) -> Callable[[Any], _T]:
+ ...
+
+ def __get__(
+ self, instance: Optional[object], owner: Type[object]
+ ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]:
if instance is None:
- return self.expr.__get__(owner, owner.__class__)
+ return self.expr.__get__(owner, owner) # type: ignore
else:
- return self.func.__get__(instance, owner)
+ return self.func.__get__(instance, owner) # type: ignore
- def expression(self, expr):
+ def expression(
+ self, expr: Callable[..., SQLCoreOperations[_T]]
+ ) -> hybrid_method[_T]:
"""Provide a modifying decorator that defines a
SQL-expression producing method."""
@@ -889,7 +960,12 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
return self
-class hybrid_property(interfaces.InspectionAttrInfo):
+Selfhybrid_property = TypeVar(
+ "Selfhybrid_property", bound="hybrid_property[Any]"
+)
+
+
+class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
"""A decorator which allows definition of a Python descriptor with both
instance-level and class-level behavior.
@@ -898,14 +974,16 @@ class hybrid_property(interfaces.InspectionAttrInfo):
is_attribute = True
extension_type = HybridExtensionType.HYBRID_PROPERTY
+ __name__: str
+
def __init__(
self,
- fget,
- fset=None,
- fdel=None,
- expr=None,
- custom_comparator=None,
- update_expr=None,
+ fget: _HybridGetterType[_T],
+ fset: Optional[_HybridSetterType[_T]] = None,
+ fdel: Optional[_HybridDeleterType[_T]] = None,
+ expr: Optional[_HybridExprCallableType[_T]] = None,
+ custom_comparator: Optional[Comparator[_T]] = None,
+ update_expr: Optional[_HybridUpdaterType[_T]] = None,
):
"""Create a new :class:`.hybrid_property`.
@@ -931,23 +1009,43 @@ class hybrid_property(interfaces.InspectionAttrInfo):
self.update_expr = update_expr
util.update_wrapper(self, fget)
- def __get__(self, instance, owner):
- if instance is None:
+ @overload
+ def __get__(
+ self: Selfhybrid_property, instance: Any, owner: Literal[None]
+ ) -> Selfhybrid_property:
+ ...
+
+ @overload
+ def __get__(
+ self, instance: Literal[None], owner: Type[object]
+ ) -> SQLCoreOperations[_T]:
+ ...
+
+ @overload
+ def __get__(self, instance: object, owner: Type[object]) -> _T:
+ ...
+
+ def __get__(
+ self, instance: Optional[object], owner: Optional[Type[object]]
+ ) -> Union[hybrid_property[_T], SQLCoreOperations[_T], _T]:
+ if owner is None:
+ return self
+ elif instance is None:
return self._expr_comparator(owner)
else:
return self.fget(instance)
- def __set__(self, instance, value):
+ def __set__(self, instance: object, value: Any) -> None:
if self.fset is None:
raise AttributeError("can't set attribute")
self.fset(instance, value)
- def __delete__(self, instance):
+ def __delete__(self, instance: object) -> None:
if self.fdel is None:
raise AttributeError("can't delete attribute")
self.fdel(instance)
- def _copy(self, **kw):
+ def _copy(self, **kw: Any) -> hybrid_property[_T]:
defaults = {
key: value
for key, value in self.__dict__.items()
@@ -957,7 +1055,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return type(self)(**defaults)
@property
- def overrides(self):
+ def overrides(self: Selfhybrid_property) -> Selfhybrid_property:
"""Prefix for a method that is overriding an existing attribute.
The :attr:`.hybrid_property.overrides` accessor just returns
@@ -992,7 +1090,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
"""
return self
- def getter(self, fget):
+ def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a getter method.
.. versionadded:: 1.2
@@ -1001,17 +1099,19 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._copy(fget=fget)
- def setter(self, fset):
+ def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a setter method."""
return self._copy(fset=fset)
- def deleter(self, fdel):
+ def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a deletion method."""
return self._copy(fdel=fdel)
- def expression(self, expr):
+ def expression(
+ self, expr: _HybridExprCallableType[_T]
+ ) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a SQL-expression
producing method.
@@ -1043,7 +1143,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._copy(expr=expr)
- def comparator(self, comparator):
+ def comparator(self, comparator: Comparator[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a custom
comparator producing method.
@@ -1078,7 +1178,9 @@ class hybrid_property(interfaces.InspectionAttrInfo):
"""
return self._copy(custom_comparator=comparator)
- def update_expression(self, meth):
+ def update_expression(
+ self, meth: _HybridUpdaterType[_T]
+ ) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines an UPDATE tuple
producing method.
@@ -1115,27 +1217,35 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._copy(update_expr=meth)
@util.memoized_property
- def _expr_comparator(self):
+ def _expr_comparator(
+ self,
+ ) -> Callable[[Any], interfaces.PropComparator[_T]]:
if self.custom_comparator is not None:
return self._get_comparator(self.custom_comparator)
elif self.expr is not None:
return self._get_expr(self.expr)
else:
- return self._get_expr(self.fget)
+ return self._get_expr(cast(_HybridExprCallableType[_T], self.fget))
- def _get_expr(self, expr):
- def _expr(cls):
+ def _get_expr(
+ self, expr: _HybridExprCallableType[_T]
+ ) -> Callable[[Any], interfaces.PropComparator[_T]]:
+ def _expr(cls: Any) -> ExprComparator[_T]:
return ExprComparator(cls, expr(cls), self)
util.update_wrapper(_expr, expr)
return self._get_comparator(_expr)
- def _get_comparator(self, comparator):
+ def _get_comparator(
+ self, comparator: Any
+ ) -> Callable[[Any], interfaces.PropComparator[_T]]:
proxy_attr = attributes.create_proxied_attribute(self)
- def expr_comparator(owner):
+ def expr_comparator(
+ owner: Type[object],
+ ) -> interfaces.PropComparator[_T]:
# because this is the descriptor protocol, we don't really know
# what our attribute name is. so search for it through the
# MRO.
@@ -1163,36 +1273,48 @@ class Comparator(interfaces.PropComparator[_T]):
:class:`~.orm.interfaces.PropComparator`
classes for usage with hybrids."""
- property = None
-
- def __init__(self, expression):
+ def __init__(self, expression: SQLCoreOperations[_T]):
self.expression = expression
- def __clause_element__(self):
+ def __clause_element__(self) -> ColumnElement[_T]:
expr = self.expression
- if hasattr(expr, "__clause_element__"):
+ if is_has_column_element_clause_element(expr):
expr = expr.__clause_element__()
+
+ elif TYPE_CHECKING:
+ assert isinstance(expr, ColumnElement)
return expr
- def adapt_to_entity(self, adapt_to_entity):
+ @util.non_memoized_property
+ def property(self) -> Any:
+ return None
+
+ def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]:
# interesting....
return self
class ExprComparator(Comparator[_T]):
- def __init__(self, cls, expression, hybrid):
+ def __init__(
+ self,
+ cls: Type[Any],
+ expression: SQLCoreOperations[_T],
+ hybrid: hybrid_property[_T],
+ ):
self.cls = cls
self.expression = expression
self.hybrid = hybrid
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
return getattr(self.expression, key)
- @property
- def info(self):
+ @util.non_memoized_property
+ def info(self) -> Dict[Any, Any]:
return self.hybrid.info
- def _bulk_update_tuples(self, value):
+ def _bulk_update_tuples(
+ self, value: Any
+ ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
if isinstance(self.expression, attributes.QueryableAttribute):
return self.expression._bulk_update_tuples(value)
elif self.hybrid.update_expr is not None:
@@ -1200,12 +1322,16 @@ class ExprComparator(Comparator[_T]):
else:
return [(self.expression, value)]
- @property
- def property(self):
- return self.expression.property
+ @util.non_memoized_property
+ def property(self) -> Any:
+ return self.expression.property # type: ignore
- def operate(self, op, *other, **kwargs):
- return op(self.expression, *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.expression, *other, **kwargs) # type: ignore
- def reverse_operate(self, op, other, **kwargs):
- return op(other, self.expression, **kwargs)
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(other, self.expression, **kwargs) # type: ignore
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index ce3a645ad..2b6ca400e 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -20,6 +20,7 @@ from collections import namedtuple
import operator
import typing
from typing import Any
+from typing import Callable
from typing import List
from typing import NamedTuple
from typing import Tuple
@@ -68,6 +69,7 @@ from ..sql import visitors
if typing.TYPE_CHECKING:
from ..sql.elements import ColumnElement
+ from ..sql.elements import SQLCoreOperations
_T = TypeVar("_T")
@@ -277,7 +279,9 @@ class QueryableAttribute(
def _from_objects(self):
return self.expression._from_objects
- def _bulk_update_tuples(self, value):
+ def _bulk_update_tuples(
+ self, value: Any
+ ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
"""Return setter tuples for a bulk UPDATE."""
return self.comparator._bulk_update_tuples(value)
@@ -416,7 +420,9 @@ HasEntityNamespace = namedtuple("HasEntityNamespace", ["entity_namespace"])
HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
-def create_proxied_attribute(descriptor):
+def create_proxied_attribute(
+ descriptor: Any,
+) -> Callable[..., QueryableAttribute[Any]]:
"""Create an QueryableAttribute / user descriptor hybrid.
Returns a new QueryableAttribute type that delegates descriptor
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index c63a89c70..cb3070103 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -638,7 +638,7 @@ class ORMDescriptor(Generic[_T], TypingOnly):
@overload
def __get__(
self, instance: Literal[None], owner: Any
- ) -> SQLORMOperations[_T]:
+ ) -> SQLCoreOperations[_T]:
...
@overload
@@ -647,7 +647,7 @@ class ORMDescriptor(Generic[_T], TypingOnly):
def __get__(
self, instance: object, owner: Any
- ) -> Union[SQLORMOperations[_T], _T]:
+ ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]:
...
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 00ddbcca7..d79774187 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -466,7 +466,9 @@ class PropComparator(SQLORMOperations[_T]):
def __clause_element__(self):
raise NotImplementedError("%r" % self)
- def _bulk_update_tuples(self, value):
+ def _bulk_update_tuples(
+ self, value: Any
+ ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
"""Receive a SQL expression that represents a value in the SET
clause of an UPDATE statement.
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 389f7e8d0..2be98b88f 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -56,3 +56,13 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
return t._is_tuple_type
+
+
+def is_has_clause_element(s: object) -> TypeGuard[roles.HasClauseElement]:
+ return hasattr(s, "__clause_element__")
+
+
+def is_has_column_element_clause_element(
+ s: object,
+) -> TypeGuard[roles.HasColumnElementClauseElement]:
+ return hasattr(s, "__clause_element__")
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index db88496a0..8f878b66c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -265,7 +265,7 @@ OPERATORS = {
operators.nulls_last_op: " NULLS LAST",
}
-FUNCTIONS: Dict[Type[Function], str] = {
+FUNCTIONS: Dict[Type[Function[Any]], str] = {
functions.coalesce: "coalesce",
functions.current_date: "CURRENT_DATE",
functions.current_time: "CURRENT_TIME",
@@ -2043,7 +2043,7 @@ class SQLCompiler(Compiled):
def visit_function(
self,
- func: Function,
+ func: Function[Any],
add_to_result_map: Optional[_ResultMapAppender] = None,
**kwargs: Any,
) -> str:
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index fdb3fc8bb..c1a7d8476 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -103,8 +103,8 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import CacheStats
from ..engine.result import Result
-_NUMERIC = Union[complex, Decimal]
-_NUMBER = Union[complex, int, Decimal]
+_NUMERIC = Union[float, Decimal]
+_NUMBER = Union[float, int, Decimal]
_T = TypeVar("_T", bound="Any")
_OPT = TypeVar("_OPT", bound="Any")
@@ -348,6 +348,7 @@ class ClauseElement(
the _copy_internals() method.
"""
+
skip = self._memoized_keys
c = self.__class__.__new__(self.__class__)
c.__dict__ = {k: v for k, v in self.__dict__.items() if k not in skip}
@@ -995,11 +996,15 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
@overload
def __truediv__(
- self: _SQO[_NMT], other: Any
+ self: _SQO[int], other: Any
) -> ColumnElement[_NUMERIC]:
...
@overload
+ def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
+ ...
+
+ @overload
def __truediv__(self, other: Any) -> ColumnElement[Any]:
...
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 6e5eec127..563b58418 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,7 +12,10 @@
from __future__ import annotations
from typing import Any
+from typing import Optional
+from typing import overload
from typing import Sequence
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import annotation
@@ -47,6 +50,8 @@ from .type_api import TypeEngine
from .visitors import InternalTraversal
from .. import util
+if TYPE_CHECKING:
+ from ._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
@@ -104,7 +109,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
_with_ordinality = False
_table_value_type = None
- def __init__(self, *clauses, **kwargs):
+ def __init__(self, *clauses: Any):
r"""Construct a :class:`.FunctionElement`.
:param \*clauses: list of column expressions that form the arguments
@@ -752,7 +757,7 @@ class _FunctionGenerator:
self.__names = []
self.opts = opts
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> _FunctionGenerator:
# passthru __ attributes; fixes pydoc
if name.startswith("__"):
try:
@@ -766,7 +771,17 @@ class _FunctionGenerator:
f.__names = list(self.__names) + [name]
return f
- def __call__(self, *c, **kwargs):
+ @overload
+ def __call__(
+ self, *c: Any, type_: TypeEngine[_T], **kwargs: Any
+ ) -> Function[_T]:
+ ...
+
+ @overload
+ def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]:
+ ...
+
+ def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]:
o = self.opts.copy()
o.update(kwargs)
@@ -795,7 +810,7 @@ func.__doc__ = _FunctionGenerator.__doc__
modifier = _FunctionGenerator(group=False)
-class Function(FunctionElement):
+class Function(FunctionElement[_T]):
r"""Describe a named SQL function.
The :class:`.Function` object is typically generated from the
@@ -842,7 +857,7 @@ class Function(FunctionElement):
packagenames: Sequence[str]
- type: TypeEngine = sqltypes.NULLTYPE
+ type: TypeEngine[_T]
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
@@ -859,19 +874,25 @@ class Function(FunctionElement):
"""
- def __init__(self, name, *clauses, **kw):
+ def __init__(
+ self,
+ name: str,
+ *clauses: Any,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ packagenames: Optional[Sequence[str]] = None,
+ ):
"""Construct a :class:`.Function`.
The :data:`.func` construct is normally used to construct
new :class:`.Function` instances.
"""
- self.packagenames = kw.pop("packagenames", None) or ()
+ self.packagenames = packagenames or ()
self.name = name
- self.type = sqltypes.to_instance(kw.get("type_", None))
+ self.type = sqltypes.to_instance(type_)
- FunctionElement.__init__(self, *clauses, **kw)
+ FunctionElement.__init__(self, *clauses)
def _bind_param(self, operator, obj, type_=None, **kw):
return BindParameter(
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index e64ec0843..4d0169370 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -17,6 +17,7 @@ import enum
import json
import pickle
from typing import Any
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TypeVar
@@ -48,6 +49,7 @@ from .. import util
from ..engine import processors
from ..util import langhelpers
from ..util import OrderedDict
+from ..util.typing import Literal
_T = TypeVar("_T", bound="Any")
@@ -373,9 +375,10 @@ class BigInteger(Integer):
__visit_name__ = "big_integer"
-class Numeric(
- _LookupExpressionAdapter, TypeEngine[Union[decimal.Decimal, float]]
-):
+_N = TypeVar("_N", bound=Union[decimal.Decimal, float])
+
+
+class Numeric(_LookupExpressionAdapter, TypeEngine[_N]):
"""A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``.
@@ -542,7 +545,7 @@ class Numeric(
}
-class Float(Numeric):
+class Float(Numeric[_N]):
"""Type representing floating point types, such as ``FLOAT`` or ``REAL``.
@@ -567,8 +570,34 @@ class Float(Numeric):
scale = None
+ @overload
+ def __init__(
+ self: Float[float],
+ precision=...,
+ decimal_return_scale=...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: Float[decimal.Decimal],
+ precision=...,
+ asdecimal: Literal[True] = ...,
+ decimal_return_scale=...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: Float[float],
+ precision=...,
+ asdecimal: Literal[False] = ...,
+ decimal_return_scale=...,
+ ):
+ ...
+
def __init__(
- self: "Float",
+ self: Float[_N],
precision=None,
asdecimal=False,
decimal_return_scale=None,