summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm')
-rw-r--r--lib/sqlalchemy/orm/bulk_persistence.py1459
-rw-r--r--lib/sqlalchemy/orm/context.py173
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py26
-rw-r--r--lib/sqlalchemy/orm/evaluator.py76
-rw-r--r--lib/sqlalchemy/orm/identity.py10
-rw-r--r--lib/sqlalchemy/orm/loading.py5
-rw-r--r--lib/sqlalchemy/orm/mapper.py15
-rw-r--r--lib/sqlalchemy/orm/persistence.py138
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/orm/session.py25
-rw-r--r--lib/sqlalchemy/orm/strategies.py11
11 files changed, 1537 insertions, 405 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
index 225292d17..3ed34a57a 100644
--- a/lib/sqlalchemy/orm/bulk_persistence.py
+++ b/lib/sqlalchemy/orm/bulk_persistence.py
@@ -15,24 +15,32 @@ specifically outside of the flush() process.
from __future__ import annotations
from typing import Any
+from typing import cast
from typing import Dict
from typing import Iterable
+from typing import Optional
+from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import attributes
+from . import context
from . import evaluator
from . import exc as orm_exc
+from . import loading
from . import persistence
from .base import NO_VALUE
from .context import AbstractORMCompileState
+from .context import FromStatement
+from .context import ORMFromStatementCompileState
+from .context import QueryContext
from .. import exc as sa_exc
-from .. import sql
from .. import util
from ..engine import Dialect
from ..engine import result as _result
from ..sql import coercions
+from ..sql import dml
from ..sql import expression
from ..sql import roles
from ..sql import select
@@ -48,16 +56,24 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from .mapper import Mapper
+ from .session import _BindArguments
from .session import ORMExecuteState
+ from .session import Session
from .session import SessionTransaction
from .state import InstanceState
+ from ..engine import Connection
+ from ..engine import cursor
+ from ..engine.interfaces import _CoreAnyExecuteParams
+ from ..engine.interfaces import _ExecuteOptionsParameter
_O = TypeVar("_O", bound=object)
-_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
+_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"]
+_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"]
+@overload
def _bulk_insert(
mapper: Mapper[_O],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
@@ -65,7 +81,36 @@ def _bulk_insert(
isstates: bool,
return_defaults: bool,
render_nulls: bool,
+ use_orm_insert_stmt: Literal[None] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
+) -> cursor.CursorResult[Any]:
+ ...
+
+
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+) -> Optional[cursor.CursorResult[Any]]:
base_mapper = mapper.base_mapper
if session_transaction.session.connection_callable:
@@ -81,13 +126,27 @@ def _bulk_insert(
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
connection = session_transaction.connection(base_mapper)
+
+ return_result: Optional[cursor.CursorResult[Any]] = None
+
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
+ is_joined_inh_supertable = super_mapper is not mapper
+ bookkeeping = (
+ is_joined_inh_supertable
+ or return_defaults
+ or (
+ use_orm_insert_stmt is not None
+ and bool(use_orm_insert_stmt._returning)
+ )
+ )
+
records = (
(
None,
@@ -112,18 +171,25 @@ def _bulk_insert(
table,
((None, mapping, mapper, connection) for mapping in mappings),
bulk=True,
- return_defaults=return_defaults,
+ return_defaults=bookkeeping,
render_nulls=render_nulls,
)
)
- persistence._emit_insert_statements(
+ result = persistence._emit_insert_statements(
base_mapper,
None,
super_mapper,
table,
records,
- bookkeeping=return_defaults,
+ bookkeeping=bookkeeping,
+ use_orm_insert_stmt=use_orm_insert_stmt,
+ execution_options=execution_options,
)
+ if use_orm_insert_stmt is not None:
+ if not use_orm_insert_stmt._returning or return_result is None:
+ return_result = result
+ elif result.returns_rows:
+ return_result = return_result.splice_horizontally(result)
if return_defaults and isstates:
identity_cls = mapper._identity_class
@@ -134,14 +200,43 @@ def _bulk_insert(
tuple([dict_[key] for key in identity_props]),
)
+ if use_orm_insert_stmt is not None:
+ assert return_result is not None
+ return return_result
+
+@overload
def _bulk_update(
mapper: Mapper[Any],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
session_transaction: SessionTransaction,
isstates: bool,
update_changed_only: bool,
+ use_orm_update_stmt: Literal[None] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = ...,
+) -> _result.Result[Any]:
+ ...
+
+
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = None,
+) -> Optional[_result.Result[Any]]:
base_mapper = mapper.base_mapper
search_keys = mapper._primary_key_propkeys
@@ -161,7 +256,8 @@ def _bulk_update(
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
if session_transaction.session.connection_callable:
raise NotImplementedError(
@@ -172,7 +268,7 @@ def _bulk_update(
connection = session_transaction.connection(base_mapper)
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
records = persistence._collect_update_commands(
@@ -193,8 +289,8 @@ def _bulk_update(
for mapping in mappings
),
bulk=True,
+ use_orm_update_stmt=use_orm_update_stmt,
)
-
persistence._emit_update_statements(
base_mapper,
None,
@@ -202,10 +298,125 @@ def _bulk_update(
table,
records,
bookkeeping=False,
+ use_orm_update_stmt=use_orm_update_stmt,
)
+ if use_orm_update_stmt is not None:
+ return _result.null_result()
+
+
+def _expand_composites(mapper, mappings):
+ composite_attrs = mapper.composites
+ if not composite_attrs:
+ return
+
+ composite_keys = set(composite_attrs.keys())
+ populators = {
+ key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
+ for key in composite_keys
+ }
+ for mapping in mappings:
+ for key in composite_keys.intersection(mapping):
+ populators[key](mapping)
+
class ORMDMLState(AbstractORMCompileState):
+ is_dml_returning = True
+ from_statement_ctx: Optional[ORMFromStatementCompileState] = None
+
+ @classmethod
+ def _get_orm_crud_kv_pairs(
+ cls, mapper, statement, kv_iterator, needs_to_be_cacheable
+ ):
+
+ core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+ for k, v in kv_iterator:
+ k = coercions.expect(roles.DMLColumnRole, k)
+
+ if isinstance(k, str):
+ desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+ if desc is NO_VALUE:
+ yield (
+ coercions.expect(roles.DMLColumnRole, k),
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ )
+ if needs_to_be_cacheable
+ else v,
+ )
+ else:
+ yield from core_get_crud_kv_pairs(
+ statement,
+ desc._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["proxy_key"]
+ )
+ yield from core_get_crud_kv_pairs(
+ statement,
+ attr._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ else:
+ yield (
+ k,
+ v
+ if not needs_to_be_cacheable
+ else coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+
+ @classmethod
+ def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_multi_crud_kv_pairs(
+ statement, kv_iterator
+ )
+
+ return [
+ dict(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper, statement, value_dict.items(), False
+ )
+ )
+ for value_dict in kv_iterator
+ ]
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
+ assert (
+ needs_to_be_cacheable
+ ), "no test coverage for needs_to_be_cacheable=False"
+
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_crud_kv_pairs(
+ statement, kv_iterator, needs_to_be_cacheable
+ )
+
+ return list(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper,
+ statement,
+ kv_iterator,
+ needs_to_be_cacheable,
+ )
+ )
+
@classmethod
def get_entity_description(cls, statement):
ext_info = statement.table._annotations["parententity"]
@@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState):
]
]
+ def _setup_orm_returning(
+ self,
+ compiler,
+ orm_level_statement,
+ dml_level_statement,
+ use_supplemental_cols=True,
+ dml_mapper=None,
+ ):
+ """establish ORM column handlers for an INSERT, UPDATE, or DELETE
+ which uses explicit returning().
+
+ called within compilation level create_for_statement.
+
+ The _return_orm_returning() method then receives the Result
+ after the statement was executed, and applies ORM loading to the
+ state that we first established here.
+
+ """
+
+ if orm_level_statement._returning:
+
+ fs = FromStatement(
+ orm_level_statement._returning, dml_level_statement
+ )
+ fs = fs.options(*orm_level_statement._with_options)
+ self.select_statement = fs
+ self.from_statement_ctx = (
+ fsc
+ ) = ORMFromStatementCompileState.create_for_statement(fs, compiler)
+ fsc.setup_dml_returning_compile_state(dml_mapper)
+
+ dml_level_statement = dml_level_statement._generate()
+ dml_level_statement._returning = ()
+
+ cols_to_return = [c for c in fsc.primary_columns if c is not None]
+
+ # since we are splicing result sets together, make sure there
+ # are columns of some kind returned in each result set
+ if not cols_to_return:
+ cols_to_return.extend(dml_mapper.primary_key)
+
+ if use_supplemental_cols:
+ dml_level_statement = dml_level_statement.return_defaults(
+ supplemental_cols=cols_to_return
+ )
+ else:
+ dml_level_statement = dml_level_statement.returning(
+ *cols_to_return
+ )
+
+ return dml_level_statement
+
+ @classmethod
+ def _return_orm_returning(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+
+ execution_context = result.context
+ compile_state = execution_context.compiled.compile_state
+
+ if compile_state.from_statement_ctx:
+ load_options = execution_options.get(
+ "_sa_orm_load_options", QueryContext.default_load_options
+ )
+ querycontext = QueryContext(
+ compile_state.from_statement_ctx,
+ compile_state.select_statement,
+ params,
+ session,
+ load_options,
+ execution_options,
+ bind_arguments,
+ )
+ return loading.instances(result, querycontext)
+ else:
+ return result
+
class BulkUDCompileState(ORMDMLState):
class default_update_options(Options):
- _synchronize_session: _SynchronizeSessionArgument = "evaluate"
- _is_delete_using = False
- _is_update_from = False
- _autoflush = True
- _subject_mapper = None
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _synchronize_session: _SynchronizeSessionArgument = "auto"
+ _can_use_returning: bool = False
+ _is_delete_using: bool = False
+ _is_update_from: bool = False
+ _autoflush: bool = True
+ _subject_mapper: Optional[Mapper[Any]] = None
_resolved_values = EMPTY_DICT
- _resolved_keys_as_propnames = EMPTY_DICT
- _value_evaluators = EMPTY_DICT
- _matched_objects = None
+ _eval_condition = None
_matched_rows = None
_refresh_identity_token = None
@@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState):
execution_options,
) = BulkUDCompileState.default_update_options.from_execution_options(
"_sa_orm_update_options",
- {"synchronize_session", "is_delete_using", "is_update_from"},
+ {
+ "synchronize_session",
+ "is_delete_using",
+ "is_update_from",
+ "dml_strategy",
+ },
execution_options,
statement._execution_options,
)
- sync = update_options._synchronize_session
- if sync is not None:
- if sync not in ("evaluate", "fetch", False):
- raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False"
- )
-
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState):
update_options += {"_subject_mapper": plugin_subject.mapper}
+ if not isinstance(params, list):
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "orm"}
+ elif update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "bulk"}
+ elif update_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ sync = update_options._synchronize_session
+ if sync is not None:
+ if sync not in ("auto", "evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'auto', 'evaluate', 'fetch', False"
+ )
+ if update_options._dml_strategy == "bulk" and sync == "fetch":
+ raise sa_exc.InvalidRequestError(
+ "The 'fetch' synchronization strategy is not available "
+ "for 'bulk' ORM updates (i.e. multiple parameter sets)"
+ )
+
if update_options._autoflush:
session._autoflush()
+ if update_options._dml_strategy == "orm":
+
+ if update_options._synchronize_session == "auto":
+ update_options = cls._do_pre_synchronize_auto(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "auto":
+ update_options += {"_synchronize_session": "evaluate"}
+
+ # indicators from the "pre exec" step that are then
+ # added to the DML statement, which will also be part of the cache
+ # key. The compile level create_for_statement() method will then
+ # consume these at compiler time.
statement = statement._annotate(
{
"synchronize_session": update_options._synchronize_session,
"is_delete_using": update_options._is_delete_using,
"is_update_from": update_options._is_update_from,
+ "dml_strategy": update_options._dml_strategy,
+ "can_use_returning": update_options._can_use_returning,
}
)
- # this stage of the execution is called before the do_orm_execute event
- # hook. meaning for an extension like horizontal sharding, this step
- # happens before the extension splits out into multiple backends and
- # runs only once. if we do pre_sync_fetch, we execute a SELECT
- # statement, which the horizontal sharding extension splits amongst the
- # shards and combines the results together.
-
- if update_options._synchronize_session == "evaluate":
- update_options = cls._do_pre_synchronize_evaluate(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "fetch":
- update_options = cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
-
return (
statement,
util.immutabledict(execution_options).union(
@@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState):
# individual ones we return here.
update_options = execution_options["_sa_orm_update_options"]
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(session, result, update_options)
- elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(session, result, update_options)
+ if update_options._dml_strategy == "orm":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(
+ session, statement, result, update_options
+ )
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(
+ session, statement, result, update_options
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_bulk_evaluate(
+ session, params, result, update_options
+ )
+ return result
- return result
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
@classmethod
def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
@@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState):
primary_key_convert = [
lookup[bpk] for bpk in mapper.base_mapper.primary_key
]
-
return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
@classmethod
- def _do_pre_synchronize_evaluate(
+ def _get_matched_objects_on_criteria(cls, update_options, states):
+ mapper = update_options._subject_mapper
+ eval_condition = update_options._eval_condition
+
+ raw_data = [
+ (state.obj(), state, state.dict)
+ for state in states
+ if state.mapper.isa(mapper) and not state.expired
+ ]
+
+ identity_token = update_options._refresh_identity_token
+ if identity_token is not None:
+ raw_data = [
+ (obj, state, dict_)
+ for obj, state, dict_ in raw_data
+ if state.identity_token == identity_token
+ ]
+
+ result = []
+ for obj, state, dict_ in raw_data:
+ evaled_condition = eval_condition(obj)
+
+ # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
+ # evaluates as True for all comparisons
+ if (
+ evaled_condition is True
+ or evaled_condition is evaluator._EXPIRED_OBJECT
+ ):
+ result.append(
+ (
+ obj,
+ state,
+ dict_,
+ evaled_condition is evaluator._EXPIRED_OBJECT,
+ )
+ )
+ return result
+
+ @classmethod
+ def _eval_condition_from_statement(cls, update_options, statement):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
+ if statement._where_criteria:
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
+ else:
+
+ def eval_condition(obj):
+ return True
+
+ return eval_condition
+
+ @classmethod
+ def _do_pre_synchronize_auto(
cls,
session,
statement,
@@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState):
bind_arguments,
update_options,
):
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
+ """setup auto sync strategy
+
+
+ "auto" checks if we can use "evaluate" first, then falls back
+ to "fetch"
+
+ evaluate is vastly more efficient for the common case
+ where session is empty, only has a few objects, and the UPDATE
+ statement can potentially match thousands/millions of rows.
- value_evaluators = resolved_keys_as_propnames = EMPTY_DICT
+ OTOH more complex criteria that fails to work with "evaluate"
+ we would hope usually correlates with fewer net rows.
+
+ """
try:
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- crit = ()
- if statement._where_criteria:
- crit += statement._where_criteria
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
- global_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(global_attributes)
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ return update_options + {
+ "_eval_condition": eval_condition,
+ "_synchronize_session": "evaluate",
+ }
- if global_attributes:
- crit += cls._adjust_for_extra_criteria(
- global_attributes, mapper
- )
+ update_options += {"_synchronize_session": "fetch"}
+ return cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
- if crit:
- eval_condition = evaluator_compiler.process(*crit)
- else:
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
- def eval_condition(obj):
- return True
+ try:
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
except evaluator.UnevaluatableError as err:
raise sa_exc.InvalidRequestError(
@@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState):
"synchronize_session execution option." % err
) from err
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- value_evaluators = {}
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- # TODO: detect when the where clause is a trivial primary key match.
- matched_objects = [
- state.obj()
- for state in session.identity_map.all_states()
- if state.mapper.isa(mapper)
- and not state.expired
- and eval_condition(state.obj())
- and (
- update_options._refresh_identity_token is None
- # TODO: coverage for the case where horizontal sharding
- # invokes an update() or delete() given an explicit identity
- # token up front
- or state.identity_token
- == update_options._refresh_identity_token
- )
- ]
return update_options + {
- "_matched_objects": matched_objects,
- "_value_evaluators": value_evaluators,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_eval_condition": eval_condition,
}
@classmethod
@@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState):
def _resolved_keys_as_propnames(cls, mapper, resolved_values):
values = []
for k, v in resolved_values:
- if isinstance(k, attributes.QueryableAttribute):
- values.append((k.key, v))
- continue
- elif hasattr(k, "__clause_element__"):
- k = k.__clause_element__()
-
if mapper and isinstance(k, expression.ColumnElement):
try:
attr = mapper._columntoproperty[k]
@@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState):
values.append((attr.key, v))
else:
raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k
+ "Attribute name not found, can't be "
+ "synchronized back to objects: %r" % k
)
return values
@@ -622,14 +1016,43 @@ class BulkUDCompileState(ORMDMLState):
)
select_stmt._where_criteria = statement._where_criteria
+ # conditionally run the SELECT statement for pre-fetch, testing the
+ # "bind" for if we can use RETURNING or not using the do_orm_execute
+ # event. If RETURNING is available, the do_orm_execute event
+ # will cancel the SELECT from being actually run.
+ #
+ # The way this is organized seems strange, why don't we just
+ # call can_use_returning() before invoking the statement and get
+ # answer?, why does this go through the whole execute phase using an
+ # event? Answer: because we are integrating with extensions such
+ # as the horizontal sharding extention that "multiplexes" an individual
+ # statement run through multiple engines, and it uses
+ # do_orm_execute() to do that.
+
+ can_use_returning = None
+
def skip_for_returning(orm_context: ORMExecuteState) -> Any:
bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- if cls.can_use_returning(
+ nonlocal can_use_returning
+
+ per_bind_result = cls.can_use_returning(
bind.dialect,
mapper,
is_update_from=update_options._is_update_from,
is_delete_using=update_options._is_delete_using,
- ):
+ )
+
+ if can_use_returning is not None:
+ if can_use_returning != per_bind_result:
+ raise sa_exc.InvalidRequestError(
+ "For synchronize_session='fetch', can't mix multiple "
+ "backends where some support RETURNING and others "
+ "don't"
+ )
+ else:
+ can_use_returning = per_bind_result
+
+ if per_bind_result:
return _result.null_result()
else:
return None
@@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState):
)
matched_rows = result.fetchall()
- value_evaluators = EMPTY_DICT
-
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- target_cls = mapper.class_
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
-
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- value_evaluators = {}
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- else:
- resolved_keys_as_propnames = EMPTY_DICT
-
return update_options + {
- "_value_evaluators": value_evaluators,
"_matched_rows": matched_rows,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_can_use_returning": can_use_returning,
}
@CompileState.plugin_for("orm", "insert")
-class ORMInsert(ORMDMLState, InsertDMLState):
+class BulkORMInsert(ORMDMLState, InsertDMLState):
+ class default_insert_options(Options):
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _render_nulls: bool = False
+ _return_defaults: bool = False
+ _subject_mapper: Optional[Mapper[Any]] = None
+
+ select_statement: Optional[FromStatement] = None
+
@classmethod
def orm_pre_session_exec(
cls,
@@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState):
bind_arguments,
is_reentrant_invoke,
):
+
+ (
+ insert_options,
+ execution_options,
+ ) = BulkORMInsert.default_insert_options.from_execution_options(
+ "_sa_orm_insert_options",
+ {"dml_strategy"},
+ execution_options,
+ statement._execution_options,
+ )
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState):
else:
bind_arguments["mapper"] = plugin_subject.mapper
+ insert_options += {"_subject_mapper": plugin_subject.mapper}
+
+ if not params:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "orm"}
+ elif insert_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "bulk"}
+ elif insert_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ if insert_options._dml_strategy != "raw":
+ # for ORM object loading, like ORMContext, we have to disable
+ # result set adapt_to_context, because we will be generating a
+ # new statement with specific columns that's cached inside of
+ # an ORMFromStatementCompileState, which we will re-use for
+ # each result.
+ if not execution_options:
+ execution_options = context._orm_load_exec_options
+ else:
+ execution_options = execution_options.union(
+ context._orm_load_exec_options
+ )
+
+ statement = statement._annotate(
+ {"dml_strategy": insert_options._dml_strategy}
+ )
+
return (
statement,
- util.immutabledict(execution_options),
+ util.immutabledict(execution_options).union(
+ {"_sa_orm_insert_options": insert_options}
+ ),
)
@classmethod
- def orm_setup_cursor_result(
+ def orm_execute_statement(
cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
- return result
+ session: Session,
+ statement: dml.Insert,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ insert_options = execution_options.get(
+ "_sa_orm_insert_options", cls.default_insert_options
+ )
+
+ if insert_options._dml_strategy not in (
+ "raw",
+ "bulk",
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM insert strategy "
+ "are 'raw', 'orm', 'bulk', 'auto"
+ )
+
+ result: _result.Result[Any]
+
+ if insert_options._dml_strategy == "raw":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return result
+
+ if insert_options._dml_strategy == "bulk":
+ mapper = insert_options._subject_mapper
+
+ if (
+ statement._post_values_clause is not None
+ and mapper._multiple_persistence_tables
+ ):
+ raise sa_exc.InvalidRequestError(
+ "bulk INSERT with a 'post values' clause "
+ "(typically upsert) not supported for multi-table "
+ f"mapper {mapper}"
+ )
+
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_insert(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ return_defaults=insert_options._return_defaults,
+ render_nulls=insert_options._render_nulls,
+ use_orm_insert_stmt=statement,
+ execution_options=execution_options,
+ )
+ elif insert_options._dml_strategy == "orm":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ else:
+ raise AssertionError()
+
+ if not bool(statement._returning):
+ return result
+
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
+
+ self = cast(
+ BulkORMInsert,
+ super().create_for_statement(statement, compiler, **kw),
+ )
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ else:
+ toplevel = True
+ if not toplevel:
+ return self
+
+ mapper = statement._propagate_attrs["plugin_subject"]
+ dml_strategy = statement._annotations.get("dml_strategy", "raw")
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_insert(compiler)
+ elif dml_strategy == "orm":
+ self._setup_for_orm_insert(compiler, mapper)
+
+ return self
+
+ @classmethod
+ def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
+ return {
+ col.key if col is not None else k: v
+ for col, k, v in (
+ (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
+ )
+ }
+
+ def _setup_for_orm_insert(self, compiler, mapper):
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=False,
+ )
+ self.statement = statement
+
+ def _setup_for_bulk_insert(self, compiler):
+ """establish an INSERT statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_insert_statement().
+
+ """
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+ an = statement._annotations
+
+ emit_insert_table, emit_insert_mapper = (
+ an["_emit_insert_table"],
+ an["_emit_insert_mapper"],
+ )
+
+ statement = statement._clone()
+
+ statement.table = emit_insert_table
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_insert_table
+ }
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=True,
+ dml_mapper=emit_insert_mapper,
+ )
+
+ self.statement = statement
@CompileState.plugin_for("orm", "update")
@@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
self = cls.__new__(cls)
+ dml_strategy = statement._annotations.get(
+ "dml_strategy", "unspecified"
+ )
+
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_update(statement, compiler)
+ elif dml_strategy in ("orm", "unspecified"):
+ self._setup_for_orm_update(statement, compiler)
+
+ return self
+
+ def _setup_for_orm_update(self, statement, compiler, **kw):
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
self.extra_criteria_entities = {}
- self._resolved_values = cls._get_resolved_values(mapper, statement)
+ self._resolved_values = self._get_resolved_values(mapper, statement)
extra_criteria_attributes = {}
@@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
if statement._values:
self._resolved_values = dict(self._resolved_values)
- new_stmt = sql.Update.__new__(sql.Update)
- new_stmt.__dict__.update(statement.__dict__)
+ new_stmt = statement._clone()
new_stmt.table = mapper.local_table
# note if the statement has _multi_values, these
@@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
elif statement._values:
new_stmt._values = self._resolved_values
- new_crit = cls._adjust_for_extra_criteria(
+ new_crit = self._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
@@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
- if compiler._annotations.get(
+ use_supplemental_cols = False
+
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect, mapper, is_multitable=self.is_multitable
- ):
- if new_stmt._returning:
- raise sa_exc.InvalidRequestError(
- "Can't use synchronize_session='fetch' "
- "with explicit returning()"
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect, mapper, is_multitable=self.is_multitable
)
- self.statement = self.statement.returning(
- *mapper.local_table.primary_key
)
- return self
+ if synchronize_session == "fetch" and can_use_returning:
+ use_supplemental_cols = True
+
+ # NOTE: we might want to RETURNING the actual columns to be
+ # synchronized also. however this is complicated and difficult
+ # to align against the behavior of "evaluate". Additionally,
+ # in a large number (if not the majority) of cases, we have the
+ # "evaluate" answer, usually a fixed value, in memory already and
+ # there's no need to re-fetch the same value
+ # over and over again. so perhaps if it could be RETURNING just
+ # the elements that were based on a SQL expression and not
+ # a constant. For now it doesn't quite seem worth it
+ new_stmt = new_stmt.return_defaults(
+ *(list(mapper.local_table.primary_key))
+ )
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
+
+ def _setup_for_bulk_update(self, statement, compiler, **kw):
+ """establish an UPDATE statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_update_statement().
+
+ """
+ statement = cast(dml.Update, statement)
+ an = statement._annotations
+
+ emit_update_table, _ = (
+ an["_emit_update_table"],
+ an["_emit_update_mapper"],
+ )
+
+ statement = statement._clone()
+ statement.table = emit_update_table
+
+ UpdateDMLState.__init__(self, statement, compiler, **kw)
+
+ if self._ordered_values:
+ raise sa_exc.InvalidRequestError(
+ "bulk ORM UPDATE does not support ordered_values() for "
+ "custom UPDATE statements with bulk parameter sets. Use a "
+ "non-bulk UPDATE statement or use values()."
+ )
+
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_update_table
+ }
+ self.statement = statement
+
+ @classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Update,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy not in ("orm", "auto", "bulk"):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM UPDATE strategy "
+ "are 'orm', 'auto', 'bulk'"
+ )
+
+ result: _result.Result[Any]
+
+ if update_options._dml_strategy == "bulk":
+ if statement._where_criteria:
+ raise sa_exc.InvalidRequestError(
+ "WHERE clause with bulk ORM UPDATE not "
+ "supported right now. Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+ mapper = update_options._subject_mapper
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_update(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ update_changed_only=False,
+ use_orm_update_stmt=statement,
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+ else:
+ return super().orm_execute_statement(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ )
@classmethod
def can_use_returning(
@@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
return True
@classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator):
- plugin_subject = statement._propagate_attrs["plugin_subject"]
-
- core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
-
- if not plugin_subject or not plugin_subject.mapper:
- return core_get_crud_kv_pairs(statement, kv_iterator)
-
- mapper = plugin_subject.mapper
-
- values = []
-
- for k, v in kv_iterator:
- k = coercions.expect(roles.DMLColumnRole, k)
+ def _do_post_synchronize_bulk_evaluate(
+ cls, session, params, result, update_options
+ ):
+ if not params:
+ return
- if isinstance(k, str):
- desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
- if desc is NO_VALUE:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- else:
- values.extend(
- core_get_crud_kv_pairs(
- statement, desc._bulk_update_tuples(v)
- )
- )
- elif "entity_namespace" in k._annotations:
- k_anno = k._annotations
- attr = _entity_namespace_key(
- k_anno["entity_namespace"], k_anno["proxy_key"]
- )
- values.extend(
- core_get_crud_kv_pairs(
- statement, attr._bulk_update_tuples(v)
- )
- )
- else:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- return values
+ mapper = update_options._subject_mapper
+ pk_keys = [prop.key for prop in mapper._identity_key_props]
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
+ identity_map = session.identity_map
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
- for obj in update_options._matched_objects:
-
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
+ for param in params:
+ identity_key = mapper.identity_key_from_primary_key(
+ (param[key] for key in pk_keys),
+ update_options._refresh_identity_token,
)
-
- # the evaluated states were gathered across all identity tokens.
- # however the post_sync events are called per identity token,
- # so filter.
- if (
- update_options._refresh_identity_token is not None
- and state.identity_token
- != update_options._refresh_identity_token
- ):
+ state = identity_map.fast_get_state(identity_key)
+ if not state:
continue
+ evaluated_keys = set(param).difference(pk_keys)
+
+ dict_ = state.dict
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ dict_[key] = param[key]
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
+ to_expire = evaluated_keys.intersection(dict_).difference(
+ to_evaluate
+ )
if to_expire:
state._expire_attributes(dict_, to_expire)
- states.add(state)
- session._register_altered(states)
+ @classmethod
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
+ )
+
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
+ )
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
-
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows
@@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
if identity_key in session.identity_map
]
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
+ if not objs:
+ return
- for obj in objs:
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [
+ (
+ obj,
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+ for obj in objs
+ ],
+ )
+
+ @classmethod
+ def _apply_update_set_values_to_objects(
+ cls, session, update_options, statement, matched_objects
+ ):
+ """apply values to objects derived from an update statement, e.g.
+ UPDATE..SET <values>
+
+ """
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ value_evaluators = {}
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ evaluated_keys = list(value_evaluators.keys())
+ attrib = set(k for k, v in resolved_keys_as_propnames)
+
+ states = set()
+ for obj, state, dict_ in matched_objects:
to_evaluate = state.unmodified.intersection(evaluated_keys)
+
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ # only run eval for attributes that are present.
+ dict_[key] = value_evaluators[key](obj)
+
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
to_expire = attrib.intersection(dict_).difference(to_evaluate)
if to_expire:
state._expire_attributes(dict_, to_expire)
@@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
@@ -1002,31 +1743,97 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
if opt._is_criteria_option:
opt.get_global_criteria(extra_criteria_attributes)
+ new_stmt = statement._clone()
+ new_stmt.table = mapper.local_table
+
new_crit = cls._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
- statement = statement.where(*new_crit)
+ new_stmt = new_stmt.where(*new_crit)
# do this first as we need to determine if there is
# DELETE..FROM
- DeleteDMLState.__init__(self, statement, compiler, **kw)
+ DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
+
+ use_supplemental_cols = False
- if compiler._annotations.get(
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect,
- mapper,
- is_multitable=self.is_multitable,
- is_delete_using=compiler._annotations.get(
- "is_delete_using", False
- ),
- ):
- self.statement = statement.returning(*statement.table.primary_key)
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect,
+ mapper,
+ is_multitable=self.is_multitable,
+ is_delete_using=compiler._annotations.get(
+ "is_delete_using", False
+ ),
+ )
+ )
+
+ if can_use_returning:
+ use_supplemental_cols = True
+
+ new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
return self
@classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Delete,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ "Bulk ORM DELETE not supported right now. "
+ "Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+
+ if update_options._dml_strategy not in (
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
+ )
+
+ return super().orm_execute_statement(
+ session, statement, params, execution_options, bind_arguments, conn
+ )
+
+ @classmethod
def can_use_returning(
cls,
dialect: Dialect,
@@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
return True
@classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
- session._remove_newly_deleted(
- [
- attributes.instance_state(obj)
- for obj in update_options._matched_objects
- ]
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
)
+ to_delete = []
+
+ for _, state, dict_, is_partially_expired in matched_objects:
+ if is_partially_expired:
+ state._expire(dict_, session.identity_map._modified)
+ else:
+ to_delete.append(state)
+
+ if to_delete:
+ session._remove_newly_deleted(to_delete)
+
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index dc96f8c3c..f8c7ba714 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from .query import Query
from .session import _BindArguments
from .session import Session
+ from ..engine import Result
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptionsParameter
from ..sql._typing import _ColumnsClauseArgument
@@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict(
class AbstractORMCompileState(CompileState):
+ is_dml_returning = False
+
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> AbstractORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
+
This method is always invoked in the context of SQLCompiler.process().
+
For a Select object, this would be invoked from
SQLCompiler.visit_select(). For the special FromStatement object used
by Query to indicate "Query.from_statement()", this is called by
@@ -233,6 +238,28 @@ class AbstractORMCompileState(CompileState):
raise NotImplementedError()
@classmethod
+ def orm_execute_statement(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ ) -> Result:
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ @classmethod
def orm_setup_cursor_result(
cls,
session,
@@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState):
def __init__(self, *arg, **kw):
raise NotImplementedError()
+ if TYPE_CHECKING:
+
+ @classmethod
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
+ ...
+
def _append_dedupe_col_collection(self, obj, col_collection):
dedupe = self.dedupe_columns
if obj not in dedupe:
@@ -333,26 +371,6 @@ class ORMCompileState(AbstractORMCompileState):
return SelectState._column_naming_convention(label_style)
@classmethod
- def create_for_statement(
- cls,
- statement: Union[Select, FromStatement],
- compiler: Optional[SQLCompiler],
- **kw: Any,
- ) -> ORMCompileState:
- """Create a context for a statement given a :class:`.Compiler`.
-
- This method is always invoked in the context of SQLCompiler.process().
-
- For a Select object, this would be invoked from
- SQLCompiler.visit_select(). For the special FromStatement object used
- by Query to indicate "Query.from_statement()", this is called by
- FromStatement._compiler_dispatch() that would be called by
- SQLCompiler.process().
-
- """
- raise NotImplementedError()
-
- @classmethod
def get_column_descriptions(cls, statement):
return _column_descriptions(statement)
@@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState):
)
+class DMLReturningColFilter:
+ """an adapter used for the DML RETURNING case.
+
+ Has a subset of the interface used by
+ :class:`.ORMAdapter` and is used for :class:`._QueryEntity`
+ instances to set up their columns as used in RETURNING for a
+ DML statement.
+
+ """
+
+ __slots__ = ("mapper", "columns", "__weakref__")
+
+ def __init__(self, target_mapper, immediate_dml_mapper):
+ if (
+ immediate_dml_mapper is not None
+ and target_mapper.local_table
+ is not immediate_dml_mapper.local_table
+ ):
+ # joined inh, or in theory other kinds of multi-table mappings
+ self.mapper = immediate_dml_mapper
+ else:
+ # single inh, normal mappings, etc.
+ self.mapper = target_mapper
+ self.columns = self.columns = util.WeakPopulateDict(
+ self.adapt_check_present # type: ignore
+ )
+
+ def __call__(self, col, as_filter):
+ for cc in sql_util._find_columns(col):
+ c2 = self.adapt_check_present(cc)
+ if c2 is not None:
+ return col
+ else:
+ return None
+
+ def adapt_check_present(self, col):
+ mapper = self.mapper
+ prop = mapper._columntoproperty.get(col, None)
+ if prop is None:
+ return None
+ return mapper.local_table.c.corresponding_column(col)
+
+
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
class ORMFromStatementCompileState(ORMCompileState):
_from_obj_alias = None
@@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: FromStatement
requested_statement: Union[SelectBase, TextClause, UpdateBase]
- dml_table: _DMLTableElement
+ dml_table: Optional[_DMLTableElement] = None
_has_orm_entities = False
multi_row_eager_loaders = False
@@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState):
statement_container: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMFromStatementCompileState:
if compiler is not None:
toplevel = not compiler.stack
@@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState):
if statement.is_dml:
self.dml_table = statement.table
+ self.is_dml_returning = True
self._entities = []
self._polymorphic_adapters = {}
@@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState):
def _get_current_adapter(self):
return None
+ def setup_dml_returning_compile_state(self, dml_mapper):
+ """used by BulkORMInsert (and Update / Delete?) to set up a handler
+ for RETURNING to return ORM objects and expressions
+
+ """
+ target_mapper = self.statement._propagate_attrs.get(
+ "plugin_subject", None
+ )
+ adapter = DMLReturningColFilter(target_mapper, dml_mapper)
+ for entity in self._entities:
+ entity.setup_dml_returning_compile_state(self, adapter)
+
class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
@@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMSelectCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
@@ -2312,6 +2386,13 @@ class _QueryEntity:
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ raise NotImplementedError()
+
def row_processor(self, context, result):
raise NotImplementedError()
@@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity):
return _instance, self._label_name, self._extra_entities
- def setup_compile_state(self, compile_state):
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ self,
+ self.path,
+ adapter,
+ compile_state.primary_columns,
+ with_polymorphic=self._with_polymorphic_mappers,
+ only_load_props=compile_state.compile_options._only_load_props,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+ def setup_compile_state(self, compile_state):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
@@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity):
only_load_props=compile_state.compile_options._only_load_props,
polymorphic_discriminator=self._polymorphic_discriminator,
)
-
compile_state._fallback_from_clauses.append(self.selectable)
@@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity):
getter, label_name, extra_entities = self._row_processor
if self.translate_raw_column:
extra_entities += (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, label_name, extra_entities
@@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity):
if self.translate_raw_column:
extra_entities = self._extra_entities + (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, self._label_name, extra_entities
else:
@@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ return
else:
column = self.column
@@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity):
self.entity_zero
) and entity.common_parent(self.entity_zero)
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ self._fetch_column = self.column
+ column = adapter(self.column, False)
+ if column is not None:
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+
def setup_compile_state(self, compile_state):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ assert compile_state.is_dml_returning
+ self._fetch_column = self.column
+ return
else:
column = self.column
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 52b70b9d4..13d3b70fe 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -19,6 +19,7 @@ import operator
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import List
from typing import NoReturn
from typing import Optional
@@ -602,6 +603,31 @@ class Composite(
def _attribute_keys(self) -> Sequence[str]:
return [prop.key for prop in self.props]
+ def _populate_composite_bulk_save_mappings_fn(
+ self,
+ ) -> Callable[[Dict[str, Any]], None]:
+
+ if self._generated_composite_accessor:
+ get_values = self._generated_composite_accessor
+ else:
+
+ def get_values(val: Any) -> Tuple[Any]:
+ return val.__composite_values__() # type: ignore
+
+ attrs = [prop.key for prop in self.props]
+
+ def populate(dest_dict: Dict[str, Any]) -> None:
+ dest_dict.update(
+ {
+ key: val
+ for key, val in zip(
+ attrs, get_values(dest_dict.pop(self.key))
+ )
+ }
+ )
+
+ return populate
+
def get_history(
self,
state: InstanceState[Any],
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
index b3129afdd..5af14cc00 100644
--- a/lib/sqlalchemy/orm/evaluator.py
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -9,8 +9,8 @@
from __future__ import annotations
-import operator
-
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
from .. import exc
from .. import inspect
from .. import util
@@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators):
return None
+class _ExpiredObject(operators.ColumnOperators):
+ def operate(self, *arg, **kw):
+ return self
+
+ def reverse_operate(self, *arg, **kw):
+ return self
+
+
_NO_OBJECT = _NoObject()
+_EXPIRED_OBJECT = _ExpiredObject()
class EvaluatorCompiler:
@@ -73,6 +82,24 @@ class EvaluatorCompiler:
f"alternate class {parentmapper.class_}"
)
key = parentmapper._columntoproperty[clause].key
+ impl = parentmapper.class_manager[key].impl
+
+ if impl is not None:
+
+ def get_corresponding_attr(obj):
+ if obj is None:
+ return _NO_OBJECT
+ state = inspect(obj)
+ dict_ = state.dict
+
+ value = impl.get(
+ state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
+ )
+ if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
+ return _EXPIRED_OBJECT
+ return value
+
+ return get_corresponding_attr
else:
key = clause.key
if (
@@ -85,15 +112,16 @@ class EvaluatorCompiler:
"make use of the actual mapped columns in ORM-evaluated "
"UPDATE / DELETE expressions."
)
+
else:
raise UnevaluatableError(f"Cannot evaluate column: {clause}")
- get_corresponding_attr = operator.attrgetter(key)
- return (
- lambda obj: get_corresponding_attr(obj)
- if obj is not None
- else _NO_OBJECT
- )
+ def get_corresponding_attr(obj):
+ if obj is None:
+ return _NO_OBJECT
+ return getattr(obj, key, _EXPIRED_OBJECT)
+
+ return get_corresponding_attr
def visit_tuple(self, clause):
return self.visit_clauselist(clause)
@@ -134,7 +162,9 @@ class EvaluatorCompiler:
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
- if value:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value:
return True
has_null = has_null or value is None
if has_null:
@@ -147,6 +177,9 @@ class EvaluatorCompiler:
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+
if not value:
if value is None or value is _NO_OBJECT:
return None
@@ -160,7 +193,9 @@ class EvaluatorCompiler:
values = []
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
- if value is None or value is _NO_OBJECT:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value is None or value is _NO_OBJECT:
return None
values.append(value)
return tuple(values)
@@ -183,13 +218,21 @@ class EvaluatorCompiler:
def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
- return eval_left(obj) == eval_right(obj)
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ return left_val == right_val
return evaluate
def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
- return eval_left(obj) != eval_right(obj)
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ return left_val != right_val
return evaluate
@@ -197,8 +240,11 @@ class EvaluatorCompiler:
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
- if left_val is None or right_val is None:
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif left_val is None or right_val is None:
return None
+
return operator(eval_left(obj), eval_right(obj))
return evaluate
@@ -274,7 +320,9 @@ class EvaluatorCompiler:
def evaluate(obj):
value = eval_inner(obj)
- if value is None:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value is None:
return None
return not value
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
index 63b131a78..4848f73f1 100644
--- a/lib/sqlalchemy/orm/identity.py
+++ b/lib/sqlalchemy/orm/identity.py
@@ -68,6 +68,11 @@ class IdentityMap:
) -> Optional[_O]:
raise NotImplementedError()
+ def fast_get_state(
+ self, key: _IdentityKeyType[_O]
+ ) -> Optional[InstanceState[_O]]:
+ raise NotImplementedError()
+
def keys(self) -> Iterable[_IdentityKeyType[Any]]:
return self._dict.keys()
@@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap):
self._dict[key] = state
state._instance_dict = self._wr
+ def fast_get_state(
+ self, key: _IdentityKeyType[_O]
+ ) -> Optional[InstanceState[_O]]:
+ return self._dict.get(key)
+
def get(
self, key: _IdentityKeyType[_O], default: Optional[_O] = None
) -> Optional[_O]:
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 7317d48be..64f2542fd 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -29,7 +29,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from sqlalchemy.orm.context import FromStatement
from . import attributes
from . import exc as orm_exc
from . import path_registry
@@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
from .base import PassiveFlag
+from .context import FromStatement
from .util import _none_set
from .util import state_str
from .. import exc as sa_exc
@@ -50,6 +50,7 @@ from ..sql import util as sql_util
from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import SelectState
+from ..util import EMPTY_DICT
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
@@ -764,7 +765,7 @@ def _instance_processor(
)
quick_populators = path.get(
- context.attributes, "memoized_setups", _none_set
+ context.attributes, "memoized_setups", EMPTY_DICT
)
todo = []
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index c8df51b06..c9cf8f49b 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -854,6 +854,7 @@ class Mapper(
_memoized_values: Dict[Any, Callable[[], Any]]
_inheriting_mappers: util.WeakSequence[Mapper[Any]]
_all_tables: Set[Table]
+ _polymorphic_attr_key: Optional[str]
_pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
_cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
@@ -1653,6 +1654,7 @@ class Mapper(
"""
setter = False
+ polymorphic_key: Optional[str] = None
if self.polymorphic_on is not None:
setter = True
@@ -1772,17 +1774,23 @@ class Mapper(
self._set_polymorphic_identity = (
mapper._set_polymorphic_identity
)
+ self._polymorphic_attr_key = (
+ mapper._polymorphic_attr_key
+ )
self._validate_polymorphic_identity = (
mapper._validate_polymorphic_identity
)
else:
self._set_polymorphic_identity = None
+ self._polymorphic_attr_key = None
return
if setter:
def _set_polymorphic_identity(state):
dict_ = state.dict
+ # TODO: what happens if polymorphic_on column attribute name
+ # does not match .key?
state.get_impl(polymorphic_key).set(
state,
dict_,
@@ -1790,6 +1798,8 @@ class Mapper(
None,
)
+ self._polymorphic_attr_key = polymorphic_key
+
def _validate_polymorphic_identity(mapper, state, dict_):
if (
polymorphic_key in dict_
@@ -1808,6 +1818,7 @@ class Mapper(
_validate_polymorphic_identity
)
else:
+ self._polymorphic_attr_key = None
self._set_polymorphic_identity = None
_validate_polymorphic_identity = None
@@ -3562,6 +3573,10 @@ class Mapper(
return util.LRUCache(self._compiled_cache_size)
@HasMemoized.memoized_attribute
+ def _multiple_persistence_tables(self):
+ return len(self.tables) > 1
+
+ @HasMemoized.memoized_attribute
def _sorted_tables(self):
table_to_mapper: Dict[Table, Mapper[Any]] = {}
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index abd528986..dfb61c28a 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -31,6 +31,7 @@ from .. import exc as sa_exc
from .. import future
from .. import sql
from .. import util
+from ..engine import cursor as _cursor
from ..sql import operators
from ..sql.elements import BooleanClauseList
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -398,6 +399,11 @@ def _collect_insert_commands(
None
)
+ if bulk and mapper._set_polymorphic_identity:
+ params.setdefault(
+ mapper._polymorphic_attr_key, mapper.polymorphic_identity
+ )
+
yield (
state,
state_dict,
@@ -411,7 +417,11 @@ def _collect_insert_commands(
def _collect_update_commands(
- uowtransaction, table, states_to_update, bulk=False
+ uowtransaction,
+ table,
+ states_to_update,
+ bulk=False,
+ use_orm_update_stmt=None,
):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -437,7 +447,11 @@ def _collect_update_commands(
pks = mapper._pks_by_table[table]
- value_params = {}
+ if use_orm_update_stmt is not None:
+ # TODO: ordered values, etc
+ value_params = use_orm_update_stmt._values
+ else:
+ value_params = {}
propkey_to_col = mapper._propkey_to_col[table]
@@ -697,6 +711,7 @@ def _emit_update_statements(
table,
update,
bookkeeping=True,
+ use_orm_update_stmt=None,
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
@@ -708,7 +723,7 @@ def _emit_update_statements(
execution_options = {"compiled_cache": base_mapper._compiled_cache}
- def update_stmt():
+ def update_stmt(existing_stmt=None):
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
@@ -725,10 +740,17 @@ def _emit_update_statements(
)
)
- stmt = table.update().where(clauses)
+ if existing_stmt is not None:
+ stmt = existing_stmt.where(clauses)
+ else:
+ stmt = table.update().where(clauses)
return stmt
- cached_stmt = base_mapper._memo(("update", table), update_stmt)
+ if use_orm_update_stmt is not None:
+ cached_stmt = update_stmt(use_orm_update_stmt)
+
+ else:
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
for (
(connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
@@ -747,6 +769,15 @@ def _emit_update_statements(
records = list(records)
statement = cached_stmt
+
+ if use_orm_update_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_update_table": table,
+ "_emit_update_mapper": mapper,
+ }
+ )
+
return_defaults = False
if not has_all_pks:
@@ -904,16 +935,35 @@ def _emit_insert_statements(
table,
insert,
bookkeeping=True,
+ use_orm_insert_stmt=None,
+ execution_options=None,
):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ if use_orm_insert_stmt is not None:
+ cached_stmt = use_orm_insert_stmt
+ exec_opt = util.EMPTY_DICT
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
+ # if a user query with RETURNING was passed, we definitely need
+ # to use RETURNING.
+ returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
+ else:
+ returning_is_required_anyway = False
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ exec_opt = {"compiled_cache": base_mapper._compiled_cache}
+
+ if execution_options:
+ execution_options = util.EMPTY_DICT.merge_with(
+ exec_opt, execution_options
+ )
+ else:
+ execution_options = exec_opt
+
+ return_result = None
for (
- (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ (connection, _, hasvalue, has_all_pks, has_all_defaults),
records,
) in groupby(
insert,
@@ -928,17 +978,29 @@ def _emit_insert_statements(
statement = cached_stmt
+ if use_orm_insert_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_insert_table": table,
+ "_emit_insert_mapper": mapper,
+ }
+ )
+
if (
- not bookkeeping
- or (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not base_mapper.local_table.implicit_returning
- or not connection.dialect.insert_returning
+ (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not base_mapper.local_table.implicit_returning
+ or not connection.dialect.insert_returning
+ )
)
+ and not returning_is_required_anyway
and has_all_pks
and not hasvalue
):
+
# the "we don't need newly generated values back" section.
# here we have all the PKs, all the defaults or we don't want
# to fetch them, or the dialect doesn't support RETURNING at all
@@ -946,7 +1008,7 @@ def _emit_insert_statements(
records = list(records)
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
if bookkeeping:
@@ -962,7 +1024,7 @@ def _emit_insert_statements(
has_all_defaults,
),
last_inserted_params,
- ) in zip(records, c.context.compiled_parameters):
+ ) in zip(records, result.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
@@ -970,19 +1032,20 @@ def _emit_insert_statements(
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
- c.returned_defaults
- if not c.context.executemany
+ result.returned_defaults
+ if not result.context.executemany
else None,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
- # here, we need defaults and/or pk values back.
+ # here, we need defaults and/or pk values back or we otherwise
+ # know that we are using RETURNING in any case
records = list(records)
if (
@@ -991,6 +1054,16 @@ def _emit_insert_statements(
and len(records) > 1
):
do_executemany = True
+ elif returning_is_required_anyway:
+ if connection.dialect.insert_executemany_returning:
+ do_executemany = True
+ else:
+ raise sa_exc.InvalidRequestError(
+ f"Can't use explicit RETURNING for bulk INSERT "
+ f"operation with "
+ f"{connection.dialect.dialect_description} backend; "
+ f"executemany is not supported with RETURNING"
+ )
else:
do_executemany = False
@@ -998,6 +1071,7 @@ def _emit_insert_statements(
statement = statement.return_defaults(
*mapper._server_default_cols[table]
)
+
if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
elif do_executemany:
@@ -1006,10 +1080,16 @@ def _emit_insert_statements(
if do_executemany:
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return_result = result
+ else:
+ return_result = return_result.splice_vertically(result)
+
if bookkeeping:
for (
(
@@ -1027,9 +1107,9 @@ def _emit_insert_statements(
returned_defaults,
) in zip_longest(
records,
- c.context.compiled_parameters,
- c.inserted_primary_key_rows,
- c.returned_defaults_rows or (),
+ result.context.compiled_parameters,
+ result.inserted_primary_key_rows,
+ result.returned_defaults_rows or (),
):
if inserted_primary_key is None:
# this is a real problem and means that we didn't
@@ -1062,7 +1142,7 @@ def _emit_insert_statements(
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
@@ -1071,6 +1151,8 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
+ assert not returning_is_required_anyway
+
for (
state,
state_dict,
@@ -1132,6 +1214,12 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return _cursor.null_dml_result()
+ else:
+ return return_result
+
def _emit_post_update_statements(
base_mapper, uowtransaction, mapper, table, update
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 6d0f055e4..4d5a98fcf 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -2978,7 +2978,7 @@ class Query(
)
def delete(
- self, synchronize_session: _SynchronizeSessionArgument = "evaluate"
+ self, synchronize_session: _SynchronizeSessionArgument = "auto"
) -> int:
r"""Perform a DELETE with an arbitrary WHERE clause.
@@ -3042,7 +3042,7 @@ class Query(
def update(
self,
values: Dict[_DMLColumnArgument, Any],
- synchronize_session: _SynchronizeSessionArgument = "evaluate",
+ synchronize_session: _SynchronizeSessionArgument = "auto",
update_args: Optional[Dict[Any, Any]] = None,
) -> int:
r"""Perform an UPDATE with an arbitrary WHERE clause.
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index a690da0d5..64c013306 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget):
statement._propagate_attrs.get("compile_state_plugin", None)
== "orm"
):
- # note that even without "future" mode, we need
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
if TYPE_CHECKING:
- assert isinstance(compile_state_cls, ORMCompileState)
+ assert isinstance(
+ compile_state_cls, context.AbstractORMCompileState
+ )
else:
compile_state_cls = None
@@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget):
statement, params or {}, execution_options=execution_options
)
- result: Result[Any] = conn.execute(
- statement, params or {}, execution_options=execution_options
- )
-
if compile_state_cls:
- result = compile_state_cls.orm_setup_cursor_result(
+ result: Result[Any] = compile_state_cls.orm_execute_statement(
self,
statement,
- params,
+ params or {},
execution_options,
bind_arguments,
- result,
+ conn,
+ )
+ else:
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
)
if _scalar_result:
@@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
@@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: Executable,
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
@@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget):
def scalars(
self,
statement: Executable,
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 19c6493db..8652591c8 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy):
fetch = self.columns[0]
if adapter:
fetch = adapter.columns[fetch]
+ if fetch is None:
+ # None happens here only for dml bulk_persistence cases
+ # when context.DMLReturningColFilter is used
+ return
+
memoized_populators[self.parent_property] = fetch
def init_class_attribute(self, mapper):
@@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader):
fetch = columns[0]
if adapter:
fetch = adapter.columns[fetch]
+ if fetch is None:
+ # None is not expected to be the result of any
+ # adapter implementation here, however there may be theoretical
+ # usages of returning() with context.DMLReturningColFilter
+ return
+
memoized_populators[self.parent_property] = fetch
def create_row_processor(