summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py547
1 files changed, 281 insertions, 266 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 295d4a3d0..511a9cef0 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -18,7 +18,7 @@ import operator
from itertools import groupby
from .. import sql, util, exc as sa_exc, schema
from . import attributes, sync, exc as orm_exc, evaluator
-from .base import _state_mapper, state_str, _attr_as_key
+from .base import state_str, _attr_as_key
from ..sql import expression
from . import loading
@@ -40,32 +40,55 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
save_obj(base_mapper, [state], uowtransaction, single=True)
return
- states_to_insert, states_to_update = _organize_states_for_save(
- base_mapper,
- states,
- uowtransaction)
-
+ states_to_update = []
+ states_to_insert = []
cached_connections = _cached_connection_dict(base_mapper)
- for table, mapper in base_mapper._sorted_tables.items():
- insert = _collect_insert_commands(base_mapper, uowtransaction,
- table, states_to_insert)
-
- update = _collect_update_commands(base_mapper, uowtransaction,
- table, states_to_update)
-
- if update:
- _emit_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
-
- if insert:
- _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, insert)
+ for (state, dict_, mapper, connection,
+ has_identity,
+ row_switch, update_version_id) in _organize_states_for_save(
+ base_mapper, states, uowtransaction
+ ):
+ if has_identity or row_switch:
+ states_to_update.append(
+ (state, dict_, mapper, connection, update_version_id)
+ )
+ else:
+ states_to_insert.append(
+ (state, dict_, mapper, connection)
+ )
- _finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update)
+ for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
+ insert = _collect_insert_commands(table, states_to_insert)
+
+ update = _collect_update_commands(
+ uowtransaction, table, states_to_update)
+
+ _emit_update_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, update)
+
+ _emit_insert_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, insert)
+
+ _finalize_insert_update_commands(
+ base_mapper, uowtransaction,
+ (
+ (state, state_dict, mapper, connection, False)
+ for state, state_dict, mapper, connection in states_to_insert
+ )
+ )
+ _finalize_insert_update_commands(
+ base_mapper, uowtransaction,
+ (
+ (state, state_dict, mapper, connection, True)
+ for state, state_dict, mapper, connection,
+ update_version_id in states_to_update
+ )
+ )
def post_update(base_mapper, states, uowtransaction, post_update_cols):
@@ -75,19 +98,28 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
"""
cached_connections = _cached_connection_dict(base_mapper)
- states_to_update = _organize_states_for_post_update(
+ states_to_update = list(_organize_states_for_post_update(
base_mapper,
- states, uowtransaction)
+ states, uowtransaction))
for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
+
+ update = (
+ (state, state_dict, sub_mapper, connection)
+ for
+ state, state_dict, sub_mapper, connection in states_to_update
+ if table in sub_mapper._pks_by_table
+ )
+
update = _collect_post_update_commands(base_mapper, uowtransaction,
- table, states_to_update,
+ table, update,
post_update_cols)
- if update:
- _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_post_update_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, update)
def delete_obj(base_mapper, states, uowtransaction):
@@ -100,24 +132,26 @@ def delete_obj(base_mapper, states, uowtransaction):
cached_connections = _cached_connection_dict(base_mapper)
- states_to_delete = _organize_states_for_delete(
+ states_to_delete = list(_organize_states_for_delete(
base_mapper,
states,
- uowtransaction)
+ uowtransaction))
table_to_mapper = base_mapper._sorted_tables
for table in reversed(list(table_to_mapper.keys())):
+ mapper = table_to_mapper[table]
+ if table not in mapper._pks_by_table:
+ continue
+
delete = _collect_delete_commands(base_mapper, uowtransaction,
table, states_to_delete)
- mapper = table_to_mapper[table]
-
_emit_delete_statements(base_mapper, uowtransaction,
cached_connections, mapper, table, delete)
- for state, state_dict, mapper, has_identity, connection \
- in states_to_delete:
+ for state, state_dict, mapper, connection, \
+ update_version_id in states_to_delete:
mapper.dispatch.after_delete(mapper, connection, state)
@@ -133,17 +167,15 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
"""
- states_to_insert = []
- states_to_update = []
-
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction,
states):
has_identity = bool(state.key)
+
instance_key = state.key or mapper._identity_key_from_state(state)
- row_switch = None
+ row_switch = update_version_id = None
# call before_XXX extensions
if not has_identity:
@@ -180,18 +212,14 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
uowtransaction.remove_state_actions(existing)
row_switch = existing
- if not has_identity and not row_switch:
- states_to_insert.append(
- (state, dict_, mapper, connection,
- has_identity, instance_key, row_switch)
- )
- else:
- states_to_update.append(
- (state, dict_, mapper, connection,
- has_identity, instance_key, row_switch)
- )
+ if (has_identity or row_switch) and mapper.version_id_col is not None:
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ row_switch if row_switch else state,
+ row_switch.dict if row_switch else dict_,
+ mapper.version_id_col)
- return states_to_insert, states_to_update
+ yield (state, dict_, mapper, connection,
+ has_identity, row_switch, update_version_id)
def _organize_states_for_post_update(base_mapper, states,
@@ -204,8 +232,7 @@ def _organize_states_for_post_update(base_mapper, states,
the execution per state.
"""
- return list(_connections_for_states(base_mapper, uowtransaction,
- states))
+ return _connections_for_states(base_mapper, uowtransaction, states)
def _organize_states_for_delete(base_mapper, states, uowtransaction):
@@ -216,72 +243,73 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
mapper, the connection to use for the execution per state.
"""
- states_to_delete = []
-
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction,
states):
mapper.dispatch.before_delete(mapper, connection, state)
- states_to_delete.append((state, dict_, mapper,
- bool(state.key), connection))
- return states_to_delete
+ if mapper.version_id_col is not None:
+ update_version_id = \
+ mapper._get_committed_state_attr_by_column(
+ state, dict_,
+ mapper.version_id_col)
+ else:
+ update_version_id = None
+
+ yield (
+ state, dict_, mapper, connection, update_version_id)
-def _collect_insert_commands(base_mapper, uowtransaction, table,
- states_to_insert):
+def _collect_insert_commands(table, states_to_insert):
"""Identify sets of values to use in INSERT statements for a
list of states.
"""
- insert = []
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_insert:
+ for state, state_dict, mapper, connection in states_to_insert:
if table not in mapper._pks_by_table:
continue
- pks = mapper._pks_by_table[table]
-
params = {}
value_params = {}
- has_all_pks = True
- has_all_defaults = True
- for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col and \
- mapper.version_id_generator is not False:
- val = mapper.version_id_generator(None)
- params[col.key] = val
+ propkey_to_col = mapper._propkey_to_col[table]
+
+ for propkey in set(propkey_to_col).intersection(state_dict):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+ if value is None:
+ continue
+ elif isinstance(value, sql.ClauseElement):
+ value_params[col.key] = value
else:
- # pull straight from the dict for
- # pending objects
- prop = mapper._columntoproperty[col]
- value = state_dict.get(prop.key, None)
-
- if value is None:
- if col in pks:
- has_all_pks = False
- elif col.default is None and \
- col.server_default is None:
- params[col.key] = value
- elif col.server_default is not None and \
- mapper.base_mapper.eager_defaults:
- has_all_defaults = False
-
- elif isinstance(value, sql.ClauseElement):
- value_params[col] = value
- else:
- params[col.key] = value
+ params[col.key] = value
+
+ for colkey in mapper._insert_cols_as_none[table].\
+ difference(params).difference(value_params):
+ params[colkey] = None
+
+ has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
+
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_default_cols[table].\
+ issubset(params)
+ else:
+ has_all_defaults = True
+
+ if mapper.version_id_generator is not False \
+ and mapper.version_id_col is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
+ params[mapper.version_id_col.key] = \
+ mapper.version_id_generator(None)
- insert.append((state, state_dict, params, mapper,
- connection, value_params, has_all_pks,
- has_all_defaults))
- return insert
+ yield (
+ state, state_dict, params, mapper,
+ connection, value_params, has_all_pks,
+ has_all_defaults)
-def _collect_update_commands(base_mapper, uowtransaction,
- table, states_to_update):
+def _collect_update_commands(uowtransaction, table, states_to_update):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -293,9 +321,9 @@ def _collect_update_commands(base_mapper, uowtransaction,
"""
- update = []
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_update:
+ for state, state_dict, mapper, connection, \
+ update_version_id in states_to_update:
+
if table not in mapper._pks_by_table:
continue
@@ -304,98 +332,59 @@ def _collect_update_commands(base_mapper, uowtransaction,
params = {}
value_params = {}
- hasdata = hasnull = False
- for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col:
- params[col._label] = \
- mapper._get_committed_state_attr_by_column(
- row_switch or state,
- row_switch and row_switch.dict
- or state_dict,
- col)
+ propkey_to_col = mapper._propkey_to_col[table]
- prop = mapper._columntoproperty[col]
- history = state.manager[prop.key].impl.get_history(
- state, state_dict, attributes.PASSIVE_NO_INITIALIZE
- )
- if history.added:
- params[col.key] = history.added[0]
- hasdata = True
+ for propkey in set(propkey_to_col).intersection(state.committed_state):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+
+ if not state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]):
+ if isinstance(value, sql.ClauseElement):
+ value_params[col] = value
+ else:
+ params[col.key] = value
+
+ if update_version_id is not None:
+ col = mapper.version_id_col
+ params[col._label] = update_version_id
+
+ if col.key not in params and \
+ mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(update_version_id)
+ params[col.key] = val
+
+ if not (params or value_params):
+ continue
+
+ pk_params = {}
+ for col in pks:
+ propkey = mapper._columntoproperty[col].key
+ history = state.manager[propkey].impl.get_history(
+ state, state_dict, attributes.PASSIVE_OFF)
+
+ if history.added:
+ if not history.deleted or \
+ ("pk_cascaded", state, col) in \
+ uowtransaction.attributes:
+ pk_params[col._label] = history.added[0]
+ params.pop(col.key, None)
else:
- if mapper.version_id_generator is not False:
- val = mapper.version_id_generator(params[col._label])
- params[col.key] = val
-
- # HACK: check for history, in case the
- # history is only
- # in a different table than the one
- # where the version_id_col is.
- for prop in mapper._columntoproperty.values():
- history = (
- state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE))
- if history.added:
- hasdata = True
+ # else, use the old value to locate the row
+ pk_params[col._label] = history.deleted[0]
+ params[col.key] = history.added[0]
else:
- prop = mapper._columntoproperty[col]
- history = state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- if isinstance(history.added[0],
- sql.ClauseElement):
- value_params[col] = history.added[0]
- else:
- value = history.added[0]
- params[col.key] = value
-
- if col in pks:
- if history.deleted and \
- not row_switch:
- # if passive_updates and sync detected
- # this was a pk->pk sync, use the new
- # value to locate the row, since the
- # DB would already have set this
- if ("pk_cascaded", state, col) in \
- uowtransaction.attributes:
- value = history.added[0]
- params[col._label] = value
- else:
- # use the old value to
- # locate the row
- value = history.deleted[0]
- params[col._label] = value
- hasdata = True
- else:
- # row switch logic can reach us here
- # remove the pk from the update params
- # so the update doesn't
- # attempt to include the pk in the
- # update statement
- del params[col.key]
- value = history.added[0]
- params[col._label] = value
- if value is None:
- hasnull = True
- else:
- hasdata = True
- elif col in pks:
- value = state.manager[prop.key].impl.get(
- state, state_dict)
- if value is None:
- hasnull = True
- params[col._label] = value
+ pk_params[col._label] = history.unchanged[0]
- if hasdata:
- if hasnull:
+ if params or value_params:
+ if None in pk_params.values():
raise orm_exc.FlushError(
- "Can't update table "
- "using NULL for primary "
+ "Can't update table using NULL for primary "
"key value")
- update.append((state, state_dict, params, mapper,
- connection, value_params))
- return update
+ params.update(pk_params)
+ yield (
+ state, state_dict, params, mapper,
+ connection, value_params)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -405,10 +394,10 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
"""
- update = []
for state, state_dict, mapper, connection in states_to_update:
- if table not in mapper._pks_by_table:
- continue
+
+ # assert table in mapper._pks_by_table
+
pks = mapper._pks_by_table[table]
params = {}
hasdata = False
@@ -417,8 +406,8 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
if col in pks:
params[col._label] = \
mapper._get_state_attr_by_column(
- state,
- state_dict, col)
+ state,
+ state_dict, col)
elif col in post_update_cols:
prop = mapper._columntoproperty[col]
@@ -430,9 +419,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
params[col.key] = value
hasdata = True
if hasdata:
- update.append((state, state_dict, params, mapper,
- connection))
- return update
+ yield params, connection
def _collect_delete_commands(base_mapper, uowtransaction, table,
@@ -440,33 +427,28 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
- delete = util.defaultdict(list)
+ for state, state_dict, mapper, connection, \
+ update_version_id in states_to_delete:
- for state, state_dict, mapper, has_identity, connection \
- in states_to_delete:
- if not has_identity or table not in mapper._pks_by_table:
+ if table not in mapper._pks_by_table:
continue
params = {}
- delete[connection].append(params)
for col in mapper._pks_by_table[table]:
params[col.key] = \
value = \
mapper._get_committed_state_attr_by_column(
- state, state_dict, col)
+ state, state_dict, col)
if value is None:
raise orm_exc.FlushError(
"Can't delete from table "
"using NULL for primary "
"key value")
- if mapper.version_id_col is not None and \
+ if update_version_id is not None and \
table.c.contains_column(mapper.version_id_col):
- params[mapper.version_id_col.key] = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict,
- mapper.version_id_col)
- return delete
+ params[mapper.version_id_col.key] = update_version_id
+ yield params, connection
def _emit_update_statements(base_mapper, uowtransaction,
@@ -500,41 +482,80 @@ def _emit_update_statements(base_mapper, uowtransaction,
statement = base_mapper._memo(('update', table), update_stmt)
- rows = 0
- for state, state_dict, params, mapper, \
- connection, value_params in update:
-
- if value_params:
- c = connection.execute(
- statement.values(value_params),
- params)
+ for (connection, paramkeys, hasvalue), \
+ records in groupby(
+ update,
+ lambda rec: (
+ rec[4],
+ tuple(sorted(rec[2])),
+ bool(rec[5]))):
+
+ rows = 0
+ records = list(records)
+ if hasvalue:
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ c = connection.execute(
+ statement.values(value_params),
+ params)
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
+ rows += c.rowcount
else:
- c = cached_connections[connection].\
- execute(statement, params)
-
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params)
- rows += c.rowcount
-
- if connection.dialect.supports_sane_rowcount:
- if rows != len(update):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(update), rows))
-
- elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description,
- stacklevel=12)
+ if needs_version_id and \
+ not connection.dialect.supports_sane_multi_rowcount and \
+ connection.dialect.supports_sane_rowcount:
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ c = cached_connections[connection].\
+ execute(statement, params)
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
+ rows += c.rowcount
+ else:
+ multiparams = [rec[2] for rec in records]
+ c = cached_connections[connection].\
+ execute(statement, multiparams)
+
+ rows += c.rowcount
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
+
+ if connection.dialect.supports_sane_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched." %
+ (table.description, len(records), rows))
+
+ elif needs_version_id:
+ util.warn("Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
def _emit_insert_statements(base_mapper, uowtransaction,
@@ -547,7 +568,7 @@ def _emit_insert_statements(base_mapper, uowtransaction,
for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
records in groupby(insert,
lambda rec: (rec[4],
- list(rec[2].keys()),
+ tuple(sorted(rec[2].keys())),
bool(rec[5]),
rec[6], rec[7])
):
@@ -604,13 +625,7 @@ def _emit_insert_statements(base_mapper, uowtransaction,
mapper._pks_by_table[table]):
prop = mapper_rec._columntoproperty[col]
if state_dict.get(prop.key) is None:
- # TODO: would rather say:
- # state_dict[prop.key] = pk
- mapper_rec._set_state_attr_by_column(
- state,
- state_dict,
- col, pk)
-
+ state_dict[prop.key] = pk
_postfetch(
mapper_rec,
uowtransaction,
@@ -643,11 +658,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
# also group them into common (connection, cols) sets
# to support executemany().
for key, grouper in groupby(
- update, lambda rec: (rec[4], list(rec[2].keys()))
+ update, lambda rec: (rec[1], sorted(rec[0]))
):
connection = key[0]
- multiparams = [params for state, state_dict,
- params, mapper, conn in grouper]
+ multiparams = [params for params, conn in grouper]
cached_connections[connection].\
execute(statement, multiparams)
@@ -677,8 +691,15 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
return table.delete(clause)
- for connection, del_objects in delete.items():
- statement = base_mapper._memo(('delete', table), delete_stmt)
+ statement = base_mapper._memo(('delete', table), delete_stmt)
+ for connection, recs in groupby(
+ delete,
+ lambda rec: rec[1]
+ ):
+ del_objects = [
+ params
+ for params, connection in recs
+ ]
connection = cached_connections[connection]
@@ -731,15 +752,12 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
)
-def _finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update):
+def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
"""finalize state on states that have been inserted or updated,
including calling after_insert/after_update events.
"""
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_insert + \
- states_to_update:
+ for state, state_dict, mapper, connection, has_identity in states:
if mapper._readonly_props:
readonly = state.unmodified_intersection(
@@ -795,11 +813,11 @@ def _postfetch(mapper, uowtransaction, table,
for col in returning_cols:
if col.primary_key:
continue
- mapper._set_state_attr_by_column(state, dict_, col, row[col])
+ dict_[mapper._columntoproperty[col].key] = row[col]
for c in prefetch_cols:
if c.key in params and c in mapper._columntoproperty:
- mapper._set_state_attr_by_column(state, dict_, c, params[c.key])
+ dict_[mapper._columntoproperty[c].key] = params[c.key]
if postfetch_cols:
state._expire_attributes(state.dict,
@@ -833,17 +851,14 @@ def _connections_for_states(base_mapper, uowtransaction, states):
connection_callable = \
uowtransaction.session.connection_callable
else:
- connection = None
+ connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
for state in _sort_states(states):
if connection_callable:
connection = connection_callable(base_mapper, state.obj())
- elif not connection:
- connection = uowtransaction.transaction.connection(
- base_mapper)
- mapper = _state_mapper(state)
+ mapper = state.manager.mapper
yield state, state.dict, mapper, connection