summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-16 12:07:25 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-17 09:42:29 -0400
commit3b520e758a715cf817075e4a90ae1b5813ffadd3 (patch)
tree260f9517af499e7fb789d188f1631cd823a59929 /lib/sqlalchemy/sql
parent6acf5d2fca4a988a77481b82662174e8015a6b37 (diff)
downloadsqlalchemy-3b520e758a715cf817075e4a90ae1b5813ffadd3.tar.gz
pep484 for hybrid
Change-Id: I53274b13094d996e11b04acb03f9613edbddf87f References: #6810
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_typing.py10
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/sql/elements.py11
-rw-r--r--lib/sqlalchemy/sql/functions.py39
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py39
5 files changed, 84 insertions, 19 deletions
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 389f7e8d0..2be98b88f 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -56,3 +56,13 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
return t._is_tuple_type
+
+
+def is_has_clause_element(s: object) -> TypeGuard[roles.HasClauseElement]:
+ return hasattr(s, "__clause_element__")
+
+
+def is_has_column_element_clause_element(
+ s: object,
+) -> TypeGuard[roles.HasColumnElementClauseElement]:
+ return hasattr(s, "__clause_element__")
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index db88496a0..8f878b66c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -265,7 +265,7 @@ OPERATORS = {
operators.nulls_last_op: " NULLS LAST",
}
-FUNCTIONS: Dict[Type[Function], str] = {
+FUNCTIONS: Dict[Type[Function[Any]], str] = {
functions.coalesce: "coalesce",
functions.current_date: "CURRENT_DATE",
functions.current_time: "CURRENT_TIME",
@@ -2043,7 +2043,7 @@ class SQLCompiler(Compiled):
def visit_function(
self,
- func: Function,
+ func: Function[Any],
add_to_result_map: Optional[_ResultMapAppender] = None,
**kwargs: Any,
) -> str:
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index fdb3fc8bb..c1a7d8476 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -103,8 +103,8 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import CacheStats
from ..engine.result import Result
-_NUMERIC = Union[complex, Decimal]
-_NUMBER = Union[complex, int, Decimal]
+_NUMERIC = Union[float, Decimal]
+_NUMBER = Union[float, int, Decimal]
_T = TypeVar("_T", bound="Any")
_OPT = TypeVar("_OPT", bound="Any")
@@ -348,6 +348,7 @@ class ClauseElement(
the _copy_internals() method.
"""
+
skip = self._memoized_keys
c = self.__class__.__new__(self.__class__)
c.__dict__ = {k: v for k, v in self.__dict__.items() if k not in skip}
@@ -995,11 +996,15 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
@overload
def __truediv__(
- self: _SQO[_NMT], other: Any
+ self: _SQO[int], other: Any
) -> ColumnElement[_NUMERIC]:
...
@overload
+ def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
+ ...
+
+ @overload
def __truediv__(self, other: Any) -> ColumnElement[Any]:
...
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 6e5eec127..563b58418 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,7 +12,10 @@
from __future__ import annotations
from typing import Any
+from typing import Optional
+from typing import overload
from typing import Sequence
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import annotation
@@ -47,6 +50,8 @@ from .type_api import TypeEngine
from .visitors import InternalTraversal
from .. import util
+if TYPE_CHECKING:
+ from ._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
@@ -104,7 +109,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
_with_ordinality = False
_table_value_type = None
- def __init__(self, *clauses, **kwargs):
+ def __init__(self, *clauses: Any):
r"""Construct a :class:`.FunctionElement`.
:param \*clauses: list of column expressions that form the arguments
@@ -752,7 +757,7 @@ class _FunctionGenerator:
self.__names = []
self.opts = opts
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> _FunctionGenerator:
# passthru __ attributes; fixes pydoc
if name.startswith("__"):
try:
@@ -766,7 +771,17 @@ class _FunctionGenerator:
f.__names = list(self.__names) + [name]
return f
- def __call__(self, *c, **kwargs):
+ @overload
+ def __call__(
+ self, *c: Any, type_: TypeEngine[_T], **kwargs: Any
+ ) -> Function[_T]:
+ ...
+
+ @overload
+ def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]:
+ ...
+
+ def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]:
o = self.opts.copy()
o.update(kwargs)
@@ -795,7 +810,7 @@ func.__doc__ = _FunctionGenerator.__doc__
modifier = _FunctionGenerator(group=False)
-class Function(FunctionElement):
+class Function(FunctionElement[_T]):
r"""Describe a named SQL function.
The :class:`.Function` object is typically generated from the
@@ -842,7 +857,7 @@ class Function(FunctionElement):
packagenames: Sequence[str]
- type: TypeEngine = sqltypes.NULLTYPE
+ type: TypeEngine[_T]
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
@@ -859,19 +874,25 @@ class Function(FunctionElement):
"""
- def __init__(self, name, *clauses, **kw):
+ def __init__(
+ self,
+ name: str,
+ *clauses: Any,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ packagenames: Optional[Sequence[str]] = None,
+ ):
"""Construct a :class:`.Function`.
The :data:`.func` construct is normally used to construct
new :class:`.Function` instances.
"""
- self.packagenames = kw.pop("packagenames", None) or ()
+ self.packagenames = packagenames or ()
self.name = name
- self.type = sqltypes.to_instance(kw.get("type_", None))
+ self.type = sqltypes.to_instance(type_)
- FunctionElement.__init__(self, *clauses, **kw)
+ FunctionElement.__init__(self, *clauses)
def _bind_param(self, operator, obj, type_=None, **kw):
return BindParameter(
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index e64ec0843..4d0169370 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -17,6 +17,7 @@ import enum
import json
import pickle
from typing import Any
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TypeVar
@@ -48,6 +49,7 @@ from .. import util
from ..engine import processors
from ..util import langhelpers
from ..util import OrderedDict
+from ..util.typing import Literal
_T = TypeVar("_T", bound="Any")
@@ -373,9 +375,10 @@ class BigInteger(Integer):
__visit_name__ = "big_integer"
-class Numeric(
- _LookupExpressionAdapter, TypeEngine[Union[decimal.Decimal, float]]
-):
+_N = TypeVar("_N", bound=Union[decimal.Decimal, float])
+
+
+class Numeric(_LookupExpressionAdapter, TypeEngine[_N]):
"""A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``.
@@ -542,7 +545,7 @@ class Numeric(
}
-class Float(Numeric):
+class Float(Numeric[_N]):
"""Type representing floating point types, such as ``FLOAT`` or ``REAL``.
@@ -567,8 +570,34 @@ class Float(Numeric):
scale = None
+ @overload
+ def __init__(
+ self: Float[float],
+ precision=...,
+ decimal_return_scale=...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: Float[decimal.Decimal],
+ precision=...,
+ asdecimal: Literal[True] = ...,
+ decimal_return_scale=...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: Float[float],
+ precision=...,
+ asdecimal: Literal[False] = ...,
+ decimal_return_scale=...,
+ ):
+ ...
+
def __init__(
- self: "Float",
+ self: Float[_N],
precision=None,
asdecimal=False,
decimal_return_scale=None,