summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/functions.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-30 18:01:58 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-04 09:26:43 -0400
commit3b4d62f4f72e8dfad7f38db192a6a90a8551608c (patch)
treed0334c4bb52f803bd7dad661f2e6a12e25f5880c /lib/sqlalchemy/sql/functions.py
parent4e603e23755f31278f27a45449120a8dea470a45 (diff)
downloadsqlalchemy-3b4d62f4f72e8dfad7f38db192a6a90a8551608c.tar.gz
pep484 - sql.selectable
the pep484 task becomes more intense as there is mounting pressure to come up with a consistency in how data moves from end-user to instance variable. current thinking is coming into: 1. there are _typing._XYZArgument objects that represent "what the user sent" 2. there's the roles, which represent a kind of "filter" for different kinds of objects. These are mostly important as the argument we pass to coerce(). 3. there's the thing that coerce() returns, which should be what the construct uses as its internal representation of the thing. This is _typing._XYZElement. but there's some controversy over whether or not we should pass actual ClauseElements around by their role or not. I think we shouldn't at the moment, but this makes the "role-ness" of something a little less portable. Like, we have to set DMLTableRole for TableClause, Join, and Alias, but then also we have to repeat those three types in order to set up _DMLTableElement. Other change introduced here, there was a deannotate=True for the left/right of a sql.join(). All tests pass without that. I'd rather not have that there as if we have a join(A, B) where A, B are mapped classes, we want them inside of the _annotations. The rationale seems to be performance, but this performance can be illustrated to be on the compile side which we hope is cached in the normal case. CTEs now accommodate for text selects including recursive. Get typing to accommodate "util.preloaded" cleanly; add "preloaded" as a real module. This seemed like we would have needed pep562 `__getattr__()` but we don't, just set names in globals() as we import them. References: #6810 Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r--lib/sqlalchemy/sql/functions.py246
1 files changed, 146 insertions, 100 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 3bca8b502..db4bb5837 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -11,10 +11,15 @@
from __future__ import annotations
+import datetime
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Mapping
from typing import Optional
from typing import overload
-from typing import Sequence
+from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -24,7 +29,9 @@ from . import operators
from . import roles
from . import schema
from . import sqltypes
+from . import type_api
from . import util as sqlutil
+from ._typing import is_table_value_type
from .base import _entity_namespace
from .base import ColumnCollection
from .base import Executable
@@ -46,16 +53,21 @@ from .elements import WithinGroup
from .selectable import FromClause
from .selectable import Select
from .selectable import TableValuedAlias
+from .sqltypes import _N
+from .sqltypes import TableValueType
from .type_api import TypeEngine
from .visitors import InternalTraversal
from .. import util
+
if TYPE_CHECKING:
from ._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
-_registry = util.defaultdict(dict)
+_registry: util.defaultdict[
+ str, Dict[str, Type[Function[Any]]]
+] = util.defaultdict(dict)
def register_function(identifier, fn, package="_default"):
@@ -103,11 +115,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
("_table_value_type", InternalTraversal.dp_has_cache_key),
]
- packagenames = ()
+ packagenames: Tuple[str, ...] = ()
_has_args = False
_with_ordinality = False
- _table_value_type = None
+ _table_value_type: Optional[TableValueType] = None
+
+ # some attributes that are defined between both ColumnElement and
+ # FromClause are set to Any here to avoid typing errors
+ primary_key: Any
+ _is_clone_of: Any
+
+ clause_expr: Grouping[Any]
def __init__(self, *clauses: Any):
r"""Construct a :class:`.FunctionElement`.
@@ -135,9 +154,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
for c in clauses
]
self._has_args = self._has_args or bool(args)
- self.clause_expr = ClauseList(
- operator=operators.comma_op, group_contents=True, *args
- ).self_group()
+ self.clause_expr = Grouping(
+ ClauseList(operator=operators.comma_op, group_contents=True, *args)
+ )
_non_anon_label = None
@@ -263,9 +282,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
expr += (with_ordinality,)
new_func._with_ordinality = True
- new_func.type = new_func._table_value_type = sqltypes.TableValueType(
- *expr
- )
+ new_func.type = new_func._table_value_type = TableValueType(*expr)
return new_func.alias(name=name, joins_implicitly=joins_implicitly)
@@ -332,7 +349,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
@property
def _all_selected_columns(self):
- if self.type._is_table_value:
+ if is_table_value_type(self.type):
cols = self.type._elements
else:
cols = [self.label(None)]
@@ -344,12 +361,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
return self.columns
@HasMemoized.memoized_attribute
- def clauses(self):
+ def clauses(self) -> ClauseList:
"""Return the underlying :class:`.ClauseList` which contains
the arguments for this :class:`.FunctionElement`.
"""
- return self.clause_expr.element
+ return cast(ClauseList, self.clause_expr.element)
def over(self, partition_by=None, order_by=None, rows=None, range_=None):
"""Produce an OVER clause against this function.
@@ -647,7 +664,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
return _entity_namespace(self.clause_expr)
-class FunctionAsBinary(BinaryExpression):
+class FunctionAsBinary(BinaryExpression[Any]):
_traverse_internals = [
("sql_function", InternalTraversal.dp_clauseelement),
("left_index", InternalTraversal.dp_plain_obj),
@@ -655,10 +672,16 @@ class FunctionAsBinary(BinaryExpression):
("modifiers", InternalTraversal.dp_plain_dict),
]
+ sql_function: FunctionElement[Any]
+ left_index: int
+ right_index: int
+
def _gen_cache_key(self, anon_map, bindparams):
return ColumnElement._gen_cache_key(self, anon_map, bindparams)
- def __init__(self, fn, left_index, right_index):
+ def __init__(
+ self, fn: FunctionElement[Any], left_index: int, right_index: int
+ ):
self.sql_function = fn
self.left_index = left_index
self.right_index = right_index
@@ -670,23 +693,30 @@ class FunctionAsBinary(BinaryExpression):
self.modifiers = {}
@property
- def left(self):
+ def left_expr(self) -> ColumnElement[Any]:
return self.sql_function.clauses.clauses[self.left_index - 1]
- @left.setter
- def left(self, value):
+ @left_expr.setter
+ def left_expr(self, value: ColumnElement[Any]) -> None:
self.sql_function.clauses.clauses[self.left_index - 1] = value
@property
- def right(self):
+ def right_expr(self) -> ColumnElement[Any]:
return self.sql_function.clauses.clauses[self.right_index - 1]
- @right.setter
- def right(self, value):
+ @right_expr.setter
+ def right_expr(self, value: ColumnElement[Any]) -> None:
self.sql_function.clauses.clauses[self.right_index - 1] = value
+ if not TYPE_CHECKING:
+ # mypy can't accommodate @property to replace an instance
+ # variable
+
+ left = left_expr
+ right = right_expr
+
-class ScalarFunctionColumn(NamedColumn):
+class ScalarFunctionColumn(NamedColumn[_T]):
__visit_name__ = "scalar_function_column"
_traverse_internals = [
@@ -698,10 +728,18 @@ class ScalarFunctionColumn(NamedColumn):
is_literal = False
table = None
- def __init__(self, fn, name, type_=None):
+ def __init__(
+ self,
+ fn: FunctionElement[_T],
+ name: str,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ ):
self.fn = fn
self.name = name
- self.type = sqltypes.to_instance(type_)
+
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
class _FunctionGenerator:
@@ -789,7 +827,7 @@ class _FunctionGenerator:
# passthru __ attributes; fixes pydoc
if name.startswith("__"):
try:
- return self.__dict__[name]
+ return self.__dict__[name] # type: ignore
except KeyError:
raise AttributeError(name)
@@ -883,8 +921,6 @@ class Function(FunctionElement[_T]):
identifier: str
- packagenames: Sequence[str]
-
type: TypeEngine[_T]
"""A :class:`_types.TypeEngine` object which refers to the SQL return
type represented by this SQL function.
@@ -907,7 +943,7 @@ class Function(FunctionElement[_T]):
name: str,
*clauses: Any,
type_: Optional[_TypeEngineArgument[_T]] = None,
- packagenames: Optional[Sequence[str]] = None,
+ packagenames: Optional[Tuple[str, ...]] = None,
):
"""Construct a :class:`.Function`.
@@ -918,7 +954,9 @@ class Function(FunctionElement[_T]):
self.packagenames = packagenames or ()
self.name = name
- self.type = sqltypes.to_instance(type_)
+ # if type is None, we get NULLTYPE, which is our _T. But I don't
+ # know how to get the overloads to express that correctly
+ self.type = type_api.to_instance(type_) # type: ignore
FunctionElement.__init__(self, *clauses)
@@ -934,7 +972,7 @@ class Function(FunctionElement[_T]):
)
-class GenericFunction(Function):
+class GenericFunction(Function[_T]):
"""Define a 'generic' function.
A generic function is a pre-established :class:`.Function`
@@ -957,7 +995,7 @@ class GenericFunction(Function):
from sqlalchemy.types import DateTime
class as_utc(GenericFunction):
- type = DateTime
+ type = DateTime()
inherit_cache = True
print(select(func.as_utc()))
@@ -971,7 +1009,7 @@ class GenericFunction(Function):
"time"::
class as_utc(GenericFunction):
- type = DateTime
+ type = DateTime()
package = "time"
inherit_cache = True
@@ -987,7 +1025,7 @@ class GenericFunction(Function):
the usage of ``name`` as the rendered name::
class GeoBuffer(GenericFunction):
- type = Geometry
+ type = Geometry()
package = "geo"
name = "ST_Buffer"
identifier = "buffer"
@@ -1006,7 +1044,7 @@ class GenericFunction(Function):
from sqlalchemy.sql import quoted_name
class GeoBuffer(GenericFunction):
- type = Geometry
+ type = Geometry()
package = "geo"
name = quoted_name("ST_Buffer", True)
identifier = "buffer"
@@ -1028,6 +1066,8 @@ class GenericFunction(Function):
coerce_arguments = True
inherit_cache = True
+ _register: bool
+
name = "GenericFunction"
def __init_subclass__(cls) -> None:
@@ -1036,7 +1076,9 @@ class GenericFunction(Function):
super().__init_subclass__()
@classmethod
- def _register_generic_function(cls, clsname, clsdict):
+ def _register_generic_function(
+ cls, clsname: str, clsdict: Mapping[str, Any]
+ ) -> None:
cls.name = name = clsdict.get("name", clsname)
cls.identifier = identifier = clsdict.get("identifier", name)
package = clsdict.get("package", "_default")
@@ -1068,11 +1110,14 @@ class GenericFunction(Function):
]
self._has_args = self._has_args or bool(parsed_args)
self.packagenames = ()
- self.clause_expr = ClauseList(
- operator=operators.comma_op, group_contents=True, *parsed_args
- ).self_group()
- self.type = sqltypes.to_instance(
+ self.clause_expr = Grouping(
+ ClauseList(
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ )
+ )
+
+ self.type = type_api.to_instance( # type: ignore
kwargs.pop("type_", None) or getattr(self, "type", None)
)
@@ -1081,7 +1126,7 @@ register_function("cast", Cast)
register_function("extract", Extract)
-class next_value(GenericFunction):
+class next_value(GenericFunction[int]):
"""Represent the 'next value', given a :class:`.Sequence`
as its single argument.
@@ -1103,7 +1148,7 @@ class next_value(GenericFunction):
seq, schema.Sequence
), "next_value() accepts a Sequence object as input."
self.sequence = seq
- self.type = sqltypes.to_instance(
+ self.type = sqltypes.to_instance( # type: ignore
seq.data_type or getattr(self, "type", None)
)
@@ -1118,7 +1163,7 @@ class next_value(GenericFunction):
return []
-class AnsiFunction(GenericFunction):
+class AnsiFunction(GenericFunction[_T]):
"""Define a function in "ansi" format, which doesn't render parenthesis."""
inherit_cache = True
@@ -1127,13 +1172,13 @@ class AnsiFunction(GenericFunction):
GenericFunction.__init__(self, *args, **kwargs)
-class ReturnTypeFromArgs(GenericFunction):
+class ReturnTypeFromArgs(GenericFunction[_T]):
"""Define a function whose return type is the same as its arguments."""
inherit_cache = True
def __init__(self, *args, **kwargs):
- args = [
+ fn_args = [
coercions.expect(
roles.ExpressionElementRole,
c,
@@ -1142,35 +1187,35 @@ class ReturnTypeFromArgs(GenericFunction):
)
for c in args
]
- kwargs.setdefault("type_", _type_from_args(args))
- kwargs["_parsed_args"] = args
- super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
+ kwargs.setdefault("type_", _type_from_args(fn_args))
+ kwargs["_parsed_args"] = fn_args
+ super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs)
-class coalesce(ReturnTypeFromArgs):
+class coalesce(ReturnTypeFromArgs[_T]):
_has_args = True
inherit_cache = True
-class max(ReturnTypeFromArgs): # noqa A001
+class max(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL MAX() aggregate function."""
inherit_cache = True
-class min(ReturnTypeFromArgs): # noqa A001
+class min(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL MIN() aggregate function."""
inherit_cache = True
-class sum(ReturnTypeFromArgs): # noqa A001
+class sum(ReturnTypeFromArgs[_T]): # noqa A001
"""The SQL SUM() aggregate function."""
inherit_cache = True
-class now(GenericFunction):
+class now(GenericFunction[datetime.datetime]):
"""The SQL now() datetime function.
SQLAlchemy dialects will usually render this particular function
@@ -1178,11 +1223,11 @@ class now(GenericFunction):
"""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class concat(GenericFunction):
+class concat(GenericFunction[str]):
"""The SQL CONCAT() function, which concatenates strings.
E.g.::
@@ -1200,28 +1245,30 @@ class concat(GenericFunction):
"""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class char_length(GenericFunction):
+class char_length(GenericFunction[int]):
"""The CHAR_LENGTH() SQL function."""
- type = sqltypes.Integer
+ type = sqltypes.Integer()
inherit_cache = True
- def __init__(self, arg, **kwargs):
- GenericFunction.__init__(self, arg, **kwargs)
+ def __init__(self, arg, **kw):
+ # slight hack to limit to just one positional argument
+ # not sure why this one function has this special treatment
+ super().__init__(arg, **kw)
-class random(GenericFunction):
+class random(GenericFunction[float]):
"""The RANDOM() SQL function."""
_has_args = True
inherit_cache = True
-class count(GenericFunction):
+class count(GenericFunction[int]):
r"""The ANSI COUNT aggregate function. With no arguments,
emits COUNT \*.
@@ -1242,7 +1289,7 @@ class count(GenericFunction):
"""
- type = sqltypes.Integer
+ type = sqltypes.Integer()
inherit_cache = True
def __init__(self, expression=None, **kwargs):
@@ -1251,70 +1298,70 @@ class count(GenericFunction):
super(count, self).__init__(expression, **kwargs)
-class current_date(AnsiFunction):
+class current_date(AnsiFunction[datetime.date]):
"""The CURRENT_DATE() SQL function."""
- type = sqltypes.Date
+ type = sqltypes.Date()
inherit_cache = True
-class current_time(AnsiFunction):
+class current_time(AnsiFunction[datetime.time]):
"""The CURRENT_TIME() SQL function."""
- type = sqltypes.Time
+ type = sqltypes.Time()
inherit_cache = True
-class current_timestamp(AnsiFunction):
+class current_timestamp(AnsiFunction[datetime.datetime]):
"""The CURRENT_TIMESTAMP() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class current_user(AnsiFunction):
+class current_user(AnsiFunction[str]):
"""The CURRENT_USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class localtime(AnsiFunction):
+class localtime(AnsiFunction[datetime.datetime]):
"""The localtime() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class localtimestamp(AnsiFunction):
+class localtimestamp(AnsiFunction[datetime.datetime]):
"""The localtimestamp() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class session_user(AnsiFunction):
+class session_user(AnsiFunction[str]):
"""The SESSION_USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class sysdate(AnsiFunction):
+class sysdate(AnsiFunction[datetime.datetime]):
"""The SYSDATE() SQL function."""
- type = sqltypes.DateTime
+ type = sqltypes.DateTime()
inherit_cache = True
-class user(AnsiFunction):
+class user(AnsiFunction[str]):
"""The USER() SQL function."""
- type = sqltypes.String
+ type = sqltypes.String()
inherit_cache = True
-class array_agg(GenericFunction):
+class array_agg(GenericFunction[_T]):
"""Support for the ARRAY_AGG function.
The ``func.array_agg(expr)`` construct returns an expression of
@@ -1334,11 +1381,10 @@ class array_agg(GenericFunction):
"""
- type = sqltypes.ARRAY
inherit_cache = True
def __init__(self, *args, **kwargs):
- args = [
+ fn_args = [
coercions.expect(
roles.ExpressionElementRole, c, apply_propagate_attrs=self
)
@@ -1348,16 +1394,16 @@ class array_agg(GenericFunction):
default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
if "type_" not in kwargs:
- type_from_args = _type_from_args(args)
+ type_from_args = _type_from_args(fn_args)
if isinstance(type_from_args, sqltypes.ARRAY):
kwargs["type_"] = type_from_args
else:
kwargs["type_"] = default_array_type(type_from_args)
- kwargs["_parsed_args"] = args
- super(array_agg, self).__init__(*args, **kwargs)
+ kwargs["_parsed_args"] = fn_args
+ super(array_agg, self).__init__(*fn_args, **kwargs)
-class OrderedSetAgg(GenericFunction):
+class OrderedSetAgg(GenericFunction[_T]):
"""Define a function where the return type is based on the sort
expression type as defined by the expression passed to the
:meth:`.FunctionElement.within_group` method."""
@@ -1366,7 +1412,7 @@ class OrderedSetAgg(GenericFunction):
inherit_cache = True
def within_group_type(self, within_group):
- func_clauses = self.clause_expr.element
+ func_clauses = cast(ClauseList, self.clause_expr.element)
order_by = sqlutil.unwrap_order_by(within_group.order_by)
if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
return sqltypes.ARRAY(order_by[0].type)
@@ -1374,7 +1420,7 @@ class OrderedSetAgg(GenericFunction):
return order_by[0].type
-class mode(OrderedSetAgg):
+class mode(OrderedSetAgg[_T]):
"""Implement the ``mode`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1389,7 +1435,7 @@ class mode(OrderedSetAgg):
inherit_cache = True
-class percentile_cont(OrderedSetAgg):
+class percentile_cont(OrderedSetAgg[_T]):
"""Implement the ``percentile_cont`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1407,7 +1453,7 @@ class percentile_cont(OrderedSetAgg):
inherit_cache = True
-class percentile_disc(OrderedSetAgg):
+class percentile_disc(OrderedSetAgg[_T]):
"""Implement the ``percentile_disc`` ordered-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1425,7 +1471,7 @@ class percentile_disc(OrderedSetAgg):
inherit_cache = True
-class rank(GenericFunction):
+class rank(GenericFunction[int]):
"""Implement the ``rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1441,7 +1487,7 @@ class rank(GenericFunction):
inherit_cache = True
-class dense_rank(GenericFunction):
+class dense_rank(GenericFunction[int]):
"""Implement the ``dense_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1457,7 +1503,7 @@ class dense_rank(GenericFunction):
inherit_cache = True
-class percent_rank(GenericFunction):
+class percent_rank(GenericFunction[_N]):
"""Implement the ``percent_rank`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1469,11 +1515,11 @@ class percent_rank(GenericFunction):
"""
- type = sqltypes.Numeric()
+ type: sqltypes.Numeric[_N] = sqltypes.Numeric()
inherit_cache = True
-class cume_dist(GenericFunction):
+class cume_dist(GenericFunction[_N]):
"""Implement the ``cume_dist`` hypothetical-set aggregate function.
This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1485,11 +1531,11 @@ class cume_dist(GenericFunction):
"""
- type = sqltypes.Numeric()
+ type: sqltypes.Numeric[_N] = sqltypes.Numeric()
inherit_cache = True
-class cube(GenericFunction):
+class cube(GenericFunction[_T]):
r"""Implement the ``CUBE`` grouping operation.
This function is used as part of the GROUP BY of a statement,
@@ -1506,7 +1552,7 @@ class cube(GenericFunction):
inherit_cache = True
-class rollup(GenericFunction):
+class rollup(GenericFunction[_T]):
r"""Implement the ``ROLLUP`` grouping operation.
This function is used as part of the GROUP BY of a statement,
@@ -1523,7 +1569,7 @@ class rollup(GenericFunction):
inherit_cache = True
-class grouping_sets(GenericFunction):
+class grouping_sets(GenericFunction[_T]):
r"""Implement the ``GROUPING SETS`` grouping operation.
This function is used as part of the GROUP BY of a statement,