diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-04-14 12:38:43 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-04-14 12:38:43 +0000 |
| commit | 997e4dfae98fc68463fa5121c2ce95498ad238d1 (patch) | |
| tree | 20e2629903b3ab2cf377d5115a5a4bc199d71929 /lib/sqlalchemy/sql | |
| parent | 42c20b015d01ced441c42a8a6c5e9ed823316682 (diff) | |
| parent | 428262a2d5374613f4a4cf925bbd9e94e0e34acc (diff) | |
| download | sqlalchemy-997e4dfae98fc68463fa5121c2ce95498ad238d1.tar.gz | |
Merge "implement multi-element expression constructs" into main
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 227 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 5 |
6 files changed, 254 insertions, 56 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 522a0bd4a..9c074db33 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2013,6 +2013,24 @@ class SQLCompiler(Compiled): return self._generate_delimited_list(clauselist.clauses, sep, **kw) + def visit_expression_clauselist(self, clauselist, **kw): + operator_ = clauselist.operator + + disp = self._get_operator_dispatch( + operator_, "expression_clauselist", None + ) + if disp: + return disp(clauselist, operator_, **kw) + + try: + opstring = OPERATORS[operator_] + except KeyError as err: + raise exc.UnsupportedCompilationError(self, operator_) from err + else: + return self._generate_delimited_list( + clauselist.clauses, opstring, **kw + ) + def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 944a0a5ce..512fca8d0 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -27,11 +27,12 @@ from . import type_api from .elements import and_ from .elements import BinaryExpression from .elements import ClauseElement -from .elements import ClauseList from .elements import CollationClause from .elements import CollectionAggregate +from .elements import ExpressionClauseList from .elements import False_ from .elements import Null +from .elements import OperatorExpression from .elements import or_ from .elements import True_ from .elements import UnaryExpression @@ -56,11 +57,9 @@ def _boolean_compare( reverse: bool = False, _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), _any_all_expr: bool = False, - result_type: Optional[ - Union[Type[TypeEngine[bool]], TypeEngine[bool]] - ] = None, + result_type: Optional[TypeEngine[bool]] = None, **kwargs: Any, -) -> BinaryExpression[bool]: +) -> OperatorExpression[bool]: if result_type is None: result_type = type_api.BOOLEANTYPE @@ -71,7 +70,7 @@ def _boolean_compare( if op in (operators.eq, operators.ne) and isinstance( obj, (bool, True_, False_) ): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), op, @@ -83,7 +82,7 @@ def _boolean_compare( operators.is_distinct_from, operators.is_not_distinct_from, ): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), op, @@ -98,7 +97,7 @@ def _boolean_compare( else: # all other None uses IS, IS NOT if op in (operators.eq, operators.is_): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), operators.is_, @@ -106,7 +105,7 @@ def _boolean_compare( type_=result_type, ) elif op in (operators.ne, operators.is_not): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), operators.is_not, @@ -125,7 +124,7 @@ def _boolean_compare( ) if reverse: - return BinaryExpression( + return OperatorExpression._construct_for_op( obj, expr, op, @@ -134,7 +133,7 @@ def _boolean_compare( modifiers=kwargs, ) else: - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, obj, op, @@ -169,11 +168,9 @@ def _binary_operate( obj: roles.BinaryElementRole[Any], *, reverse: bool = False, - result_type: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] - ] = None, + result_type: Optional[TypeEngine[_T]] = None, **kw: Any, -) -> BinaryExpression[_T]: +) -> OperatorExpression[_T]: coerced_obj = coercions.expect( roles.BinaryElementRole, obj, expr=expr, operator=op @@ -189,7 +186,9 @@ def _binary_operate( op, right.comparator ) - return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) + return OperatorExpression._construct_for_op( + left, right, op, type_=result_type, modifiers=kw + ) def _conjunction_operate( @@ -311,7 +310,9 @@ def _between_impl( """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( expr, - ClauseList( + ExpressionClauseList._construct_for_list( + operators.and_, + type_api.NULLTYPE, coercions.expect( roles.BinaryElementRole, cleft, @@ -324,9 +325,7 @@ def _between_impl( expr=expr, operator=operators.and_, ), - operator=operators.and_, group=False, - group_contents=False, ), op, negate=operators.not_between_op diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 805758283..d47d138f7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1323,7 +1323,11 @@ class ColumnElement( if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return AsBoolean(self, operators.is_false, operators.is_true) else: - return cast("UnaryExpression[_T]", super()._negate()) + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) type: TypeEngine[_T] @@ -2501,6 +2505,8 @@ class ClauseList( __visit_name__ = "clauselist" + # this is used only by the ORM in a legacy use case for + # composite attributes _is_clause_list = True _traverse_internals: _TraverseInternalsType = [ @@ -2516,18 +2522,14 @@ class ClauseList( operator: OperatorType = operators.comma_op, group: bool = True, group_contents: bool = True, - _flatten_sub_clauses: bool = False, _literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole, ): self.operator = operator self.group = group self.group_contents = group_contents clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses - if _flatten_sub_clauses: - clauses_iterator = util.flatten_iterator(clauses_iterator) - - self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role text_converter_role: Type[roles.SQLRole] = _literal_as_text_role + self._text_converter_role = text_converter_role if self.group_contents: self.clauses = [ @@ -2594,8 +2596,176 @@ class ClauseList( return self -class BooleanClauseList(ClauseList, ColumnElement[bool]): - __visit_name__ = "clauselist" +class OperatorExpression(ColumnElement[_T]): + """base for expressions that contain an operator and operands + + .. versionadded:: 2.0 + + """ + + operator: OperatorType + type: TypeEngine[_T] + + group: bool = True + + @property + def is_comparison(self): + return operators.is_comparison(self.operator) + + def self_group(self, against=None): + if ( + self.group + and operators.is_precedent(self.operator, against) + or ( + # a negate against a non-boolean operator + # doesn't make too much sense but we should + # group for that + against is operators.inv + and not operators.is_boolean(self.operator) + ) + ): + return Grouping(self) + else: + return self + + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + raise NotImplementedError() + + @classmethod + def _construct_for_op( + cls, + left: ColumnElement[Any], + right: ColumnElement[Any], + op: OperatorType, + *, + type_: TypeEngine[_T], + negate: Optional[OperatorType] = None, + modifiers: Optional[Mapping[str, Any]] = None, + ) -> OperatorExpression[_T]: + + if operators.is_associative(op): + assert ( + negate is None + ), f"negate not supported for associative operator {op}" + + multi = False + if getattr( + left, "operator", None + ) is op and type_._compare_type_affinity(left.type): + multi = True + left_flattened = left._flattened_operator_clauses + else: + left_flattened = (left,) + + if getattr( + right, "operator", None + ) is op and type_._compare_type_affinity(right.type): + multi = True + right_flattened = right._flattened_operator_clauses + else: + right_flattened = (right,) + + if multi: + return ExpressionClauseList._construct_for_list( + op, type_, *(left_flattened + right_flattened) + ) + + return BinaryExpression( + left, right, op, type_=type_, negate=negate, modifiers=modifiers + ) + + +class ExpressionClauseList(OperatorExpression[_T]): + """Describe a list of clauses, separated by an operator, + in a column expression context. + + :class:`.ExpressionClauseList` differs from :class:`.ClauseList` in that + it represents a column-oriented DQL expression only, not an open ended + list of anything comma separated. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "expression_clauselist" + + _traverse_internals: _TraverseInternalsType = [ + ("clauses", InternalTraversal.dp_clauseelement_tuple), + ("operator", InternalTraversal.dp_operator), + ] + + clauses: typing_Tuple[ColumnElement[Any], ...] + + group: bool + + def __init__( + self, + operator: OperatorType, + *clauses: _ColumnExpressionArgument[Any], + type_: Optional[_TypeEngineArgument[_T]] = None, + ): + self.operator = operator + + self.clauses = tuple( + coercions.expect( + roles.ExpressionElementRole, clause, apply_propagate_attrs=self + ) + for clause in clauses + ) + self._is_implicitly_boolean = operators.is_boolean(self.operator) + self.type = type_api.to_instance(type_) # type: ignore + + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + return self.clauses + + def __iter__(self) -> Iterator[ColumnElement[Any]]: + return iter(self.clauses) + + def __len__(self) -> int: + return len(self.clauses) + + @property + def _select_iterable(self) -> _SelectIterable: + return (self,) + + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return list(itertools.chain(*[c._from_objects for c in self.clauses])) + + def _append_inplace(self, clause: ColumnElement[Any]) -> None: + self.clauses += (clause,) + + @classmethod + def _construct_for_list( + cls, + operator: OperatorType, + type_: TypeEngine[_T], + *clauses: ColumnElement[Any], + group: bool = True, + ) -> ExpressionClauseList[_T]: + self = cls.__new__(cls) + self.group = group + self.clauses = clauses + self.operator = operator + self.type = type_ + return self + + def _negate(self) -> Any: + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) + + +class BooleanClauseList(ExpressionClauseList[bool]): + __visit_name__ = "expression_clauselist" inherit_cache = True def __init__(self, *arg, **kw): @@ -2668,7 +2838,15 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): if lcc > 1: # multiple elements. Return regular BooleanClauseList # which will link elements against the operator. - return cls._construct_raw(operator, convert_clauses) # type: ignore # noqa: E501 + + flattened_clauses = itertools.chain.from_iterable( + (c for c in to_flat._flattened_operator_clauses) + if getattr(to_flat, "operator", None) is operator + else (to_flat,) + for to_flat in convert_clauses + ) + + return cls._construct_raw(operator, flattened_clauses) # type: ignore # noqa: E501 elif lcc == 1: # just one element. return it as a single boolean element, # not a list and discard the operator. @@ -2726,10 +2904,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): clauses: Optional[Sequence[ColumnElement[Any]]] = None, ) -> BooleanClauseList: self = cls.__new__(cls) - self.clauses = list(clauses) if clauses else [] + self.clauses = tuple(clauses) if clauses else () self.group = True self.operator = operator - self.group_contents = True self.type = type_api.BOOLEANTYPE self._is_implicitly_boolean = True return self @@ -2768,9 +2945,6 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): else: return super(BooleanClauseList, self).self_group(against=against) - def _negate(self): - return ClauseList._negate(self) - and_ = BooleanClauseList.and_ or_ = BooleanClauseList.or_ @@ -3357,7 +3531,7 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): return AsBoolean(self.element, self.negate, self.operator) -class BinaryExpression(ColumnElement[_T]): +class BinaryExpression(OperatorExpression[_T]): """Represent an expression that is ``LEFT <operator> RIGHT``. A :class:`.BinaryExpression` is generated automatically @@ -3394,12 +3568,12 @@ class BinaryExpression(ColumnElement[_T]): modifiers: Optional[Mapping[str, Any]] left: ColumnElement[Any] - right: Union[ColumnElement[Any], ClauseList] + right: ColumnElement[Any] def __init__( self, left: ColumnElement[Any], - right: Union[ColumnElement[Any], ClauseList], + right: ColumnElement[Any], operator: OperatorType, type_: Optional[_TypeEngineArgument[_T]] = None, negate: Optional[OperatorType] = None, @@ -3427,6 +3601,12 @@ class BinaryExpression(ColumnElement[_T]): else: self.modifiers = modifiers + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + return (self.left, self.right) + def __bool__(self): """Implement Python-side "bool" for BinaryExpression as a simple "identity" check for the left and right attributes, @@ -3465,8 +3645,6 @@ class BinaryExpression(ColumnElement[_T]): else: raise TypeError("Boolean value of this clause is not defined") - __nonzero__ = __bool__ - if typing.TYPE_CHECKING: def __invert__( @@ -3474,21 +3652,10 @@ class BinaryExpression(ColumnElement[_T]): ) -> "BinaryExpression[_T]": ... - @property - def is_comparison(self): - return operators.is_comparison(self.operator) - @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects - def self_group(self, against=None): - - if operators.is_precedent(self.operator, against): - return Grouping(self) - else: - return self - def _negate(self): if self.negate is not None: return BinaryExpression( diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 455e74f7b..d08bbf4eb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -81,6 +81,7 @@ from .elements import ClauseList as ClauseList from .elements import CollectionAggregate as CollectionAggregate from .elements import ColumnClause as ColumnClause from .elements import ColumnElement as ColumnElement +from .elements import ExpressionClauseList as ExpressionClauseList from .elements import Extract as Extract from .elements import False_ as False_ from .elements import FunctionFilter as FunctionFilter diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 803e85654..8d98f893f 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -84,6 +84,7 @@ class HasExpressionLookup(TypeEngineMixin): raise NotImplementedError() class Comparator(TypeEngine.Comparator[_CT]): + __slots__ = () _blank_dict = util.EMPTY_DICT @@ -114,6 +115,8 @@ class Concatenable(TypeEngineMixin): typically strings.""" class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _adapt_expression( self, op: OperatorType, @@ -143,6 +146,8 @@ class Indexable(TypeEngineMixin): """ class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _setup_getitem(self, index): raise NotImplementedError() @@ -174,12 +179,9 @@ class String(Concatenable, TypeEngine[str]): __visit_name__ = "string" def __init__( - # note pylance appears to require the "self" type in a constructor - # for the _T type to be correctly recognized when we send the - # class as the argument, e.g. `column("somecol", String)` self, - length=None, - collation=None, + length: Optional[int] = None, + collation: Optional[str] = None, ): """ Create a string-holding type. @@ -1508,6 +1510,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): ) from err class Comparator(String.Comparator[str]): + __slots__ = () + type: String def _adapt_expression( @@ -1963,7 +1967,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): TypeDecorator.Comparator[_CT], _AbstractInterval.Comparator[_CT], ): - pass + __slots__ = () comparator_factory = Comparator @@ -2385,6 +2389,8 @@ class JSON(Indexable, TypeEngine[Any]): class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" + __slots__ = () + def _setup_getitem(self, index): if not isinstance(index, str) and isinstance( index, collections_abc.Sequence @@ -2710,6 +2716,8 @@ class ARRAY( """ + __slots__ = () + def _setup_getitem(self, index): arr_type = cast(ARRAY, self.type) @@ -3221,6 +3229,8 @@ class NullType(TypeEngine[None]): return process class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _adapt_expression( self, op: OperatorType, diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index cbc4e9e70..c23cd04dd 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -924,7 +924,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): ): return COMPARE_FAILED - def compare_clauselist(self, left, right, **kw): + def compare_expression_clauselist(self, left, right, **kw): if left.operator is right.operator: if operators.is_associative(left.operator): if self._compare_unordered_sequences( @@ -938,6 +938,9 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): else: return COMPARE_FAILED + def compare_clauselist(self, left, right, **kw): + return self.compare_expression_clauselist(left, right, **kw) + def compare_binary(self, left, right, **kw): if left.operator == right.operator: if operators.is_commutative(left.operator): |
