diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-30 18:01:58 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-04 09:26:43 -0400 |
commit | 3b4d62f4f72e8dfad7f38db192a6a90a8551608c (patch) | |
tree | d0334c4bb52f803bd7dad661f2e6a12e25f5880c /lib/sqlalchemy/sql/functions.py | |
parent | 4e603e23755f31278f27a45449120a8dea470a45 (diff) | |
download | sqlalchemy-3b4d62f4f72e8dfad7f38db192a6a90a8551608c.tar.gz |
pep484 - sql.selectable
the pep484 task becomes more intense as there is mounting
pressure to come up with a consistency in how data moves
from end-user to instance variable.
current thinking is coming into:
1. there are _typing._XYZArgument objects that represent "what the
user sent"
2. there's the roles, which represent a kind of "filter" for different
kinds of objects. These are mostly important as the argument
we pass to coerce().
3. there's the thing that coerce() returns, which should be what the
construct uses as its internal representation of the thing.
This is _typing._XYZElement.
but there's some controversy over whether or
not we should pass actual ClauseElements around by their role
or not. I think we shouldn't at the moment, but this makes the
"role-ness" of something a little less portable. Like, we have
to set DMLTableRole for TableClause, Join, and Alias, but then
also we have to repeat those three types in order to set up
_DMLTableElement.
Other change introduced here, there was a deannotate=True
for the left/right of a sql.join(). All tests pass without that.
I'd rather not have that there as if we have a join(A, B) where
A, B are mapped classes, we want them inside of the _annotations.
The rationale seems to be performance, but this performance can
be illustrated to be on the compile side which we hope is cached
in the normal case.
CTEs now accommodate for text selects including recursive.
Get typing to accommodate "util.preloaded" cleanly; add "preloaded"
as a real module. This seemed like we would have needed
pep562 `__getattr__()` but we don't, just set names in
globals() as we import them.
References: #6810
Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0
Diffstat (limited to 'lib/sqlalchemy/sql/functions.py')
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 246 |
1 files changed, 146 insertions, 100 deletions
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 3bca8b502..db4bb5837 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -11,10 +11,15 @@ from __future__ import annotations +import datetime from typing import Any +from typing import cast +from typing import Dict +from typing import Mapping from typing import Optional from typing import overload -from typing import Sequence +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -24,7 +29,9 @@ from . import operators from . import roles from . import schema from . import sqltypes +from . import type_api from . import util as sqlutil +from ._typing import is_table_value_type from .base import _entity_namespace from .base import ColumnCollection from .base import Executable @@ -46,16 +53,21 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias +from .sqltypes import _N +from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util + if TYPE_CHECKING: from ._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) -_registry = util.defaultdict(dict) +_registry: util.defaultdict[ + str, Dict[str, Type[Function[Any]]] +] = util.defaultdict(dict) def register_function(identifier, fn, package="_default"): @@ -103,11 +115,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ("_table_value_type", InternalTraversal.dp_has_cache_key), ] - packagenames = () + packagenames: Tuple[str, ...] = () _has_args = False _with_ordinality = False - _table_value_type = None + _table_value_type: Optional[TableValueType] = None + + # some attributes that are defined between both ColumnElement and + # FromClause are set to Any here to avoid typing errors + primary_key: Any + _is_clone_of: Any + + clause_expr: Grouping[Any] def __init__(self, *clauses: Any): r"""Construct a :class:`.FunctionElement`. @@ -135,9 +154,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): for c in clauses ] self._has_args = self._has_args or bool(args) - self.clause_expr = ClauseList( - operator=operators.comma_op, group_contents=True, *args - ).self_group() + self.clause_expr = Grouping( + ClauseList(operator=operators.comma_op, group_contents=True, *args) + ) _non_anon_label = None @@ -263,9 +282,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): expr += (with_ordinality,) new_func._with_ordinality = True - new_func.type = new_func._table_value_type = sqltypes.TableValueType( - *expr - ) + new_func.type = new_func._table_value_type = TableValueType(*expr) return new_func.alias(name=name, joins_implicitly=joins_implicitly) @@ -332,7 +349,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): @property def _all_selected_columns(self): - if self.type._is_table_value: + if is_table_value_type(self.type): cols = self.type._elements else: cols = [self.label(None)] @@ -344,12 +361,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return self.columns @HasMemoized.memoized_attribute - def clauses(self): + def clauses(self) -> ClauseList: """Return the underlying :class:`.ClauseList` which contains the arguments for this :class:`.FunctionElement`. """ - return self.clause_expr.element + return cast(ClauseList, self.clause_expr.element) def over(self, partition_by=None, order_by=None, rows=None, range_=None): """Produce an OVER clause against this function. @@ -647,7 +664,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return _entity_namespace(self.clause_expr) -class FunctionAsBinary(BinaryExpression): +class FunctionAsBinary(BinaryExpression[Any]): _traverse_internals = [ ("sql_function", InternalTraversal.dp_clauseelement), ("left_index", InternalTraversal.dp_plain_obj), @@ -655,10 +672,16 @@ class FunctionAsBinary(BinaryExpression): ("modifiers", InternalTraversal.dp_plain_dict), ] + sql_function: FunctionElement[Any] + left_index: int + right_index: int + def _gen_cache_key(self, anon_map, bindparams): return ColumnElement._gen_cache_key(self, anon_map, bindparams) - def __init__(self, fn, left_index, right_index): + def __init__( + self, fn: FunctionElement[Any], left_index: int, right_index: int + ): self.sql_function = fn self.left_index = left_index self.right_index = right_index @@ -670,23 +693,30 @@ class FunctionAsBinary(BinaryExpression): self.modifiers = {} @property - def left(self): + def left_expr(self) -> ColumnElement[Any]: return self.sql_function.clauses.clauses[self.left_index - 1] - @left.setter - def left(self, value): + @left_expr.setter + def left_expr(self, value: ColumnElement[Any]) -> None: self.sql_function.clauses.clauses[self.left_index - 1] = value @property - def right(self): + def right_expr(self) -> ColumnElement[Any]: return self.sql_function.clauses.clauses[self.right_index - 1] - @right.setter - def right(self, value): + @right_expr.setter + def right_expr(self, value: ColumnElement[Any]) -> None: self.sql_function.clauses.clauses[self.right_index - 1] = value + if not TYPE_CHECKING: + # mypy can't accommodate @property to replace an instance + # variable + + left = left_expr + right = right_expr + -class ScalarFunctionColumn(NamedColumn): +class ScalarFunctionColumn(NamedColumn[_T]): __visit_name__ = "scalar_function_column" _traverse_internals = [ @@ -698,10 +728,18 @@ class ScalarFunctionColumn(NamedColumn): is_literal = False table = None - def __init__(self, fn, name, type_=None): + def __init__( + self, + fn: FunctionElement[_T], + name: str, + type_: Optional[_TypeEngineArgument[_T]] = None, + ): self.fn = fn self.name = name - self.type = sqltypes.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore class _FunctionGenerator: @@ -789,7 +827,7 @@ class _FunctionGenerator: # passthru __ attributes; fixes pydoc if name.startswith("__"): try: - return self.__dict__[name] + return self.__dict__[name] # type: ignore except KeyError: raise AttributeError(name) @@ -883,8 +921,6 @@ class Function(FunctionElement[_T]): identifier: str - packagenames: Sequence[str] - type: TypeEngine[_T] """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -907,7 +943,7 @@ class Function(FunctionElement[_T]): name: str, *clauses: Any, type_: Optional[_TypeEngineArgument[_T]] = None, - packagenames: Optional[Sequence[str]] = None, + packagenames: Optional[Tuple[str, ...]] = None, ): """Construct a :class:`.Function`. @@ -918,7 +954,9 @@ class Function(FunctionElement[_T]): self.packagenames = packagenames or () self.name = name - self.type = sqltypes.to_instance(type_) + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore FunctionElement.__init__(self, *clauses) @@ -934,7 +972,7 @@ class Function(FunctionElement[_T]): ) -class GenericFunction(Function): +class GenericFunction(Function[_T]): """Define a 'generic' function. A generic function is a pre-established :class:`.Function` @@ -957,7 +995,7 @@ class GenericFunction(Function): from sqlalchemy.types import DateTime class as_utc(GenericFunction): - type = DateTime + type = DateTime() inherit_cache = True print(select(func.as_utc())) @@ -971,7 +1009,7 @@ class GenericFunction(Function): "time":: class as_utc(GenericFunction): - type = DateTime + type = DateTime() package = "time" inherit_cache = True @@ -987,7 +1025,7 @@ class GenericFunction(Function): the usage of ``name`` as the rendered name:: class GeoBuffer(GenericFunction): - type = Geometry + type = Geometry() package = "geo" name = "ST_Buffer" identifier = "buffer" @@ -1006,7 +1044,7 @@ class GenericFunction(Function): from sqlalchemy.sql import quoted_name class GeoBuffer(GenericFunction): - type = Geometry + type = Geometry() package = "geo" name = quoted_name("ST_Buffer", True) identifier = "buffer" @@ -1028,6 +1066,8 @@ class GenericFunction(Function): coerce_arguments = True inherit_cache = True + _register: bool + name = "GenericFunction" def __init_subclass__(cls) -> None: @@ -1036,7 +1076,9 @@ class GenericFunction(Function): super().__init_subclass__() @classmethod - def _register_generic_function(cls, clsname, clsdict): + def _register_generic_function( + cls, clsname: str, clsdict: Mapping[str, Any] + ) -> None: cls.name = name = clsdict.get("name", clsname) cls.identifier = identifier = clsdict.get("identifier", name) package = clsdict.get("package", "_default") @@ -1068,11 +1110,14 @@ class GenericFunction(Function): ] self._has_args = self._has_args or bool(parsed_args) self.packagenames = () - self.clause_expr = ClauseList( - operator=operators.comma_op, group_contents=True, *parsed_args - ).self_group() - self.type = sqltypes.to_instance( + self.clause_expr = Grouping( + ClauseList( + operator=operators.comma_op, group_contents=True, *parsed_args + ) + ) + + self.type = type_api.to_instance( # type: ignore kwargs.pop("type_", None) or getattr(self, "type", None) ) @@ -1081,7 +1126,7 @@ register_function("cast", Cast) register_function("extract", Extract) -class next_value(GenericFunction): +class next_value(GenericFunction[int]): """Represent the 'next value', given a :class:`.Sequence` as its single argument. @@ -1103,7 +1148,7 @@ class next_value(GenericFunction): seq, schema.Sequence ), "next_value() accepts a Sequence object as input." self.sequence = seq - self.type = sqltypes.to_instance( + self.type = sqltypes.to_instance( # type: ignore seq.data_type or getattr(self, "type", None) ) @@ -1118,7 +1163,7 @@ class next_value(GenericFunction): return [] -class AnsiFunction(GenericFunction): +class AnsiFunction(GenericFunction[_T]): """Define a function in "ansi" format, which doesn't render parenthesis.""" inherit_cache = True @@ -1127,13 +1172,13 @@ class AnsiFunction(GenericFunction): GenericFunction.__init__(self, *args, **kwargs) -class ReturnTypeFromArgs(GenericFunction): +class ReturnTypeFromArgs(GenericFunction[_T]): """Define a function whose return type is the same as its arguments.""" inherit_cache = True def __init__(self, *args, **kwargs): - args = [ + fn_args = [ coercions.expect( roles.ExpressionElementRole, c, @@ -1142,35 +1187,35 @@ class ReturnTypeFromArgs(GenericFunction): ) for c in args ] - kwargs.setdefault("type_", _type_from_args(args)) - kwargs["_parsed_args"] = args - super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) + kwargs.setdefault("type_", _type_from_args(fn_args)) + kwargs["_parsed_args"] = fn_args + super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs) -class coalesce(ReturnTypeFromArgs): +class coalesce(ReturnTypeFromArgs[_T]): _has_args = True inherit_cache = True -class max(ReturnTypeFromArgs): # noqa A001 +class max(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL MAX() aggregate function.""" inherit_cache = True -class min(ReturnTypeFromArgs): # noqa A001 +class min(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL MIN() aggregate function.""" inherit_cache = True -class sum(ReturnTypeFromArgs): # noqa A001 +class sum(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL SUM() aggregate function.""" inherit_cache = True -class now(GenericFunction): +class now(GenericFunction[datetime.datetime]): """The SQL now() datetime function. SQLAlchemy dialects will usually render this particular function @@ -1178,11 +1223,11 @@ class now(GenericFunction): """ - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class concat(GenericFunction): +class concat(GenericFunction[str]): """The SQL CONCAT() function, which concatenates strings. E.g.:: @@ -1200,28 +1245,30 @@ class concat(GenericFunction): """ - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class char_length(GenericFunction): +class char_length(GenericFunction[int]): """The CHAR_LENGTH() SQL function.""" - type = sqltypes.Integer + type = sqltypes.Integer() inherit_cache = True - def __init__(self, arg, **kwargs): - GenericFunction.__init__(self, arg, **kwargs) + def __init__(self, arg, **kw): + # slight hack to limit to just one positional argument + # not sure why this one function has this special treatment + super().__init__(arg, **kw) -class random(GenericFunction): +class random(GenericFunction[float]): """The RANDOM() SQL function.""" _has_args = True inherit_cache = True -class count(GenericFunction): +class count(GenericFunction[int]): r"""The ANSI COUNT aggregate function. With no arguments, emits COUNT \*. @@ -1242,7 +1289,7 @@ class count(GenericFunction): """ - type = sqltypes.Integer + type = sqltypes.Integer() inherit_cache = True def __init__(self, expression=None, **kwargs): @@ -1251,70 +1298,70 @@ class count(GenericFunction): super(count, self).__init__(expression, **kwargs) -class current_date(AnsiFunction): +class current_date(AnsiFunction[datetime.date]): """The CURRENT_DATE() SQL function.""" - type = sqltypes.Date + type = sqltypes.Date() inherit_cache = True -class current_time(AnsiFunction): +class current_time(AnsiFunction[datetime.time]): """The CURRENT_TIME() SQL function.""" - type = sqltypes.Time + type = sqltypes.Time() inherit_cache = True -class current_timestamp(AnsiFunction): +class current_timestamp(AnsiFunction[datetime.datetime]): """The CURRENT_TIMESTAMP() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class current_user(AnsiFunction): +class current_user(AnsiFunction[str]): """The CURRENT_USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class localtime(AnsiFunction): +class localtime(AnsiFunction[datetime.datetime]): """The localtime() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class localtimestamp(AnsiFunction): +class localtimestamp(AnsiFunction[datetime.datetime]): """The localtimestamp() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class session_user(AnsiFunction): +class session_user(AnsiFunction[str]): """The SESSION_USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class sysdate(AnsiFunction): +class sysdate(AnsiFunction[datetime.datetime]): """The SYSDATE() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class user(AnsiFunction): +class user(AnsiFunction[str]): """The USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class array_agg(GenericFunction): +class array_agg(GenericFunction[_T]): """Support for the ARRAY_AGG function. The ``func.array_agg(expr)`` construct returns an expression of @@ -1334,11 +1381,10 @@ class array_agg(GenericFunction): """ - type = sqltypes.ARRAY inherit_cache = True def __init__(self, *args, **kwargs): - args = [ + fn_args = [ coercions.expect( roles.ExpressionElementRole, c, apply_propagate_attrs=self ) @@ -1348,16 +1394,16 @@ class array_agg(GenericFunction): default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: - type_from_args = _type_from_args(args) + type_from_args = _type_from_args(fn_args) if isinstance(type_from_args, sqltypes.ARRAY): kwargs["type_"] = type_from_args else: kwargs["type_"] = default_array_type(type_from_args) - kwargs["_parsed_args"] = args - super(array_agg, self).__init__(*args, **kwargs) + kwargs["_parsed_args"] = fn_args + super(array_agg, self).__init__(*fn_args, **kwargs) -class OrderedSetAgg(GenericFunction): +class OrderedSetAgg(GenericFunction[_T]): """Define a function where the return type is based on the sort expression type as defined by the expression passed to the :meth:`.FunctionElement.within_group` method.""" @@ -1366,7 +1412,7 @@ class OrderedSetAgg(GenericFunction): inherit_cache = True def within_group_type(self, within_group): - func_clauses = self.clause_expr.element + func_clauses = cast(ClauseList, self.clause_expr.element) order_by = sqlutil.unwrap_order_by(within_group.order_by) if self.array_for_multi_clause and len(func_clauses.clauses) > 1: return sqltypes.ARRAY(order_by[0].type) @@ -1374,7 +1420,7 @@ class OrderedSetAgg(GenericFunction): return order_by[0].type -class mode(OrderedSetAgg): +class mode(OrderedSetAgg[_T]): """Implement the ``mode`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1389,7 +1435,7 @@ class mode(OrderedSetAgg): inherit_cache = True -class percentile_cont(OrderedSetAgg): +class percentile_cont(OrderedSetAgg[_T]): """Implement the ``percentile_cont`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1407,7 +1453,7 @@ class percentile_cont(OrderedSetAgg): inherit_cache = True -class percentile_disc(OrderedSetAgg): +class percentile_disc(OrderedSetAgg[_T]): """Implement the ``percentile_disc`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1425,7 +1471,7 @@ class percentile_disc(OrderedSetAgg): inherit_cache = True -class rank(GenericFunction): +class rank(GenericFunction[int]): """Implement the ``rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1441,7 +1487,7 @@ class rank(GenericFunction): inherit_cache = True -class dense_rank(GenericFunction): +class dense_rank(GenericFunction[int]): """Implement the ``dense_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1457,7 +1503,7 @@ class dense_rank(GenericFunction): inherit_cache = True -class percent_rank(GenericFunction): +class percent_rank(GenericFunction[_N]): """Implement the ``percent_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1469,11 +1515,11 @@ class percent_rank(GenericFunction): """ - type = sqltypes.Numeric() + type: sqltypes.Numeric[_N] = sqltypes.Numeric() inherit_cache = True -class cume_dist(GenericFunction): +class cume_dist(GenericFunction[_N]): """Implement the ``cume_dist`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1485,11 +1531,11 @@ class cume_dist(GenericFunction): """ - type = sqltypes.Numeric() + type: sqltypes.Numeric[_N] = sqltypes.Numeric() inherit_cache = True -class cube(GenericFunction): +class cube(GenericFunction[_T]): r"""Implement the ``CUBE`` grouping operation. This function is used as part of the GROUP BY of a statement, @@ -1506,7 +1552,7 @@ class cube(GenericFunction): inherit_cache = True -class rollup(GenericFunction): +class rollup(GenericFunction[_T]): r"""Implement the ``ROLLUP`` grouping operation. This function is used as part of the GROUP BY of a statement, @@ -1523,7 +1569,7 @@ class rollup(GenericFunction): inherit_cache = True -class grouping_sets(GenericFunction): +class grouping_sets(GenericFunction[_T]): r"""Implement the ``GROUPING SETS`` grouping operation. This function is used as part of the GROUP BY of a statement, |