diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-09-26 01:17:44 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-09-26 01:17:44 +0000 |
| commit | 6201b4d88666983b883b96d22a159aa2594de94b (patch) | |
| tree | 4036c155ca7c274ea4bd12c059fd8fcd277fc026 /lib/sqlalchemy/orm | |
| parent | f81fdd9a9008a6517f89f2115765b7db9a32721b (diff) | |
| parent | a8029f5a7e3e376ec57f1614ab0294b717d53c05 (diff) | |
| download | sqlalchemy-6201b4d88666983b883b96d22a159aa2594de94b.tar.gz | |
Merge "ORM bulk insert via execute" into main
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/bulk_persistence.py | 1459 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 173 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 76 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/identity.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 138 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 11 |
11 files changed, 1537 insertions, 405 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 225292d17..3ed34a57a 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -15,24 +15,32 @@ specifically outside of the flush() process. from __future__ import annotations from typing import Any +from typing import cast from typing import Dict from typing import Iterable +from typing import Optional +from typing import overload from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import attributes +from . import context from . import evaluator from . import exc as orm_exc +from . import loading from . import persistence from .base import NO_VALUE from .context import AbstractORMCompileState +from .context import FromStatement +from .context import ORMFromStatementCompileState +from .context import QueryContext from .. import exc as sa_exc -from .. import sql from .. import util from ..engine import Dialect from ..engine import result as _result from ..sql import coercions +from ..sql import dml from ..sql import expression from ..sql import roles from ..sql import select @@ -48,16 +56,24 @@ from ..util.typing import Literal if TYPE_CHECKING: from .mapper import Mapper + from .session import _BindArguments from .session import ORMExecuteState + from .session import Session from .session import SessionTransaction from .state import InstanceState + from ..engine import Connection + from ..engine import cursor + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter _O = TypeVar("_O", bound=object) -_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"] +_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"] +_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"] +@overload def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], @@ -65,7 +81,36 @@ def _bulk_insert( isstates: bool, return_defaults: bool, render_nulls: bool, + use_orm_insert_stmt: Literal[None] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., ) -> None: + ... + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = ..., + execution_options: Optional[_ExecuteOptionsParameter] = ..., +) -> cursor.CursorResult[Any]: + ... + + +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, +) -> Optional[cursor.CursorResult[Any]]: base_mapper = mapper.base_mapper if session_transaction.session.connection_callable: @@ -81,13 +126,27 @@ def _bulk_insert( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) + + return_result: Optional[cursor.CursorResult[Any]] = None + for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue + is_joined_inh_supertable = super_mapper is not mapper + bookkeeping = ( + is_joined_inh_supertable + or return_defaults + or ( + use_orm_insert_stmt is not None + and bool(use_orm_insert_stmt._returning) + ) + ) + records = ( ( None, @@ -112,18 +171,25 @@ def _bulk_insert( table, ((None, mapping, mapper, connection) for mapping in mappings), bulk=True, - return_defaults=return_defaults, + return_defaults=bookkeeping, render_nulls=render_nulls, ) ) - persistence._emit_insert_statements( + result = persistence._emit_insert_statements( base_mapper, None, super_mapper, table, records, - bookkeeping=return_defaults, + bookkeeping=bookkeeping, + use_orm_insert_stmt=use_orm_insert_stmt, + execution_options=execution_options, ) + if use_orm_insert_stmt is not None: + if not use_orm_insert_stmt._returning or return_result is None: + return_result = result + elif result.returns_rows: + return_result = return_result.splice_horizontally(result) if return_defaults and isstates: identity_cls = mapper._identity_class @@ -134,14 +200,43 @@ def _bulk_insert( tuple([dict_[key] for key in identity_props]), ) + if use_orm_insert_stmt is not None: + assert return_result is not None + return return_result + +@overload def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, isstates: bool, update_changed_only: bool, + use_orm_update_stmt: Literal[None] = ..., ) -> None: + ... + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = ..., +) -> _result.Result[Any]: + ... + + +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = None, +) -> Optional[_result.Result[Any]]: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys @@ -161,7 +256,8 @@ def _bulk_update( else: mappings = [state.dict for state in mappings] else: - mappings = list(mappings) + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) if session_transaction.session.connection_callable: raise NotImplementedError( @@ -172,7 +268,7 @@ def _bulk_update( connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(super_mapper): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue records = persistence._collect_update_commands( @@ -193,8 +289,8 @@ def _bulk_update( for mapping in mappings ), bulk=True, + use_orm_update_stmt=use_orm_update_stmt, ) - persistence._emit_update_statements( base_mapper, None, @@ -202,10 +298,125 @@ def _bulk_update( table, records, bookkeeping=False, + use_orm_update_stmt=use_orm_update_stmt, ) + if use_orm_update_stmt is not None: + return _result.null_result() + + +def _expand_composites(mapper, mappings): + composite_attrs = mapper.composites + if not composite_attrs: + return + + composite_keys = set(composite_attrs.keys()) + populators = { + key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() + for key in composite_keys + } + for mapping in mappings: + for key in composite_keys.intersection(mapping): + populators[key](mapping) + class ORMDMLState(AbstractORMCompileState): + is_dml_returning = True + from_statement_ctx: Optional[ORMFromStatementCompileState] = None + + @classmethod + def _get_orm_crud_kv_pairs( + cls, mapper, statement, kv_iterator, needs_to_be_cacheable + ): + + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, str): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + yield ( + coercions.expect(roles.DMLColumnRole, k), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v, + ) + else: + yield from core_get_crud_kv_pairs( + statement, + desc._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + yield from core_get_crud_kv_pairs( + statement, + attr._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + else: + yield ( + k, + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ), + ) + + @classmethod + def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_multi_crud_kv_pairs( + statement, kv_iterator + ) + + return [ + dict( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, statement, value_dict.items(), False + ) + ) + for value_dict in kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): + assert ( + needs_to_be_cacheable + ), "no test coverage for needs_to_be_cacheable=False" + + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_crud_kv_pairs( + statement, kv_iterator, needs_to_be_cacheable + ) + + return list( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, + statement, + kv_iterator, + needs_to_be_cacheable, + ) + ) + @classmethod def get_entity_description(cls, statement): ext_info = statement.table._annotations["parententity"] @@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState): ] ] + def _setup_orm_returning( + self, + compiler, + orm_level_statement, + dml_level_statement, + use_supplemental_cols=True, + dml_mapper=None, + ): + """establish ORM column handlers for an INSERT, UPDATE, or DELETE + which uses explicit returning(). + + called within compilation level create_for_statement. + + The _return_orm_returning() method then receives the Result + after the statement was executed, and applies ORM loading to the + state that we first established here. + + """ + + if orm_level_statement._returning: + + fs = FromStatement( + orm_level_statement._returning, dml_level_statement + ) + fs = fs.options(*orm_level_statement._with_options) + self.select_statement = fs + self.from_statement_ctx = ( + fsc + ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + fsc.setup_dml_returning_compile_state(dml_mapper) + + dml_level_statement = dml_level_statement._generate() + dml_level_statement._returning = () + + cols_to_return = [c for c in fsc.primary_columns if c is not None] + + # since we are splicing result sets together, make sure there + # are columns of some kind returned in each result set + if not cols_to_return: + cols_to_return.extend(dml_mapper.primary_key) + + if use_supplemental_cols: + dml_level_statement = dml_level_statement.return_defaults( + supplemental_cols=cols_to_return + ) + else: + dml_level_statement = dml_level_statement.returning( + *cols_to_return + ) + + return dml_level_statement + + @classmethod + def _return_orm_returning( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + if compile_state.from_statement_ctx: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + querycontext = QueryContext( + compile_state.from_statement_ctx, + compile_state.select_statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + else: + return result + class BulkUDCompileState(ORMDMLState): class default_update_options(Options): - _synchronize_session: _SynchronizeSessionArgument = "evaluate" - _is_delete_using = False - _is_update_from = False - _autoflush = True - _subject_mapper = None + _dml_strategy: _DMLStrategyArgument = "auto" + _synchronize_session: _SynchronizeSessionArgument = "auto" + _can_use_returning: bool = False + _is_delete_using: bool = False + _is_update_from: bool = False + _autoflush: bool = True + _subject_mapper: Optional[Mapper[Any]] = None _resolved_values = EMPTY_DICT - _resolved_keys_as_propnames = EMPTY_DICT - _value_evaluators = EMPTY_DICT - _matched_objects = None + _eval_condition = None _matched_rows = None _refresh_identity_token = None @@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState): execution_options, ) = BulkUDCompileState.default_update_options.from_execution_options( "_sa_orm_update_options", - {"synchronize_session", "is_delete_using", "is_update_from"}, + { + "synchronize_session", + "is_delete_using", + "is_update_from", + "dml_strategy", + }, execution_options, statement._execution_options, ) - sync = update_options._synchronize_session - if sync is not None: - if sync not in ("evaluate", "fetch", False): - raise sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are 'evaluate', 'fetch', False" - ) - bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState): update_options += {"_subject_mapper": plugin_subject.mapper} + if not isinstance(params, list): + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "orm"} + elif update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "bulk"} + elif update_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + sync = update_options._synchronize_session + if sync is not None: + if sync not in ("auto", "evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'auto', 'evaluate', 'fetch', False" + ) + if update_options._dml_strategy == "bulk" and sync == "fetch": + raise sa_exc.InvalidRequestError( + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates (i.e. multiple parameter sets)" + ) + if update_options._autoflush: session._autoflush() + if update_options._dml_strategy == "orm": + + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. statement = statement._annotate( { "synchronize_session": update_options._synchronize_session, "is_delete_using": update_options._is_delete_using, "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, } ) - # this stage of the execution is called before the do_orm_execute event - # hook. meaning for an extension like horizontal sharding, this step - # happens before the extension splits out into multiple backends and - # runs only once. if we do pre_sync_fetch, we execute a SELECT - # statement, which the horizontal sharding extension splits amongst the - # shards and combines the results together. - - if update_options._synchronize_session == "evaluate": - update_options = cls._do_pre_synchronize_evaluate( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "fetch": - update_options = cls._do_pre_synchronize_fetch( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - return ( statement, util.immutabledict(execution_options).union( @@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState): # individual ones we return here. update_options = execution_options["_sa_orm_update_options"] - if update_options._synchronize_session == "evaluate": - cls._do_post_synchronize_evaluate(session, result, update_options) - elif update_options._synchronize_session == "fetch": - cls._do_post_synchronize_fetch(session, result, update_options) + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate( + session, statement, result, update_options + ) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch( + session, statement, result, update_options + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_bulk_evaluate( + session, params, result, update_options + ) + return result - return result + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) @classmethod def _adjust_for_extra_criteria(cls, global_attributes, ext_info): @@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState): primary_key_convert = [ lookup[bpk] for bpk in mapper.base_mapper.primary_key ] - return [tuple(row[idx] for idx in primary_key_convert) for row in rows] @classmethod - def _do_pre_synchronize_evaluate( + def _get_matched_objects_on_criteria(cls, update_options, states): + mapper = update_options._subject_mapper + eval_condition = update_options._eval_condition + + raw_data = [ + (state.obj(), state, state.dict) + for state in states + if state.mapper.isa(mapper) and not state.expired + ] + + identity_token = update_options._refresh_identity_token + if identity_token is not None: + raw_data = [ + (obj, state, dict_) + for obj, state, dict_ in raw_data + if state.identity_token == identity_token + ] + + result = [] + for obj, state, dict_ in raw_data: + evaled_condition = eval_condition(obj) + + # caution: don't use "in ()" or == here, _EXPIRE_OBJECT + # evaluates as True for all comparisons + if ( + evaled_condition is True + or evaled_condition is evaluator._EXPIRED_OBJECT + ): + result.append( + ( + obj, + state, + dict_, + evaled_condition is evaluator._EXPIRED_OBJECT, + ) + ) + return result + + @classmethod + def _eval_condition_from_statement(cls, update_options, statement): + mapper = update_options._subject_mapper + target_cls = mapper.class_ + + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + crit = () + if statement._where_criteria: + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria(global_attributes, mapper) + + if crit: + eval_condition = evaluator_compiler.process(*crit) + else: + + def eval_condition(obj): + return True + + return eval_condition + + @classmethod + def _do_pre_synchronize_auto( cls, session, statement, @@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState): bind_arguments, update_options, ): - mapper = update_options._subject_mapper - target_cls = mapper.class_ + """setup auto sync strategy + + + "auto" checks if we can use "evaluate" first, then falls back + to "fetch" + + evaluate is vastly more efficient for the common case + where session is empty, only has a few objects, and the UPDATE + statement can potentially match thousands/millions of rows. - value_evaluators = resolved_keys_as_propnames = EMPTY_DICT + OTOH more complex criteria that fails to work with "evaluate" + we would hope usually correlates with fewer net rows. + + """ try: - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - crit = () - if statement._where_criteria: - crit += statement._where_criteria + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) - global_attributes = {} - for opt in statement._with_options: - if opt._is_criteria_option: - opt.get_global_criteria(global_attributes) + except evaluator.UnevaluatableError: + pass + else: + return update_options + { + "_eval_condition": eval_condition, + "_synchronize_session": "evaluate", + } - if global_attributes: - crit += cls._adjust_for_extra_criteria( - global_attributes, mapper - ) + update_options += {"_synchronize_session": "fetch"} + return cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) - if crit: - eval_condition = evaluator_compiler.process(*crit) - else: + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): - def eval_condition(obj): - return True + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) except evaluator.UnevaluatableError as err: raise sa_exc.InvalidRequestError( @@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState): "synchronize_session execution option." % err ) from err - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - value_evaluators = {} - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - # TODO: detect when the where clause is a trivial primary key match. - matched_objects = [ - state.obj() - for state in session.identity_map.all_states() - if state.mapper.isa(mapper) - and not state.expired - and eval_condition(state.obj()) - and ( - update_options._refresh_identity_token is None - # TODO: coverage for the case where horizontal sharding - # invokes an update() or delete() given an explicit identity - # token up front - or state.identity_token - == update_options._refresh_identity_token - ) - ] return update_options + { - "_matched_objects": matched_objects, - "_value_evaluators": value_evaluators, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_eval_condition": eval_condition, } @classmethod @@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState): def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] for k, v in resolved_values: - if isinstance(k, attributes.QueryableAttribute): - values.append((k.key, v)) - continue - elif hasattr(k, "__clause_element__"): - k = k.__clause_element__() - if mapper and isinstance(k, expression.ColumnElement): try: attr = mapper._columntoproperty[k] @@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState): values.append((attr.key, v)) else: raise sa_exc.InvalidRequestError( - "Invalid expression type: %r" % k + "Attribute name not found, can't be " + "synchronized back to objects: %r" % k ) return values @@ -622,14 +1016,43 @@ class BulkUDCompileState(ORMDMLState): ) select_stmt._where_criteria = statement._where_criteria + # conditionally run the SELECT statement for pre-fetch, testing the + # "bind" for if we can use RETURNING or not using the do_orm_execute + # event. If RETURNING is available, the do_orm_execute event + # will cancel the SELECT from being actually run. + # + # The way this is organized seems strange, why don't we just + # call can_use_returning() before invoking the statement and get + # answer?, why does this go through the whole execute phase using an + # event? Answer: because we are integrating with extensions such + # as the horizontal sharding extention that "multiplexes" an individual + # statement run through multiple engines, and it uses + # do_orm_execute() to do that. + + can_use_returning = None + def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - if cls.can_use_returning( + nonlocal can_use_returning + + per_bind_result = cls.can_use_returning( bind.dialect, mapper, is_update_from=update_options._is_update_from, is_delete_using=update_options._is_delete_using, - ): + ) + + if can_use_returning is not None: + if can_use_returning != per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't mix multiple " + "backends where some support RETURNING and others " + "don't" + ) + else: + can_use_returning = per_bind_result + + if per_bind_result: return _result.null_result() else: return None @@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState): ) matched_rows = result.fetchall() - value_evaluators = EMPTY_DICT - - if statement.__visit_name__ == "lambda_element": - # ._resolved is called on every LambdaElement in order to - # generate the cache key, so this access does not add - # additional expense - effective_statement = statement._resolved - else: - effective_statement = statement - - if effective_statement.__visit_name__ == "update": - target_cls = mapper.class_ - evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - resolved_values = cls._get_resolved_values( - mapper, effective_statement - ) - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - - resolved_keys_as_propnames = cls._resolved_keys_as_propnames( - mapper, resolved_values - ) - value_evaluators = {} - for key, value in resolved_keys_as_propnames: - try: - _evaluator = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) - except evaluator.UnevaluatableError: - pass - else: - value_evaluators[key] = _evaluator - - else: - resolved_keys_as_propnames = EMPTY_DICT - return update_options + { - "_value_evaluators": value_evaluators, "_matched_rows": matched_rows, - "_resolved_keys_as_propnames": resolved_keys_as_propnames, + "_can_use_returning": can_use_returning, } @CompileState.plugin_for("orm", "insert") -class ORMInsert(ORMDMLState, InsertDMLState): +class BulkORMInsert(ORMDMLState, InsertDMLState): + class default_insert_options(Options): + _dml_strategy: _DMLStrategyArgument = "auto" + _render_nulls: bool = False + _return_defaults: bool = False + _subject_mapper: Optional[Mapper[Any]] = None + + select_statement: Optional[FromStatement] = None + @classmethod def orm_pre_session_exec( cls, @@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState): bind_arguments, is_reentrant_invoke, ): + + ( + insert_options, + execution_options, + ) = BulkORMInsert.default_insert_options.from_execution_options( + "_sa_orm_insert_options", + {"dml_strategy"}, + execution_options, + statement._execution_options, + ) bind_arguments["clause"] = statement try: plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState): else: bind_arguments["mapper"] = plugin_subject.mapper + insert_options += {"_subject_mapper": plugin_subject.mapper} + + if not params: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "orm"} + elif insert_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "bulk"} + elif insert_options._dml_strategy == "orm": + raise sa_exc.InvalidRequestError( + 'Can\'t use "orm" ORM insert strategy with a ' + "separate parameter list" + ) + + if insert_options._dml_strategy != "raw": + # for ORM object loading, like ORMContext, we have to disable + # result set adapt_to_context, because we will be generating a + # new statement with specific columns that's cached inside of + # an ORMFromStatementCompileState, which we will re-use for + # each result. + if not execution_options: + execution_options = context._orm_load_exec_options + else: + execution_options = execution_options.union( + context._orm_load_exec_options + ) + + statement = statement._annotate( + {"dml_strategy": insert_options._dml_strategy} + ) + return ( statement, - util.immutabledict(execution_options), + util.immutabledict(execution_options).union( + {"_sa_orm_insert_options": insert_options} + ), ) @classmethod - def orm_setup_cursor_result( + def orm_execute_statement( cls, - session, - statement, - params, - execution_options, - bind_arguments, - result, - ): - return result + session: Session, + statement: dml.Insert, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + insert_options = execution_options.get( + "_sa_orm_insert_options", cls.default_insert_options + ) + + if insert_options._dml_strategy not in ( + "raw", + "bulk", + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM insert strategy " + "are 'raw', 'orm', 'bulk', 'auto" + ) + + result: _result.Result[Any] + + if insert_options._dml_strategy == "raw": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return result + + if insert_options._dml_strategy == "bulk": + mapper = insert_options._subject_mapper + + if ( + statement._post_values_clause is not None + and mapper._multiple_persistence_tables + ): + raise sa_exc.InvalidRequestError( + "bulk INSERT with a 'post values' clause " + "(typically upsert) not supported for multi-table " + f"mapper {mapper}" + ) + + assert mapper is not None + assert session._transaction is not None + result = _bulk_insert( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + return_defaults=insert_options._return_defaults, + render_nulls=insert_options._render_nulls, + use_orm_insert_stmt=statement, + execution_options=execution_options, + ) + elif insert_options._dml_strategy == "orm": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + else: + raise AssertionError() + + if not bool(statement._returning): + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: + + self = cast( + BulkORMInsert, + super().create_for_statement(statement, compiler, **kw), + ) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + if not toplevel: + return self + + mapper = statement._propagate_attrs["plugin_subject"] + dml_strategy = statement._annotations.get("dml_strategy", "raw") + if dml_strategy == "bulk": + self._setup_for_bulk_insert(compiler) + elif dml_strategy == "orm": + self._setup_for_orm_insert(compiler, mapper) + + return self + + @classmethod + def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): + return { + col.key if col is not None else k: v + for col, k, v in ( + (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() + ) + } + + def _setup_for_orm_insert(self, compiler, mapper): + statement = orm_level_statement = cast(dml.Insert, self.statement) + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=False, + ) + self.statement = statement + + def _setup_for_bulk_insert(self, compiler): + """establish an INSERT statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_insert_statement(). + + """ + statement = orm_level_statement = cast(dml.Insert, self.statement) + an = statement._annotations + + emit_insert_table, emit_insert_mapper = ( + an["_emit_insert_table"], + an["_emit_insert_mapper"], + ) + + statement = statement._clone() + + statement.table = emit_insert_table + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_insert_table + } + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + use_supplemental_cols=True, + dml_mapper=emit_insert_mapper, + ) + + self.statement = statement @CompileState.plugin_for("orm", "update") @@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self = cls.__new__(cls) + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if dml_strategy == "bulk": + self._setup_for_bulk_update(statement, compiler) + elif dml_strategy in ("orm", "unspecified"): + self._setup_for_orm_update(statement, compiler) + + return self + + def _setup_for_orm_update(self, statement, compiler, **kw): + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper self.extra_criteria_entities = {} - self._resolved_values = cls._get_resolved_values(mapper, statement) + self._resolved_values = self._get_resolved_values(mapper, statement) extra_criteria_attributes = {} @@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if statement._values: self._resolved_values = dict(self._resolved_values) - new_stmt = sql.Update.__new__(sql.Update) - new_stmt.__dict__.update(statement.__dict__) + new_stmt = statement._clone() new_stmt.table = mapper.local_table # note if the statement has _multi_values, these @@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): elif statement._values: new_stmt._values = self._resolved_values - new_crit = cls._adjust_for_extra_criteria( + new_crit = self._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: @@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): UpdateDMLState.__init__(self, new_stmt, compiler, **kw) - if compiler._annotations.get( + use_supplemental_cols = False + + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, mapper, is_multitable=self.is_multitable - ): - if new_stmt._returning: - raise sa_exc.InvalidRequestError( - "Can't use synchronize_session='fetch' " - "with explicit returning()" + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable ) - self.statement = self.statement.returning( - *mapper.local_table.primary_key ) - return self + if synchronize_session == "fetch" and can_use_returning: + use_supplemental_cols = True + + # NOTE: we might want to RETURNING the actual columns to be + # synchronized also. however this is complicated and difficult + # to align against the behavior of "evaluate". Additionally, + # in a large number (if not the majority) of cases, we have the + # "evaluate" answer, usually a fixed value, in memory already and + # there's no need to re-fetch the same value + # over and over again. so perhaps if it could be RETURNING just + # the elements that were based on a SQL expression and not + # a constant. For now it doesn't quite seem worth it + new_stmt = new_stmt.return_defaults( + *(list(mapper.local_table.primary_key)) + ) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + def _setup_for_bulk_update(self, statement, compiler, **kw): + """establish an UPDATE statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_update_statement(). + + """ + statement = cast(dml.Update, statement) + an = statement._annotations + + emit_update_table, _ = ( + an["_emit_update_table"], + an["_emit_update_mapper"], + ) + + statement = statement._clone() + statement.table = emit_update_table + + UpdateDMLState.__init__(self, statement, compiler, **kw) + + if self._ordered_values: + raise sa_exc.InvalidRequestError( + "bulk ORM UPDATE does not support ordered_values() for " + "custom UPDATE statements with bulk parameter sets. Use a " + "non-bulk UPDATE statement or use values()." + ) + + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_update_table + } + self.statement = statement + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Update, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy not in ("orm", "auto", "bulk"): + raise sa_exc.ArgumentError( + "Valid strategies for ORM UPDATE strategy " + "are 'orm', 'auto', 'bulk'" + ) + + result: _result.Result[Any] + + if update_options._dml_strategy == "bulk": + if statement._where_criteria: + raise sa_exc.InvalidRequestError( + "WHERE clause with bulk ORM UPDATE not " + "supported right now. Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + mapper = update_options._subject_mapper + assert mapper is not None + assert session._transaction is not None + result = _bulk_update( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + update_changed_only=False, + use_orm_update_stmt=statement, + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + else: + return super().orm_execute_statement( + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) @classmethod def can_use_returning( @@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): return True @classmethod - def _get_crud_kv_pairs(cls, statement, kv_iterator): - plugin_subject = statement._propagate_attrs["plugin_subject"] - - core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs - - if not plugin_subject or not plugin_subject.mapper: - return core_get_crud_kv_pairs(statement, kv_iterator) - - mapper = plugin_subject.mapper - - values = [] - - for k, v in kv_iterator: - k = coercions.expect(roles.DMLColumnRole, k) + def _do_post_synchronize_bulk_evaluate( + cls, session, params, result, update_options + ): + if not params: + return - if isinstance(k, str): - desc = _entity_namespace_key(mapper, k, default=NO_VALUE) - if desc is NO_VALUE: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - else: - values.extend( - core_get_crud_kv_pairs( - statement, desc._bulk_update_tuples(v) - ) - ) - elif "entity_namespace" in k._annotations: - k_anno = k._annotations - attr = _entity_namespace_key( - k_anno["entity_namespace"], k_anno["proxy_key"] - ) - values.extend( - core_get_crud_kv_pairs( - statement, attr._bulk_update_tuples(v) - ) - ) - else: - values.append( - ( - k, - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ), - ) - ) - return values + mapper = update_options._subject_mapper + pk_keys = [prop.key for prop in mapper._identity_key_props] - @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): + identity_map = session.identity_map - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) - for obj in update_options._matched_objects: - - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), + for param in params: + identity_key = mapper.identity_key_from_primary_key( + (param[key] for key in pk_keys), + update_options._refresh_identity_token, ) - - # the evaluated states were gathered across all identity tokens. - # however the post_sync events are called per identity token, - # so filter. - if ( - update_options._refresh_identity_token is not None - and state.identity_token - != update_options._refresh_identity_token - ): + state = identity_map.fast_get_state(identity_key) + if not state: continue + evaluated_keys = set(param).difference(pk_keys) + + dict_ = state.dict # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + dict_[key] = param[key] state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) - to_expire = attrib.intersection(dict_).difference(to_evaluate) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = evaluated_keys.intersection(dict_).difference( + to_evaluate + ) if to_expire: state._expire_attributes(dict_, to_expire) - states.add(state) - session._register_altered(states) + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + ) @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - states = set() - evaluated_keys = list(update_options._value_evaluators.keys()) - - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows @@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if identity_key in session.identity_map ] - values = update_options._resolved_keys_as_propnames - attrib = set(k for k, v in values) + if not objs: + return - for obj in objs: - state, dict_ = ( - attributes.instance_state(obj), - attributes.instance_dict(obj), - ) + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [ + ( + obj, + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + for obj in objs + ], + ) + + @classmethod + def _apply_update_set_values_to_objects( + cls, session, update_options, statement, matched_objects + ): + """apply values to objects derived from an update statement, e.g. + UPDATE..SET <values> + + """ + mapper = update_options._subject_mapper + target_cls = mapper.class_ + evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + evaluated_keys = list(value_evaluators.keys()) + attrib = set(k for k, v in resolved_keys_as_propnames) + + states = set() + for obj, state, dict_ in matched_objects: to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: if key in dict_: - dict_[key] = update_options._value_evaluators[key](obj) + # only run eval for attributes that are present. + dict_[key] = value_evaluators[key](obj) + state.manager.dispatch.refresh(state, None, to_evaluate) state._commit(dict_, list(to_evaluate)) + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. to_expire = attrib.intersection(dict_).difference(to_evaluate) if to_expire: state._expire_attributes(dict_, to_expire) @@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): def create_for_statement(cls, statement, compiler, **kw): self = cls.__new__(cls) + orm_level_statement = statement + ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper @@ -1002,31 +1743,97 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): if opt._is_criteria_option: opt.get_global_criteria(extra_criteria_attributes) + new_stmt = statement._clone() + new_stmt.table = mapper.local_table + new_crit = cls._adjust_for_extra_criteria( extra_criteria_attributes, mapper ) if new_crit: - statement = statement.where(*new_crit) + new_stmt = new_stmt.where(*new_crit) # do this first as we need to determine if there is # DELETE..FROM - DeleteDMLState.__init__(self, statement, compiler, **kw) + DeleteDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False - if compiler._annotations.get( + synchronize_session = compiler._annotations.get( "synchronize_session", None - ) == "fetch" and self.can_use_returning( - compiler.dialect, - mapper, - is_multitable=self.is_multitable, - is_delete_using=compiler._annotations.get( - "is_delete_using", False - ), - ): - self.statement = statement.returning(*statement.table.primary_key) + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ) + ) + + if can_use_returning: + use_supplemental_cols = True + + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt return self @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Delete, + params: _CoreAnyExecuteParams, + execution_options: _ExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + "Bulk ORM DELETE not supported right now. " + "Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + + if update_options._dml_strategy not in ( + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM DELETE strategy are 'orm', 'auto'" + ) + + return super().orm_execute_statement( + session, statement, params, execution_options, bind_arguments, conn + ) + + @classmethod def can_use_returning( cls, dialect: Dialect, @@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): return True @classmethod - def _do_post_synchronize_evaluate(cls, session, result, update_options): - - session._remove_newly_deleted( - [ - attributes.instance_state(obj) - for obj in update_options._matched_objects - ] + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), ) + to_delete = [] + + for _, state, dict_, is_partially_expired in matched_objects: + if is_partially_expired: + state._expire(dict_, session.identity_map._modified) + else: + to_delete.append(state) + + if to_delete: + session._remove_newly_deleted(to_delete) + @classmethod - def _do_post_synchronize_fetch(cls, session, result, update_options): + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): target_mapper = update_options._subject_mapper - if result.returns_rows: - rows = cls._interpret_returning_rows(target_mapper, result.all()) + returned_defaults_rows = result.returned_defaults_rows + + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) matched_rows = [ tuple(row) + (update_options._refresh_identity_token,) - for row in rows + for row in pk_rows ] else: matched_rows = update_options._matched_rows diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index dc96f8c3c..f8c7ba714 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from .query import Query from .session import _BindArguments from .session import Session + from ..engine import Result from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..sql._typing import _ColumnsClauseArgument @@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict( class AbstractORMCompileState(CompileState): + is_dml_returning = False + @classmethod def create_for_statement( cls, statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> AbstractORMCompileState: """Create a context for a statement given a :class:`.Compiler`. + This method is always invoked in the context of SQLCompiler.process(). + For a Select object, this would be invoked from SQLCompiler.visit_select(). For the special FromStatement object used by Query to indicate "Query.from_statement()", this is called by @@ -233,6 +238,28 @@ class AbstractORMCompileState(CompileState): raise NotImplementedError() @classmethod + def orm_execute_statement( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) -> Result: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod def orm_setup_cursor_result( cls, session, @@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() + if TYPE_CHECKING: + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + ... + def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns if obj not in dedupe: @@ -333,26 +371,6 @@ class ORMCompileState(AbstractORMCompileState): return SelectState._column_naming_convention(label_style) @classmethod - def create_for_statement( - cls, - statement: Union[Select, FromStatement], - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: - """Create a context for a statement given a :class:`.Compiler`. - - This method is always invoked in the context of SQLCompiler.process(). - - For a Select object, this would be invoked from - SQLCompiler.visit_select(). For the special FromStatement object used - by Query to indicate "Query.from_statement()", this is called by - FromStatement._compiler_dispatch() that would be called by - SQLCompiler.process(). - - """ - raise NotImplementedError() - - @classmethod def get_column_descriptions(cls, statement): return _column_descriptions(statement) @@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState): ) +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. + + Has a subset of the interface used by + :class:`.ORMAdapter` and is used for :class:`._QueryEntity` + instances to set up their columns as used in RETURNING for a + DML statement. + + """ + + __slots__ = ("mapper", "columns", "__weakref__") + + def __init__(self, target_mapper, immediate_dml_mapper): + if ( + immediate_dml_mapper is not None + and target_mapper.local_table + is not immediate_dml_mapper.local_table + ): + # joined inh, or in theory other kinds of multi-table mappings + self.mapper = immediate_dml_mapper + else: + # single inh, normal mappings, etc. + self.mapper = target_mapper + self.columns = self.columns = util.WeakPopulateDict( + self.adapt_check_present # type: ignore + ) + + def __call__(self, col, as_filter): + for cc in sql_util._find_columns(col): + c2 = self.adapt_check_present(cc) + if c2 is not None: + return col + else: + return None + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is None: + return None + return mapper.local_table.c.corresponding_column(col) + + @sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None @@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: FromStatement requested_statement: Union[SelectBase, TextClause, UpdateBase] - dml_table: _DMLTableElement + dml_table: Optional[_DMLTableElement] = None _has_orm_entities = False multi_row_eager_loaders = False @@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState): statement_container: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMFromStatementCompileState: if compiler is not None: toplevel = not compiler.stack @@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState): if statement.is_dml: self.dml_table = statement.table + self.is_dml_returning = True self._entities = [] self._polymorphic_adapters = {} @@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState): def _get_current_adapter(self): return None + def setup_dml_returning_compile_state(self, dml_mapper): + """used by BulkORMInsert (and Update / Delete?) to set up a handler + for RETURNING to return ORM objects and expressions + + """ + target_mapper = self.statement._propagate_attrs.get( + "plugin_subject", None + ) + adapter = DMLReturningColFilter(target_mapper, dml_mapper) + for entity in self._entities: + entity.setup_dml_returning_compile_state(self, adapter) + class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): """Core construct that represents a load of ORM objects from various @@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: + ) -> ORMSelectCompileState: """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) @@ -2312,6 +2386,13 @@ class _QueryEntity: def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + raise NotImplementedError() + def row_processor(self, context, result): raise NotImplementedError() @@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity): return _instance, self._label_name, self._extra_entities - def setup_compile_state(self, compile_state): + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + def setup_compile_state(self, compile_state): adapter = self._get_entity_clauses(compile_state) single_table_crit = self.mapper._single_table_criterion @@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity): only_load_props=compile_state.compile_options._only_load_props, polymorphic_discriminator=self._polymorphic_discriminator, ) - compile_state._fallback_from_clauses.append(self.selectable) @@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity): getter, label_name, extra_entities = self._row_processor if self.translate_raw_column: extra_entities += ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, label_name, extra_entities @@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity): if self.translate_raw_column: extra_entities = self._extra_entities + ( - result.context.invoked_statement._raw_columns[ - self.raw_column_index - ], + context.query._raw_columns[self.raw_column_index], ) return getter, self._label_name, extra_entities else: @@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + return else: column = self.column @@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity): self.entity_zero ) and entity.common_parent(self.entity_zero) + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + self._fetch_column = self.column + column = adapter(self.column, False) + if column is not None: + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + def setup_compile_state(self, compile_state): current_adapter = compile_state._get_current_adapter() if current_adapter: column = current_adapter(self.column, False) + if column is None: + assert compile_state.is_dml_returning + self._fetch_column = self.column + return else: column = self.column diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 52b70b9d4..13d3b70fe 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -19,6 +19,7 @@ import operator import typing from typing import Any from typing import Callable +from typing import Dict from typing import List from typing import NoReturn from typing import Optional @@ -602,6 +603,31 @@ class Composite( def _attribute_keys(self) -> Sequence[str]: return [prop.key for prop in self.props] + def _populate_composite_bulk_save_mappings_fn( + self, + ) -> Callable[[Dict[str, Any]], None]: + + if self._generated_composite_accessor: + get_values = self._generated_composite_accessor + else: + + def get_values(val: Any) -> Tuple[Any]: + return val.__composite_values__() # type: ignore + + attrs = [prop.key for prop in self.props] + + def populate(dest_dict: Dict[str, Any]) -> None: + dest_dict.update( + { + key: val + for key, val in zip( + attrs, get_values(dest_dict.pop(self.key)) + ) + } + ) + + return populate + def get_history( self, state: InstanceState[Any], diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index b3129afdd..5af14cc00 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -9,8 +9,8 @@ from __future__ import annotations -import operator - +from .base import LoaderCallableStatus +from .base import PassiveFlag from .. import exc from .. import inspect from .. import util @@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators): return None +class _ExpiredObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return self + + def reverse_operate(self, *arg, **kw): + return self + + _NO_OBJECT = _NoObject() +_EXPIRED_OBJECT = _ExpiredObject() class EvaluatorCompiler: @@ -73,6 +82,24 @@ class EvaluatorCompiler: f"alternate class {parentmapper.class_}" ) key = parentmapper._columntoproperty[clause].key + impl = parentmapper.class_manager[key].impl + + if impl is not None: + + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + state = inspect(obj) + dict_ = state.dict + + value = impl.get( + state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + if value is LoaderCallableStatus.PASSIVE_NO_RESULT: + return _EXPIRED_OBJECT + return value + + return get_corresponding_attr else: key = clause.key if ( @@ -85,15 +112,16 @@ class EvaluatorCompiler: "make use of the actual mapped columns in ORM-evaluated " "UPDATE / DELETE expressions." ) + else: raise UnevaluatableError(f"Cannot evaluate column: {clause}") - get_corresponding_attr = operator.attrgetter(key) - return ( - lambda obj: get_corresponding_attr(obj) - if obj is not None - else _NO_OBJECT - ) + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + return getattr(obj, key, _EXPIRED_OBJECT) + + return get_corresponding_attr def visit_tuple(self, clause): return self.visit_clauselist(clause) @@ -134,7 +162,9 @@ class EvaluatorCompiler: has_null = False for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value: return True has_null = has_null or value is None if has_null: @@ -147,6 +177,9 @@ class EvaluatorCompiler: def evaluate(obj): for sub_evaluate in evaluators: value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + if not value: if value is None or value is _NO_OBJECT: return None @@ -160,7 +193,9 @@ class EvaluatorCompiler: values = [] for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value is None or value is _NO_OBJECT: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None or value is _NO_OBJECT: return None values.append(value) return tuple(values) @@ -183,13 +218,21 @@ class EvaluatorCompiler: def visit_is_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) == eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val == right_val return evaluate def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): def evaluate(obj): - return eval_left(obj) != eval_right(obj) + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val != right_val return evaluate @@ -197,8 +240,11 @@ class EvaluatorCompiler: def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) - if left_val is None or right_val is None: + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif left_val is None or right_val is None: return None + return operator(eval_left(obj), eval_right(obj)) return evaluate @@ -274,7 +320,9 @@ class EvaluatorCompiler: def evaluate(obj): value = eval_inner(obj) - if value is None: + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None: return None return not value diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 63b131a78..4848f73f1 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -68,6 +68,11 @@ class IdentityMap: ) -> Optional[_O]: raise NotImplementedError() + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + def keys(self) -> Iterable[_IdentityKeyType[Any]]: return self._dict.keys() @@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap): self._dict[key] = state state._instance_dict = self._wr + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + return self._dict.get(key) + def get( self, key: _IdentityKeyType[_O], default: Optional[_O] = None ) -> Optional[_O]: diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 7317d48be..64f2542fd 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -29,7 +29,6 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.orm.context import FromStatement from . import attributes from . import exc as orm_exc from . import path_registry @@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED from .base import PassiveFlag +from .context import FromStatement from .util import _none_set from .util import state_str from .. import exc as sa_exc @@ -50,6 +50,7 @@ from ..sql import util as sql_util from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectState +from ..util import EMPTY_DICT if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -764,7 +765,7 @@ def _instance_processor( ) quick_populators = path.get( - context.attributes, "memoized_setups", _none_set + context.attributes, "memoized_setups", EMPTY_DICT ) todo = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c8df51b06..c9cf8f49b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -854,6 +854,7 @@ class Mapper( _memoized_values: Dict[Any, Callable[[], Any]] _inheriting_mappers: util.WeakSequence[Mapper[Any]] _all_tables: Set[Table] + _polymorphic_attr_key: Optional[str] _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] @@ -1653,6 +1654,7 @@ class Mapper( """ setter = False + polymorphic_key: Optional[str] = None if self.polymorphic_on is not None: setter = True @@ -1772,17 +1774,23 @@ class Mapper( self._set_polymorphic_identity = ( mapper._set_polymorphic_identity ) + self._polymorphic_attr_key = ( + mapper._polymorphic_attr_key + ) self._validate_polymorphic_identity = ( mapper._validate_polymorphic_identity ) else: self._set_polymorphic_identity = None + self._polymorphic_attr_key = None return if setter: def _set_polymorphic_identity(state): dict_ = state.dict + # TODO: what happens if polymorphic_on column attribute name + # does not match .key? state.get_impl(polymorphic_key).set( state, dict_, @@ -1790,6 +1798,8 @@ class Mapper( None, ) + self._polymorphic_attr_key = polymorphic_key + def _validate_polymorphic_identity(mapper, state, dict_): if ( polymorphic_key in dict_ @@ -1808,6 +1818,7 @@ class Mapper( _validate_polymorphic_identity ) else: + self._polymorphic_attr_key = None self._set_polymorphic_identity = None _validate_polymorphic_identity = None @@ -3562,6 +3573,10 @@ class Mapper( return util.LRUCache(self._compiled_cache_size) @HasMemoized.memoized_attribute + def _multiple_persistence_tables(self): + return len(self.tables) > 1 + + @HasMemoized.memoized_attribute def _sorted_tables(self): table_to_mapper: Dict[Table, Mapper[Any]] = {} diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index abd528986..dfb61c28a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -31,6 +31,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import cursor as _cursor from ..sql import operators from ..sql.elements import BooleanClauseList from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -398,6 +399,11 @@ def _collect_insert_commands( None ) + if bulk and mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + yield ( state, state_dict, @@ -411,7 +417,11 @@ def _collect_insert_commands( def _collect_update_commands( - uowtransaction, table, states_to_update, bulk=False + uowtransaction, + table, + states_to_update, + bulk=False, + use_orm_update_stmt=None, ): """Identify sets of values to use in UPDATE statements for a list of states. @@ -437,7 +447,11 @@ def _collect_update_commands( pks = mapper._pks_by_table[table] - value_params = {} + if use_orm_update_stmt is not None: + # TODO: ordered values, etc + value_params = use_orm_update_stmt._values + else: + value_params = {} propkey_to_col = mapper._propkey_to_col[table] @@ -697,6 +711,7 @@ def _emit_update_statements( table, update, bookkeeping=True, + use_orm_update_stmt=None, ): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -708,7 +723,7 @@ def _emit_update_statements( execution_options = {"compiled_cache": base_mapper._compiled_cache} - def update_stmt(): + def update_stmt(existing_stmt=None): clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: @@ -725,10 +740,17 @@ def _emit_update_statements( ) ) - stmt = table.update().where(clauses) + if existing_stmt is not None: + stmt = existing_stmt.where(clauses) + else: + stmt = table.update().where(clauses) return stmt - cached_stmt = base_mapper._memo(("update", table), update_stmt) + if use_orm_update_stmt is not None: + cached_stmt = update_stmt(use_orm_update_stmt) + + else: + cached_stmt = base_mapper._memo(("update", table), update_stmt) for ( (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), @@ -747,6 +769,15 @@ def _emit_update_statements( records = list(records) statement = cached_stmt + + if use_orm_update_stmt is not None: + statement = statement._annotate( + { + "_emit_update_table": table, + "_emit_update_mapper": mapper, + } + ) + return_defaults = False if not has_all_pks: @@ -904,16 +935,35 @@ def _emit_insert_statements( table, insert, bookkeeping=True, + use_orm_insert_stmt=None, + execution_options=None, ): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(("insert", table), table.insert) + if use_orm_insert_stmt is not None: + cached_stmt = use_orm_insert_stmt + exec_opt = util.EMPTY_DICT - execution_options = {"compiled_cache": base_mapper._compiled_cache} + # if a user query with RETURNING was passed, we definitely need + # to use RETURNING. + returning_is_required_anyway = bool(use_orm_insert_stmt._returning) + else: + returning_is_required_anyway = False + cached_stmt = base_mapper._memo(("insert", table), table.insert) + exec_opt = {"compiled_cache": base_mapper._compiled_cache} + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + exec_opt, execution_options + ) + else: + execution_options = exec_opt + + return_result = None for ( - (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + (connection, _, hasvalue, has_all_pks, has_all_defaults), records, ) in groupby( insert, @@ -928,17 +978,29 @@ def _emit_insert_statements( statement = cached_stmt + if use_orm_insert_stmt is not None: + statement = statement._annotate( + { + "_emit_insert_table": table, + "_emit_insert_mapper": mapper, + } + ) + if ( - not bookkeeping - or ( - has_all_defaults - or not base_mapper.eager_defaults - or not base_mapper.local_table.implicit_returning - or not connection.dialect.insert_returning + ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not base_mapper.local_table.implicit_returning + or not connection.dialect.insert_returning + ) ) + and not returning_is_required_anyway and has_all_pks and not hasvalue ): + # the "we don't need newly generated values back" section. # here we have all the PKs, all the defaults or we don't want # to fetch them, or the dialect doesn't support RETURNING at all @@ -946,7 +1008,7 @@ def _emit_insert_statements( records = list(records) multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) if bookkeeping: @@ -962,7 +1024,7 @@ def _emit_insert_statements( has_all_defaults, ), last_inserted_params, - ) in zip(records, c.context.compiled_parameters): + ) in zip(records, result.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -970,19 +1032,20 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, - c.returned_defaults - if not c.context.executemany + result.returned_defaults + if not result.context.executemany else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: - # here, we need defaults and/or pk values back. + # here, we need defaults and/or pk values back or we otherwise + # know that we are using RETURNING in any case records = list(records) if ( @@ -991,6 +1054,16 @@ def _emit_insert_statements( and len(records) > 1 ): do_executemany = True + elif returning_is_required_anyway: + if connection.dialect.insert_executemany_returning: + do_executemany = True + else: + raise sa_exc.InvalidRequestError( + f"Can't use explicit RETURNING for bulk INSERT " + f"operation with " + f"{connection.dialect.dialect_description} backend; " + f"executemany is not supported with RETURNING" + ) else: do_executemany = False @@ -998,6 +1071,7 @@ def _emit_insert_statements( statement = statement.return_defaults( *mapper._server_default_cols[table] ) + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) elif do_executemany: @@ -1006,10 +1080,16 @@ def _emit_insert_statements( if do_executemany: multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) + if use_orm_insert_stmt is not None: + if return_result is None: + return_result = result + else: + return_result = return_result.splice_vertically(result) + if bookkeeping: for ( ( @@ -1027,9 +1107,9 @@ def _emit_insert_statements( returned_defaults, ) in zip_longest( records, - c.context.compiled_parameters, - c.inserted_primary_key_rows, - c.returned_defaults_rows or (), + result.context.compiled_parameters, + result.inserted_primary_key_rows, + result.returned_defaults_rows or (), ): if inserted_primary_key is None: # this is a real problem and means that we didn't @@ -1062,7 +1142,7 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, @@ -1071,6 +1151,8 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + assert not returning_is_required_anyway + for ( state, state_dict, @@ -1132,6 +1214,12 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) + if use_orm_insert_stmt is not None: + if return_result is None: + return _cursor.null_dml_result() + else: + return return_result + def _emit_post_update_statements( base_mapper, uowtransaction, mapper, table, update diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6d0f055e4..4d5a98fcf 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2978,7 +2978,7 @@ class Query( ) def delete( - self, synchronize_session: _SynchronizeSessionArgument = "evaluate" + self, synchronize_session: _SynchronizeSessionArgument = "auto" ) -> int: r"""Perform a DELETE with an arbitrary WHERE clause. @@ -3042,7 +3042,7 @@ class Query( def update( self, values: Dict[_DMLColumnArgument, Any], - synchronize_session: _SynchronizeSessionArgument = "evaluate", + synchronize_session: _SynchronizeSessionArgument = "auto", update_args: Optional[Dict[Any, Any]] = None, ) -> int: r"""Perform an UPDATE with an arbitrary WHERE clause. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a690da0d5..64c013306 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget): statement._propagate_attrs.get("compile_state_plugin", None) == "orm" ): - # note that even without "future" mode, we need compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) if TYPE_CHECKING: - assert isinstance(compile_state_cls, ORMCompileState) + assert isinstance( + compile_state_cls, context.AbstractORMCompileState + ) else: compile_state_cls = None @@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget): statement, params or {}, execution_options=execution_options ) - result: Result[Any] = conn.execute( - statement, params or {}, execution_options=execution_options - ) - if compile_state_cls: - result = compile_state_cls.orm_setup_cursor_result( + result: Result[Any] = compile_state_cls.orm_execute_statement( self, statement, - params, + params or {}, execution_options, bind_arguments, - result, + conn, + ) + else: + result = conn.execute( + statement, params or {}, execution_options=execution_options ) if _scalar_result: @@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: TypedReturnsRows[Tuple[_T]], - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, @@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget): def scalars( self, statement: Executable, - params: Optional[_CoreSingleExecuteParams] = None, + params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 19c6493db..8652591c8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy): fetch = self.columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None happens here only for dml bulk_persistence cases + # when context.DMLReturningColFilter is used + return + memoized_populators[self.parent_property] = fetch def init_class_attribute(self, mapper): @@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader): fetch = columns[0] if adapter: fetch = adapter.columns[fetch] + if fetch is None: + # None is not expected to be the result of any + # adapter implementation here, however there may be theoretical + # usages of returning() with context.DMLReturningColFilter + return + memoized_populators[self.parent_property] = fetch def create_row_processor( |
