diff options
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 37 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 16 |
2 files changed, 35 insertions, 18 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 198eeb46f..2a697a6f9 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,7 +23,8 @@ from ..sql import expression from . import loading -def _bulk_insert(mapper, mappings, session_transaction, isstates): +def _bulk_insert( + mapper, mappings, session_transaction, isstates, return_defaults): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -34,7 +35,11 @@ def _bulk_insert(mapper, mappings, session_transaction, isstates): "not supported in bulk_insert()") if isstates: - mappings = [state.dict for state in mappings] + if return_defaults: + states = [(state, state.dict) for state in mappings] + mappings = [dict_ for (state, dict_) in states] + else: + mappings = [state.dict for state in mappings] else: mappings = list(mappings) @@ -44,22 +49,30 @@ def _bulk_insert(mapper, mappings, session_transaction, isstates): continue records = ( - (None, None, params, super_mapper, - connection, value_params, True, True) + (None, state_dict, params, super_mapper, + connection, value_params, has_all_pks, has_all_defaults) for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(table, ( (None, mapping, mapper, connection) for mapping in mappings), - bulk=True + bulk=True, return_defaults=return_defaults ) ) - _emit_insert_statements(base_mapper, None, cached_connections, super_mapper, table, records, - bookkeeping=False) + bookkeeping=return_defaults) + + if return_defaults and isstates: + identity_cls = mapper._identity_class + identity_props = [p.key for p in mapper._identity_key_props] + for state, dict_ in states: + state.key = ( + identity_cls, + tuple([dict_[key] for key in identity_props]) + ) def _bulk_update(mapper, mappings, session_transaction, isstates): @@ -341,7 +354,9 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): state, dict_, mapper, connection, update_version_id) -def _collect_insert_commands(table, states_to_insert, bulk=False): +def _collect_insert_commands( + table, states_to_insert, + bulk=False, return_defaults=False): """Identify sets of values to use in INSERT statements for a list of states. @@ -370,6 +385,7 @@ def _collect_insert_commands(table, states_to_insert, bulk=False): difference(params).difference(value_params): params[colkey] = None + if not bulk or return_defaults: has_all_pks = mapper._pk_keys_by_table[table].issubset(params) if mapper.base_mapper.eager_defaults: @@ -884,9 +900,8 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): toload_now.extend(state._unloaded_non_object) elif mapper.version_id_col is not None and \ mapper.version_id_generator is False: - prop = mapper._columntoproperty[mapper.version_id_col] - if prop.key in state.unloaded: - toload_now.extend([prop.key]) + if mapper._version_id_prop.key in state.unloaded: + toload_now.extend([mapper._version_id_prop.key]) if toload_now: state.key = base_mapper._identity_key_from_state(state) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e075b9c71..1611688b0 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2036,20 +2036,22 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) - def bulk_save_objects(self, objects): + def bulk_save_objects(self, objects, return_defaults=False): for (mapper, isupdate), states in itertools.groupby( (attributes.instance_state(obj) for obj in objects), lambda state: (state.mapper, state.key is not None) ): - self._bulk_save_mappings(mapper, states, isupdate, True) + self._bulk_save_mappings( + mapper, states, isupdate, True, return_defaults) - def bulk_insert_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, False, False) + def bulk_insert_mappings(self, mapper, mappings, return_defaults=False): + self._bulk_save_mappings(mapper, mappings, False, False, return_defaults) def bulk_update_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, True, False) + self._bulk_save_mappings(mapper, mappings, True, False, False) - def _bulk_save_mappings(self, mapper, mappings, isupdate, isstates): + def _bulk_save_mappings( + self, mapper, mappings, isupdate, isstates, return_defaults): mapper = _class_to_mapper(mapper) self._flushing = True @@ -2061,7 +2063,7 @@ class Session(_SessionClassMethods): mapper, mappings, transaction, isstates) else: persistence._bulk_insert( - mapper, mappings, transaction, isstates) + mapper, mappings, transaction, isstates, return_defaults) transaction.commit() except: |