summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-12 13:52:31 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-13 17:19:31 -0400
commit428262a2d5374613f4a4cf925bbd9e94e0e34acc (patch)
tree9f71ec4a09d3ea584b3e399085254fb278049a6f /lib/sqlalchemy
parenta45e2284dad17fbbba3bea9d5e5304aab21c8c94 (diff)
downloadsqlalchemy-428262a2d5374613f4a4cf925bbd9e94e0e34acc.tar.gz
implement multi-element expression constructs
Improved the construction of SQL binary expressions to allow for very long expressions against the same associative operator without special steps needed in order to avoid high memory use and excess recursion depth. A particular binary operation ``A op B`` can now be joined against another element ``op C`` and the resulting structure will be "flattened" so that the representation as well as SQL compilation does not require recursion. To implement this more cleanly, the biggest change here is that column-oriented lists of things are broken away from ClauseList in a new class ExpressionClauseList, that also forms the basis of BooleanClauseList. ClauseList is still used for the generic "comma-separated list" of things such as Tuple and things like ORDER BY, as well as in some API endpoints. Also adds __slots__ to the TypeEngine-bound Comparator classes. Still can't really do __slots__ on ClauseElement. Fixes: #7744 Change-Id: I81a8ceb6f8f3bb0fe52d58f3cb42e4b6c2bc9018
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py5
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py7
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py32
-rw-r--r--lib/sqlalchemy/orm/evaluator.py3
-rw-r--r--lib/sqlalchemy/orm/persistence.py12
-rw-r--r--lib/sqlalchemy/sql/compiler.py18
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py37
-rw-r--r--lib/sqlalchemy/sql/elements.py227
-rw-r--r--lib/sqlalchemy/sql/expression.py1
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py22
-rw-r--r--lib/sqlalchemy/sql/traversals.py5
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py25
12 files changed, 320 insertions, 74 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 35428b659..2bacaaf33 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -1833,6 +1833,11 @@ class MSSQLCompiler(compiler.SQLCompiler):
def visit_char_length_func(self, fn, **kw):
return "LEN%s" % self.function_argspec(fn, **kw)
+ def visit_concat_op_expression_clauselist(
+ self, clauselist, operator, **kw
+ ):
+ return " + ".join(self.process(elem, **kw) for elem in clauselist)
+
def visit_concat_op_binary(self, binary, operator, **kw):
return "%s + %s" % (
self.process(binary.left, **kw),
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 25f4c6945..b53e55abf 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1322,6 +1322,13 @@ class MySQLCompiler(compiler.SQLCompiler):
return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses)
+ def visit_concat_op_expression_clauselist(
+ self, clauselist, operator, **kw
+ ):
+ return "concat(%s)" % (
+ ", ".join(self.process(elem, **kw) for elem in clauselist.clauses)
+ )
+
def visit_concat_op_binary(self, binary, operator, **kw):
return "concat(%s, %s)" % (
self.process(binary.left, **kw),
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
index 74643c4d9..7eec7b86f 100644
--- a/lib/sqlalchemy/dialects/postgresql/array.py
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -5,14 +5,19 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
+
import re
+from typing import Any
+from typing import TypeVar
from ... import types as sqltypes
from ... import util
-from ...sql import coercions
from ...sql import expression
from ...sql import operators
-from ...sql import roles
+
+
+_T = TypeVar("_T", bound=Any)
def Any(other, arrexpr, operator=operators.eq):
@@ -33,7 +38,7 @@ def All(other, arrexpr, operator=operators.eq):
return arrexpr.all(other, operator)
-class array(expression.ClauseList, expression.ColumnElement):
+class array(expression.ExpressionClauseList[_T]):
"""A PostgreSQL ARRAY literal.
@@ -90,16 +95,19 @@ class array(expression.ClauseList, expression.ColumnElement):
inherit_cache = True
def __init__(self, clauses, **kw):
- clauses = [
- coercions.expect(roles.ExpressionElementRole, c) for c in clauses
- ]
-
- self._type_tuple = [arg.type for arg in clauses]
- main_type = kw.pop(
- "type_",
- self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE,
+
+ type_arg = kw.pop("type_", None)
+ super(array, self).__init__(operators.comma_op, *clauses, **kw)
+
+ self._type_tuple = [arg.type for arg in self.clauses]
+
+ main_type = (
+ type_arg
+ if type_arg is not None
+ else self._type_tuple[0]
+ if self._type_tuple
+ else sqltypes.NULLTYPE
)
- super(array, self).__init__(*clauses, **kw)
if isinstance(main_type, ARRAY):
self.type = ARRAY(
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
index 453fc8903..1b3340dc5 100644
--- a/lib/sqlalchemy/orm/evaluator.py
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -94,6 +94,9 @@ class EvaluatorCompiler:
def visit_tuple(self, clause):
return self.visit_clauselist(clause)
+ def visit_expression_clauselist(self, clause):
+ return self.visit_clauselist(clause)
+
def visit_clauselist(self, clause):
evaluators = [self.process(clause) for clause in clause.clauses]
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 93b49ab25..7298d3630 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -884,12 +884,12 @@ def _emit_update_statements(
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
+ clauses._append_inplace(
col == sql.bindparam(col._label, type_=col.type)
)
if needs_version_id:
- clauses.clauses.append(
+ clauses._append_inplace(
mapper.version_id_col
== sql.bindparam(
mapper.version_id_col._label,
@@ -1316,12 +1316,12 @@ def _emit_post_update_statements(
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
+ clauses._append_inplace(
col == sql.bindparam(col._label, type_=col.type)
)
if needs_version_id:
- clauses.clauses.append(
+ clauses._append_inplace(
mapper.version_id_col
== sql.bindparam(
mapper.version_id_col._label,
@@ -1437,12 +1437,12 @@ def _emit_delete_statements(
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
+ clauses._append_inplace(
col == sql.bindparam(col.key, type_=col.type)
)
if need_version_id:
- clauses.clauses.append(
+ clauses._append_inplace(
mapper.version_id_col
== sql.bindparam(
mapper.version_id_col.key, type_=mapper.version_id_col.type
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 522a0bd4a..9c074db33 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -2013,6 +2013,24 @@ class SQLCompiler(Compiled):
return self._generate_delimited_list(clauselist.clauses, sep, **kw)
+ def visit_expression_clauselist(self, clauselist, **kw):
+ operator_ = clauselist.operator
+
+ disp = self._get_operator_dispatch(
+ operator_, "expression_clauselist", None
+ )
+ if disp:
+ return disp(clauselist, operator_, **kw)
+
+ try:
+ opstring = OPERATORS[operator_]
+ except KeyError as err:
+ raise exc.UnsupportedCompilationError(self, operator_) from err
+ else:
+ return self._generate_delimited_list(
+ clauselist.clauses, opstring, **kw
+ )
+
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 944a0a5ce..512fca8d0 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -27,11 +27,12 @@ from . import type_api
from .elements import and_
from .elements import BinaryExpression
from .elements import ClauseElement
-from .elements import ClauseList
from .elements import CollationClause
from .elements import CollectionAggregate
+from .elements import ExpressionClauseList
from .elements import False_
from .elements import Null
+from .elements import OperatorExpression
from .elements import or_
from .elements import True_
from .elements import UnaryExpression
@@ -56,11 +57,9 @@ def _boolean_compare(
reverse: bool = False,
_python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
_any_all_expr: bool = False,
- result_type: Optional[
- Union[Type[TypeEngine[bool]], TypeEngine[bool]]
- ] = None,
+ result_type: Optional[TypeEngine[bool]] = None,
**kwargs: Any,
-) -> BinaryExpression[bool]:
+) -> OperatorExpression[bool]:
if result_type is None:
result_type = type_api.BOOLEANTYPE
@@ -71,7 +70,7 @@ def _boolean_compare(
if op in (operators.eq, operators.ne) and isinstance(
obj, (bool, True_, False_)
):
- return BinaryExpression(
+ return OperatorExpression._construct_for_op(
expr,
coercions.expect(roles.ConstExprRole, obj),
op,
@@ -83,7 +82,7 @@ def _boolean_compare(
operators.is_distinct_from,
operators.is_not_distinct_from,
):
- return BinaryExpression(
+ return OperatorExpression._construct_for_op(
expr,
coercions.expect(roles.ConstExprRole, obj),
op,
@@ -98,7 +97,7 @@ def _boolean_compare(
else:
# all other None uses IS, IS NOT
if op in (operators.eq, operators.is_):
- return BinaryExpression(
+ return OperatorExpression._construct_for_op(
expr,
coercions.expect(roles.ConstExprRole, obj),
operators.is_,
@@ -106,7 +105,7 @@ def _boolean_compare(
type_=result_type,
)
elif op in (operators.ne, operators.is_not):
- return BinaryExpression(
+ return OperatorExpression._construct_for_op(
expr,
coercions.expect(roles.ConstExprRole, obj),
operators.is_not,
@@ -125,7 +124,7 @@ def _boolean_compare(
)
if reverse:
- return BinaryExpression(
+ return OperatorExpression._construct_for_op(
obj,
expr,
op,
@@ -134,7 +133,7 @@ def _boolean_compare(
modifiers=kwargs,
)
else:
- return BinaryExpression(
+ return OperatorExpression._construct_for_op(
expr,
obj,
op,
@@ -169,11 +168,9 @@ def _binary_operate(
obj: roles.BinaryElementRole[Any],
*,
reverse: bool = False,
- result_type: Optional[
- Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
- ] = None,
+ result_type: Optional[TypeEngine[_T]] = None,
**kw: Any,
-) -> BinaryExpression[_T]:
+) -> OperatorExpression[_T]:
coerced_obj = coercions.expect(
roles.BinaryElementRole, obj, expr=expr, operator=op
@@ -189,7 +186,9 @@ def _binary_operate(
op, right.comparator
)
- return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
+ return OperatorExpression._construct_for_op(
+ left, right, op, type_=result_type, modifiers=kw
+ )
def _conjunction_operate(
@@ -311,7 +310,9 @@ def _between_impl(
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
expr,
- ClauseList(
+ ExpressionClauseList._construct_for_list(
+ operators.and_,
+ type_api.NULLTYPE,
coercions.expect(
roles.BinaryElementRole,
cleft,
@@ -324,9 +325,7 @@ def _between_impl(
expr=expr,
operator=operators.and_,
),
- operator=operators.and_,
group=False,
- group_contents=False,
),
op,
negate=operators.not_between_op
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 805758283..d47d138f7 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -1323,7 +1323,11 @@ class ColumnElement(
if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return AsBoolean(self, operators.is_false, operators.is_true)
else:
- return cast("UnaryExpression[_T]", super()._negate())
+ grouped = self.self_group(against=operators.inv)
+ assert isinstance(grouped, ColumnElement)
+ return UnaryExpression(
+ grouped, operator=operators.inv, wraps_column_expression=True
+ )
type: TypeEngine[_T]
@@ -2501,6 +2505,8 @@ class ClauseList(
__visit_name__ = "clauselist"
+ # this is used only by the ORM in a legacy use case for
+ # composite attributes
_is_clause_list = True
_traverse_internals: _TraverseInternalsType = [
@@ -2516,18 +2522,14 @@ class ClauseList(
operator: OperatorType = operators.comma_op,
group: bool = True,
group_contents: bool = True,
- _flatten_sub_clauses: bool = False,
_literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole,
):
self.operator = operator
self.group = group
self.group_contents = group_contents
clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses
- if _flatten_sub_clauses:
- clauses_iterator = util.flatten_iterator(clauses_iterator)
-
- self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
+ self._text_converter_role = text_converter_role
if self.group_contents:
self.clauses = [
@@ -2594,8 +2596,176 @@ class ClauseList(
return self
-class BooleanClauseList(ClauseList, ColumnElement[bool]):
- __visit_name__ = "clauselist"
+class OperatorExpression(ColumnElement[_T]):
+ """base for expressions that contain an operator and operands
+
+ .. versionadded:: 2.0
+
+ """
+
+ operator: OperatorType
+ type: TypeEngine[_T]
+
+ group: bool = True
+
+ @property
+ def is_comparison(self):
+ return operators.is_comparison(self.operator)
+
+ def self_group(self, against=None):
+ if (
+ self.group
+ and operators.is_precedent(self.operator, against)
+ or (
+ # a negate against a non-boolean operator
+ # doesn't make too much sense but we should
+ # group for that
+ against is operators.inv
+ and not operators.is_boolean(self.operator)
+ )
+ ):
+ return Grouping(self)
+ else:
+ return self
+
+ @property
+ def _flattened_operator_clauses(
+ self,
+ ) -> typing_Tuple[ColumnElement[Any], ...]:
+ raise NotImplementedError()
+
+ @classmethod
+ def _construct_for_op(
+ cls,
+ left: ColumnElement[Any],
+ right: ColumnElement[Any],
+ op: OperatorType,
+ *,
+ type_: TypeEngine[_T],
+ negate: Optional[OperatorType] = None,
+ modifiers: Optional[Mapping[str, Any]] = None,
+ ) -> OperatorExpression[_T]:
+
+ if operators.is_associative(op):
+ assert (
+ negate is None
+ ), f"negate not supported for associative operator {op}"
+
+ multi = False
+ if getattr(
+ left, "operator", None
+ ) is op and type_._compare_type_affinity(left.type):
+ multi = True
+ left_flattened = left._flattened_operator_clauses
+ else:
+ left_flattened = (left,)
+
+ if getattr(
+ right, "operator", None
+ ) is op and type_._compare_type_affinity(right.type):
+ multi = True
+ right_flattened = right._flattened_operator_clauses
+ else:
+ right_flattened = (right,)
+
+ if multi:
+ return ExpressionClauseList._construct_for_list(
+ op, type_, *(left_flattened + right_flattened)
+ )
+
+ return BinaryExpression(
+ left, right, op, type_=type_, negate=negate, modifiers=modifiers
+ )
+
+
+class ExpressionClauseList(OperatorExpression[_T]):
+ """Describe a list of clauses, separated by an operator,
+ in a column expression context.
+
+ :class:`.ExpressionClauseList` differs from :class:`.ClauseList` in that
+ it represents a column-oriented DQL expression only, not an open ended
+ list of anything comma separated.
+
+ .. versionadded:: 2.0
+
+ """
+
+ __visit_name__ = "expression_clauselist"
+
+ _traverse_internals: _TraverseInternalsType = [
+ ("clauses", InternalTraversal.dp_clauseelement_tuple),
+ ("operator", InternalTraversal.dp_operator),
+ ]
+
+ clauses: typing_Tuple[ColumnElement[Any], ...]
+
+ group: bool
+
+ def __init__(
+ self,
+ operator: OperatorType,
+ *clauses: _ColumnExpressionArgument[Any],
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ ):
+ self.operator = operator
+
+ self.clauses = tuple(
+ coercions.expect(
+ roles.ExpressionElementRole, clause, apply_propagate_attrs=self
+ )
+ for clause in clauses
+ )
+ self._is_implicitly_boolean = operators.is_boolean(self.operator)
+ self.type = type_api.to_instance(type_) # type: ignore
+
+ @property
+ def _flattened_operator_clauses(
+ self,
+ ) -> typing_Tuple[ColumnElement[Any], ...]:
+ return self.clauses
+
+ def __iter__(self) -> Iterator[ColumnElement[Any]]:
+ return iter(self.clauses)
+
+ def __len__(self) -> int:
+ return len(self.clauses)
+
+ @property
+ def _select_iterable(self) -> _SelectIterable:
+ return (self,)
+
+ @util.ro_non_memoized_property
+ def _from_objects(self) -> List[FromClause]:
+ return list(itertools.chain(*[c._from_objects for c in self.clauses]))
+
+ def _append_inplace(self, clause: ColumnElement[Any]) -> None:
+ self.clauses += (clause,)
+
+ @classmethod
+ def _construct_for_list(
+ cls,
+ operator: OperatorType,
+ type_: TypeEngine[_T],
+ *clauses: ColumnElement[Any],
+ group: bool = True,
+ ) -> ExpressionClauseList[_T]:
+ self = cls.__new__(cls)
+ self.group = group
+ self.clauses = clauses
+ self.operator = operator
+ self.type = type_
+ return self
+
+ def _negate(self) -> Any:
+ grouped = self.self_group(against=operators.inv)
+ assert isinstance(grouped, ColumnElement)
+ return UnaryExpression(
+ grouped, operator=operators.inv, wraps_column_expression=True
+ )
+
+
+class BooleanClauseList(ExpressionClauseList[bool]):
+ __visit_name__ = "expression_clauselist"
inherit_cache = True
def __init__(self, *arg, **kw):
@@ -2668,7 +2838,15 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
if lcc > 1:
# multiple elements. Return regular BooleanClauseList
# which will link elements against the operator.
- return cls._construct_raw(operator, convert_clauses) # type: ignore # noqa: E501
+
+ flattened_clauses = itertools.chain.from_iterable(
+ (c for c in to_flat._flattened_operator_clauses)
+ if getattr(to_flat, "operator", None) is operator
+ else (to_flat,)
+ for to_flat in convert_clauses
+ )
+
+ return cls._construct_raw(operator, flattened_clauses) # type: ignore # noqa: E501
elif lcc == 1:
# just one element. return it as a single boolean element,
# not a list and discard the operator.
@@ -2726,10 +2904,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
clauses: Optional[Sequence[ColumnElement[Any]]] = None,
) -> BooleanClauseList:
self = cls.__new__(cls)
- self.clauses = list(clauses) if clauses else []
+ self.clauses = tuple(clauses) if clauses else ()
self.group = True
self.operator = operator
- self.group_contents = True
self.type = type_api.BOOLEANTYPE
self._is_implicitly_boolean = True
return self
@@ -2768,9 +2945,6 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
else:
return super(BooleanClauseList, self).self_group(against=against)
- def _negate(self):
- return ClauseList._negate(self)
-
and_ = BooleanClauseList.and_
or_ = BooleanClauseList.or_
@@ -3357,7 +3531,7 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]):
return AsBoolean(self.element, self.negate, self.operator)
-class BinaryExpression(ColumnElement[_T]):
+class BinaryExpression(OperatorExpression[_T]):
"""Represent an expression that is ``LEFT <operator> RIGHT``.
A :class:`.BinaryExpression` is generated automatically
@@ -3394,12 +3568,12 @@ class BinaryExpression(ColumnElement[_T]):
modifiers: Optional[Mapping[str, Any]]
left: ColumnElement[Any]
- right: Union[ColumnElement[Any], ClauseList]
+ right: ColumnElement[Any]
def __init__(
self,
left: ColumnElement[Any],
- right: Union[ColumnElement[Any], ClauseList],
+ right: ColumnElement[Any],
operator: OperatorType,
type_: Optional[_TypeEngineArgument[_T]] = None,
negate: Optional[OperatorType] = None,
@@ -3427,6 +3601,12 @@ class BinaryExpression(ColumnElement[_T]):
else:
self.modifiers = modifiers
+ @property
+ def _flattened_operator_clauses(
+ self,
+ ) -> typing_Tuple[ColumnElement[Any], ...]:
+ return (self.left, self.right)
+
def __bool__(self):
"""Implement Python-side "bool" for BinaryExpression as a
simple "identity" check for the left and right attributes,
@@ -3465,8 +3645,6 @@ class BinaryExpression(ColumnElement[_T]):
else:
raise TypeError("Boolean value of this clause is not defined")
- __nonzero__ = __bool__
-
if typing.TYPE_CHECKING:
def __invert__(
@@ -3474,21 +3652,10 @@ class BinaryExpression(ColumnElement[_T]):
) -> "BinaryExpression[_T]":
...
- @property
- def is_comparison(self):
- return operators.is_comparison(self.operator)
-
@util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
return self.left._from_objects + self.right._from_objects
- def self_group(self, against=None):
-
- if operators.is_precedent(self.operator, against):
- return Grouping(self)
- else:
- return self
-
def _negate(self):
if self.negate is not None:
return BinaryExpression(
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 455e74f7b..d08bbf4eb 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -81,6 +81,7 @@ from .elements import ClauseList as ClauseList
from .elements import CollectionAggregate as CollectionAggregate
from .elements import ColumnClause as ColumnClause
from .elements import ColumnElement as ColumnElement
+from .elements import ExpressionClauseList as ExpressionClauseList
from .elements import Extract as Extract
from .elements import False_ as False_
from .elements import FunctionFilter as FunctionFilter
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 803e85654..8d98f893f 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -84,6 +84,7 @@ class HasExpressionLookup(TypeEngineMixin):
raise NotImplementedError()
class Comparator(TypeEngine.Comparator[_CT]):
+ __slots__ = ()
_blank_dict = util.EMPTY_DICT
@@ -114,6 +115,8 @@ class Concatenable(TypeEngineMixin):
typically strings."""
class Comparator(TypeEngine.Comparator[_T]):
+ __slots__ = ()
+
def _adapt_expression(
self,
op: OperatorType,
@@ -143,6 +146,8 @@ class Indexable(TypeEngineMixin):
"""
class Comparator(TypeEngine.Comparator[_T]):
+ __slots__ = ()
+
def _setup_getitem(self, index):
raise NotImplementedError()
@@ -174,12 +179,9 @@ class String(Concatenable, TypeEngine[str]):
__visit_name__ = "string"
def __init__(
- # note pylance appears to require the "self" type in a constructor
- # for the _T type to be correctly recognized when we send the
- # class as the argument, e.g. `column("somecol", String)`
self,
- length=None,
- collation=None,
+ length: Optional[int] = None,
+ collation: Optional[str] = None,
):
"""
Create a string-holding type.
@@ -1508,6 +1510,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
) from err
class Comparator(String.Comparator[str]):
+ __slots__ = ()
+
type: String
def _adapt_expression(
@@ -1963,7 +1967,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
TypeDecorator.Comparator[_CT],
_AbstractInterval.Comparator[_CT],
):
- pass
+ __slots__ = ()
comparator_factory = Comparator
@@ -2385,6 +2389,8 @@ class JSON(Indexable, TypeEngine[Any]):
class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]):
"""Define comparison operations for :class:`_types.JSON`."""
+ __slots__ = ()
+
def _setup_getitem(self, index):
if not isinstance(index, str) and isinstance(
index, collections_abc.Sequence
@@ -2710,6 +2716,8 @@ class ARRAY(
"""
+ __slots__ = ()
+
def _setup_getitem(self, index):
arr_type = cast(ARRAY, self.type)
@@ -3221,6 +3229,8 @@ class NullType(TypeEngine[None]):
return process
class Comparator(TypeEngine.Comparator[_T]):
+ __slots__ = ()
+
def _adapt_expression(
self,
op: OperatorType,
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index cbc4e9e70..c23cd04dd 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -924,7 +924,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
):
return COMPARE_FAILED
- def compare_clauselist(self, left, right, **kw):
+ def compare_expression_clauselist(self, left, right, **kw):
if left.operator is right.operator:
if operators.is_associative(left.operator):
if self._compare_unordered_sequences(
@@ -938,6 +938,9 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
else:
return COMPARE_FAILED
+ def compare_clauselist(self, left, right, **kw):
+ return self.compare_expression_clauselist(left, right, **kw)
+
def compare_binary(self, left, right, **kw):
if left.operator == right.operator:
if operators.is_commutative(left.operator):
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index cc14dd9c4..25fe844c3 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -320,6 +320,31 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
data = r"backslash one \ backslash two \\ end"
literal_round_trip(String(40), [data], [data])
+ def test_concatenate_binary(self, connection):
+ """dialects with special string concatenation operators should
+ implement visit_concat_op_binary() and visit_concat_op_clauselist()
+ in their compiler.
+
+ .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed
+ for dialects to override the string concatenation operator.
+
+ """
+ eq_(connection.scalar(select(literal("a") + "b")), "ab")
+
+ def test_concatenate_clauselist(self, connection):
+ """dialects with special string concatenation operators should
+ implement visit_concat_op_binary() and visit_concat_op_clauselist()
+ in their compiler.
+
+ .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed
+ for dialects to override the string concatenation operator.
+
+ """
+ eq_(
+ connection.scalar(select(literal("a") + "b" + "c" + "d" + "e")),
+ "abcde",
+ )
+
class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
compare = None