diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2023-04-27 00:13:00 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-04-27 00:13:00 +0000 |
| commit | 11535752b94acb41ff684cf8d9c745038addc447 (patch) | |
| tree | fec970fe35228d33c45280ad558ed9bc251b5208 /lib/sqlalchemy | |
| parent | c89c2b3d9a18bd0eb4c8ace50ef875101c9f4b70 (diff) | |
| parent | 8ec396873c9bbfcc4416e55b5f9d8653554a1df0 (diff) | |
| download | sqlalchemy-11535752b94acb41ff684cf8d9c745038addc447.tar.gz | |
Merge "support parameters in all ORM insert modes" into main
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/bulk_persistence.py | 77 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 22 |
3 files changed, 107 insertions, 23 deletions
diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 8388d3980..cb416d69e 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -150,6 +150,23 @@ def _bulk_insert( for table, super_mapper in mappers_to_run: + # find bindparams in the statement. For bulk, we don't really know if + # a key in the params applies to a different table since we are + # potentially inserting for multiple tables here; looking at the + # bindparam() is a lot more direct. in most cases this will + # use _generate_cache_key() which is memoized, although in practice + # the ultimate statement that's executed is probably not the same + # object so that memoization might not matter much. + extra_bp_names = ( + [ + b.key + for b in use_orm_insert_stmt._get_embedded_bindparams() + if b.key in mappings[0] + ] + if use_orm_insert_stmt is not None + else () + ) + records = ( ( None, @@ -176,6 +193,7 @@ def _bulk_insert( bulk=True, return_defaults=bookkeeping, render_nulls=render_nulls, + include_bulk_keys=extra_bp_names, ) ) @@ -218,6 +236,7 @@ def _bulk_update( isstates: bool, update_changed_only: bool, use_orm_update_stmt: Literal[None] = ..., + enable_check_rowcount: bool = True, ) -> None: ... @@ -230,6 +249,7 @@ def _bulk_update( isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., + enable_check_rowcount: bool = True, ) -> _result.Result[Any]: ... @@ -241,6 +261,7 @@ def _bulk_update( isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = None, + enable_check_rowcount: bool = True, ) -> Optional[_result.Result[Any]]: base_mapper = mapper.base_mapper @@ -272,6 +293,18 @@ def _bulk_update( connection = session_transaction.connection(base_mapper) + # find bindparams in the statement. see _bulk_insert for similar + # notes for the insert case + extra_bp_names = ( + [ + b.key + for b in use_orm_update_stmt._get_embedded_bindparams() + if b.key in mappings[0] + ] + if use_orm_update_stmt is not None + else () + ) + for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: continue @@ -295,6 +328,7 @@ def _bulk_update( ), bulk=True, use_orm_update_stmt=use_orm_update_stmt, + include_bulk_keys=extra_bp_names, ) persistence._emit_update_statements( base_mapper, @@ -304,6 +338,7 @@ def _bulk_update( records, bookkeeping=False, use_orm_update_stmt=use_orm_update_stmt, + enable_check_rowcount=enable_check_rowcount, ) if use_orm_update_stmt is not None: @@ -588,6 +623,7 @@ class BulkUDCompileState(ORMDMLState): is_multitable: bool = False, is_update_from: bool = False, is_delete_using: bool = False, + is_executemany: bool = False, ) -> bool: raise NotImplementedError() @@ -639,11 +675,6 @@ class BulkUDCompileState(ORMDMLState): else: if update_options._dml_strategy == "auto": update_options += {"_dml_strategy": "bulk"} - elif update_options._dml_strategy == "orm": - raise sa_exc.InvalidRequestError( - 'Can\'t use "orm" ORM insert strategy with a ' - "separate parameter list" - ) sync = update_options._synchronize_session if sync is not None: @@ -1062,6 +1093,7 @@ class BulkUDCompileState(ORMDMLState): mapper, is_update_from=update_options._is_update_from, is_delete_using=update_options._is_delete_using, + is_executemany=orm_context.is_executemany, ) if can_use_returning is not None: @@ -1071,6 +1103,12 @@ class BulkUDCompileState(ORMDMLState): "backends where some support RETURNING and others " "don't" ) + elif orm_context.is_executemany and not per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't use multiple " + "parameter sets in ORM mode, which this backend does not " + "support with RETURNING" + ) else: can_use_returning = per_bind_result @@ -1146,11 +1184,6 @@ class BulkORMInsert(ORMDMLState, InsertDMLState): else: if insert_options._dml_strategy == "auto": insert_options += {"_dml_strategy": "bulk"} - elif insert_options._dml_strategy == "orm": - raise sa_exc.InvalidRequestError( - 'Can\'t use "orm" ORM insert strategy with a ' - "separate parameter list" - ) if insert_options._dml_strategy != "raw": # for ORM object loading, like ORMContext, we have to disable @@ -1512,12 +1545,20 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): result: _result.Result[Any] if update_options._dml_strategy == "bulk": - if statement._where_criteria: + enable_check_rowcount = not statement._where_criteria + + assert update_options._synchronize_session != "fetch" + + if ( + statement._where_criteria + and update_options._synchronize_session == "evaluate" + ): raise sa_exc.InvalidRequestError( - "WHERE clause with bulk ORM UPDATE not " - "supported right now. Statement may be invoked at the " - "Core level using " - "session.connection().execute(stmt, parameters)" + "bulk synchronize of persistent objects not supported " + "when using bulk update with additional WHERE " + "criteria right now. add synchronize_session=None " + "execution option to bypass synchronize of persistent " + "objects." ) mapper = update_options._subject_mapper assert mapper is not None @@ -1532,6 +1573,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): isstates=False, update_changed_only=False, use_orm_update_stmt=statement, + enable_check_rowcount=enable_check_rowcount, ) return cls.orm_setup_cursor_result( session, @@ -1560,6 +1602,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): is_multitable: bool = False, is_update_from: bool = False, is_delete_using: bool = False, + is_executemany: bool = False, ) -> bool: # normal answer for "should we use RETURNING" at all. @@ -1569,6 +1612,9 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): if not normal_answer: return False + if is_executemany: + return dialect.update_executemany_returning + # these workarounds are currently hypothetical for UPDATE, # unlike DELETE where they impact MariaDB if is_update_from: @@ -1869,6 +1915,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): is_multitable: bool = False, is_update_from: bool = False, is_delete_using: bool = False, + is_executemany: bool = False, ) -> bool: # normal answer for "should we use RETURNING" at all. diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 1af55df00..6fa338ced 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -326,9 +326,11 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): def _collect_insert_commands( table, states_to_insert, + *, bulk=False, return_defaults=False, render_nulls=False, + include_bulk_keys=(), ): """Identify sets of values to use in INSERT statements for a list of states. @@ -401,10 +403,14 @@ def _collect_insert_commands( None ) - if bulk and mapper._set_polymorphic_identity: - params.setdefault( - mapper._polymorphic_attr_key, mapper.polymorphic_identity - ) + if bulk: + if mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + + if include_bulk_keys: + params.update((k, state_dict[k]) for k in include_bulk_keys) yield ( state, @@ -422,8 +428,10 @@ def _collect_update_commands( uowtransaction, table, states_to_update, + *, bulk=False, use_orm_update_stmt=None, + include_bulk_keys=(), ): """Identify sets of values to use in UPDATE statements for a list of states. @@ -581,6 +589,9 @@ def _collect_update_commands( "key value on column %s" % (table, col) ) + if include_bulk_keys: + params.update((k, state_dict[k]) for k in include_bulk_keys) + if params or value_params: params.update(pk_params) yield ( @@ -712,8 +723,10 @@ def _emit_update_statements( mapper, table, update, + *, bookkeeping=True, use_orm_update_stmt=None, + enable_check_rowcount=True, ): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -847,10 +860,10 @@ def _emit_update_statements( c.returned_defaults, ) rows += c.rowcount - check_rowcount = assert_singlerow + check_rowcount = enable_check_rowcount and assert_singlerow else: if not allow_executemany: - check_rowcount = assert_singlerow + check_rowcount = enable_check_rowcount and assert_singlerow for ( state, state_dict, @@ -883,8 +896,9 @@ def _emit_update_statements( else: multiparams = [rec[2] for rec in records] - check_rowcount = assert_multirow or ( - assert_singlerow and len(multiparams) == 1 + check_rowcount = enable_check_rowcount and ( + assert_multirow + or (assert_singlerow and len(multiparams) == 1) ) c = connection.execute( @@ -941,6 +955,7 @@ def _emit_insert_statements( mapper, table, insert, + *, bookkeeping=True, use_orm_insert_stmt=None, execution_options=None, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ff47ec79d..2e32da754 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -502,6 +502,28 @@ class ClauseElement( connection, distilled_params, execution_options ).scalar() + def _get_embedded_bindparams(self) -> Sequence[BindParameter[Any]]: + """Return the list of :class:`.BindParameter` objects embedded in the + object. + + This accomplishes the same purpose as ``visitors.traverse()`` or + similar would provide, however by making use of the cache key + it takes advantage of memoization of the key to result in fewer + net method calls, assuming the statement is also going to be + executed. + + """ + + key = self._generate_cache_key() + if key is None: + bindparams: List[BindParameter[Any]] = [] + + traverse(self, {}, {"bindparam": bindparams.append}) + return bindparams + + else: + return key.bindparams + def unique_params( self, __optionaldict: Optional[Dict[str, Any]] = None, |
