diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/_psycopg_common.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/array.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/instrumentation.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 37 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 227 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 81 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 25 |
17 files changed, 378 insertions, 125 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 35428b659..2bacaaf33 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1833,6 +1833,11 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "LEN%s" % self.function_argspec(fn, **kw) + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return " + ".join(self.process(elem, **kw) for elem in clauselist) + def visit_concat_op_binary(self, binary, operator, **kw): return "%s + %s" % ( self.process(binary.left, **kw), diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 25f4c6945..b53e55abf 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1322,6 +1322,13 @@ class MySQLCompiler(compiler.SQLCompiler): return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return "concat(%s)" % ( + ", ".join(self.process(elem, **kw) for elem in clauselist.clauses) + ) + def visit_concat_op_binary(self, binary, operator, **kw): return "concat(%s, %s)" % ( self.process(binary.left, **kw), diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index 7f936fefb..e7d5e77c3 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -170,15 +170,17 @@ class _PGDialect_common_psycopg(PGDialect): def do_ping(self, dbapi_connection): cursor = None + before_autocommit = dbapi_connection.autocommit try: - self._do_autocommit(dbapi_connection, True) + if not before_autocommit: + self._do_autocommit(dbapi_connection, True) cursor = dbapi_connection.cursor() try: cursor.execute(self._dialect_specific_select_one) finally: cursor.close() - if not dbapi_connection.closed: - self._do_autocommit(dbapi_connection, False) + if not before_autocommit and not dbapi_connection.closed: + self._do_autocommit(dbapi_connection, before_autocommit) except self.dbapi.Error as err: if self.is_disconnect(err, dbapi_connection, cursor): return False diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 74643c4d9..7eec7b86f 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -5,14 +5,19 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import re +from typing import Any +from typing import TypeVar from ... import types as sqltypes from ... import util -from ...sql import coercions from ...sql import expression from ...sql import operators -from ...sql import roles + + +_T = TypeVar("_T", bound=Any) def Any(other, arrexpr, operator=operators.eq): @@ -33,7 +38,7 @@ def All(other, arrexpr, operator=operators.eq): return arrexpr.all(other, operator) -class array(expression.ClauseList, expression.ColumnElement): +class array(expression.ExpressionClauseList[_T]): """A PostgreSQL ARRAY literal. @@ -90,16 +95,19 @@ class array(expression.ClauseList, expression.ColumnElement): inherit_cache = True def __init__(self, clauses, **kw): - clauses = [ - coercions.expect(roles.ExpressionElementRole, c) for c in clauses - ] - - self._type_tuple = [arg.type for arg in clauses] - main_type = kw.pop( - "type_", - self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE, + + type_arg = kw.pop("type_", None) + super(array, self).__init__(operators.comma_op, *clauses, **kw) + + self._type_tuple = [arg.type for arg in self.clauses] + + main_type = ( + type_arg + if type_arg is not None + else self._type_tuple[0] + if self._type_tuple + else sqltypes.NULLTYPE ) - super(array, self).__init__(*clauses, **kw) if isinstance(main_type, ARRAY): self.type = ARRAY( diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 9207221df..b811d1cab 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -385,12 +385,14 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): != self._psycopg_TransactionStatus.IDLE ): dbapi_conn.rollback() - before = dbapi_conn.autocommit + before_autocommit = dbapi_conn.autocommit try: - self._do_autocommit(dbapi_conn, True) + if not before_autocommit: + self._do_autocommit(dbapi_conn, True) dbapi_conn.execute(command) finally: - self._do_autocommit(dbapi_conn, before) + if not before_autocommit: + self._do_autocommit(dbapi_conn, before_autocommit) def do_rollback_twophase( self, connection, xid, is_prepared=True, recover=False diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 4a699f63b..70507015b 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -109,6 +109,10 @@ class DeclarativeMeta( def __init__( cls, classname: Any, bases: Any, dict_: Any, **kw: Any ) -> None: + # use cls.__dict__, which can be modified by an + # __init_subclass__() method (#7900) + dict_ = cls.__dict__ + # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named # "registry" @@ -293,7 +297,8 @@ class declared_attr(interfaces._MappedAttribute[_T]): # here, we are inside of the declarative scan. use the registry # that is tracking the values of these attributes. - declarative_scan = manager.declarative_scan + declarative_scan = manager.declarative_scan() + assert declarative_scan is not None reg = declarative_scan.declared_attr_reg if self in reg: diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 3fb8af80c..804d05ce1 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -161,7 +161,13 @@ def _check_declared_props_nocascade(obj, name, cls): class _MapperConfig: - __slots__ = ("cls", "classname", "properties", "declared_attr_reg") + __slots__ = ( + "cls", + "classname", + "properties", + "declared_attr_reg", + "__weakref__", + ) @classmethod def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw): @@ -311,13 +317,15 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_kw, ): + # grab class dict before the instrumentation manager has been added. + # reduces cycles + self.clsdict_view = ( + util.immutabledict(dict_) if dict_ else util.EMPTY_DICT + ) super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) self.registry = registry self.persist_selectable = None - self.clsdict_view = ( - util.immutabledict(dict_) if dict_ else util.EMPTY_DICT - ) self.collected_attributes = {} self.collected_annotations: Dict[str, Tuple[Any, bool]] = {} self.declared_columns = util.OrderedSet() diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 453fc8903..1b3340dc5 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -94,6 +94,9 @@ class EvaluatorCompiler: def visit_tuple(self, clause): return self.visit_clauselist(clause) + def visit_expression_clauselist(self, clause): + return self.visit_clauselist(clause) + def visit_clauselist(self, clause): evaluators = [self.process(clause) for clause in clause.clauses] diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index a5dc305d2..0d4b630da 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -39,6 +39,7 @@ from typing import Optional from typing import Set from typing import TYPE_CHECKING from typing import TypeVar +import weakref from . import base from . import collections @@ -167,7 +168,7 @@ class ClassManager( if registry: registry._add_manager(self) if declarative_scan: - self.declarative_scan = declarative_scan + self.declarative_scan = weakref.ref(declarative_scan) if expired_attribute_loader: self.expired_attribute_loader = expired_attribute_loader diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 18aa9945f..d41c4ebb8 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -884,12 +884,12 @@ def _emit_update_statements( clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: - clauses.clauses.append( + clauses._append_inplace( col == sql.bindparam(col._label, type_=col.type) ) if needs_version_id: - clauses.clauses.append( + clauses._append_inplace( mapper.version_id_col == sql.bindparam( mapper.version_id_col._label, @@ -1316,12 +1316,12 @@ def _emit_post_update_statements( clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: - clauses.clauses.append( + clauses._append_inplace( col == sql.bindparam(col._label, type_=col.type) ) if needs_version_id: - clauses.clauses.append( + clauses._append_inplace( mapper.version_id_col == sql.bindparam( mapper.version_id_col._label, @@ -1437,12 +1437,12 @@ def _emit_delete_statements( clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: - clauses.clauses.append( + clauses._append_inplace( col == sql.bindparam(col.key, type_=col.type) ) if need_version_id: - clauses.clauses.append( + clauses._append_inplace( mapper.version_id_col == sql.bindparam( mapper.version_id_col.key, type_=mapper.version_id_col.type @@ -2209,6 +2209,14 @@ class ORMInsert(ORMDMLState, InsertDMLState): bind_arguments, is_reentrant_invoke, ): + bind_arguments["clause"] = statement + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + bind_arguments["mapper"] = plugin_subject.mapper + return ( statement, util.immutabledict(execution_options), diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1498c8341..938be0f81 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2013,6 +2013,24 @@ class SQLCompiler(Compiled): return self._generate_delimited_list(clauselist.clauses, sep, **kw) + def visit_expression_clauselist(self, clauselist, **kw): + operator_ = clauselist.operator + + disp = self._get_operator_dispatch( + operator_, "expression_clauselist", None + ) + if disp: + return disp(clauselist, operator_, **kw) + + try: + opstring = OPERATORS[operator_] + except KeyError as err: + raise exc.UnsupportedCompilationError(self, operator_) from err + else: + return self._generate_delimited_list( + clauselist.clauses, opstring, **kw + ) + def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 944a0a5ce..512fca8d0 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -27,11 +27,12 @@ from . import type_api from .elements import and_ from .elements import BinaryExpression from .elements import ClauseElement -from .elements import ClauseList from .elements import CollationClause from .elements import CollectionAggregate +from .elements import ExpressionClauseList from .elements import False_ from .elements import Null +from .elements import OperatorExpression from .elements import or_ from .elements import True_ from .elements import UnaryExpression @@ -56,11 +57,9 @@ def _boolean_compare( reverse: bool = False, _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), _any_all_expr: bool = False, - result_type: Optional[ - Union[Type[TypeEngine[bool]], TypeEngine[bool]] - ] = None, + result_type: Optional[TypeEngine[bool]] = None, **kwargs: Any, -) -> BinaryExpression[bool]: +) -> OperatorExpression[bool]: if result_type is None: result_type = type_api.BOOLEANTYPE @@ -71,7 +70,7 @@ def _boolean_compare( if op in (operators.eq, operators.ne) and isinstance( obj, (bool, True_, False_) ): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), op, @@ -83,7 +82,7 @@ def _boolean_compare( operators.is_distinct_from, operators.is_not_distinct_from, ): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), op, @@ -98,7 +97,7 @@ def _boolean_compare( else: # all other None uses IS, IS NOT if op in (operators.eq, operators.is_): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), operators.is_, @@ -106,7 +105,7 @@ def _boolean_compare( type_=result_type, ) elif op in (operators.ne, operators.is_not): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), operators.is_not, @@ -125,7 +124,7 @@ def _boolean_compare( ) if reverse: - return BinaryExpression( + return OperatorExpression._construct_for_op( obj, expr, op, @@ -134,7 +133,7 @@ def _boolean_compare( modifiers=kwargs, ) else: - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, obj, op, @@ -169,11 +168,9 @@ def _binary_operate( obj: roles.BinaryElementRole[Any], *, reverse: bool = False, - result_type: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] - ] = None, + result_type: Optional[TypeEngine[_T]] = None, **kw: Any, -) -> BinaryExpression[_T]: +) -> OperatorExpression[_T]: coerced_obj = coercions.expect( roles.BinaryElementRole, obj, expr=expr, operator=op @@ -189,7 +186,9 @@ def _binary_operate( op, right.comparator ) - return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) + return OperatorExpression._construct_for_op( + left, right, op, type_=result_type, modifiers=kw + ) def _conjunction_operate( @@ -311,7 +310,9 @@ def _between_impl( """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( expr, - ClauseList( + ExpressionClauseList._construct_for_list( + operators.and_, + type_api.NULLTYPE, coercions.expect( roles.BinaryElementRole, cleft, @@ -324,9 +325,7 @@ def _between_impl( expr=expr, operator=operators.and_, ), - operator=operators.and_, group=False, - group_contents=False, ), op, negate=operators.not_between_op diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 20938fd5a..ea0fa7996 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1340,7 +1340,11 @@ class ColumnElement( if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return AsBoolean(self, operators.is_false, operators.is_true) else: - return cast("UnaryExpression[_T]", super()._negate()) + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) type: TypeEngine[_T] @@ -2519,6 +2523,8 @@ class ClauseList( __visit_name__ = "clauselist" + # this is used only by the ORM in a legacy use case for + # composite attributes _is_clause_list = True _traverse_internals: _TraverseInternalsType = [ @@ -2534,18 +2540,14 @@ class ClauseList( operator: OperatorType = operators.comma_op, group: bool = True, group_contents: bool = True, - _flatten_sub_clauses: bool = False, _literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole, ): self.operator = operator self.group = group self.group_contents = group_contents clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses - if _flatten_sub_clauses: - clauses_iterator = util.flatten_iterator(clauses_iterator) - - self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role text_converter_role: Type[roles.SQLRole] = _literal_as_text_role + self._text_converter_role = text_converter_role if self.group_contents: self.clauses = [ @@ -2612,8 +2614,176 @@ class ClauseList( return self -class BooleanClauseList(ClauseList, ColumnElement[bool]): - __visit_name__ = "clauselist" +class OperatorExpression(ColumnElement[_T]): + """base for expressions that contain an operator and operands + + .. versionadded:: 2.0 + + """ + + operator: OperatorType + type: TypeEngine[_T] + + group: bool = True + + @property + def is_comparison(self): + return operators.is_comparison(self.operator) + + def self_group(self, against=None): + if ( + self.group + and operators.is_precedent(self.operator, against) + or ( + # a negate against a non-boolean operator + # doesn't make too much sense but we should + # group for that + against is operators.inv + and not operators.is_boolean(self.operator) + ) + ): + return Grouping(self) + else: + return self + + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + raise NotImplementedError() + + @classmethod + def _construct_for_op( + cls, + left: ColumnElement[Any], + right: ColumnElement[Any], + op: OperatorType, + *, + type_: TypeEngine[_T], + negate: Optional[OperatorType] = None, + modifiers: Optional[Mapping[str, Any]] = None, + ) -> OperatorExpression[_T]: + + if operators.is_associative(op): + assert ( + negate is None + ), f"negate not supported for associative operator {op}" + + multi = False + if getattr( + left, "operator", None + ) is op and type_._compare_type_affinity(left.type): + multi = True + left_flattened = left._flattened_operator_clauses + else: + left_flattened = (left,) + + if getattr( + right, "operator", None + ) is op and type_._compare_type_affinity(right.type): + multi = True + right_flattened = right._flattened_operator_clauses + else: + right_flattened = (right,) + + if multi: + return ExpressionClauseList._construct_for_list( + op, type_, *(left_flattened + right_flattened) + ) + + return BinaryExpression( + left, right, op, type_=type_, negate=negate, modifiers=modifiers + ) + + +class ExpressionClauseList(OperatorExpression[_T]): + """Describe a list of clauses, separated by an operator, + in a column expression context. + + :class:`.ExpressionClauseList` differs from :class:`.ClauseList` in that + it represents a column-oriented DQL expression only, not an open ended + list of anything comma separated. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "expression_clauselist" + + _traverse_internals: _TraverseInternalsType = [ + ("clauses", InternalTraversal.dp_clauseelement_tuple), + ("operator", InternalTraversal.dp_operator), + ] + + clauses: typing_Tuple[ColumnElement[Any], ...] + + group: bool + + def __init__( + self, + operator: OperatorType, + *clauses: _ColumnExpressionArgument[Any], + type_: Optional[_TypeEngineArgument[_T]] = None, + ): + self.operator = operator + + self.clauses = tuple( + coercions.expect( + roles.ExpressionElementRole, clause, apply_propagate_attrs=self + ) + for clause in clauses + ) + self._is_implicitly_boolean = operators.is_boolean(self.operator) + self.type = type_api.to_instance(type_) # type: ignore + + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + return self.clauses + + def __iter__(self) -> Iterator[ColumnElement[Any]]: + return iter(self.clauses) + + def __len__(self) -> int: + return len(self.clauses) + + @property + def _select_iterable(self) -> _SelectIterable: + return (self,) + + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return list(itertools.chain(*[c._from_objects for c in self.clauses])) + + def _append_inplace(self, clause: ColumnElement[Any]) -> None: + self.clauses += (clause,) + + @classmethod + def _construct_for_list( + cls, + operator: OperatorType, + type_: TypeEngine[_T], + *clauses: ColumnElement[Any], + group: bool = True, + ) -> ExpressionClauseList[_T]: + self = cls.__new__(cls) + self.group = group + self.clauses = clauses + self.operator = operator + self.type = type_ + return self + + def _negate(self) -> Any: + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) + + +class BooleanClauseList(ExpressionClauseList[bool]): + __visit_name__ = "expression_clauselist" inherit_cache = True def __init__(self, *arg, **kw): @@ -2686,7 +2856,15 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): if lcc > 1: # multiple elements. Return regular BooleanClauseList # which will link elements against the operator. - return cls._construct_raw(operator, convert_clauses) # type: ignore # noqa: E501 + + flattened_clauses = itertools.chain.from_iterable( + (c for c in to_flat._flattened_operator_clauses) + if getattr(to_flat, "operator", None) is operator + else (to_flat,) + for to_flat in convert_clauses + ) + + return cls._construct_raw(operator, flattened_clauses) # type: ignore # noqa: E501 elif lcc == 1: # just one element. return it as a single boolean element, # not a list and discard the operator. @@ -2744,10 +2922,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): clauses: Optional[Sequence[ColumnElement[Any]]] = None, ) -> BooleanClauseList: self = cls.__new__(cls) - self.clauses = list(clauses) if clauses else [] + self.clauses = tuple(clauses) if clauses else () self.group = True self.operator = operator - self.group_contents = True self.type = type_api.BOOLEANTYPE self._is_implicitly_boolean = True return self @@ -2786,9 +2963,6 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): else: return super(BooleanClauseList, self).self_group(against=against) - def _negate(self): - return ClauseList._negate(self) - and_ = BooleanClauseList.and_ or_ = BooleanClauseList.or_ @@ -3375,7 +3549,7 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): return AsBoolean(self.element, self.negate, self.operator) -class BinaryExpression(ColumnElement[_T]): +class BinaryExpression(OperatorExpression[_T]): """Represent an expression that is ``LEFT <operator> RIGHT``. A :class:`.BinaryExpression` is generated automatically @@ -3412,12 +3586,12 @@ class BinaryExpression(ColumnElement[_T]): modifiers: Optional[Mapping[str, Any]] left: ColumnElement[Any] - right: Union[ColumnElement[Any], ClauseList] + right: ColumnElement[Any] def __init__( self, left: ColumnElement[Any], - right: Union[ColumnElement[Any], ClauseList], + right: ColumnElement[Any], operator: OperatorType, type_: Optional[_TypeEngineArgument[_T]] = None, negate: Optional[OperatorType] = None, @@ -3445,6 +3619,12 @@ class BinaryExpression(ColumnElement[_T]): else: self.modifiers = modifiers + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + return (self.left, self.right) + def __bool__(self): """Implement Python-side "bool" for BinaryExpression as a simple "identity" check for the left and right attributes, @@ -3483,8 +3663,6 @@ class BinaryExpression(ColumnElement[_T]): else: raise TypeError("Boolean value of this clause is not defined") - __nonzero__ = __bool__ - if typing.TYPE_CHECKING: def __invert__( @@ -3492,21 +3670,10 @@ class BinaryExpression(ColumnElement[_T]): ) -> "BinaryExpression[_T]": ... - @property - def is_comparison(self): - return operators.is_comparison(self.operator) - @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects - def self_group(self, against=None): - - if operators.is_precedent(self.operator, against): - return Grouping(self) - else: - return self - def _negate(self): if self.negate is not None: return BinaryExpression( diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 455e74f7b..d08bbf4eb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -81,6 +81,7 @@ from .elements import ClauseList as ClauseList from .elements import CollectionAggregate as CollectionAggregate from .elements import ColumnClause as ColumnClause from .elements import ColumnElement as ColumnElement +from .elements import ExpressionClauseList as ExpressionClauseList from .elements import Extract as Extract from .elements import False_ as False_ from .elements import FunctionFilter as FunctionFilter diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 72e658db8..64d6ea81b 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -88,6 +88,7 @@ class HasExpressionLookup(TypeEngineMixin): raise NotImplementedError() class Comparator(TypeEngine.Comparator[_CT]): + __slots__ = () _blank_dict = util.EMPTY_DICT @@ -118,6 +119,8 @@ class Concatenable(TypeEngineMixin): typically strings.""" class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _adapt_expression( self, op: OperatorType, @@ -147,6 +150,8 @@ class Indexable(TypeEngineMixin): """ class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _setup_getitem(self, index): raise NotImplementedError() @@ -178,9 +183,6 @@ class String(Concatenable, TypeEngine[str]): __visit_name__ = "string" def __init__( - # note pylance appears to require the "self" type in a constructor - # for the _T type to be correctly recognized when we send the - # class as the argument, e.g. `column("somecol", String)` self, length: Optional[int] = None, collation: Optional[str] = None, @@ -416,41 +418,34 @@ _N = TypeVar("_N", bound=Union[decimal.Decimal, float]) class Numeric(HasExpressionLookup, TypeEngine[_N]): - """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``. + """Base for non-integer numeric types, such as + ``NUMERIC``, ``FLOAT``, ``DECIMAL``, and other variants. - This type returns Python ``decimal.Decimal`` objects by default, unless - the :paramref:`.Numeric.asdecimal` flag is set to False, in which case - they are coerced to Python ``float`` objects. + The :class:`.Numeric` datatype when used directly will render DDL + corresponding to precision numerics if available, such as + ``NUMERIC(precision, scale)``. The :class:`.Float` subclass will + attempt to render a floating-point datatype such as ``FLOAT(precision)``. - .. note:: + :class:`.Numeric` returns Python ``decimal.Decimal`` objects by default, + based on the default value of ``True`` for the + :paramref:`.Numeric.asdecimal` parameter. If this parameter is set to + False, returned values are coerced to Python ``float`` objects. - The :class:`.Numeric` type is designed to receive data from a database - type that is explicitly known to be a decimal type - (e.g. ``DECIMAL``, ``NUMERIC``, others) and not a floating point - type (e.g. ``FLOAT``, ``REAL``, others). - If the database column on the server is in fact a floating-point - type, such as ``FLOAT`` or ``REAL``, use the :class:`.Float` - type or a subclass, otherwise numeric coercion between - ``float``/``Decimal`` may or may not function as expected. + The :class:`.Float` subtype, being more specific to floating point, + defaults the :paramref:`.Float.asdecimal` flag to False so that the + default Python datatype is ``float``. .. note:: - The Python ``decimal.Decimal`` class is generally slow - performing; cPython 3.3 has now switched to use the `cdecimal - <https://pypi.org/project/cdecimal/>`_ library natively. For - older Python versions, the ``cdecimal`` library can be patched - into any application where it will replace the ``decimal`` - library fully, however this needs to be applied globally and - before any other modules have been imported, as follows:: - - import sys - import cdecimal - sys.modules["decimal"] = cdecimal - - Note that the ``cdecimal`` and ``decimal`` libraries are **not - compatible with each other**, so patching ``cdecimal`` at the - global level is the only way it can be used effectively with - various DBAPIs that hardcode to import the ``decimal`` library. + When using a :class:`.Numeric` datatype against a database type that + returns Python floating point values to the driver, the accuracy of the + decimal conversion indicated by :paramref:`.Numeric.asdecimal` may be + limited. The behavior of specific numeric/floating point datatypes + is a product of the SQL datatype in use, the Python :term:`DBAPI` + in use, as well as strategies that may be present within + the SQLAlchemy dialect in use. Users requiring specific precision/ + scale are encouraged to experiment with the available datatypes + in order to determine the best results. """ @@ -490,8 +485,6 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): value of ".scale" as the default for decimal_return_scale, if not otherwise specified. - .. versionadded:: 0.9.0 - When using the ``Numeric`` type, care should be taken to ensure that the asdecimal setting is appropriate for the DBAPI in use - when Numeric applies a conversion from Decimal->float or float-> @@ -589,16 +582,6 @@ class Float(Numeric[_N]): :paramref:`.Float.asdecimal` flag is set to True, in which case they are coerced to ``decimal.Decimal`` objects. - .. note:: - - The :class:`.Float` type is designed to receive data from a database - type that is explicitly known to be a floating point type - (e.g. ``FLOAT``, ``REAL``, others) - and not a decimal type (e.g. ``DECIMAL``, ``NUMERIC``, others). - If the database column on the server is in fact a Numeric - type, such as ``DECIMAL`` or ``NUMERIC``, use the :class:`.Numeric` - type or a subclass, otherwise numeric coercion between - ``float``/``Decimal`` may or may not function as expected. """ @@ -1514,6 +1497,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): ) from err class Comparator(String.Comparator[str]): + __slots__ = () + type: String def _adapt_expression( @@ -1976,7 +1961,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): TypeDecorator.Comparator[_CT], _AbstractInterval.Comparator[_CT], ): - pass + __slots__ = () comparator_factory = Comparator @@ -2398,6 +2383,8 @@ class JSON(Indexable, TypeEngine[Any]): class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" + __slots__ = () + def _setup_getitem(self, index): if not isinstance(index, str) and isinstance( index, collections_abc.Sequence @@ -2726,6 +2713,8 @@ class ARRAY( """ + __slots__ = () + def _setup_getitem(self, index): arr_type = cast(ARRAY, self.type) @@ -3243,6 +3232,8 @@ class NullType(TypeEngine[None]): return process class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _adapt_expression( self, op: OperatorType, diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index cbc4e9e70..c23cd04dd 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -924,7 +924,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): ): return COMPARE_FAILED - def compare_clauselist(self, left, right, **kw): + def compare_expression_clauselist(self, left, right, **kw): if left.operator is right.operator: if operators.is_associative(left.operator): if self._compare_unordered_sequences( @@ -938,6 +938,9 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): else: return COMPARE_FAILED + def compare_clauselist(self, left, right, **kw): + return self.compare_expression_clauselist(left, right, **kw) + def compare_binary(self, left, right, **kw): if left.operator == right.operator: if operators.is_commutative(left.operator): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index cc14dd9c4..25fe844c3 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -320,6 +320,31 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): data = r"backslash one \ backslash two \\ end" literal_round_trip(String(40), [data], [data]) + def test_concatenate_binary(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_(connection.scalar(select(literal("a") + "b")), "ab") + + def test_concatenate_clauselist(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_( + connection.scalar(select(literal("a") + "b" + "c" + "d" + "e")), + "abcde", + ) + class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): compare = None |
