summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-01-06 01:14:26 -0500
committermike bayer <mike_mp@zzzcomputing.com>2019-01-06 17:34:50 +0000
commit1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch)
tree28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/orm/persistence.py
parent404e69426b05a82d905cbb3ad33adafccddb00dd (diff)
downloadsqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits applied at all. The black run will format code consistently, however in some cases that are prevalent in SQLAlchemy code it produces too-long lines. The too-long lines will be resolved in the following commit that will resolve all remaining flake8 issues including shadowed builtins, long lines, import order, unused imports, duplicate imports, and docstring issues. Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py1186
1 files changed, 746 insertions, 440 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 7f9b7db0c..dc86a60e5 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -25,8 +25,13 @@ from . import loading
def _bulk_insert(
- mapper, mappings, session_transaction, isstates, return_defaults,
- render_nulls):
+ mapper,
+ mappings,
+ session_transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
@@ -34,7 +39,8 @@ def _bulk_insert(
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
- "not supported in bulk_insert()")
+ "not supported in bulk_insert()"
+ )
if isstates:
if return_defaults:
@@ -51,22 +57,33 @@ def _bulk_insert(
continue
records = (
- (None, state_dict, params, mapper,
- connection, value_params, has_all_pks, has_all_defaults)
- for
- state, state_dict, params, mp,
- conn, value_params, has_all_pks,
- has_all_defaults in _collect_insert_commands(table, (
- (None, mapping, mapper, connection)
- for mapping in mappings),
- bulk=True, return_defaults=return_defaults,
- render_nulls=render_nulls
+ (
+ None,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
+ for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(
+ table,
+ ((None, mapping, mapper, connection) for mapping in mappings),
+ bulk=True,
+ return_defaults=return_defaults,
+ render_nulls=render_nulls,
)
)
- _emit_insert_statements(base_mapper, None,
- cached_connections,
- super_mapper, table, records,
- bookkeeping=return_defaults)
+ _emit_insert_statements(
+ base_mapper,
+ None,
+ cached_connections,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=return_defaults,
+ )
if return_defaults and isstates:
identity_cls = mapper._identity_class
@@ -74,12 +91,13 @@ def _bulk_insert(
for state, dict_ in states:
state.key = (
identity_cls,
- tuple([dict_[key] for key in identity_props])
+ tuple([dict_[key] for key in identity_props]),
)
-def _bulk_update(mapper, mappings, session_transaction,
- isstates, update_changed_only):
+def _bulk_update(
+ mapper, mappings, session_transaction, isstates, update_changed_only
+):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
@@ -91,9 +109,8 @@ def _bulk_update(mapper, mappings, session_transaction,
def _changed_dict(mapper, state):
return dict(
(k, v)
- for k, v in state.dict.items() if k in state.committed_state or k
- in search_keys
-
+ for k, v in state.dict.items()
+ if k in state.committed_state or k in search_keys
)
if isstates:
@@ -107,7 +124,8 @@ def _bulk_update(mapper, mappings, session_transaction,
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
- "not supported in bulk_update()")
+ "not supported in bulk_update()"
+ )
connection = session_transaction.connection(base_mapper)
@@ -115,21 +133,38 @@ def _bulk_update(mapper, mappings, session_transaction,
if not mapper.isa(super_mapper):
continue
- records = _collect_update_commands(None, table, (
- (None, mapping, mapper, connection,
- (mapping[mapper._version_id_prop.key]
- if mapper._version_id_prop else None))
- for mapping in mappings
- ), bulk=True)
+ records = _collect_update_commands(
+ None,
+ table,
+ (
+ (
+ None,
+ mapping,
+ mapper,
+ connection,
+ (
+ mapping[mapper._version_id_prop.key]
+ if mapper._version_id_prop
+ else None
+ ),
+ )
+ for mapping in mappings
+ ),
+ bulk=True,
+ )
- _emit_update_statements(base_mapper, None,
- cached_connections,
- super_mapper, table, records,
- bookkeeping=False)
+ _emit_update_statements(
+ base_mapper,
+ None,
+ cached_connections,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=False,
+ )
-def save_obj(
- base_mapper, states, uowtransaction, single=False):
+def save_obj(base_mapper, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
@@ -150,19 +185,21 @@ def save_obj(
states_to_insert = []
cached_connections = _cached_connection_dict(base_mapper)
- for (state, dict_, mapper, connection,
- has_identity,
- row_switch, update_version_id) in _organize_states_for_save(
- base_mapper, states, uowtransaction
- ):
+ for (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ ) in _organize_states_for_save(base_mapper, states, uowtransaction):
if has_identity or row_switch:
states_to_update.append(
(state, dict_, mapper, connection, update_version_id)
)
else:
- states_to_insert.append(
- (state, dict_, mapper, connection)
- )
+ states_to_insert.append((state, dict_, mapper, connection))
for table, mapper in base_mapper._sorted_tables.items():
if table not in mapper._pks_by_table:
@@ -170,18 +207,30 @@ def save_obj(
insert = _collect_insert_commands(table, states_to_insert)
update = _collect_update_commands(
- uowtransaction, table, states_to_update)
+ uowtransaction, table, states_to_update
+ )
- _emit_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ )
- _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, insert)
+ _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ insert,
+ )
_finalize_insert_update_commands(
- base_mapper, uowtransaction,
+ base_mapper,
+ uowtransaction,
chain(
(
(state, state_dict, mapper, connection, False)
@@ -189,10 +238,9 @@ def save_obj(
),
(
(state, state_dict, mapper, connection, True)
- for state, state_dict, mapper, connection,
- update_version_id in states_to_update
- )
- )
+ for state, state_dict, mapper, connection, update_version_id in states_to_update
+ ),
+ ),
)
@@ -203,9 +251,9 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
"""
cached_connections = _cached_connection_dict(base_mapper)
- states_to_update = list(_organize_states_for_post_update(
- base_mapper,
- states, uowtransaction))
+ states_to_update = list(
+ _organize_states_for_post_update(base_mapper, states, uowtransaction)
+ )
for table, mapper in base_mapper._sorted_tables.items():
if table not in mapper._pks_by_table:
@@ -213,25 +261,32 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
update = (
(
- state, state_dict, sub_mapper, connection,
+ state,
+ state_dict,
+ sub_mapper,
+ connection,
mapper._get_committed_state_attr_by_column(
state, state_dict, mapper.version_id_col
- ) if mapper.version_id_col is not None else None
+ )
+ if mapper.version_id_col is not None
+ else None,
)
- for
- state, state_dict, sub_mapper, connection in states_to_update
+ for state, state_dict, sub_mapper, connection in states_to_update
if table in sub_mapper._pks_by_table
)
update = _collect_post_update_commands(
- base_mapper, uowtransaction,
- table, update,
- post_update_cols
+ base_mapper, uowtransaction, table, update, post_update_cols
)
- _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_post_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ )
def delete_obj(base_mapper, states, uowtransaction):
@@ -244,10 +299,9 @@ def delete_obj(base_mapper, states, uowtransaction):
cached_connections = _cached_connection_dict(base_mapper)
- states_to_delete = list(_organize_states_for_delete(
- base_mapper,
- states,
- uowtransaction))
+ states_to_delete = list(
+ _organize_states_for_delete(base_mapper, states, uowtransaction)
+ )
table_to_mapper = base_mapper._sorted_tables
@@ -258,14 +312,26 @@ def delete_obj(base_mapper, states, uowtransaction):
elif mapper.inherits and mapper.passive_deletes:
continue
- delete = _collect_delete_commands(base_mapper, uowtransaction,
- table, states_to_delete)
+ delete = _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+ )
- _emit_delete_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, delete)
+ _emit_delete_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ delete,
+ )
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_delete:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
mapper.dispatch.after_delete(mapper, connection, state)
@@ -282,8 +348,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
"""
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction, states
+ ):
has_identity = bool(state.key)
@@ -304,25 +370,29 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
# no instance_key attached to it), and another instance
# with the same identity key already exists as persistent.
# convert to an UPDATE if so.
- if not has_identity and \
- instance_key in uowtransaction.session.identity_map:
- instance = \
- uowtransaction.session.identity_map[instance_key]
+ if (
+ not has_identity
+ and instance_key in uowtransaction.session.identity_map
+ ):
+ instance = uowtransaction.session.identity_map[instance_key]
existing = attributes.instance_state(instance)
if not uowtransaction.was_already_deleted(existing):
if not uowtransaction.is_deleted(existing):
raise orm_exc.FlushError(
"New instance %s with identity key %s conflicts "
- "with persistent instance %s" %
- (state_str(state), instance_key,
- state_str(existing)))
+ "with persistent instance %s"
+ % (state_str(state), instance_key, state_str(existing))
+ )
base_mapper._log_debug(
"detected row switch for identity %s. "
"will update %s, remove %s from "
- "transaction", instance_key,
- state_str(state), state_str(existing))
+ "transaction",
+ instance_key,
+ state_str(state),
+ state_str(existing),
+ )
# remove the "delete" flag from the existing element
uowtransaction.remove_state_actions(existing)
@@ -332,14 +402,21 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
update_version_id = mapper._get_committed_state_attr_by_column(
row_switch if row_switch else state,
row_switch.dict if row_switch else dict_,
- mapper.version_id_col)
+ mapper.version_id_col,
+ )
- yield (state, dict_, mapper, connection,
- has_identity, row_switch, update_version_id)
+ yield (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ )
-def _organize_states_for_post_update(base_mapper, states,
- uowtransaction):
+def _organize_states_for_post_update(base_mapper, states, uowtransaction):
"""Make an initial pass across a set of states for UPDATE
corresponding to post_update.
@@ -360,26 +437,28 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
"""
for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction,
- states):
+ base_mapper, uowtransaction, states
+ ):
mapper.dispatch.before_delete(mapper, connection, state)
if mapper.version_id_col is not None:
- update_version_id = \
- mapper._get_committed_state_attr_by_column(
- state, dict_,
- mapper.version_id_col)
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ )
else:
update_version_id = None
- yield (
- state, dict_, mapper, connection, update_version_id)
+ yield (state, dict_, mapper, connection, update_version_id)
def _collect_insert_commands(
- table, states_to_insert,
- bulk=False, return_defaults=False, render_nulls=False):
+ table,
+ states_to_insert,
+ bulk=False,
+ return_defaults=False,
+ render_nulls=False,
+):
"""Identify sets of values to use in INSERT statements for a
list of states.
@@ -400,10 +479,16 @@ def _collect_insert_commands(
col = propkey_to_col[propkey]
if value is None and col not in eval_none and not render_nulls:
continue
- elif not bulk and hasattr(value, '__clause_element__') or \
- isinstance(value, sql.ClauseElement):
- value_params[col.key] = value.__clause_element__() \
- if hasattr(value, '__clause_element__') else value
+ elif (
+ not bulk
+ and hasattr(value, "__clause_element__")
+ or isinstance(value, sql.ClauseElement)
+ ):
+ value_params[col.key] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
else:
params[col.key] = value
@@ -414,8 +499,11 @@ def _collect_insert_commands(
# which might be worth removing, as it should not be necessary
# and also produces confusion, given that "missing" and None
# now have distinct meanings
- for colkey in mapper._insert_cols_as_none[table].\
- difference(params).difference(value_params):
+ for colkey in (
+ mapper._insert_cols_as_none[table]
+ .difference(params)
+ .difference(value_params)
+ ):
params[colkey] = None
if not bulk or return_defaults:
@@ -424,28 +512,38 @@ def _collect_insert_commands(
has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].\
- issubset(params)
+ has_all_defaults = mapper._server_default_cols[table].issubset(
+ params
+ )
else:
has_all_defaults = True
else:
has_all_defaults = has_all_pks = True
- if mapper.version_id_generator is not False \
- and mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
- params[mapper.version_id_col.key] = \
- mapper.version_id_generator(None)
+ if (
+ mapper.version_id_generator is not False
+ and mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ params[mapper.version_id_col.key] = mapper.version_id_generator(
+ None
+ )
yield (
- state, state_dict, params, mapper,
- connection, value_params, has_all_pks,
- has_all_defaults)
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
def _collect_update_commands(
- uowtransaction, table, states_to_update,
- bulk=False):
+ uowtransaction, table, states_to_update, bulk=False
+):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -457,8 +555,13 @@ def _collect_update_commands(
"""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_update:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
if table not in mapper._pks_by_table:
continue
@@ -474,36 +577,48 @@ def _collect_update_commands(
# look at mapper attribute keys for pk
params = dict(
(propkey_to_col[propkey].key, state_dict[propkey])
- for propkey in
- set(propkey_to_col).intersection(state_dict).difference(
- mapper._pk_attr_keys_by_table[table])
+ for propkey in set(propkey_to_col)
+ .intersection(state_dict)
+ .difference(mapper._pk_attr_keys_by_table[table])
)
has_all_defaults = True
else:
params = {}
for propkey in set(propkey_to_col).intersection(
- state.committed_state):
+ state.committed_state
+ ):
value = state_dict[propkey]
col = propkey_to_col[propkey]
- if hasattr(value, '__clause_element__') or \
- isinstance(value, sql.ClauseElement):
- value_params[col] = value.__clause_element__() \
- if hasattr(value, '__clause_element__') else value
+ if hasattr(value, "__clause_element__") or isinstance(
+ value, sql.ClauseElement
+ ):
+ value_params[col] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
# guard against values that generate non-__nonzero__
# objects for __eq__()
- elif state.manager[propkey].impl.is_equal(
- value, state.committed_state[propkey]) is not True:
+ elif (
+ state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]
+ )
+ is not True
+ ):
params[col.key] = value
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_onupdate_default_cols[table].\
- issubset(params)
+ has_all_defaults = mapper._server_onupdate_default_cols[
+ table
+ ].issubset(params)
else:
has_all_defaults = True
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
if not bulk and not (params or value_params):
# HACK: check for history in other tables, in case the
@@ -511,10 +626,9 @@ def _collect_update_commands(
# where the version_id_col is. This logic was lost
# from 0.9 -> 1.0.0 and restored in 1.0.6.
for prop in mapper._columntoproperty.values():
- history = (
- state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE))
+ history = state.manager[prop.key].impl.get_history(
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
if history.added:
break
else:
@@ -525,8 +639,9 @@ def _collect_update_commands(
no_params = not params and not value_params
params[col._label] = update_version_id
- if (bulk or col.key not in params) and \
- mapper.version_id_generator is not False:
+ if (
+ bulk or col.key not in params
+ ) and mapper.version_id_generator is not False:
val = mapper.version_id_generator(update_version_id)
params[col.key] = val
elif mapper.version_id_generator is False and no_params:
@@ -545,9 +660,9 @@ def _collect_update_commands(
# look at mapper attribute keys for pk
pk_params = dict(
(propkey_to_col[propkey]._label, state_dict.get(propkey))
- for propkey in
- set(propkey_to_col).
- intersection(mapper._pk_attr_keys_by_table[table])
+ for propkey in set(propkey_to_col).intersection(
+ mapper._pk_attr_keys_by_table[table]
+ )
)
else:
pk_params = {}
@@ -555,12 +670,15 @@ def _collect_update_commands(
propkey = mapper._columntoproperty[col].key
history = state.manager[propkey].impl.get_history(
- state, state_dict, attributes.PASSIVE_OFF)
+ state, state_dict, attributes.PASSIVE_OFF
+ )
if history.added:
- if not history.deleted or \
- ("pk_cascaded", state, col) in \
- uowtransaction.attributes:
+ if (
+ not history.deleted
+ or ("pk_cascaded", state, col)
+ in uowtransaction.attributes
+ ):
pk_params[col._label] = history.added[0]
params.pop(col.key, None)
else:
@@ -573,24 +691,38 @@ def _collect_update_commands(
if pk_params[col._label] is None:
raise orm_exc.FlushError(
"Can't update table %s using NULL for primary "
- "key value on column %s" % (table, col))
+ "key value on column %s" % (table, col)
+ )
if params or value_params:
params.update(pk_params)
yield (
- state, state_dict, params, mapper,
- connection, value_params, has_all_defaults, has_all_pks)
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ )
-def _collect_post_update_commands(base_mapper, uowtransaction, table,
- states_to_update, post_update_cols):
+def _collect_post_update_commands(
+ base_mapper, uowtransaction, table, states_to_update, post_update_cols
+):
"""Identify sets of values to use in UPDATE statements for a
list of states within a post_update operation.
"""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_update:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
# assert table in mapper._pks_by_table
@@ -600,100 +732,128 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
for col in mapper._cols_by_table[table]:
if col in pks:
- params[col._label] = \
- mapper._get_state_attr_by_column(
- state,
- state_dict, col, passive=attributes.PASSIVE_OFF)
+ params[col._label] = mapper._get_state_attr_by_column(
+ state, state_dict, col, passive=attributes.PASSIVE_OFF
+ )
elif col in post_update_cols or col.onupdate is not None:
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
- state, state_dict,
- attributes.PASSIVE_NO_INITIALIZE)
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
if history.added:
value = history.added[0]
params[col.key] = value
hasdata = True
if hasdata:
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
col = mapper.version_id_col
params[col._label] = update_version_id
- if bool(state.key) and col.key not in params and \
- mapper.version_id_generator is not False:
+ if (
+ bool(state.key)
+ and col.key not in params
+ and mapper.version_id_generator is not False
+ ):
val = mapper.version_id_generator(update_version_id)
params[col.key] = val
yield state, state_dict, mapper, connection, params
-def _collect_delete_commands(base_mapper, uowtransaction, table,
- states_to_delete):
+def _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+):
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
- for state, state_dict, mapper, connection, \
- update_version_id in states_to_delete:
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
if table not in mapper._pks_by_table:
continue
params = {}
for col in mapper._pks_by_table[table]:
- params[col.key] = \
- value = \
- mapper._get_committed_state_attr_by_column(
- state, state_dict, col)
+ params[
+ col.key
+ ] = value = mapper._get_committed_state_attr_by_column(
+ state, state_dict, col
+ )
if value is None:
raise orm_exc.FlushError(
"Can't delete from table %s "
"using NULL for primary "
- "key value on column %s" % (table, col))
+ "key value on column %s" % (table, col)
+ )
- if update_version_id is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
params[mapper.version_id_col.key] = update_version_id
yield params, connection
-def _emit_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update,
- bookkeeping=True):
+def _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ update,
+ bookkeeping=True,
+):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
- needs_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ clause.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
if needs_version_id:
clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col
+ == sql.bindparam(
mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
+ type_=mapper.version_id_col.type,
+ )
+ )
stmt = table.update(clause)
return stmt
- cached_stmt = base_mapper._memo(('update', table), update_stmt)
-
- for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
- records in groupby(
- update,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # set of parameter keys
- bool(rec[5]), # whether or not we have "value" parameters
- rec[6], # has_all_defaults
- rec[7] # has all pks
- )
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
+
+ for (
+ (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
+ records,
+ ) in groupby(
+ update,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # set of parameter keys
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6], # has_all_defaults
+ rec[7], # has all pks
+ ),
):
rows = 0
records = list(records)
@@ -704,8 +864,11 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not has_all_pks:
statement = statement.return_defaults()
return_defaults = True
- elif bookkeeping and not has_all_defaults and \
- mapper.base_mapper.eager_defaults:
+ elif (
+ bookkeeping
+ and not has_all_defaults
+ and mapper.base_mapper.eager_defaults
+ ):
statement = statement.return_defaults()
return_defaults = True
elif mapper.version_id_col is not None:
@@ -718,17 +881,24 @@ def _emit_update_statements(base_mapper, uowtransaction,
else connection.dialect.supports_sane_rowcount_returning
)
- assert_multirow = assert_singlerow and \
- connection.dialect.supports_sane_multi_rowcount
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
allow_multirow = has_all_defaults and not needs_version_id
if hasvalue:
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_defaults, has_all_pks in records:
- c = connection.execute(
- statement.values(value_params),
- params)
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = connection.execute(statement.values(value_params), params)
if bookkeeping:
_postfetch(
mapper,
@@ -738,17 +908,26 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
rows += c.rowcount
check_rowcount = assert_singlerow
else:
if not allow_multirow:
check_rowcount = assert_singlerow
- for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults, \
- has_all_pks in records:
- c = cached_connections[connection].\
- execute(statement, params)
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = cached_connections[connection].execute(
+ statement, params
+ )
# TODO: why with bookkeeping=False?
if bookkeeping:
@@ -760,24 +939,32 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
rows += c.rowcount
else:
multiparams = [rec[2] for rec in records]
check_rowcount = assert_multirow or (
- assert_singlerow and
- len(multiparams) == 1
+ assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(
+ statement, multiparams
+ )
rows += c.rowcount
- for state, state_dict, params, mapper, \
- connection, value_params, \
- has_all_defaults, has_all_pks in records:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
if bookkeeping:
_postfetch(
mapper,
@@ -787,59 +974,85 @@ def _emit_update_statements(base_mapper, uowtransaction,
state_dict,
c,
c.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(records), rows))
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description)
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
-def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert,
- bookkeeping=True):
+def _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ cached_connections,
+ mapper,
+ table,
+ insert,
+ bookkeeping=True,
+):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(('insert', table), table.insert)
-
- for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
- records in groupby(
- insert,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # parameter keys
- bool(rec[5]), # whether we have "value" parameters
- rec[6],
- rec[7])):
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+
+ for (
+ (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ records,
+ ) in groupby(
+ insert,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # parameter keys
+ bool(rec[5]), # whether we have "value" parameters
+ rec[6],
+ rec[7],
+ ),
+ ):
statement = cached_stmt
- if not bookkeeping or \
- (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not connection.dialect.implicit_returning
- ) and has_all_pks and not hasvalue:
+ if (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ )
+ and has_all_pks
+ and not hasvalue
+ ):
records = list(records)
multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(statement, multiparams)
if bookkeeping:
- for (state, state_dict, params, mapper_rec,
- conn, value_params, has_all_pks, has_all_defaults), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ ) in zip(records, c.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
@@ -849,7 +1062,8 @@ def _emit_insert_statements(base_mapper, uowtransaction,
state_dict,
c,
last_inserted_params,
- value_params)
+ value_params,
+ )
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
@@ -859,24 +1073,33 @@ def _emit_insert_statements(base_mapper, uowtransaction,
elif mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
- for state, state_dict, params, mapper_rec, \
- connection, value_params, \
- has_all_pks, has_all_defaults in records:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in records:
if value_params:
result = connection.execute(
- statement.values(value_params),
- params)
+ statement.values(value_params), params
+ )
else:
- result = cached_connections[connection].\
- execute(statement, params)
+ result = cached_connections[connection].execute(
+ statement, params
+ )
primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
- for pk, col in zip(primary_key,
- mapper._pks_by_table[table]):
+ for pk, col in zip(
+ primary_key, mapper._pks_by_table[table]
+ ):
prop = mapper_rec._columntoproperty[col]
if state_dict.get(prop.key) is None:
state_dict[prop.key] = pk
@@ -890,31 +1113,39 @@ def _emit_insert_statements(base_mapper, uowtransaction,
state_dict,
result,
result.context.compiled_parameters[0],
- value_params)
+ value_params,
+ )
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
-def _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update):
+def _emit_post_update_statements(
+ base_mapper, uowtransaction, cached_connections, mapper, table, update
+):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
- needs_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def update_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
+ clause.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
if needs_version_id:
clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col
+ == sql.bindparam(
mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
+ type_=mapper.version_id_col.type,
+ )
+ )
stmt = table.update(clause)
@@ -923,17 +1154,15 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
return stmt
- statement = base_mapper._memo(('post_update', table), update_stmt)
+ statement = base_mapper._memo(("post_update", table), update_stmt)
# execute each UPDATE in the order according to the original
# list of states to guarantee row access order, but
# also group them into common (connection, cols) sets
# to support executemany().
for key, records in groupby(
- update, lambda rec: (
- rec[3], # connection
- set(rec[4]), # parameter keys
- )
+ update,
+ lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
):
rows = 0
@@ -945,84 +1174,96 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
if mapper.version_id_col is None
else connection.dialect.supports_sane_rowcount_returning
)
- assert_multirow = assert_singlerow and \
- connection.dialect.supports_sane_multi_rowcount
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
allow_multirow = not needs_version_id or assert_multirow
-
if not allow_multirow:
check_rowcount = assert_singlerow
- for state, state_dict, mapper_rec, \
- connection, params in records:
- c = cached_connections[connection].\
- execute(statement, params)
+ for state, state_dict, mapper_rec, connection, params in records:
+ c = cached_connections[connection].execute(statement, params)
_postfetch_post_update(
- mapper_rec, uowtransaction, table, state, state_dict,
- c, c.context.compiled_parameters[0])
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
rows += c.rowcount
else:
multiparams = [
- params for
- state, state_dict, mapper_rec, conn, params in records]
+ params
+ for state, state_dict, mapper_rec, conn, params in records
+ ]
check_rowcount = assert_multirow or (
- assert_singlerow and
- len(multiparams) == 1
+ assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].\
- execute(statement, multiparams)
+ c = cached_connections[connection].execute(statement, multiparams)
rows += c.rowcount
- for state, state_dict, mapper_rec, \
- connection, params in records:
+ for state, state_dict, mapper_rec, connection, params in records:
_postfetch_post_update(
- mapper_rec, uowtransaction, table, state, state_dict,
- c, c.context.compiled_parameters[0])
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(records), rows))
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description)
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
-def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
- mapper, table, delete):
+def _emit_delete_statements(
+ base_mapper, uowtransaction, cached_connections, mapper, table, delete
+):
"""Emit DELETE statements corresponding to value lists collected
by _collect_delete_commands()."""
- need_version_id = mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]
+ need_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
def delete_stmt():
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(
- col == sql.bindparam(col.key, type_=col.type))
+ col == sql.bindparam(col.key, type_=col.type)
+ )
if need_version_id:
clause.clauses.append(
- mapper.version_id_col ==
- sql.bindparam(
- mapper.version_id_col.key,
- type_=mapper.version_id_col.type
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col.key, type_=mapper.version_id_col.type
)
)
return table.delete(clause)
- statement = base_mapper._memo(('delete', table), delete_stmt)
- for connection, recs in groupby(
- delete,
- lambda rec: rec[1] # connection
- ):
+ statement = base_mapper._memo(("delete", table), delete_stmt)
+ for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
del_objects = [params for params, connection in recs]
connection = cached_connections[connection]
@@ -1049,9 +1290,10 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
else:
util.warn(
"Dialect %s does not support deleted rowcount "
- "- versioning cannot be verified." %
- connection.dialect.dialect_description,
- stacklevel=12)
+ "- versioning cannot be verified."
+ % connection.dialect.dialect_description,
+ stacklevel=12,
+ )
connection.execute(statement, del_objects)
else:
c = connection.execute(statement, del_objects)
@@ -1061,23 +1303,26 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
rows_matched = c.rowcount
- if base_mapper.confirm_deleted_rows and \
- rows_matched > -1 and expected != rows_matched:
+ if (
+ base_mapper.confirm_deleted_rows
+ and rows_matched > -1
+ and expected != rows_matched
+ ):
if only_warn:
util.warn(
"DELETE statement on table '%s' expected to "
"delete %d row(s); %d were matched. Please set "
"confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning." %
- (table.description, expected, rows_matched)
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
)
else:
raise orm_exc.StaleDataError(
"DELETE statement on table '%s' expected to "
"delete %d row(s); %d were matched. Please set "
"confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning." %
- (table.description, expected, rows_matched)
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
)
@@ -1091,13 +1336,16 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
if mapper._readonly_props:
readonly = state.unmodified_intersection(
[
- p.key for p in mapper._readonly_props
+ p.key
+ for p in mapper._readonly_props
if (
- p.expire_on_flush and
- (not p.deferred or p.key in state.dict)
- ) or (
- not p.expire_on_flush and
- not p.deferred and p.key not in state.dict
+ p.expire_on_flush
+ and (not p.deferred or p.key in state.dict)
+ )
+ or (
+ not p.expire_on_flush
+ and not p.deferred
+ and p.key not in state.dict
)
]
)
@@ -1112,11 +1360,14 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
if base_mapper.eager_defaults:
toload_now.extend(
state._unloaded_non_object.intersection(
- mapper._server_default_plus_onupdate_propkeys)
+ mapper._server_default_plus_onupdate_propkeys
+ )
)
- if mapper.version_id_col is not None and \
- mapper.version_id_generator is False:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_generator is False
+ ):
if mapper._version_id_prop.key in state.unloaded:
toload_now.extend([mapper._version_id_prop.key])
@@ -1124,8 +1375,10 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
state.key = base_mapper._identity_key_from_state(state)
loading.load_on_ident(
uowtransaction.session.query(mapper),
- state.key, refresh_state=state,
- only_load_props=toload_now)
+ state.key,
+ refresh_state=state,
+ only_load_props=toload_now,
+ )
# call after_XXX extensions
if not has_identity:
@@ -1133,23 +1386,29 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
else:
mapper.dispatch.after_update(mapper, connection, state)
- if mapper.version_id_generator is False and \
- mapper.version_id_col is not None:
+ if (
+ mapper.version_id_generator is False
+ and mapper.version_id_col is not None
+ ):
if state_dict[mapper._version_id_prop.key] is None:
raise orm_exc.FlushError(
- "Instance does not contain a non-NULL version value")
+ "Instance does not contain a non-NULL version value"
+ )
-def _postfetch_post_update(mapper, uowtransaction, table,
- state, dict_, result, params):
+def _postfetch_post_update(
+ mapper, uowtransaction, table, state, dict_, result, params
+):
if uowtransaction.is_deleted(state):
return
prefetch_cols = result.context.compiled.prefetch
postfetch_cols = result.context.compiled.postfetch
- if mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
@@ -1164,18 +1423,23 @@ def _postfetch_post_update(mapper, uowtransaction, table,
if refresh_flush and load_evt_attrs:
mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs)
+ state, uowtransaction, load_evt_attrs
+ )
if postfetch_cols:
- state._expire_attributes(state.dict,
- [mapper._columntoproperty[c].key
- for c in postfetch_cols if c in
- mapper._columntoproperty]
- )
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
-def _postfetch(mapper, uowtransaction, table,
- state, dict_, result, params, value_params):
+def _postfetch(
+ mapper, uowtransaction, table, state, dict_, result, params, value_params
+):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
@@ -1184,8 +1448,10 @@ def _postfetch(mapper, uowtransaction, table,
postfetch_cols = result.context.compiled.postfetch
returning_cols = result.context.compiled.returning
- if mapper.version_id_col is not None and \
- mapper.version_id_col in mapper._cols_by_table[table]:
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
@@ -1219,23 +1485,32 @@ def _postfetch(mapper, uowtransaction, table,
if refresh_flush and load_evt_attrs:
mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs)
+ state, uowtransaction, load_evt_attrs
+ )
if postfetch_cols:
- state._expire_attributes(state.dict,
- [mapper._columntoproperty[c].key
- for c in postfetch_cols if c in
- mapper._columntoproperty]
- )
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates)
+ sync.populate(
+ state,
+ m,
+ state,
+ m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates,
+ )
def _postfetch_bulk_save(mapper, dict_, table):
@@ -1255,8 +1530,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
# organize individual states with the connection
# to use for update
if uowtransaction.session.connection_callable:
- connection_callable = \
- uowtransaction.session.connection_callable
+ connection_callable = uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
@@ -1275,7 +1549,8 @@ def _cached_connection_dict(base_mapper):
return util.PopulateDict(
lambda conn: conn.execution_options(
compiled_cache=base_mapper._compiled_cache
- ))
+ )
+ )
def _sort_states(states):
@@ -1287,9 +1562,12 @@ def _sort_states(states):
except TypeError as err:
raise sa_exc.InvalidRequestError(
"Could not sort objects by primary key; primary key "
- "values must be sortable in Python (was: %s)" % err)
- return sorted(pending, key=operator.attrgetter("insert_order")) + \
- persistent_sorted
+ "values must be sortable in Python (was: %s)" % err
+ )
+ return (
+ sorted(pending, key=operator.attrgetter("insert_order"))
+ + persistent_sorted
+ )
class BulkUD(object):
@@ -1302,21 +1580,22 @@ class BulkUD(object):
def _validate_query_state(self):
for attr, methname, notset, op in (
- ('_limit', 'limit()', None, operator.is_),
- ('_offset', 'offset()', None, operator.is_),
- ('_order_by', 'order_by()', False, operator.is_),
- ('_group_by', 'group_by()', False, operator.is_),
- ('_distinct', 'distinct()', False, operator.is_),
+ ("_limit", "limit()", None, operator.is_),
+ ("_offset", "offset()", None, operator.is_),
+ ("_order_by", "order_by()", False, operator.is_),
+ ("_group_by", "group_by()", False, operator.is_),
+ ("_distinct", "distinct()", False, operator.is_),
(
- '_from_obj',
- 'join(), outerjoin(), select_from(), or from_self()',
- (), operator.eq)
+ "_from_obj",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
):
if not op(getattr(self.query, attr), notset):
raise sa_exc.InvalidRequestError(
"Can't call Query.update() or Query.delete() "
- "when %s has been called" %
- (methname, )
+ "when %s has been called" % (methname,)
)
@property
@@ -1330,8 +1609,8 @@ class BulkUD(object):
except KeyError:
raise sa_exc.ArgumentError(
"Valid strategies for session synchronization "
- "are %s" % (", ".join(sorted(repr(x)
- for x in lookup))))
+ "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
+ )
else:
return klass(*arg)
@@ -1400,9 +1679,9 @@ class BulkEvaluate(BulkUD):
try:
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
if query.whereclause is not None:
- eval_condition = evaluator_compiler.process(
- query.whereclause)
+ eval_condition = evaluator_compiler.process(query.whereclause)
else:
+
def eval_condition(obj):
return True
@@ -1411,15 +1690,20 @@ class BulkEvaluate(BulkUD):
except evaluator.UnevaluatableError as err:
raise sa_exc.InvalidRequestError(
'Could not evaluate current criteria in Python: "%s". '
- 'Specify \'fetch\' or False for the '
- 'synchronize_session parameter.' % err)
+ "Specify 'fetch' or False for the "
+ "synchronize_session parameter." % err
+ )
# TODO: detect when the where clause is a trivial primary key match
self.matched_objects = [
- obj for (cls, pk, identity_token), obj in
- query.session.identity_map.items()
- if issubclass(cls, target_cls) and
- eval_condition(obj)]
+ obj
+ for (
+ cls,
+ pk,
+ identity_token,
+ ), obj in query.session.identity_map.items()
+ if issubclass(cls, target_cls) and eval_condition(obj)
+ ]
class BulkFetch(BulkUD):
@@ -1430,11 +1714,11 @@ class BulkFetch(BulkUD):
session = query.session
context = query._compile_context()
select_stmt = context.statement.with_only_columns(
- self.primary_table.primary_key)
+ self.primary_table.primary_key
+ )
self.matched_rows = session.execute(
- select_stmt,
- mapper=self.mapper,
- params=query._params).fetchall()
+ select_stmt, mapper=self.mapper, params=query._params
+ ).fetchall()
class BulkUpdate(BulkUD):
@@ -1447,18 +1731,26 @@ class BulkUpdate(BulkUD):
@classmethod
def factory(cls, query, synchronize_session, values, update_kwargs):
- return BulkUD._factory({
- "evaluate": BulkUpdateEvaluate,
- "fetch": BulkUpdateFetch,
- False: BulkUpdate
- }, synchronize_session, query, values, update_kwargs)
+ return BulkUD._factory(
+ {
+ "evaluate": BulkUpdateEvaluate,
+ "fetch": BulkUpdateFetch,
+ False: BulkUpdate,
+ },
+ synchronize_session,
+ query,
+ values,
+ update_kwargs,
+ )
@property
def _resolved_values(self):
values = []
for k, v in (
- self.values.items() if hasattr(self.values, 'items')
- else self.values):
+ self.values.items()
+ if hasattr(self.values, "items")
+ else self.values
+ ):
if self.mapper:
if isinstance(k, util.string_types):
desc = _entity_descriptor(self.mapper, k)
@@ -1478,7 +1770,7 @@ class BulkUpdate(BulkUD):
if isinstance(k, attributes.QueryableAttribute):
values.append((k.key, v))
continue
- elif hasattr(k, '__clause_element__'):
+ elif hasattr(k, "__clause_element__"):
k = k.__clause_element__()
if self.mapper and isinstance(k, expression.ColumnElement):
@@ -1490,18 +1782,22 @@ class BulkUpdate(BulkUD):
values.append((attr.key, v))
else:
raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k)
+ "Invalid expression type: %r" % k
+ )
return values
def _do_exec(self):
values = self._resolved_values
- if not self.update_kwargs.get('preserve_parameter_order', False):
+ if not self.update_kwargs.get("preserve_parameter_order", False):
values = dict(values)
- update_stmt = sql.update(self.primary_table,
- self.context.whereclause, values,
- **self.update_kwargs)
+ update_stmt = sql.update(
+ self.primary_table,
+ self.context.whereclause,
+ values,
+ **self.update_kwargs
+ )
self._execute_stmt(update_stmt)
@@ -1518,15 +1814,18 @@ class BulkDelete(BulkUD):
@classmethod
def factory(cls, query, synchronize_session):
- return BulkUD._factory({
- "evaluate": BulkDeleteEvaluate,
- "fetch": BulkDeleteFetch,
- False: BulkDelete
- }, synchronize_session, query)
+ return BulkUD._factory(
+ {
+ "evaluate": BulkDeleteEvaluate,
+ "fetch": BulkDeleteFetch,
+ False: BulkDelete,
+ },
+ synchronize_session,
+ query,
+ )
def _do_exec(self):
- delete_stmt = sql.delete(self.primary_table,
- self.context.whereclause)
+ delete_stmt = sql.delete(self.primary_table, self.context.whereclause)
self._execute_stmt(delete_stmt)
@@ -1544,32 +1843,33 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
values = self._resolved_values_keys_as_propnames
for key, value in values:
self.value_evaluators[key] = evaluator_compiler.process(
- expression._literal_as_binds(value))
+ expression._literal_as_binds(value)
+ )
def _do_post_synchronize(self):
session = self.query.session
states = set()
evaluated_keys = list(self.value_evaluators.keys())
for obj in self.matched_objects:
- state, dict_ = attributes.instance_state(obj),\
- attributes.instance_dict(obj)
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
# only evaluate unmodified attributes
- to_evaluate = state.unmodified.intersection(
- evaluated_keys)
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
dict_[key] = self.value_evaluators[key](obj)
- state.manager.dispatch.refresh(
- state, None, to_evaluate)
+ state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
# expire attributes with pending changes
# (there was no autoflush, so they are overwritten)
- state._expire_attributes(dict_,
- set(evaluated_keys).
- difference(to_evaluate))
+ state._expire_attributes(
+ dict_, set(evaluated_keys).difference(to_evaluate)
+ )
states.add(state)
session._register_altered(states)
@@ -1580,8 +1880,8 @@ class BulkDeleteEvaluate(BulkEvaluate, BulkDelete):
def _do_post_synchronize(self):
self.query.session._remove_newly_deleted(
- [attributes.instance_state(obj)
- for obj in self.matched_objects])
+ [attributes.instance_state(obj) for obj in self.matched_objects]
+ )
class BulkUpdateFetch(BulkFetch, BulkUpdate):
@@ -1592,15 +1892,18 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate):
session = self.query.session
target_mapper = self.query._mapper_zero()
- states = set([
- attributes.instance_state(session.identity_map[identity_key])
- for identity_key in [
- target_mapper.identity_key_from_primary_key(
- list(primary_key))
- for primary_key in self.matched_rows
+ states = set(
+ [
+ attributes.instance_state(session.identity_map[identity_key])
+ for identity_key in [
+ target_mapper.identity_key_from_primary_key(
+ list(primary_key)
+ )
+ for primary_key in self.matched_rows
+ ]
+ if identity_key in session.identity_map
]
- if identity_key in session.identity_map
- ])
+ )
values = self._resolved_values_keys_as_propnames
attrib = set(k for k, v in values)
@@ -1622,10 +1925,13 @@ class BulkDeleteFetch(BulkFetch, BulkDelete):
# TODO: inline this and call remove_newly_deleted
# once
identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key))
+ list(primary_key)
+ )
if identity_key in session.identity_map:
session._remove_newly_deleted(
- [attributes.instance_state(
- session.identity_map[identity_key]
- )]
+ [
+ attributes.instance_state(
+ session.identity_map[identity_key]
+ )
+ ]
)