diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 156 |
1 files changed, 117 insertions, 39 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 88524dc49..cbe7bde33 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -960,6 +960,7 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults, ) rows += c.rowcount check_rowcount = assert_singlerow @@ -992,6 +993,7 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults, ) rows += c.rowcount else: @@ -1028,6 +1030,9 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults + if not c.context.executemany + else None, ) if check_rowcount: @@ -1086,7 +1091,10 @@ def _emit_insert_statements( and has_all_pks and not hasvalue ): - + # the "we don't need newly generated values back" section. + # here we have all the PKs, all the defaults or we don't want + # to fetch them, or the dialect doesn't support RETURNING at all + # so we have to post-fetch / use lastrowid anyway. records = list(records) multiparams = [rec[2] for rec in records] @@ -1116,63 +1124,132 @@ def _emit_insert_statements( last_inserted_params, value_params, False, + c.returned_defaults + if not c.context.executemany + else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + # here, we need defaults and/or pk values back. + + records = list(records) + if ( + not hasvalue + and connection.dialect.insert_executemany_returning + and len(records) > 1 + ): + do_executemany = True + else: + do_executemany = False + if not has_all_defaults and base_mapper.eager_defaults: statement = statement.return_defaults() elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) + elif do_executemany: + statement = statement.return_defaults(*table.primary_key) - for ( - state, - state_dict, - params, - mapper_rec, - connection, - value_params, - has_all_pks, - has_all_defaults, - ) in records: + if do_executemany: + multiparams = [rec[2] for rec in records] - if value_params: - result = connection.execute( - statement.values(value_params), params - ) - else: - result = cached_connections[connection].execute( - statement, params - ) + c = cached_connections[connection].execute( + statement, multiparams + ) + if bookkeeping: + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + inserted_primary_key, + returned_defaults, + ) in util.zip_longest( + records, + c.context.compiled_parameters, + c.inserted_primary_key_rows, + c.returned_defaults_rows or (), + ): + for pk, col in zip( + inserted_primary_key, mapper._pks_by_table[table], + ): + prop = mapper_rec._columntoproperty[col] + if state_dict.get(prop.key) is None: + state_dict[prop.key] = pk + + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params, + False, + returned_defaults, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + else: + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: + + if value_params: + result = connection.execute( + statement.values(value_params), params + ) + else: + result = cached_connections[connection].execute( + statement, params + ) - primary_key = result.inserted_primary_key - if primary_key is not None: - # set primary key attributes + primary_key = result.inserted_primary_key + assert primary_key for pk, col in zip( primary_key, mapper._pks_by_table[table] ): prop = mapper_rec._columntoproperty[col] - if pk is not None and ( + if ( col in value_params or state_dict.get(prop.key) is None ): state_dict[prop.key] = pk - if bookkeeping: - if state: - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - result.context.compiled_parameters[0], - value_params, - False, - ) - else: - _postfetch_bulk_save(mapper_rec, state_dict, table) + if bookkeeping: + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params, + False, + result.returned_defaults + if not result.context.executemany + else None, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) def _emit_post_update_statements( @@ -1507,6 +1584,7 @@ def _postfetch( params, value_params, isupdate, + returned_defaults, ): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that @@ -1527,7 +1605,7 @@ def _postfetch( load_evt_attrs = [] if returning_cols: - row = result.returned_defaults + row = returned_defaults if row is not None: for row_value, col in zip(row, returning_cols): # pk cols returned from insert are handled |
