summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-08-20 14:24:45 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-08-20 14:24:45 -0400
commit85e75ebcee15f216ace71628f1e491e36663d5c8 (patch)
tree30e41b994471fa3a0e1bbf792ad70c1d78b15eb1
parent92b0ad0fef0b9ee3d54767cf17e2baf1fd1546da (diff)
downloadsqlalchemy-85e75ebcee15f216ace71628f1e491e36663d5c8.tar.gz
- factor out determination of current version id out of
_collect_update_commands and _collect_delete_commands
-rw-r--r--lib/sqlalchemy/orm/persistence.py110
1 files changed, 55 insertions, 55 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 37b696d0f..511a9cef0 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -45,38 +45,26 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
cached_connections = _cached_connection_dict(base_mapper)
for (state, dict_, mapper, connection,
- has_identity, row_switch) in _organize_states_for_save(
+ has_identity,
+ row_switch, update_version_id) in _organize_states_for_save(
base_mapper, states, uowtransaction
):
if has_identity or row_switch:
states_to_update.append(
- (state, dict_, mapper, connection,
- has_identity, row_switch)
+ (state, dict_, mapper, connection, update_version_id)
)
else:
states_to_insert.append(
- (state, dict_, mapper, connection,
- has_identity, row_switch)
+ (state, dict_, mapper, connection)
)
for table, mapper in base_mapper._sorted_tables.items():
if table not in mapper._pks_by_table:
continue
- insert = (
- (state, state_dict, sub_mapper, connection)
- for state, state_dict, sub_mapper, connection, has_identity,
- row_switch in states_to_insert
- if table in sub_mapper._pks_by_table
- )
- insert = _collect_insert_commands(table, insert)
+ insert = _collect_insert_commands(table, states_to_insert)
- update = (
- (state, state_dict, sub_mapper, connection, row_switch)
- for state, state_dict, sub_mapper, connection, has_identity,
- row_switch in states_to_update
- if table in sub_mapper._pks_by_table
- )
- update = _collect_update_commands(uowtransaction, table, update)
+ update = _collect_update_commands(
+ uowtransaction, table, states_to_update)
_emit_update_statements(base_mapper, uowtransaction,
cached_connections,
@@ -89,9 +77,16 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
_finalize_insert_update_commands(
base_mapper, uowtransaction,
(
- (state, state_dict, mapper, connection, has_identity)
- for state, state_dict, mapper, connection, has_identity,
- row_switch in states_to_insert + states_to_update
+ (state, state_dict, mapper, connection, False)
+ for state, state_dict, mapper, connection in states_to_insert
+ )
+ )
+ _finalize_insert_update_commands(
+ base_mapper, uowtransaction,
+ (
+ (state, state_dict, mapper, connection, True)
+ for state, state_dict, mapper, connection,
+ update_version_id in states_to_update
)
)
@@ -149,21 +144,14 @@ def delete_obj(base_mapper, states, uowtransaction):
if table not in mapper._pks_by_table:
continue
- delete = (
- (state, state_dict, sub_mapper, connection)
- for state, state_dict, sub_mapper, has_identity, connection
- in states_to_delete if table in sub_mapper._pks_by_table
- and has_identity
- )
-
delete = _collect_delete_commands(base_mapper, uowtransaction,
- table, delete)
+ table, states_to_delete)
_emit_delete_statements(base_mapper, uowtransaction,
cached_connections, mapper, table, delete)
- for state, state_dict, mapper, has_identity, connection \
- in states_to_delete:
+ for state, state_dict, mapper, connection, \
+ update_version_id in states_to_delete:
mapper.dispatch.after_delete(mapper, connection, state)
@@ -187,7 +175,7 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
instance_key = state.key or mapper._identity_key_from_state(state)
- row_switch = None
+ row_switch = update_version_id = None
# call before_XXX extensions
if not has_identity:
@@ -224,8 +212,14 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
uowtransaction.remove_state_actions(existing)
row_switch = existing
+ if (has_identity or row_switch) and mapper.version_id_col is not None:
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ row_switch if row_switch else state,
+ row_switch.dict if row_switch else dict_,
+ mapper.version_id_col)
+
yield (state, dict_, mapper, connection,
- has_identity, row_switch)
+ has_identity, row_switch, update_version_id)
def _organize_states_for_post_update(base_mapper, states,
@@ -255,7 +249,16 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
mapper.dispatch.before_delete(mapper, connection, state)
- yield state, dict_, mapper, bool(state.key), connection
+ if mapper.version_id_col is not None:
+ update_version_id = \
+ mapper._get_committed_state_attr_by_column(
+ state, dict_,
+ mapper.version_id_col)
+ else:
+ update_version_id = None
+
+ yield (
+ state, dict_, mapper, connection, update_version_id)
def _collect_insert_commands(table, states_to_insert):
@@ -264,8 +267,8 @@ def _collect_insert_commands(table, states_to_insert):
"""
for state, state_dict, mapper, connection in states_to_insert:
-
- # assert table in mapper._pks_by_table
+ if table not in mapper._pks_by_table:
+ continue
params = {}
value_params = {}
@@ -318,9 +321,11 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
"""
- for state, state_dict, mapper, connection, row_switch in states_to_update:
+ for state, state_dict, mapper, connection, \
+ update_version_id in states_to_update:
- # assert table in mapper._pks_by_table
+ if table not in mapper._pks_by_table:
+ continue
pks = mapper._pks_by_table[table]
@@ -340,17 +345,13 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
else:
params[col.key] = value
- if mapper.version_id_col is not None:
+ if update_version_id is not None:
col = mapper.version_id_col
- params[col._label] = \
- mapper._get_committed_state_attr_by_column(
- row_switch if row_switch else state,
- row_switch.dict if row_switch else state_dict,
- col)
+ params[col._label] = update_version_id
if col.key not in params and \
mapper.version_id_generator is not False:
- val = mapper.version_id_generator(params[col._label])
+ val = mapper.version_id_generator(update_version_id)
params[col.key] = val
if not (params or value_params):
@@ -364,7 +365,8 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
if history.added:
if not history.deleted or \
- ("pk_cascaded", state, col) in uowtransaction.attributes:
+ ("pk_cascaded", state, col) in \
+ uowtransaction.attributes:
pk_params[col._label] = history.added[0]
params.pop(col.key, None)
else:
@@ -374,7 +376,6 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
else:
pk_params[col._label] = history.unchanged[0]
-
if params or value_params:
if None in pk_params.values():
raise orm_exc.FlushError(
@@ -426,9 +427,11 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
- for state, state_dict, mapper, connection in states_to_delete:
+ for state, state_dict, mapper, connection, \
+ update_version_id in states_to_delete:
- # assert table in mapper._pks_by_table
+ if table not in mapper._pks_by_table:
+ continue
params = {}
for col in mapper._pks_by_table[table]:
@@ -442,12 +445,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
"using NULL for primary "
"key value")
- if mapper.version_id_col is not None and \
+ if update_version_id is not None and \
table.c.contains_column(mapper.version_id_col):
- params[mapper.version_id_col.key] = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict,
- mapper.version_id_col)
+ params[mapper.version_id_col.key] = update_version_id
yield params, connection