diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-16 12:07:25 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-17 09:42:29 -0400 |
| commit | 3b520e758a715cf817075e4a90ae1b5813ffadd3 (patch) | |
| tree | 260f9517af499e7fb789d188f1631cd823a59929 /lib/sqlalchemy/sql | |
| parent | 6acf5d2fca4a988a77481b82662174e8015a6b37 (diff) | |
| download | sqlalchemy-3b520e758a715cf817075e4a90ae1b5813ffadd3.tar.gz | |
pep484 for hybrid
Change-Id: I53274b13094d996e11b04acb03f9613edbddf87f
References: #6810
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/_typing.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 39 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 39 |
5 files changed, 84 insertions, 19 deletions
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 389f7e8d0..2be98b88f 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -56,3 +56,13 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: return t._is_tuple_type + + +def is_has_clause_element(s: object) -> TypeGuard[roles.HasClauseElement]: + return hasattr(s, "__clause_element__") + + +def is_has_column_element_clause_element( + s: object, +) -> TypeGuard[roles.HasColumnElementClauseElement]: + return hasattr(s, "__clause_element__") diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index db88496a0..8f878b66c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -265,7 +265,7 @@ OPERATORS = { operators.nulls_last_op: " NULLS LAST", } -FUNCTIONS: Dict[Type[Function], str] = { +FUNCTIONS: Dict[Type[Function[Any]], str] = { functions.coalesce: "coalesce", functions.current_date: "CURRENT_DATE", functions.current_time: "CURRENT_TIME", @@ -2043,7 +2043,7 @@ class SQLCompiler(Compiled): def visit_function( self, - func: Function, + func: Function[Any], add_to_result_map: Optional[_ResultMapAppender] = None, **kwargs: Any, ) -> str: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index fdb3fc8bb..c1a7d8476 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -103,8 +103,8 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import CacheStats from ..engine.result import Result -_NUMERIC = Union[complex, Decimal] -_NUMBER = Union[complex, int, Decimal] +_NUMERIC = Union[float, Decimal] +_NUMBER = Union[float, int, Decimal] _T = TypeVar("_T", bound="Any") _OPT = TypeVar("_OPT", bound="Any") @@ -348,6 +348,7 @@ class ClauseElement( the _copy_internals() method. """ + skip = self._memoized_keys c = self.__class__.__new__(self.__class__) c.__dict__ = {k: v for k, v in self.__dict__.items() if k not in skip} @@ -995,11 +996,15 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): @overload def __truediv__( - self: _SQO[_NMT], other: Any + self: _SQO[int], other: Any ) -> ColumnElement[_NUMERIC]: ... @overload + def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: + ... + + @overload def __truediv__(self, other: Any) -> ColumnElement[Any]: ... diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6e5eec127..563b58418 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -12,7 +12,10 @@ from __future__ import annotations from typing import Any +from typing import Optional +from typing import overload from typing import Sequence +from typing import TYPE_CHECKING from typing import TypeVar from . import annotation @@ -47,6 +50,8 @@ from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util +if TYPE_CHECKING: + from ._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) @@ -104,7 +109,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): _with_ordinality = False _table_value_type = None - def __init__(self, *clauses, **kwargs): + def __init__(self, *clauses: Any): r"""Construct a :class:`.FunctionElement`. :param \*clauses: list of column expressions that form the arguments @@ -752,7 +757,7 @@ class _FunctionGenerator: self.__names = [] self.opts = opts - def __getattr__(self, name): + def __getattr__(self, name: str) -> _FunctionGenerator: # passthru __ attributes; fixes pydoc if name.startswith("__"): try: @@ -766,7 +771,17 @@ class _FunctionGenerator: f.__names = list(self.__names) + [name] return f - def __call__(self, *c, **kwargs): + @overload + def __call__( + self, *c: Any, type_: TypeEngine[_T], **kwargs: Any + ) -> Function[_T]: + ... + + @overload + def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: + ... + + def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: o = self.opts.copy() o.update(kwargs) @@ -795,7 +810,7 @@ func.__doc__ = _FunctionGenerator.__doc__ modifier = _FunctionGenerator(group=False) -class Function(FunctionElement): +class Function(FunctionElement[_T]): r"""Describe a named SQL function. The :class:`.Function` object is typically generated from the @@ -842,7 +857,7 @@ class Function(FunctionElement): packagenames: Sequence[str] - type: TypeEngine = sqltypes.NULLTYPE + type: TypeEngine[_T] """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -859,19 +874,25 @@ class Function(FunctionElement): """ - def __init__(self, name, *clauses, **kw): + def __init__( + self, + name: str, + *clauses: Any, + type_: Optional[_TypeEngineArgument[_T]] = None, + packagenames: Optional[Sequence[str]] = None, + ): """Construct a :class:`.Function`. The :data:`.func` construct is normally used to construct new :class:`.Function` instances. """ - self.packagenames = kw.pop("packagenames", None) or () + self.packagenames = packagenames or () self.name = name - self.type = sqltypes.to_instance(kw.get("type_", None)) + self.type = sqltypes.to_instance(type_) - FunctionElement.__init__(self, *clauses, **kw) + FunctionElement.__init__(self, *clauses) def _bind_param(self, operator, obj, type_=None, **kw): return BindParameter( diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index e64ec0843..4d0169370 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -17,6 +17,7 @@ import enum import json import pickle from typing import Any +from typing import overload from typing import Sequence from typing import Tuple from typing import TypeVar @@ -48,6 +49,7 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict +from ..util.typing import Literal _T = TypeVar("_T", bound="Any") @@ -373,9 +375,10 @@ class BigInteger(Integer): __visit_name__ = "big_integer" -class Numeric( - _LookupExpressionAdapter, TypeEngine[Union[decimal.Decimal, float]] -): +_N = TypeVar("_N", bound=Union[decimal.Decimal, float]) + + +class Numeric(_LookupExpressionAdapter, TypeEngine[_N]): """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``. @@ -542,7 +545,7 @@ class Numeric( } -class Float(Numeric): +class Float(Numeric[_N]): """Type representing floating point types, such as ``FLOAT`` or ``REAL``. @@ -567,8 +570,34 @@ class Float(Numeric): scale = None + @overload + def __init__( + self: Float[float], + precision=..., + decimal_return_scale=..., + ): + ... + + @overload + def __init__( + self: Float[decimal.Decimal], + precision=..., + asdecimal: Literal[True] = ..., + decimal_return_scale=..., + ): + ... + + @overload + def __init__( + self: Float[float], + precision=..., + asdecimal: Literal[False] = ..., + decimal_return_scale=..., + ): + ... + def __init__( - self: "Float", + self: Float[_N], precision=None, asdecimal=False, decimal_return_scale=None, |
