From a8029f5a7e3e376ec57f1614ab0294b717d53c05 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 7 Aug 2022 12:14:19 -0400 Subject: ORM bulk insert via execute * ORM Insert now includes "bulk" mode that will run essentially the same process as session.bulk_insert_mappings; interprets the given list of values as ORM attributes for key names * ORM UPDATE has a similar feature, without RETURNING support, for session.bulk_update_mappings * Added support for upserts to do RETURNING ORM objects as well * ORM UPDATE/DELETE with list of parameters + WHERE criteria is a not implemented; use connection * ORM UPDATE/DELETE defaults to "auto" synchronize_session; use fetch if RETURNING is present, evaluate if not, as "fetch" is much more efficient (no expired object SELECT problem) and less error prone if RETURNING is available UPDATE: howver this is inefficient! please continue to use evaluate for simple cases, auto can move to fetch if criteria not evaluable * "Evaluate" criteria will now not preemptively unexpire and SELECT attributes that were individually expired. Instead, if evaluation of the criteria indicates that the necessary attrs were expired, we expire the object completely (delete) or expire the SET attrs unconditionally (update). This keeps the object in the same unloaded state where it will refresh those attrs on the next pass, for this generally unusual case. (originally #5664) * Core change! update/delete rowcount comes from len(rows) if RETURNING was used. SQLite at least otherwise did not support this. adjusted test_rowcount accordingly * ORM DELETE with a list of parameters at all is also a not implemented as this would imply "bulk", and there is no bulk_delete_mappings (could be, but we dont have that) * ORM insert().values() with single or multi-values translates key names based on ORM attribute names * ORM returning() implemented for insert, update, delete; explcit returning clauses now interpret rows in an ORM context, with support for qualifying loader options as well * session.bulk_insert_mappings() assigns polymorphic identity if not set. * explicit RETURNING + synchronize_session='fetch' is now supported with UPDATE and DELETE. * expanded return_defaults() to work with DELETE also. * added support for composite attributes to be present in the dictionaries used by bulk_insert_mappings and bulk_update_mappings, which is also the new ORM bulk insert/update feature, that will expand the composite values into their individual mapped attributes the way they'd be on a mapped instance. * bulk UPDATE supports "synchronize_session=evaluate", is the default. this does not apply to session.bulk_update_mappings, just the new version * both bulk UPDATE and bulk INSERT, the latter with or without RETURNING, support *heterogenous* parameter sets. session.bulk_insert/update_mappings did this, so this feature is maintained. now cursor result can be both horizontally and vertically spliced :) This is now a long story with a lot of options, which in itself is a problem to be able to document all of this in some way that makes sense. raising exceptions for use cases we haven't supported is pretty important here too, the tradition of letting unsupported things just not work is likely not a good idea at this point, though there are still many cases that aren't easily avoidable Fixes: #8360 Fixes: #7864 Fixes: #7865 Change-Id: Idf28379f8705e403a3c6a937f6a798a042ef2540 --- lib/sqlalchemy/orm/bulk_persistence.py | 1465 +++++++++++++++++++++++++------- lib/sqlalchemy/orm/context.py | 173 +++- lib/sqlalchemy/orm/descriptor_props.py | 26 + lib/sqlalchemy/orm/evaluator.py | 76 +- lib/sqlalchemy/orm/identity.py | 10 + lib/sqlalchemy/orm/loading.py | 5 +- lib/sqlalchemy/orm/mapper.py | 15 + lib/sqlalchemy/orm/persistence.py | 138 ++- lib/sqlalchemy/orm/query.py | 4 +- lib/sqlalchemy/orm/session.py | 25 +- lib/sqlalchemy/orm/strategies.py | 11 + 11 files changed, 1540 insertions(+), 408 deletions(-) (limited to 'lib/sqlalchemy/orm') diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 225292d17..3ed34a57a 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -15,24 +15,32 @@ specifically outside of the flush() process. from __future__ import annotations from typing import Any +from typing import cast from typing import Dict from typing import Iterable +from typing import Optional +from typing import overload from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes +from . import context from . import evaluator from . import exc as orm_exc +from . import loading from . import persistence from .base import NO_VALUE from .context import AbstractORMCompileState +from .context import FromStatement +from .context import ORMFromStatementCompileState +from .context import QueryContext from .. import exc as sa_exc -from .. import sql from .. import util from ..engine import Dialect from ..engine import result as _result from ..sql import coercions +from ..sql import dml from ..sql import expression from ..sql import roles from ..sql import select @@ -48,16 +56,24 @@ from ..util.typing import Literal if TYPE_CHECKING: from .mapper import Mapper + from .session import _BindArguments from .session import ORMExecuteState + from .session import Session from .session import SessionTransaction from .state import InstanceState + from ..engine import Connection + from ..engine import cursor + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter _O = TypeVar("_O", bound=object) -_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"] +_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"] +_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"] +@overload def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], @@ -65,7 +81,36 @@ def _bulk_insert( isstates: bool, return_defaults: bool, render_nulls: bool, + use_orm_insert_stmt: Literal[None] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., ) -> None: + ... + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., +) -> cursor.CursorResult[Any]: + ... + + +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, +) -> Optional[cursor.CursorResult[Any]]: base_mapper = mapper.base_mapper if session_transaction.session.connection_callable: @@ -81,13 +126,27 @@ def _bulk_insert( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) + + return_result: Optional[cursor.CursorResult[Any]] = None + for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue + is_joined_inh_supertable = super_mapper is not mapper + bookkeeping = ( + is_joined_inh_supertable + or return_defaults + or ( + use_orm_insert_stmt is not None + and bool(use_orm_insert_stmt._returning) + ) + ) + records = ( ( None, @@ -112,18 +171,25 @@ def _bulk_insert( table, ((None, mapping, mapper, connection) for mapping in mappings), bulk=True, - return_defaults=return_defaults, + return_defaults=bookkeeping, render_nulls=render_nulls, ) ) - persistence._emit_insert_statements( + result = persistence._emit_insert_statements( base_mapper, None, super_mapper, table, records, - bookkeeping=return_defaults, + bookkeeping=bookkeeping, + use_orm_insert_stmt=use_orm_insert_stmt, + execution_options=execution_options, ) + if use_orm_insert_stmt is not None: + if not use_orm_insert_stmt._returning or return_result is None: + return_result = result + elif result.returns_rows: + return_result = return_result.splice_horizontally(result) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -134,14 +200,43 @@ def _bulk_insert( tuple([dict_[key] for key in identity_props]), ) + if use_orm_insert_stmt is not None: + assert return_result is not None + return return_result + +@overload def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, + use_orm_update_stmt: Literal[None] = ..., ) -> None: + ... + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = ..., +) -> _result.Result[Any]: + ... + + +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = None, +) -> Optional[_result.Result[Any]]: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys @@ -161,7 +256,8 @@ def _bulk_update( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) if session_transaction.session.connection_callable: raise NotImplementedError( @@ -172,7 +268,7 @@ def _bulk_update( connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue records = persistence._collect_update_commands( @@ -193,8 +289,8 @@ def _bulk_update( for mapping in mappings ), bulk=True, + use_orm_update_stmt=use_orm_update_stmt, ) - persistence._emit_update_statements( base_mapper, None, @@ -202,10 +298,125 @@ def _bulk_update( table, records, bookkeeping=False, + use_orm_update_stmt=use_orm_update_stmt, ) + if use_orm_update_stmt is not None: + return _result.null_result() + + +def _expand_composites(mapper, mappings): + composite_attrs = mapper.composites + if not composite_attrs: + return + + composite_keys = set(composite_attrs.keys()) + populators = { + key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() + for key in composite_keys + } + for mapping in mappings: + for key in composite_keys.intersection(mapping): + populators[key](mapping) + class ORMDMLState(AbstractORMCompileState): + is_dml_returning = True + from_statement_ctx: Optional[ORMFromStatementCompileState] = None + + @classmethod + def _get_orm_crud_kv_pairs( + cls, mapper, statement, kv_iterator, needs_to_be_cacheable + ): + + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, str): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + yield ( + coercions.expect(roles.DMLColumnRole, k), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v, + ) + else: + yield from core_get_crud_kv_pairs( + statement, + desc._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + yield from core_get_crud_kv_pairs( + statement, + attr._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + else: + yield ( + k, + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + + @classmethod + def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_multi_crud_kv_pairs( + statement, kv_iterator + ) + + return [ + dict( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, statement, value_dict.items(), False + ) + ) + for value_dict in kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): + assert ( + needs_to_be_cacheable + ), "no test coverage for needs_to_be_cacheable=False" + + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_crud_kv_pairs( + statement, kv_iterator, needs_to_be_cacheable + ) + + return list( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, + statement, + kv_iterator, + needs_to_be_cacheable, + ) + ) + @classmethod def get_entity_description(cls, statement): ext_info = statement.table._annotations["parententity"] @@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState): ] ] + def _setup_orm_returning( + self, + compiler, + orm_level_statement, + dml_level_statement, + use_supplemental_cols=True, + dml_mapper=None, + ): + """establish ORM column handlers for an INSERT, UPDATE, or DELETE + which uses explicit returning(). + + called within compilation level create_for_statement. + + The _return_orm_returning() method then receives the Result + after the statement was executed, and applies ORM loading to the + state that we first established here. + + """ + + if orm_level_statement._returning: + + fs = FromStatement( + orm_level_statement._returning, dml_level_statement + ) + fs = fs.options(*orm_level_statement._with_options) + self.select_statement = fs + self.from_statement_ctx = ( + fsc + ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + fsc.setup_dml_returning_compile_state(dml_mapper) + + dml_level_statement = dml_level_statement._generate() + dml_level_statement._returning = () + + cols_to_return = [c for c in fsc.primary_columns if c is not None] + + # since we are splicing result sets together, make sure there + # are columns of some kind returned in each result set + if not cols_to_return: + cols_to_return.extend(dml_mapper.primary_key) + + if use_supplemental_cols: + dml_level_statement = dml_level_statement.return_defaults( + supplemental_cols=cols_to_return + ) + else: + dml_level_statement = dml_level_statement.returning( + *cols_to_return + ) + + return dml_level_statement + + @classmethod + def _return_orm_returning( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + if compile_state.from_statement_ctx: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state.from_statement_ctx, + compile_state.select_statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + else: + return result + class BulkUDCompileState(ORMDMLState): class default_update_options(Options): - _synchronize_session: _SynchronizeSessionArgument = "evaluate" - _is_delete_using = False - _is_update_from = False - _autoflush = True - _subject_mapper = None + _dml_strategy: _DMLStrategyArgument = "auto" + _synchronize_session: _SynchronizeSessionArgument = "auto" + _can_use_returning: bool = False + _is_delete_using: bool = False + _is_update_from: bool = False + _autoflush: bool = True + _subject_mapper: Optional[Mapper[Any]] = None _resolved_values = EMPTY_DICT - _resolved_keys_as_propnames = EMPTY_DICT - _value_evaluators = EMPTY_DICT - _matched_objects = None + _eval_condition = None _matched_rows = None _refresh_identity_token = None @@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState): execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", - {"synchronize_session", "is_delete_using", "is_update_from"}, + { + "synchronize_session", + "is_delete_using", + "is_update_from", + "dml_strategy", + }, execution_options, statement._execution_options, ) - sync = update_options._synchronize_session - if sync is not None: - if sync not in ("evaluate", "fetch", False): - raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are 'evaluate', 'fetch', False" - ) - bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState): update_options += {"_subject_mapper": plugin_subject.mapper} + if not isinstance(params, list): + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "orm"} + elif update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "bulk"} + elif update_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + sync = update_options._synchronize_session + if sync is not None: + if sync not in ("auto", "evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'auto', 'evaluate', 'fetch', False" + ) + if update_options._dml_strategy == "bulk" and sync == "fetch": + raise sa_exc.InvalidRequestError( + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates (i.e. multiple parameter sets)" + ) + if update_options._autoflush: session._autoflush() + if update_options._dml_strategy == "orm": + + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. statement = statement._annotate( { "synchronize_session": update_options._synchronize_session, "is_delete_using": update_options._is_delete_using, "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, } ) - # this stage of the execution is called before the do_orm_execute event - # hook. meaning for an extension like horizontal sharding, this step - # happens before the extension splits out into multiple backends and - # runs only once. if we do pre_sync_fetch, we execute a SELECT - # statement, which the horizontal sharding extension splits amongst the - # shards and combines the results together. - - if update_options._synchronize_session == "evaluate": - update_options = cls._do_pre_synchronize_evaluate( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "fetch": - update_options = cls._do_pre_synchronize_fetch( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - return ( statement, util.immutabledict(execution_options).union( @@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState): # individual ones we return here. update_options = execution_options["_sa_orm_update_options"] - if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, result, update_options) - elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, result, update_options) + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate( + session, statement, result, update_options + ) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch( + session, statement, result, update_options + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_bulk_evaluate( + session, params, result, update_options + ) + return result - return result + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) @classmethod def _adjust_for_extra_criteria(cls, global_attributes, ext_info): @@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState): primary_key_convert = [ lookup[bpk] for bpk in mapper.base_mapper.primary_key ] - return [tuple(row[idx] for idx in primary_key_convert) for row in rows] @classmethod - def _do_pre_synchronize_evaluate( + def _get_matched_objects_on_criteria(cls, update_options, states): + mapper = update_options._subject_mapper + eval_condition = update_options._eval_condition + + raw_data = [ + (state.obj(), state, state.dict) + for state in states + if state.mapper.isa(mapper) and not state.expired + ] + + identity_token = update_options._refresh_identity_token + if identity_token is not None: + raw_data = [ + (obj, state, dict_) + for obj, state, dict_ in raw_data + if state.identity_token == identity_token + ] + + result = [] + for obj, state, dict_ in raw_data: + evaled_condition = eval_condition(obj) + + # caution: don't use "in ()" or == here, _EXPIRE_OBJECT + # evaluates as True for all comparisons + if ( + evaled_condition is True + or evaled_condition is evaluator._EXPIRED_OBJECT + ): + result.append( + ( + obj, + state, + dict_, + evaled_condition is evaluator._EXPIRED_OBJECT, + ) + ) + return result + + @classmethod + def _eval_condition_from_statement(cls, update_options, statement): + mapper = update_options._subject_mapper + target_cls = mapper.class_ + + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + crit = () + if statement._where_criteria: + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria(global_attributes, mapper) + + if crit: + eval_condition = evaluator_compiler.process(*crit) + else: + + def eval_condition(obj): + return True + + return eval_condition + + @classmethod + def _do_pre_synchronize_auto( cls, session, statement, @@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState): bind_arguments, update_options, ): - mapper = update_options._subject_mapper - target_cls = mapper.class_ + """setup auto sync strategy + + + "auto" checks if we can use "evaluate" first, then falls back + to "fetch" - value_evaluators = resolved_keys_as_propnames = EMPTY_DICT + evaluate is vastly more efficient for the common case + where session is empty, only has a few objects, and the UPDATE + statement can potentially match thousands/millions of rows. + + OTOH more complex criteria that fails to work with "evaluate" + we would hope usually correlates with fewer net rows. + + """ try: - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - crit = () - if statement._where_criteria: - crit += statement._where_criteria + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) - global_attributes = {} - for opt in statement._with_options: - if opt._is_criteria_option: - opt.get_global_criteria(global_attributes) + except evaluator.UnevaluatableError: + pass + else: + return update_options + { + "_eval_condition": eval_condition, + "_synchronize_session": "evaluate", + } - if global_attributes: - crit += cls._adjust_for_extra_criteria( - global_attributes, mapper - ) + update_options += {"_synchronize_session": "fetch"} + return cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) - if crit: - eval_condition = evaluator_compiler.process(*crit) - else: + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): - def eval_condition(obj): - return True + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( @@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState): "synchronize_session execution option." % err ) from err - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - # TODO: detect when the where clause is a trivial primary key match. - matched_objects = [ - state.obj() - for state in session.identity_map.all_states() - if state.mapper.isa(mapper) - and not state.expired - and eval_condition(state.obj()) - and ( - update_options._refresh_identity_token is None - # TODO: coverage for the case where horizontal sharding - # invokes an update() or delete() given an explicit identity - # token up front - or state.identity_token - == update_options._refresh_identity_token - ) - ] return update_options + { - "_matched_objects": matched_objects, - "_value_evaluators": value_evaluators, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_eval_condition": eval_condition, } @classmethod @@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState): def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] for k, v in resolved_values: - if isinstance(k, attributes.QueryableAttribute): - values.append((k.key, v)) - continue - elif hasattr(k, "__clause_element__"): - k = k.__clause_element__() - if mapper and isinstance(k, expression.ColumnElement): try: attr = mapper._columntoproperty[k] @@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k + "Attribute name not found, can't be " + "synchronized back to objects: %r" % k ) return values @@ -622,17 +1016,46 @@ class BulkUDCompileState(ORMDMLState): ) select_stmt._where_criteria = statement._where_criteria + # conditionally run the SELECT statement for pre-fetch, testing the + # "bind" for if we can use RETURNING or not using the do_orm_execute + # event. If RETURNING is available, the do_orm_execute event + # will cancel the SELECT from being actually run. + # + # The way this is organized seems strange, why don't we just + # call can_use_returning() before invoking the statement and get + # answer?, why does this go through the whole execute phase using an + # event? Answer: because we are integrating with extensions such + # as the horizontal sharding extention that "multiplexes" an individual + # statement run through multiple engines, and it uses + # do_orm_execute() to do that. + + can_use_returning = None + def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - if cls.can_use_returning( + nonlocal can_use_returning + + per_bind_result = cls.can_use_returning( bind.dialect, mapper, is_update_from=update_options._is_update_from, is_delete_using=update_options._is_delete_using, - ): - return _result.null_result() - else: - return None + ) + + if can_use_returning is not None: + if can_use_returning != per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't mix multiple " + "backends where some support RETURNING and others " + "don't" + ) + else: + can_use_returning = per_bind_result + + if per_bind_result: + return _result.null_result() + else: + return None result = session.execute( select_stmt, @@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState): ) matched_rows = result.fetchall() - value_evaluators = EMPTY_DICT - - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - target_cls = mapper.class_ - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - value_evaluators = {} - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - else: - resolved_keys_as_propnames = EMPTY_DICT - return update_options + { - "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_can_use_returning": can_use_returning, } @CompileState.plugin_for("orm", "insert") -class ORMInsert(ORMDMLState, InsertDMLState): +class BulkORMInsert(ORMDMLState, InsertDMLState): + class default_insert_options(Options): + _dml_strategy: _DMLStrategyArgument = "auto" + _render_nulls: bool = False + _return_defaults: bool = False + _subject_mapper: Optional[Mapper[Any]] = None + + select_statement: Optional[FromStatement] = None + @classmethod def orm_pre_session_exec( cls, @@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState): bind_arguments, is_reentrant_invoke, ): + + ( + insert_options, + execution_options, + ) = BulkORMInsert.default_insert_options.from_execution_options( + "_sa_orm_insert_options", + {"dml_strategy"}, + execution_options, + statement._execution_options, + ) bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState): else: bind_arguments["mapper"] = plugin_subject.mapper + insert_options += {"_subject_mapper": plugin_subject.mapper} + + if not params: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "orm"} + elif insert_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "bulk"} + elif insert_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + if insert_options._dml_strategy != "raw": + # for ORM object loading, like ORMContext, we have to disable + # result set adapt_to_context, because we will be generating a + # new statement with specific columns that's cached inside of + # an ORMFromStatementCompileState, which we will re-use for + # each result. + if not execution_options: + execution_options = context._orm_load_exec_options + else: + execution_options = execution_options.union( + context._orm_load_exec_options + ) + + statement = statement._annotate( + {"dml_strategy": insert_options._dml_strategy} + ) + return ( statement, - util.immutabledict(execution_options), + util.immutabledict(execution_options).union( + {"_sa_orm_insert_options": insert_options} + ), ) @classmethod - def orm_setup_cursor_result( + def orm_execute_statement( cls, - session, - statement, - params, - execution_options, - bind_arguments, - result, - ): - return result + session: Session, + statement: dml.Insert, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + insert_options = execution_options.get( + "_sa_orm_insert_options", cls.default_insert_options + ) + + if insert_options._dml_strategy not in ( + "raw", + "bulk", + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM insert strategy " + "are 'raw', 'orm', 'bulk', 'auto" + ) + + result: _result.Result[Any] + + if insert_options._dml_strategy == "raw": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return result + + if insert_options._dml_strategy == "bulk": + mapper = insert_options._subject_mapper + + if ( + statement._post_values_clause is not None + and mapper._multiple_persistence_tables + ): + raise sa_exc.InvalidRequestError( + "bulk INSERT with a 'post values' clause " + "(typically upsert) not supported for multi-table " + f"mapper {mapper}" + ) + + assert mapper is not None + assert session._transaction is not None + result = _bulk_insert( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + return_defaults=insert_options._return_defaults, + render_nulls=insert_options._render_nulls, + use_orm_insert_stmt=statement, + execution_options=execution_options, + ) + elif insert_options._dml_strategy == "orm": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + else: + raise AssertionError() + + if not bool(statement._returning): + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: + + self = cast( + BulkORMInsert, + super().create_for_statement(statement, compiler, **kw), + ) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + if not toplevel: + return self + + mapper = statement._propagate_attrs["plugin_subject"] + dml_strategy = statement._annotations.get("dml_strategy", "raw") + if dml_strategy == "bulk": + self._setup_for_bulk_insert(compiler) + elif dml_strategy == "orm": + self._setup_for_orm_insert(compiler, mapper) + + return self + + @classmethod + def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): + return { + col.key if col is not None else k: v + for col, k, v in ( + (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() + ) + } + + def _setup_for_orm_insert(self, compiler, mapper): + statement = orm_level_statement = cast(dml.Insert, self.statement) + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=False, + ) + self.statement = statement + + def _setup_for_bulk_insert(self, compiler): + """establish an INSERT statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_insert_statement(). + + """ + statement = orm_level_statement = cast(dml.Insert, self.statement) + an = statement._annotations + + emit_insert_table, emit_insert_mapper = ( + an["_emit_insert_table"], + an["_emit_insert_mapper"], + ) + + statement = statement._clone() + + statement.table = emit_insert_table + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_insert_table + } + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=True, + dml_mapper=emit_insert_mapper, + ) + + self.statement = statement @CompileState.plugin_for("orm", "update") @@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self = cls.__new__(cls) + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if dml_strategy == "bulk": + self._setup_for_bulk_update(statement, compiler) + elif dml_strategy in ("orm", "unspecified"): + self._setup_for_orm_update(statement, compiler) + + return self + + def _setup_for_orm_update(self, statement, compiler, **kw): + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper self.extra_criteria_entities = {} - self._resolved_values = cls._get_resolved_values(mapper, statement) + self._resolved_values = self._get_resolved_values(mapper, statement) extra_criteria_attributes = {} @@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if statement._values: self._resolved_values = dict(self._resolved_values) - new_stmt = sql.Update.__new__(sql.Update) - new_stmt.__dict__.update(statement.__dict__) + new_stmt = statement._clone() new_stmt.table = mapper.local_table # note if the statement has _multi_values, these @@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): elif statement._values: new_stmt._values = self._resolved_values - new_crit = cls._adjust_for_extra_criteria( + new_crit = self._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: @@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): UpdateDMLState.__init__(self, new_stmt, compiler, **kw) - if compiler._annotations.get( + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, mapper, is_multitable=self.is_multitable - ): - if new_stmt._returning: - raise sa_exc.InvalidRequestError( - "Can't use synchronize_session='fetch' " - "with explicit returning()" + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable ) - self.statement = self.statement.returning( - *mapper.local_table.primary_key ) - return self + if synchronize_session == "fetch" and can_use_returning: + use_supplemental_cols = True + + # NOTE: we might want to RETURNING the actual columns to be + # synchronized also. however this is complicated and difficult + # to align against the behavior of "evaluate". Additionally, + # in a large number (if not the majority) of cases, we have the + # "evaluate" answer, usually a fixed value, in memory already and + # there's no need to re-fetch the same value + # over and over again. so perhaps if it could be RETURNING just + # the elements that were based on a SQL expression and not + # a constant. For now it doesn't quite seem worth it + new_stmt = new_stmt.return_defaults( + *(list(mapper.local_table.primary_key)) + ) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + def _setup_for_bulk_update(self, statement, compiler, **kw): + """establish an UPDATE statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_update_statement(). + + """ + statement = cast(dml.Update, statement) + an = statement._annotations + + emit_update_table, _ = ( + an["_emit_update_table"], + an["_emit_update_mapper"], + ) + + statement = statement._clone() + statement.table = emit_update_table + + UpdateDMLState.__init__(self, statement, compiler, **kw) + + if self._ordered_values: + raise sa_exc.InvalidRequestError( + "bulk ORM UPDATE does not support ordered_values() for " + "custom UPDATE statements with bulk parameter sets. Use a " + "non-bulk UPDATE statement or use values()." + ) + + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_update_table + } + self.statement = statement + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Update, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy not in ("orm", "auto", "bulk"): + raise sa_exc.ArgumentError( + "Valid strategies for ORM UPDATE strategy " + "are 'orm', 'auto', 'bulk'" + ) + + result: _result.Result[Any] + + if update_options._dml_strategy == "bulk": + if statement._where_criteria: + raise sa_exc.InvalidRequestError( + "WHERE clause with bulk ORM UPDATE not " + "supported right now. Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + mapper = update_options._subject_mapper + assert mapper is not None + assert session._transaction is not None + result = _bulk_update( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + update_changed_only=False, + use_orm_update_stmt=statement, + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + else: + return super().orm_execute_statement( + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) @classmethod def can_use_returning( @@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): return True @classmethod - def _get_crud_kv_pairs(cls, statement, kv_iterator): - plugin_subject = statement._propagate_attrs["plugin_subject"] - - core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs - - if not plugin_subject or not plugin_subject.mapper: - return core_get_crud_kv_pairs(statement, kv_iterator) - - mapper = plugin_subject.mapper - - values = [] - - for k, v in kv_iterator: - k = coercions.expect(roles.DMLColumnRole, k) + def _do_post_synchronize_bulk_evaluate( + cls, session, params, result, update_options + ): + if not params: + return - if isinstance(k, str): - desc = _entity_namespace_key(mapper, k, default=NO_VALUE) - if desc is NO_VALUE: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - else: - values.extend( - core_get_crud_kv_pairs( - statement, desc._bulk_update_tuples(v) - ) - ) - elif "entity_namespace" in k._annotations: - k_anno = k._annotations - attr = _entity_namespace_key( - k_anno["entity_namespace"], k_anno["proxy_key"] - ) - values.extend( - core_get_crud_kv_pairs( - statement, attr._bulk_update_tuples(v) - ) - ) - else: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - return values + mapper = update_options._subject_mapper + pk_keys = [prop.key for prop in mapper._identity_key_props] - @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): + identity_map = session.identity_map - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) - for obj in update_options._matched_objects: - - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), + for param in params: + identity_key = mapper.identity_key_from_primary_key( + (param[key] for key in pk_keys), + update_options._refresh_identity_token, ) - - # the evaluated states were gathered across all identity tokens. - # however the post_sync events are called per identity token, - # so filter. - if ( - update_options._refresh_identity_token is not None - and state.identity_token - != update_options._refresh_identity_token - ): + state = identity_map.fast_get_state(identity_key) + if not state: continue + evaluated_keys = set(param).difference(pk_keys) + + dict_ = state.dict # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + dict_[key] = param[key] state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) - to_expire = attrib.intersection(dict_).difference(to_evaluate) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = evaluated_keys.intersection(dict_).difference( + to_evaluate + ) if to_expire: state._expire_attributes(dict_, to_expire) - states.add(state) - session._register_altered(states) + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + ) @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows @@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if identity_key in session.identity_map ] - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) + if not objs: + return - for obj in objs: - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), - ) + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [ + ( + obj, + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + for obj in objs + ], + ) + + @classmethod + def _apply_update_set_values_to_objects( + cls, session, update_options, statement, matched_objects + ): + """apply values to objects derived from an update statement, e.g. + UPDATE..SET + + """ + mapper = update_options._subject_mapper + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + evaluated_keys = list(value_evaluators.keys()) + attrib = set(k for k, v in resolved_keys_as_propnames) + + states = set() + for obj, state, dict_ in matched_objects: to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + # only run eval for attributes that are present. + dict_[key] = value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: state._expire_attributes(dict_, to_expire) @@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper @@ -1002,30 +1743,96 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): if opt._is_criteria_option: opt.get_global_criteria(extra_criteria_attributes) + new_stmt = statement._clone() + new_stmt.table = mapper.local_table + new_crit = cls._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: - statement = statement.where(*new_crit) + new_stmt = new_stmt.where(*new_crit) # do this first as we need to determine if there is # DELETE..FROM - DeleteDMLState.__init__(self, statement, compiler, **kw) + DeleteDMLState.__init__(self, new_stmt, compiler, **kw) - if compiler._annotations.get( + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, - mapper, - is_multitable=self.is_multitable, - is_delete_using=compiler._annotations.get( - "is_delete_using", False - ), - ): - self.statement = statement.returning(*statement.table.primary_key) + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ) + ) + + if can_use_returning: + use_supplemental_cols = True + + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt return self + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Delete, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + "Bulk ORM DELETE not supported right now. " + "Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + + if update_options._dml_strategy not in ( + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" + ) + + return super().orm_execute_statement( + session, statement, params, execution_options, bind_arguments, conn + ) + @classmethod def can_use_returning( cls, @@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): return True @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): - - session._remove_newly_deleted( - [ - attributes.instance_state(obj) - for obj in update_options._matched_objects - ] + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), ) + to_delete = [] + + for _, state, dict_, is_partially_expired in matched_objects: + if is_partially_expired: + state._expire(dict_, session.identity_map._modified) + else: + to_delete.append(state) + + if to_delete: + session._remove_newly_deleted(to_delete) + @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index dc96f8c3c..f8c7ba714 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from .query import Query from .session import _BindArguments from .session import Session + from ..engine import Result from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..sql._typing import _ColumnsClauseArgument @@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict( class AbstractORMCompileState(CompileState): + is_dml_returning = False + @classmethod def create_for_statement( cls, statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> AbstractORMCompileState: """Create a context for a statement given a :class:`.Compiler`. + This method is always invoked in the context of SQLCompiler.process(). + For a Select object, this would be invoked from SQLCompiler.visit_select(). For the special FromStatement object used by Query to indicate "Query.from_statement()", this is called by @@ -232,6 +237,28 @@ class AbstractORMCompileState(CompileState): ): raise NotImplementedError() + @classmethod + def orm_execute_statement( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) -> Result: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + @classmethod def orm_setup_cursor_result( cls, @@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() + if TYPE_CHECKING: + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + ... + def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns if obj not in dedupe: @@ -332,26 +370,6 @@ class ORMCompileState(AbstractORMCompileState): else: return SelectState._column_naming_convention(label_style) - @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: - """Create a context for a statement given a :class:`.Compiler`. - - This method is always invoked in the context of SQLCompiler.process(). - - For a Select object, this would be invoked from - SQLCompiler.visit_select(). For the special FromStatement object used - by Query to indicate "Query.from_statement()", this is called by - FromStatement._compiler_dispatch() that would be called by - SQLCompiler.process(). - - """ - raise NotImplementedError() - @classmethod def get_column_descriptions(cls, statement): return _column_descriptions(statement) @@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState): ) +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. + + Has a subset of the interface used by + :class:`.ORMAdapter` and is used for :class:`._QueryEntity` + instances to set up their columns as used in RETURNING for a + DML statement. + + """ + + __slots__ = ("mapper", "columns", "__weakref__") + + def __init__(self, target_mapper, immediate_dml_mapper): + if ( + immediate_dml_mapper is not None + and target_mapper.local_table + is not immediate_dml_mapper.local_table + ): + # joined inh, or in theory other kinds of multi-table mappings + self.mapper = immediate_dml_mapper + else: + # single inh, normal mappings, etc. + self.mapper = target_mapper + self.columns = self.columns = util.WeakPopulateDict( + self.adapt_check_present # type: ignore + ) + + def __call__(self, col, as_filter): + for cc in sql_util._find_columns(col): + c2 = self.adapt_check_present(cc) + if c2 is not None: + return col + else: + return None + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is None: + return None + return mapper.local_table.c.corresponding_column(col) + + @sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None @@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: FromStatement requested_statement: Union[SelectBase, TextClause, UpdateBase] - dml_table: _DMLTableElement + dml_table: Optional[_DMLTableElement] = None _has_orm_entities = False multi_row_eager_loaders = False @@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMFromStatementCompileState: if compiler is not None: toplevel = not compiler.stack @@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState): if statement.is_dml: self.dml_table = statement.table + self.is_dml_returning = True self._entities = [] self._polymorphic_adapters = {} @@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState): def _get_current_adapter(self): return None + def setup_dml_returning_compile_state(self, dml_mapper): + """used by BulkORMInsert (and Update / Delete?) to set up a handler + for RETURNING to return ORM objects and expressions + + """ + target_mapper = self.statement._propagate_attrs.get( + "plugin_subject", None + ) + adapter = DMLReturningColFilter(target_mapper, dml_mapper) + for entity in self._entities: + entity.setup_dml_returning_compile_state(self, adapter) + class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): """Core construct that represents a load of ORM objects from various @@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMSelectCompileState: """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) @@ -2312,6 +2386,13 @@ class _QueryEntity: def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + raise NotImplementedError() + def row_processor(self, context, result): raise NotImplementedError() @@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity): return _instance, self._label_name, self._extra_entities - def setup_compile_state(self, compile_state): + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + def setup_compile_state(self, compile_state): adapter = self._get_entity_clauses(compile_state) single_table_crit = self.mapper._single_table_criterion @@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity): only_load_props=compile_state.compile_options._only_load_props, polymorphic_discriminator=self._polymorphic_discriminator, ) - compile_state._fallback_from_clauses.append(self.selectable) @@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity): getter, label_name, extra_entities = self._row_processor if self.translate_raw_column: extra_entities += ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, label_name, extra_entities @@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity): if self.translate_raw_column: extra_entities = self._extra_entities + ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, self._label_name, extra_entities else: @@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + return else: column = self.column @@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity): self.entity_zero ) and entity.common_parent(self.entity_zero) + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + self._fetch_column = self.column + column = adapter(self.column, False) + if column is not None: + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + assert compile_state.is_dml_returning + self._fetch_column = self.column + return else: column = self.column diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 52b70b9d4..13d3b70fe 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -19,6 +19,7 @@ import operator import typing from typing import Any from typing import Callable +from typing import Dict from typing import List from typing import NoReturn from typing import Optional @@ -602,6 +603,31 @@ class Composite( def _attribute_keys(self) -> Sequence[str]: return [prop.key for prop in self.props] + def _populate_composite_bulk_save_mappings_fn( + self, + ) -> Callable[[Dict[str, Any]], None]: + + if self._generated_composite_accessor: + get_values = self._generated_composite_accessor + else: + + def get_values(val: Any) -> Tuple[Any]: + return val.__composite_values__() # type: ignore + + attrs = [prop.key for prop in self.props] + + def populate(dest_dict: Dict[str, Any]) -> None: + dest_dict.update( + { + key: val + for key, val in zip( + attrs, get_values(dest_dict.pop(self.key)) + ) + } + ) + + return populate + def get_history( self, state: InstanceState[Any], diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index b3129afdd..5af14cc00 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -9,8 +9,8 @@ from __future__ import annotations -import operator - +from .base import LoaderCallableStatus +from .base import PassiveFlag from .. import exc from .. import inspect from .. import util @@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators): return None +class _ExpiredObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return self + + def reverse_operate(self, *arg, **kw): + return self + + _NO_OBJECT = _NoObject() +_EXPIRED_OBJECT = _ExpiredObject() class EvaluatorCompiler: @@ -73,6 +82,24 @@ class EvaluatorCompiler: f"alternate class {parentmapper.class_}" ) key = parentmapper._columntoproperty[clause].key + impl = parentmapper.class_manager[key].impl + + if impl is not None: + + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + state = inspect(obj) + dict_ = state.dict + + value = impl.get( + state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + if value is LoaderCallableStatus.PASSIVE_NO_RESULT: + return _EXPIRED_OBJECT + return value + + return get_corresponding_attr else: key = clause.key if ( @@ -85,15 +112,16 @@ class EvaluatorCompiler: "make use of the actual mapped columns in ORM-evaluated " "UPDATE / DELETE expressions." ) + else: raise UnevaluatableError(f"Cannot evaluate column: {clause}") - get_corresponding_attr = operator.attrgetter(key) - return ( - lambda obj: get_corresponding_attr(obj) - if obj is not None - else _NO_OBJECT - ) + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + return getattr(obj, key, _EXPIRED_OBJECT) + + return get_corresponding_attr def visit_tuple(self, clause): return self.visit_clauselist(clause) @@ -134,7 +162,9 @@ class EvaluatorCompiler: has_null = False for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value: return True has_null = has_null or value is None if has_null: @@ -147,6 +177,9 @@ class EvaluatorCompiler: def evaluate(obj): for sub_evaluate in evaluators: value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + if not value: if value is None or value is _NO_OBJECT: return None @@ -160,7 +193,9 @@ class EvaluatorCompiler: values = [] for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value is None or value is _NO_OBJECT: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None or value is _NO_OBJECT: return None values.append(value) return tuple(values) @@ -183,13 +218,21 @@ class EvaluatorCompiler: def visit_is_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) == eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val == right_val return evaluate def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) != eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val != right_val return evaluate @@ -197,8 +240,11 @@ class EvaluatorCompiler: def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) - if left_val is None or right_val is None: + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif left_val is None or right_val is None: return None + return operator(eval_left(obj), eval_right(obj)) return evaluate @@ -274,7 +320,9 @@ class EvaluatorCompiler: def evaluate(obj): value = eval_inner(obj) - if value is None: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None: return None return not value diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 63b131a78..4848f73f1 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -68,6 +68,11 @@ class IdentityMap: ) -> Optional[_O]: raise NotImplementedError() + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + def keys(self) -> Iterable[_IdentityKeyType[Any]]: return self._dict.keys() @@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap): self._dict[key] = state state._instance_dict = self._wr + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + return self._dict.get(key) + def get( self, key: _IdentityKeyType[_O], default: Optional[_O] = None ) -> Optional[_O]: diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 7317d48be..64f2542fd 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -29,7 +29,6 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.orm.context import FromStatement from . import attributes from . import exc as orm_exc from . import path_registry @@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED from .base import PassiveFlag +from .context import FromStatement from .util import _none_set from .util import state_str from .. import exc as sa_exc @@ -50,6 +50,7 @@ from ..sql import util as sql_util from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectState +from ..util import EMPTY_DICT if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -764,7 +765,7 @@ def _instance_processor( ) quick_populators = path.get( - context.attributes, "memoized_setups", _none_set + context.attributes, "memoized_setups", EMPTY_DICT ) todo = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c8df51b06..c9cf8f49b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -854,6 +854,7 @@ class Mapper( _memoized_values: Dict[Any, Callable[[], Any]] _inheriting_mappers: util.WeakSequence[Mapper[Any]] _all_tables: Set[Table] + _polymorphic_attr_key: Optional[str] _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] @@ -1653,6 +1654,7 @@ class Mapper( """ setter = False + polymorphic_key: Optional[str] = None if self.polymorphic_on is not None: setter = True @@ -1772,17 +1774,23 @@ class Mapper( self._set_polymorphic_identity = ( mapper._set_polymorphic_identity ) + self._polymorphic_attr_key = ( + mapper._polymorphic_attr_key + ) self._validate_polymorphic_identity = ( mapper._validate_polymorphic_identity ) else: self._set_polymorphic_identity = None + self._polymorphic_attr_key = None return if setter: def _set_polymorphic_identity(state): dict_ = state.dict + # TODO: what happens if polymorphic_on column attribute name + # does not match .key? state.get_impl(polymorphic_key).set( state, dict_, @@ -1790,6 +1798,8 @@ class Mapper( None, ) + self._polymorphic_attr_key = polymorphic_key + def _validate_polymorphic_identity(mapper, state, dict_): if ( polymorphic_key in dict_ @@ -1808,6 +1818,7 @@ class Mapper( _validate_polymorphic_identity ) else: + self._polymorphic_attr_key = None self._set_polymorphic_identity = None _validate_polymorphic_identity = None @@ -3561,6 +3572,10 @@ class Mapper( def _compiled_cache(self): return util.LRUCache(self._compiled_cache_size) + @HasMemoized.memoized_attribute + def _multiple_persistence_tables(self): + return len(self.tables) > 1 + @HasMemoized.memoized_attribute def _sorted_tables(self): table_to_mapper: Dict[Table, Mapper[Any]] = {} diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index abd528986..dfb61c28a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -31,6 +31,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import cursor as _cursor from ..sql import operators from ..sql.elements import BooleanClauseList from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -398,6 +399,11 @@ def _collect_insert_commands( None ) + if bulk and mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + yield ( state, state_dict, @@ -411,7 +417,11 @@ def _collect_insert_commands( def _collect_update_commands( - uowtransaction, table, states_to_update, bulk=False + uowtransaction, + table, + states_to_update, + bulk=False, + use_orm_update_stmt=None, ): """Identify sets of values to use in UPDATE statements for a list of states. @@ -437,7 +447,11 @@ def _collect_update_commands( pks = mapper._pks_by_table[table] - value_params = {} + if use_orm_update_stmt is not None: + # TODO: ordered values, etc + value_params = use_orm_update_stmt._values + else: + value_params = {} propkey_to_col = mapper._propkey_to_col[table] @@ -697,6 +711,7 @@ def _emit_update_statements( table, update, bookkeeping=True, + use_orm_update_stmt=None, ): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -708,7 +723,7 @@ def _emit_update_statements( execution_options = {"compiled_cache": base_mapper._compiled_cache} - def update_stmt(): + def update_stmt(existing_stmt=None): clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: @@ -725,10 +740,17 @@ def _emit_update_statements( ) ) - stmt = table.update().where(clauses) + if existing_stmt is not None: + stmt = existing_stmt.where(clauses) + else: + stmt = table.update().where(clauses) return stmt - cached_stmt = base_mapper._memo(("update", table), update_stmt) + if use_orm_update_stmt is not None: + cached_stmt = update_stmt(use_orm_update_stmt) + + else: + cached_stmt = base_mapper._memo(("update", table), update_stmt) for ( (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), @@ -747,6 +769,15 @@ def _emit_update_statements( records = list(records) statement = cached_stmt + + if use_orm_update_stmt is not None: + statement = statement._annotate( + { + "_emit_update_table": table, + "_emit_update_mapper": mapper, + } + ) + return_defaults = False if not has_all_pks: @@ -904,16 +935,35 @@ def _emit_insert_statements( table, insert, bookkeeping=True, + use_orm_insert_stmt=None, + execution_options=None, ): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(("insert", table), table.insert) + if use_orm_insert_stmt is not None: + cached_stmt = use_orm_insert_stmt + exec_opt = util.EMPTY_DICT - execution_options = {"compiled_cache": base_mapper._compiled_cache} + # if a user query with RETURNING was passed, we definitely need + # to use RETURNING. + returning_is_required_anyway = bool(use_orm_insert_stmt._returning) + else: + returning_is_required_anyway = False + cached_stmt = base_mapper._memo(("insert", table), table.insert) + exec_opt = {"compiled_cache": base_mapper._compiled_cache} + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + exec_opt, execution_options + ) + else: + execution_options = exec_opt + + return_result = None for ( - (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + (connection, _, hasvalue, has_all_pks, has_all_defaults), records, ) in groupby( insert, @@ -928,17 +978,29 @@ def _emit_insert_statements( statement = cached_stmt + if use_orm_insert_stmt is not None: + statement = statement._annotate( + { + "_emit_insert_table": table, + "_emit_insert_mapper": mapper, + } + ) + if ( - not bookkeeping - or ( - has_all_defaults - or not base_mapper.eager_defaults - or not base_mapper.local_table.implicit_returning - or not connection.dialect.insert_returning + ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not base_mapper.local_table.implicit_returning + or not connection.dialect.insert_returning + ) ) + and not returning_is_required_anyway and has_all_pks and not hasvalue ): + # the "we don't need newly generated values back" section. # here we have all the PKs, all the defaults or we don't want # to fetch them, or the dialect doesn't support RETURNING at all @@ -946,7 +1008,7 @@ def _emit_insert_statements( records = list(records) multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) if bookkeeping: @@ -962,7 +1024,7 @@ def _emit_insert_statements( has_all_defaults, ), last_inserted_params, - ) in zip(records, c.context.compiled_parameters): + ) in zip(records, result.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -970,19 +1032,20 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, - c.returned_defaults - if not c.context.executemany + result.returned_defaults + if not result.context.executemany else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: - # here, we need defaults and/or pk values back. + # here, we need defaults and/or pk values back or we otherwise + # know that we are using RETURNING in any case records = list(records) if ( @@ -991,6 +1054,16 @@ def _emit_insert_statements( and len(records) > 1 ): do_executemany = True + elif returning_is_required_anyway: + if connection.dialect.insert_executemany_returning: + do_executemany = True + else: + raise sa_exc.InvalidRequestError( + f"Can't use explicit RETURNING for bulk INSERT " + f"operation with " + f"{connection.dialect.dialect_description} backend; " + f"executemany is not supported with RETURNING" + ) else: do_executemany = False @@ -998,6 +1071,7 @@ def _emit_insert_statements( statement = statement.return_defaults( *mapper._server_default_cols[table] ) + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) elif do_executemany: @@ -1006,10 +1080,16 @@ def _emit_insert_statements( if do_executemany: multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) + if use_orm_insert_stmt is not None: + if return_result is None: + return_result = result + else: + return_result = return_result.splice_vertically(result) + if bookkeeping: for ( ( @@ -1027,9 +1107,9 @@ def _emit_insert_statements( returned_defaults, ) in zip_longest( records, - c.context.compiled_parameters, - c.inserted_primary_key_rows, - c.returned_defaults_rows or (), + result.context.compiled_parameters, + result.inserted_primary_key_rows, + result.returned_defaults_rows or (), ): if inserted_primary_key is None: # this is a real problem and means that we didn't @@ -1062,7 +1142,7 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, @@ -1071,6 +1151,8 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + assert not returning_is_required_anyway + for ( state, state_dict, @@ -1132,6 +1214,12 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) + if use_orm_insert_stmt is not None: + if return_result is None: + return _cursor.null_dml_result() + else: + return return_result + def _emit_post_update_statements( base_mapper, uowtransaction, mapper, table, update diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6d0f055e4..4d5a98fcf 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2978,7 +2978,7 @@ class Query( ) def delete( - self, synchronize_session: _SynchronizeSessionArgument = "evaluate" + self, synchronize_session: _SynchronizeSessionArgument = "auto" ) -> int: r"""Perform a DELETE with an arbitrary WHERE clause. @@ -3042,7 +3042,7 @@ class Query( def update( self, values: Dict[_DMLColumnArgument, Any], - synchronize_session: _SynchronizeSessionArgument = "evaluate", + synchronize_session: _SynchronizeSessionArgument = "auto", update_args: Optional[Dict[Any, Any]] = None, ) -> int: r"""Perform an UPDATE with an arbitrary WHERE clause. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a690da0d5..64c013306 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget): statement._propagate_attrs.get("compile_state_plugin", None) == "orm" ): - # note that even without "future" mode, we need compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) if TYPE_CHECKING: - assert isinstance(compile_state_cls, ORMCompileState) + assert isinstance( + compile_state_cls, context.AbstractORMCompileState + ) else: compile_state_cls = None @@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget): statement, params or {}, execution_options=execution_options ) - result: Result[Any] = conn.execute( - statement, params or {}, execution_options=execution_options - ) - if compile_state_cls: - result = compile_state_cls.orm_setup_cursor_result( + result: Result[Any] = compile_state_cls.orm_execute_statement( self, statement, - params, + params or {}, execution_options, bind_arguments, - result, + conn, + ) + else: + result = conn.execute( + statement, params or {}, execution_options=execution_options ) if _scalar_result: @@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 19c6493db..8652591c8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy): fetch = self.columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None happens here only for dml bulk_persistence cases + # when context.DMLReturningColFilter is used + return + memoized_populators[self.parent_property] = fetch def init_class_attribute(self, mapper): @@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader): fetch = columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None is not expected to be the result of any + # adapter implementation here, however there may be theoretical + # usages of returning() with context.DMLReturningColFilter + return + memoized_populators[self.parent_property] = fetch def create_row_processor( -- cgit v1.2.1