summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py1186
1 files changed, 746 insertions, 440 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 7f9b7db0c..dc86a60e5 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -25,8 +25,13 @@ from . import loading
def _bulk_insert(
- mapper, mappings, session_transaction, isstates, return_defaults,
- render_nulls):
+ mapper,
+ mappings,
+ session_transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
@@ -34,7 +39,8 @@ def _bulk_insert(
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
- "not supported in bulk_insert()")
+ "not supported in bulk_insert()"
+ )
if isstates:
if return_defaults:
@@ -51,22 +57,33 @@ def _bulk_insert(
continue
records = (
- (None, state_dict, params, mapper,
- connection, value_params, has_all_pks, has_all_defaults)
- for
- state, state_dict, params, mp,
- conn, value_params, has_all_pks,
- has_all_defaults in _collect_insert_commands(table, (
- (None, mapping, mapper, connection)
- for mapping in mappings),
- bulk=True, return_defaults=return_defaults,
- render_nulls=render_nulls
+ (
+ None,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
+ for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(
+ table,
+ ((None, mapping, mapper, connection) for mapping in mappings),
+ bulk=True,
+ return_defaults=return_defaults,
+ render_nulls=render_nulls,
)
)
- _emit_insert_statements(base_mapper, None,
- cached_connections,
- super_mapper, table, records,
- bookkeeping=return_defaults)
+ _emit_insert_statements(
+ base_mapper,
+ None,
+ cached_connections,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=return_defaults,
+ )
if return_defaults and isstates:
identity_cls = mapper._identity_class
@@ -74,12 +91,13 @@ def _bulk_insert(
for state, dict_ in states:
state.key = (
identity_cls,
- tuple([dict_[key] for key in identity_props])
+ tuple([dict_[key] for key in identity_props]),
)
-def _bulk_update(mapper, mappings, session_transaction,
- isstates, update_changed_only):
+def _bulk_update(
+ mapper, mappings, session_transaction, isstates, update_changed_only
+):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
@@ -91,9 +109,8 @@ def _bulk_update(mapper, mappings, session_transaction,
def _changed_dict(mapper, state):
return dict(
(k, v)
- for k, v in state.dict.items() if k in state.committed_state or k
- in search_keys
-
+ for k, v in state.dict.items()
+ if k in state.committed_state or k in search_keys
)
if isstates:
@@ -107,7 +124,8 @@ def _bulk_update(mapper, mappings, session_transaction,
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
- "not supported in bulk_update()")
+ "not supported in bulk_update()"
+ )
connection = session_transaction.connection(base_mapper)
@@ -115,21 +133,38 @@ def _bulk_update(mapper, mappings, session_transaction,
if not mapper.isa(super_mapper):
continue
- records = _collect_update_commands(None, table, (
- (None, mapping, mapper, connection,
- (mapping[mapper._version_id_prop.key]
- if mapper._version_id_prop else None))
- for mapping in mappings
- ), bulk=True)
+ records = _collect_update_commands(
+ None,
+ table,
+ (
+ (
+ None,
+ mapping,
+ mapper,
+ connection,
+ (
+ mapping[mapper._version_id_prop.key]
+ if mapper._version_id_prop
+ else None
+ ),
+ )
+ for mapping in mappings
+ ),
+ bulk=True,
+ )
- _emit_update_statements(base_mapper, None,
- cached_connections,
- super_mapper, table, records,
- bookkeeping=False)
+ _emit_update_statements(
+ base_mapper,
+ None,
+ cached_connections,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=False,
+ )
-def save_obj(
- base_mapper, states, uowtransaction, single=False):
+def save_obj(base_mapper, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
@@ -150,19 +185,21 @@ def save_obj(
states_to_insert = []
cached_connections = _cached_connection_dict(base_mapper)
- for (state, dict_, mapper, connection,
- has_identity,
- row_switch, update_version_id) in _organize_states_for_save(
- base_mapper, states, uowtransaction
- ):
+ for (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ ) in _organize_states_for_save(base_mapper, states, uowtransaction):
if has_identity or row_switch:
states_to_update.append(
(state, dict_, mapper, connection, update_version_id)
)
else:
- states_to_insert.append(
- (state, dict_, mapper, connection)
- )
+ states_to_insert.append((state, dict_, mapper, connection))
for table, mapper in base_mapper._sorted_tables.items():
if table not in mapper._pks_by_table:
@@ -170,18 +207,30 @@ def save_obj(
insert = _collect_insert_commands(table, states_to_insert)
update = _collect_update_commands(
- uowtransaction, table, states_to_update)
+ uowtransaction, table, states_to_update
+ )
- _emit_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ )
- _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, insert)
+ _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ insert,
+ )
_finalize_insert_update_commands(
- base_mapper, uowtransaction,
+ base_mapper,
+ uowtransaction,
chain(
(
(state, state_dict, mapper, connection, False)
@@ -189,10 +238,9 @@ def save_obj(
),
(
(state, state_dict, mapper, connection, True)
- for state, state_dict, mapper, connection,
- update_version_id in states_to_update
- )
- )
+ for state, state_dict, mapper, connection, update_version_id in states_to_update
+ ),
+ ),
)
@@ -203,9 +251,9 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
"""
cached_connections = _cached_connection_dict(base_mapper)
- states_to_update = list(_organize_states_for_post_update(
- base_mapper,
- states, uowtransaction))
+ states_to_update = list(
+ _organize_states_for_post_update(base_mapper, states, uowtransaction)
+ )
for table, mapper in base_mapper._sorted_tables.items():
if table not in mapper._pks_by_table:
@@ -213,25 +261,32 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
update = (
(
- state, state_dict, sub_mapper, connection,
+ state,
+ state_dict,
+ sub_mapper,
+ connection,
mapper._get_committed_state_attr_by_column(
state, state_dict, mapper.version_id_col
- ) if mapper.version_id_col is not None else None
+ )
+ if mapper.version_id_col is not None
+ else None,
)
- for
- state, state_dict, sub_mapper, connection in states_to_update
+ for state, state_dict, sub_mapper, connection in states_to_update
if table in sub_mapper._pks_by_table
)
update = _collect_post_update_commands(
- base_mapper, uowtransaction,
- table, update,
- post_update_cols
+ base_mapper, uowtransaction, table, update, post_update_cols
)
- _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_post_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ )
def delete_obj(base_mapper, states, uowtransaction):
@@ -244,10 +299,9 @@ def delete_obj(base_mapper, states, uowtransaction):
cached_connections = _cached_connection_dict(base_mapper)
- states_to_delete = list(_organize_states_for_delete(
- base_mapper,
- states,
- uowtransaction))
+ states_to_delete = list(
+ _organize_states_for_delete(base_mapper, states, uowtransaction)
+ )
table_to_mapper = base_mapper._sorted_tables
@@ -258,14 +312,26 @@ def delete_obj(base_mapper, states, uowtransaction):
elif mapper.inherits and mapper.passive_deletes:
continue
- delete = _collect_delete_commands(base_mapper, uowtransaction,
- table, states_to_delete)
+ delete = _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+ )
- _emit_delete_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, delete)
+ _emit_delete_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ delete,
+ )
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_delete:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
mapper.dispatch.after_delete(mapper, connection, state)
@@ -282,8 +348,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
"""
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction, states
+ ):
has_identity = bool(state.key)
@@ -304,25 +370,29 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
# no instance_key attached to it), and another instance
# with the same identity key already exists as persistent.
# convert to an UPDATE if so.
- if not has_identity and \
- instance_key in uowtransaction.session.identity_map:
- instance = \
- uowtransaction.session.identity_map[instance_key]
+ if (
+ not has_identity
+ and instance_key in uowtransaction.session.identity_map
+ ):
+ instance = uowtransaction.session.identity_map[instance_key]
existing = attributes.instance_state(instance)
if not uowtransaction.was_already_deleted(existing):
if not uowtransaction.is_deleted(existing):
raise orm_exc.FlushError(
"New instance %s with identity key %s conflicts "
- "with persistent instance %s" %
- (state_str(state), instance_key,
- state_str(existing)))
+ "with persistent instance %s"
+ % (state_str(state), instance_key, state_str(existing))
+ )
base_mapper._log_debug(
"detected row switch for identity %s. "
"will update %s, remove %s from "
- "transaction", instance_key,
- state_str(state), state_str(existing))
+ "transaction",
+ instance_key,
+ state_str(state),
+ state_str(existing),
+ )
# remove the "delete" flag from the existing element
uowtransaction.remove_state_actions(existing)
@@ -332,14 +402,21 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
update_version_id = mapper._get_committed_state_attr_by_column(
row_switch if row_switch else state,
row_switch.dict if row_switch else dict_,
- mapper.version_id_col)
+ mapper.version_id_col,
+ )
- yield (state, dict_, mapper, connection,
- has_identity, row_switch, update_version_id)
+ yield (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ )
-def _organize_states_for_post_update(base_mapper, states,
- uowtransaction):
+def _organize_states_for_post_update(base_mapper, states, uowtransaction):
"""Make an initial pass across a set of states for UPDATE
corresponding to post_update.
@@ -360,26 +437,28 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
"""
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction, states
+ ):
mapper.dispatch.before_delete(mapper, connection, state)
if mapper.version_id_col is not None:
- update_version_id = \
- mapper._get_committed_state_attr_by_column(
- state, dict_,
- mapper.version_id_col)
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ )
else:
update_version_id = None
- yield (
- state, dict_, mapper, connection, update_version_id)
+ yield (state, dict_, mapper, connection, update_version_id)
def _collect_insert_commands(
- table, states_to_insert,
- bulk=False, return_defaults=False, render_nulls=False):
+ table,
+ states_to_insert,
+ bulk=False,
+ return_defaults=False,
+ render_nulls=False,
+):
"""Identify sets of values to use in INSERT statements for a
list of states.
@@ -400,10 +479,16 @@ def _collect_insert_commands(
col = propkey_to_col[propkey]
if value is None and col not in eval_none and not render_nulls:
continue
- elif not bulk and hasattr(value, '__clause_element__') or \
- isinstance(value, sql.ClauseElement):
- value_params[col.key] = value.__clause_element__() \
- if hasattr(value, '__clause_element__') else value
+ elif (
+ not bulk
+ and hasattr(value, "__clause_element__")
+ or isinstance(value, sql.ClauseElement)
+ ):
+ value_params[col.key] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
else:
params[col.key] = value
@@ -414,8 +499,11 @@ def _collect_insert_commands(
# which might be worth removing, as it should not be necessary
# and also produces confusion, given that "missing" and None
# now have distinct meanings
- for colkey in mapper._insert_cols_as_none[table].\
- difference(params).difference(value_params):
+ for colkey in (
+ mapper._insert_cols_as_none[table]
+ .difference(params)
+ .difference(value_params)
+ ):
params[colkey] = None
if not bulk or return_defaults:
@@ -424,28 +512,38 @@ def _collect_insert_commands(
has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].\
- issubset(params)
+ has_all_defaults = mapper._server_default_cols[table].issubset(
+ params
+ )
else:
has_all_defaults = True
else:
has_all_defaults = has_all_pks = True
- if mapper.version_id_generator is not False \
- and mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
- params[mapper.version_id_col.key] = \
- mapper.version_id_generator(None)
+ if (
+ mapper.version_id_generator is not False
+ and mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ params[mapper.version_id_col.key] = mapper.version_id_generator(
+ None
+ )
yield (
- state, state_dict, params, mapper,
- connection, value_params, has_all_pks,
- has_all_defaults)
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
def _collect_update_commands(
- uowtransaction, table, states_to_update,
- bulk=False):
+ uowtransaction, table, states_to_update, bulk=False
+):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -457,8 +555,13 @@ def _collect_update_commands(
"""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_update:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
if table not in mapper._pks_by_table:
continue
@@ -474,36 +577,48 @@ def _collect_update_commands(
# look at mapper attribute keys for pk
params = dict(
(propkey_to_col[propkey].key, state_dict[propkey])
- for propkey in
- set(propkey_to_col).intersection(state_dict).difference(
- mapper._pk_attr_keys_by_table[table])
+ for propkey in set(propkey_to_col)
+ .intersection(state_dict)
+ .difference(mapper._pk_attr_keys_by_table[table])
)
has_all_defaults = True
else:
params = {}
for propkey in set(propkey_to_col).intersection(
- state.committed_state):
+ state.committed_state
+ ):
value = state_dict[propkey]
col = propkey_to_col[propkey]
- if hasattr(value, '__clause_element__') or \
- isinstance(value, sql.ClauseElement):
- value_params[col] = value.__clause_element__() \
- if hasattr(value, '__clause_element__') else value
+ if hasattr(value, "__clause_element__") or isinstance(
+ value, sql.ClauseElement
+ ):
+ value_params[col] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
# guard against values that generate non-__nonzero__
# objects for __eq__()
- elif state.manager[propkey].impl.is_equal(
- value, state.committed_state[propkey]) is not True:
+ elif (
+ state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]
+ )
+ is not True
+ ):
params[col.key] = value
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_onupdate_default_cols[table].\
- issubset(params)
+ has_all_defaults = mapper._server_onupdate_default_cols[
+ table
+ ].issubset(params)
else:
has_all_defaults = True
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
if not bulk and not (params or value_params):
# HACK: check for history in other tables, in case the
@@ -511,10 +626,9 @@ def _collect_update_commands(
# where the version_id_col is. This logic was lost
# from 0.9 -> 1.0.0 and restored in 1.0.6.
for prop in mapper._columntoproperty.values():
- history = (
- state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE))
+ history = state.manager[prop.key].impl.get_history(
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
if history.added:
break
else:
@@ -525,8 +639,9 @@ def _collect_update_commands(
no_params = not params and not value_params
params[col._label] = update_version_id
- if (bulk or col.key not in params) and \
- mapper.version_id_generator is not False:
+ if (
+ bulk or col.key not in params
+ ) and mapper.version_id_generator is not False:
val = mapper.version_id_generator(update_version_id)
params[col.key] = val
elif mapper.version_id_generator is False and no_params:
@@ -545,9 +660,9 @@ def _collect_update_commands(
# look at mapper attribute keys for pk
pk_params = dict(
(propkey_to_col[propkey]._label, state_dict.get(propkey))
- for propkey in
- set(propkey_to_col).
- intersection(mapper._pk_attr_keys_by_table[table])
+ for propkey in set(propkey_to_col).intersection(
+ mapper._pk_attr_keys_by_table[table]
+ )
)
else:
pk_params = {}
@@ -555,12 +670,15 @@ def _collect_update_commands(
propkey = mapper._columntoproperty[col].key
history = state.manager[propkey].impl.get_history(
- state, state_dict, attributes.PASSIVE_OFF)
+ state, state_dict, attributes.PASSIVE_OFF
+ )
if history.added:
- if not history.deleted or \
- ("pk_cascaded", state, col) in \
- uowtransaction.attributes:
+ if (
+ not history.deleted
+ or ("pk_cascaded", state, col)
+ in uowtransaction.attributes
+ ):
pk_params[col._label] = history.added[0]
params.pop(col.key, None)
else:
@@ -573,24 +691,38 @@ def _collect_update_commands(
if pk_params[col._label] is None:
raise orm_exc.FlushError(
"Can't update table %s using NULL for primary "
- "key value on column %s" % (table, col))
+ "key value on column %s" % (table, col)
+ )
if params or value_params:
params.update(pk_params)
yield (
- state, state_dict, params, mapper,
- connection, value_params, has_all_defaults, has_all_pks)
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ )
-def _collect_post_update_commands(base_mapper, uowtransaction, table,
- states_to_update, post_update_cols):
+def _collect_post_update_commands(
+ base_mapper, uowtransaction, table, states_to_update, post_update_cols
+):
"""Identify sets of values to use in UPDATE statements for a
list of states within a post_update operation.
"""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_update:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
# assert table in mapper._pks_by_table
@@ -600,100 +732,128 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
for col in mapper._cols_by_table[table]:
if col in pks:
- params[col._label] = \
- mapper._get_state_attr_by_column(
- state,
- state_dict, col, passive=attributes.PASSIVE_OFF)
+ params[col._label] = mapper._get_state_attr_by_column(
+ state, state_dict, col, passive=attributes.PASSIVE_OFF
+ )
elif col in post_update_cols or col.onupdate is not None:
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
if history.added:
value = history.added[0]
params[col.key] = value
hasdata = True
if hasdata:
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
col = mapper.version_id_col
params[col._label] = update_version_id
- if bool(state.key) and col.key not in params and \
- mapper.version_id_generator is not False:
+ if (
+ bool(state.key)
+ and col.key not in params
+ and mapper.version_id_generator is not False
+ ):
val = mapper.version_id_generator(update_version_id)
params[col.key] = val
yield state, state_dict, mapper, connection, params
-def _collect_delete_commands(base_mapper, uowtransaction, table,
- states_to_delete):
+def _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+):
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_delete:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
if table not in mapper._pks_by_table:
continue
params = {}
for col in mapper._pks_by_table[table]:
- params[col.key] = \
- value = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict, col)
+ params[
+ col.key
+ ] = value = mapper._get_committed_state_attr_by_column(
+ state, state_dict, col
+ )
if value is None:
raise orm_exc.FlushError(
"Can't delete from table %s "
"using NULL for primary "
- "key value on column %s" % (table, col))
+ "key value on column %s" % (table, col)
+ )
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
params[mapper.version_id_col.key] = update_version_id
yield params, connection
-def _emit_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update,
- bookkeeping=True):
+def _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ bookkeeping=True,
+):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
- needs_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ clause.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
if needs_version_id:
clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col
+ == sql.bindparam(
mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
+ type_=mapper.version_id_col.type,
+ )
+ )
stmt = table.update(clause)
return stmt
- cached_stmt = base_mapper._memo(('update', table), update_stmt)
-
- for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
- records in groupby(
- update,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # set of parameter keys
- bool(rec[5]), # whether or not we have "value" parameters
- rec[6], # has_all_defaults
- rec[7] # has all pks
- )
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
+
+ for (
+ (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
+ records,
+ ) in groupby(
+ update,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # set of parameter keys
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6], # has_all_defaults
+ rec[7], # has all pks
+ ),
):
rows = 0
records = list(records)
@@ -704,8 +864,11 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not has_all_pks:
statement = statement.return_defaults()
return_defaults = True
- elif bookkeeping and not has_all_defaults and \
- mapper.base_mapper.eager_defaults:
+ elif (
+ bookkeeping
+ and not has_all_defaults
+ and mapper.base_mapper.eager_defaults
+ ):
statement = statement.return_defaults()
return_defaults = True
elif mapper.version_id_col is not None:
@@ -718,17 +881,24 @@ def _emit_update_statements(base_mapper, uowtransaction,
else connection.dialect.supports_sane_rowcount_returning
)
- assert_multirow = assert_singlerow and \
- connection.dialect.supports_sane_multi_rowcount
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
allow_multirow = has_all_defaults and not needs_version_id
if hasvalue:
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_defaults, has_all_pks in records:
- c = connection.execute(
- statement.values(value_params),
- params)
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = connection.execute(statement.values(value_params), params)
if bookkeeping:
_postfetch(
mapper,
@@ -738,17 +908,26 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
rows += c.rowcount
check_rowcount = assert_singlerow
else:
if not allow_multirow:
check_rowcount = assert_singlerow
- for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults, \
- has_all_pks in records:
- c = cached_connections[connection].\
- execute(statement, params)
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = cached_connections[connection].execute(
+ statement, params
+ )
# TODO: why with bookkeeping=False?
if bookkeeping:
@@ -760,24 +939,32 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
rows += c.rowcount
else:
multiparams = [rec[2] for rec in records]
check_rowcount = assert_multirow or (
- assert_singlerow and
- len(multiparams) == 1
+ assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(
+ statement, multiparams
+ )
rows += c.rowcount
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_defaults, has_all_pks in records:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
if bookkeeping:
_postfetch(
mapper,
@@ -787,59 +974,85 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(records), rows))
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description)
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
-def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert,
- bookkeeping=True):
+def _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ insert,
+ bookkeeping=True,
+):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(('insert', table), table.insert)
-
- for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
- records in groupby(
- insert,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # parameter keys
- bool(rec[5]), # whether we have "value" parameters
- rec[6],
- rec[7])):
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+
+ for (
+ (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ records,
+ ) in groupby(
+ insert,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # parameter keys
+ bool(rec[5]), # whether we have "value" parameters
+ rec[6],
+ rec[7],
+ ),
+ ):
statement = cached_stmt
- if not bookkeeping or \
- (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not connection.dialect.implicit_returning
- ) and has_all_pks and not hasvalue:
+ if (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ )
+ and has_all_pks
+ and not hasvalue
+ ):
records = list(records)
multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(statement, multiparams)
if bookkeeping:
- for (state, state_dict, params, mapper_rec,
- conn, value_params, has_all_pks, has_all_defaults), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ ) in zip(records, c.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
@@ -849,7 +1062,8 @@ def _emit_insert_statements(base_mapper, uowtransaction,
state_dict,
c,
last_inserted_params,
- value_params)
+ value_params,
+ )
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
@@ -859,24 +1073,33 @@ def _emit_insert_statements(base_mapper, uowtransaction,
elif mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
- for state, state_dict, params, mapper_rec, \
- connection, value_params, \
- has_all_pks, has_all_defaults in records:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in records:
if value_params:
result = connection.execute(
- statement.values(value_params),
- params)
+ statement.values(value_params), params
+ )
else:
- result = cached_connections[connection].\
- execute(statement, params)
+ result = cached_connections[connection].execute(
+ statement, params
+ )
primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
- for pk, col in zip(primary_key,
- mapper._pks_by_table[table]):
+ for pk, col in zip(
+ primary_key, mapper._pks_by_table[table]
+ ):
prop = mapper_rec._columntoproperty[col]
if state_dict.get(prop.key) is None:
state_dict[prop.key] = pk
@@ -890,31 +1113,39 @@ def _emit_insert_statements(base_mapper, uowtransaction,
state_dict,
result,
result.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
-def _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update):
+def _emit_post_update_statements(
+ base_mapper, uowtransaction, cached_connections, mapper, table, update
+):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
- needs_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ clause.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
if needs_version_id:
clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col
+ == sql.bindparam(
mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
+ type_=mapper.version_id_col.type,
+ )
+ )
stmt = table.update(clause)
@@ -923,17 +1154,15 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
return stmt
- statement = base_mapper._memo(('post_update', table), update_stmt)
+ statement = base_mapper._memo(("post_update", table), update_stmt)
# execute each UPDATE in the order according to the original
# list of states to guarantee row access order, but
# also group them into common (connection, cols) sets
# to support executemany().
for key, records in groupby(
- update, lambda rec: (
- rec[3], # connection
- set(rec[4]), # parameter keys
- )
+ update,
+ lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
):
rows = 0
@@ -945,84 +1174,96 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
if mapper.version_id_col is None
else connection.dialect.supports_sane_rowcount_returning
)
- assert_multirow = assert_singlerow and \
- connection.dialect.supports_sane_multi_rowcount
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
allow_multirow = not needs_version_id or assert_multirow
-
if not allow_multirow:
check_rowcount = assert_singlerow
- for state, state_dict, mapper_rec, \
- connection, params in records:
- c = cached_connections[connection].\
- execute(statement, params)
+ for state, state_dict, mapper_rec, connection, params in records:
+ c = cached_connections[connection].execute(statement, params)
_postfetch_post_update(
- mapper_rec, uowtransaction, table, state, state_dict,
- c, c.context.compiled_parameters[0])
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
rows += c.rowcount
else:
multiparams = [
- params for
- state, state_dict, mapper_rec, conn, params in records]
+ params
+ for state, state_dict, mapper_rec, conn, params in records
+ ]
check_rowcount = assert_multirow or (
- assert_singlerow and
- len(multiparams) == 1
+ assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(statement, multiparams)
rows += c.rowcount
- for state, state_dict, mapper_rec, \
- connection, params in records:
+ for state, state_dict, mapper_rec, connection, params in records:
_postfetch_post_update(
- mapper_rec, uowtransaction, table, state, state_dict,
- c, c.context.compiled_parameters[0])
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(records), rows))
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description)
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
-def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
- mapper, table, delete):
+def _emit_delete_statements(
+ base_mapper, uowtransaction, cached_connections, mapper, table, delete
+):
"""Emit DELETE statements corresponding to value lists collected
by _collect_delete_commands()."""
- need_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ need_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def delete_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(
- col == sql.bindparam(col.key, type_=col.type))
+ col == sql.bindparam(col.key, type_=col.type)
+ )
if need_version_id:
clause.clauses.append(
- mapper.version_id_col ==
- sql.bindparam(
- mapper.version_id_col.key,
- type_=mapper.version_id_col.type
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col.key, type_=mapper.version_id_col.type
)
)
return table.delete(clause)
- statement = base_mapper._memo(('delete', table), delete_stmt)
- for connection, recs in groupby(
- delete,
- lambda rec: rec[1] # connection
- ):
+ statement = base_mapper._memo(("delete", table), delete_stmt)
+ for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
del_objects = [params for params, connection in recs]
connection = cached_connections[connection]
@@ -1049,9 +1290,10 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
else:
util.warn(
"Dialect %s does not support deleted rowcount "
- "- versioning cannot be verified." %
- connection.dialect.dialect_description,
- stacklevel=12)
+ "- versioning cannot be verified."
+ % connection.dialect.dialect_description,
+ stacklevel=12,
+ )
connection.execute(statement, del_objects)
else:
c = connection.execute(statement, del_objects)
@@ -1061,23 +1303,26 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
rows_matched = c.rowcount
- if base_mapper.confirm_deleted_rows and \
- rows_matched > -1 and expected != rows_matched:
+ if (
+ base_mapper.confirm_deleted_rows
+ and rows_matched > -1
+ and expected != rows_matched
+ ):
if only_warn:
util.warn(
"DELETE statement on table '%s' expected to "
"delete %d row(s); %d were matched. Please set "
"confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning." %
- (table.description, expected, rows_matched)
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
)
else:
raise orm_exc.StaleDataError(
"DELETE statement on table '%s' expected to "
"delete %d row(s); %d were matched. Please set "
"confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning." %
- (table.description, expected, rows_matched)
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
)
@@ -1091,13 +1336,16 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
if mapper._readonly_props:
readonly = state.unmodified_intersection(
[
- p.key for p in mapper._readonly_props
+ p.key
+ for p in mapper._readonly_props
if (
- p.expire_on_flush and
- (not p.deferred or p.key in state.dict)
- ) or (
- not p.expire_on_flush and
- not p.deferred and p.key not in state.dict
+ p.expire_on_flush
+ and (not p.deferred or p.key in state.dict)
+ )
+ or (
+ not p.expire_on_flush
+ and not p.deferred
+ and p.key not in state.dict
)
]
)
@@ -1112,11 +1360,14 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
if base_mapper.eager_defaults:
toload_now.extend(
state._unloaded_non_object.intersection(
- mapper._server_default_plus_onupdate_propkeys)
+ mapper._server_default_plus_onupdate_propkeys
+ )
)
- if mapper.version_id_col is not None and \
- mapper.version_id_generator is False:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_generator is False
+ ):
if mapper._version_id_prop.key in state.unloaded:
toload_now.extend([mapper._version_id_prop.key])
@@ -1124,8 +1375,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
state.key = base_mapper._identity_key_from_state(state)
loading.load_on_ident(
uowtransaction.session.query(mapper),
- state.key, refresh_state=state,
- only_load_props=toload_now)
+ state.key,
+ refresh_state=state,
+ only_load_props=toload_now,
+ )
# call after_XXX extensions
if not has_identity:
@@ -1133,23 +1386,29 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
else:
mapper.dispatch.after_update(mapper, connection, state)
- if mapper.version_id_generator is False and \
- mapper.version_id_col is not None:
+ if (
+ mapper.version_id_generator is False
+ and mapper.version_id_col is not None
+ ):
if state_dict[mapper._version_id_prop.key] is None:
raise orm_exc.FlushError(
- "Instance does not contain a non-NULL version value")
+ "Instance does not contain a non-NULL version value"
+ )
-def _postfetch_post_update(mapper, uowtransaction, table,
- state, dict_, result, params):
+def _postfetch_post_update(
+ mapper, uowtransaction, table, state, dict_, result, params
+):
if uowtransaction.is_deleted(state):
return
prefetch_cols = result.context.compiled.prefetch
postfetch_cols = result.context.compiled.postfetch
- if mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
@@ -1164,18 +1423,23 @@ def _postfetch_post_update(mapper, uowtransaction, table,
if refresh_flush and load_evt_attrs:
mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs)
+ state, uowtransaction, load_evt_attrs
+ )
if postfetch_cols:
- state._expire_attributes(state.dict,
- [mapper._columntoproperty[c].key
- for c in postfetch_cols if c in
- mapper._columntoproperty]
- )
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
-def _postfetch(mapper, uowtransaction, table,
- state, dict_, result, params, value_params):
+def _postfetch(
+ mapper, uowtransaction, table, state, dict_, result, params, value_params
+):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
@@ -1184,8 +1448,10 @@ def _postfetch(mapper, uowtransaction, table,
postfetch_cols = result.context.compiled.postfetch
returning_cols = result.context.compiled.returning
- if mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
@@ -1219,23 +1485,32 @@ def _postfetch(mapper, uowtransaction, table,
if refresh_flush and load_evt_attrs:
mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs)
+ state, uowtransaction, load_evt_attrs
+ )
if postfetch_cols:
- state._expire_attributes(state.dict,
- [mapper._columntoproperty[c].key
- for c in postfetch_cols if c in
- mapper._columntoproperty]
- )
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates)
+ sync.populate(
+ state,
+ m,
+ state,
+ m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates,
+ )
def _postfetch_bulk_save(mapper, dict_, table):
@@ -1255,8 +1530,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
# organize individual states with the connection
# to use for update
if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
+ connection_callable = uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
@@ -1275,7 +1549,8 @@ def _cached_connection_dict(base_mapper):
return util.PopulateDict(
lambda conn: conn.execution_options(
compiled_cache=base_mapper._compiled_cache
- ))
+ )
+ )
def _sort_states(states):
@@ -1287,9 +1562,12 @@ def _sort_states(states):
except TypeError as err:
raise sa_exc.InvalidRequestError(
"Could not sort objects by primary key; primary key "
- "values must be sortable in Python (was: %s)" % err)
- return sorted(pending, key=operator.attrgetter("insert_order")) + \
- persistent_sorted
+ "values must be sortable in Python (was: %s)" % err
+ )
+ return (
+ sorted(pending, key=operator.attrgetter("insert_order"))
+ + persistent_sorted
+ )
class BulkUD(object):
@@ -1302,21 +1580,22 @@ class BulkUD(object):
def _validate_query_state(self):
for attr, methname, notset, op in (
- ('_limit', 'limit()', None, operator.is_),
- ('_offset', 'offset()', None, operator.is_),
- ('_order_by', 'order_by()', False, operator.is_),
- ('_group_by', 'group_by()', False, operator.is_),
- ('_distinct', 'distinct()', False, operator.is_),
+ ("_limit", "limit()", None, operator.is_),
+ ("_offset", "offset()", None, operator.is_),
+ ("_order_by", "order_by()", False, operator.is_),
+ ("_group_by", "group_by()", False, operator.is_),
+ ("_distinct", "distinct()", False, operator.is_),
(
- '_from_obj',
- 'join(), outerjoin(), select_from(), or from_self()',
- (), operator.eq)
+ "_from_obj",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
):
if not op(getattr(self.query, attr), notset):
raise sa_exc.InvalidRequestError(
"Can't call Query.update() or Query.delete() "
- "when %s has been called" %
- (methname, )
+ "when %s has been called" % (methname,)
)
@property
@@ -1330,8 +1609,8 @@ class BulkUD(object):
except KeyError:
raise sa_exc.ArgumentError(
"Valid strategies for session synchronization "
- "are %s" % (", ".join(sorted(repr(x)
- for x in lookup))))
+ "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
+ )
else:
return klass(*arg)
@@ -1400,9 +1679,9 @@ class BulkEvaluate(BulkUD):
try:
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
if query.whereclause is not None:
- eval_condition = evaluator_compiler.process(
- query.whereclause)
+ eval_condition = evaluator_compiler.process(query.whereclause)
else:
+
def eval_condition(obj):
return True
@@ -1411,15 +1690,20 @@ class BulkEvaluate(BulkUD):
except evaluator.UnevaluatableError as err:
raise sa_exc.InvalidRequestError(
'Could not evaluate current criteria in Python: "%s". '
- 'Specify \'fetch\' or False for the '
- 'synchronize_session parameter.' % err)
+ "Specify 'fetch' or False for the "
+ "synchronize_session parameter." % err
+ )
# TODO: detect when the where clause is a trivial primary key match
self.matched_objects = [
- obj for (cls, pk, identity_token), obj in
- query.session.identity_map.items()
- if issubclass(cls, target_cls) and
- eval_condition(obj)]
+ obj
+ for (
+ cls,
+ pk,
+ identity_token,
+ ), obj in query.session.identity_map.items()
+ if issubclass(cls, target_cls) and eval_condition(obj)
+ ]
class BulkFetch(BulkUD):
@@ -1430,11 +1714,11 @@ class BulkFetch(BulkUD):
session = query.session
context = query._compile_context()
select_stmt = context.statement.with_only_columns(
- self.primary_table.primary_key)
+ self.primary_table.primary_key
+ )
self.matched_rows = session.execute(
- select_stmt,
- mapper=self.mapper,
- params=query._params).fetchall()
+ select_stmt, mapper=self.mapper, params=query._params
+ ).fetchall()
class BulkUpdate(BulkUD):
@@ -1447,18 +1731,26 @@ class BulkUpdate(BulkUD):
@classmethod
def factory(cls, query, synchronize_session, values, update_kwargs):
- return BulkUD._factory({
- "evaluate": BulkUpdateEvaluate,
- "fetch": BulkUpdateFetch,
- False: BulkUpdate
- }, synchronize_session, query, values, update_kwargs)
+ return BulkUD._factory(
+ {
+ "evaluate": BulkUpdateEvaluate,
+ "fetch": BulkUpdateFetch,
+ False: BulkUpdate,
+ },
+ synchronize_session,
+ query,
+ values,
+ update_kwargs,
+ )
@property
def _resolved_values(self):
values = []
for k, v in (
- self.values.items() if hasattr(self.values, 'items')
- else self.values):
+ self.values.items()
+ if hasattr(self.values, "items")
+ else self.values
+ ):
if self.mapper:
if isinstance(k, util.string_types):
desc = _entity_descriptor(self.mapper, k)
@@ -1478,7 +1770,7 @@ class BulkUpdate(BulkUD):
if isinstance(k, attributes.QueryableAttribute):
values.append((k.key, v))
continue
- elif hasattr(k, '__clause_element__'):
+ elif hasattr(k, "__clause_element__"):
k = k.__clause_element__()
if self.mapper and isinstance(k, expression.ColumnElement):
@@ -1490,18 +1782,22 @@ class BulkUpdate(BulkUD):
values.append((attr.key, v))
else:
raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k)
+ "Invalid expression type: %r" % k
+ )
return values
def _do_exec(self):
values = self._resolved_values
- if not self.update_kwargs.get('preserve_parameter_order', False):
+ if not self.update_kwargs.get("preserve_parameter_order", False):
values = dict(values)
- update_stmt = sql.update(self.primary_table,
- self.context.whereclause, values,
- **self.update_kwargs)
+ update_stmt = sql.update(
+ self.primary_table,
+ self.context.whereclause,
+ values,
+ **self.update_kwargs
+ )
self._execute_stmt(update_stmt)
@@ -1518,15 +1814,18 @@ class BulkDelete(BulkUD):
@classmethod
def factory(cls, query, synchronize_session):
- return BulkUD._factory({
- "evaluate": BulkDeleteEvaluate,
- "fetch": BulkDeleteFetch,
- False: BulkDelete
- }, synchronize_session, query)
+ return BulkUD._factory(
+ {
+ "evaluate": BulkDeleteEvaluate,
+ "fetch": BulkDeleteFetch,
+ False: BulkDelete,
+ },
+ synchronize_session,
+ query,
+ )
def _do_exec(self):
- delete_stmt = sql.delete(self.primary_table,
- self.context.whereclause)
+ delete_stmt = sql.delete(self.primary_table, self.context.whereclause)
self._execute_stmt(delete_stmt)
@@ -1544,32 +1843,33 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
values = self._resolved_values_keys_as_propnames
for key, value in values:
self.value_evaluators[key] = evaluator_compiler.process(
- expression._literal_as_binds(value))
+ expression._literal_as_binds(value)
+ )
def _do_post_synchronize(self):
session = self.query.session
states = set()
evaluated_keys = list(self.value_evaluators.keys())
for obj in self.matched_objects:
- state, dict_ = attributes.instance_state(obj),\
- attributes.instance_dict(obj)
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
# only evaluate unmodified attributes
- to_evaluate = state.unmodified.intersection(
- evaluated_keys)
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
dict_[key] = self.value_evaluators[key](obj)
- state.manager.dispatch.refresh(
- state, None, to_evaluate)
+ state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
# expire attributes with pending changes
# (there was no autoflush, so they are overwritten)
- state._expire_attributes(dict_,
- set(evaluated_keys).
- difference(to_evaluate))
+ state._expire_attributes(
+ dict_, set(evaluated_keys).difference(to_evaluate)
+ )
states.add(state)
session._register_altered(states)
@@ -1580,8 +1880,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete):
def _do_post_synchronize(self):
self.query.session._remove_newly_deleted(
- [attributes.instance_state(obj)
- for obj in self.matched_objects])
+ [attributes.instance_state(obj) for obj in self.matched_objects]
+ )
class BulkUpdateFetch(BulkFetch, BulkUpdate):
@@ -1592,15 +1892,18 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate):
session = self.query.session
target_mapper = self.query._mapper_zero()
- states = set([
- attributes.instance_state(session.identity_map[identity_key])
- for identity_key in [
- target_mapper.identity_key_from_primary_key(
- list(primary_key))
- for primary_key in self.matched_rows
+ states = set(
+ [
+ attributes.instance_state(session.identity_map[identity_key])
+ for identity_key in [
+ target_mapper.identity_key_from_primary_key(
+ list(primary_key)
+ )
+ for primary_key in self.matched_rows
+ ]
+ if identity_key in session.identity_map
]
- if identity_key in session.identity_map
- ])
+ )
values = self._resolved_values_keys_as_propnames
attrib = set(k for k, v in values)
@@ -1622,10 +1925,13 @@ class BulkDeleteFetch(BulkFetch, BulkDelete):
# TODO: inline this and call remove_newly_deleted
# once
identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key))
+ list(primary_key)
+ )
if identity_key in session.identity_map:
session._remove_newly_deleted(
- [attributes.instance_state(
- session.identity_map[identity_key]
- )]
+ [
+ attributes.instance_state(
+ session.identity_map[identity_key]
+ )
+ ]
)