diff options
Diffstat (limited to 'lib/sqlalchemy/orm/unitofwork.py')
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 264 |
1 files changed, 143 insertions, 121 deletions
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index a83a99d78..545811bb4 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -41,9 +41,11 @@ def track_cascade_events(descriptor, prop): prop = state.manager.mapper._props[key] item_state = attributes.instance_state(item) - if prop._cascade.save_update and \ - (prop.cascade_backrefs or key == initiator.key) and \ - not sess._contains_state(item_state): + if ( + prop._cascade.save_update + and (prop.cascade_backrefs or key == initiator.key) + and not sess._contains_state(item_state) + ): sess._save_or_update_state(item_state) return item @@ -59,12 +61,15 @@ def track_cascade_events(descriptor, prop): sess._flush_warning( "collection remove" if prop.uselist - else "related attribute delete") + else "related attribute delete" + ) - if item is not None and \ - item is not attributes.NEVER_SET and \ - item is not attributes.PASSIVE_NO_RESULT and \ - prop._cascade.delete_orphan: + if ( + item is not None + and item is not attributes.NEVER_SET + and item is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): # expunge pending orphans item_state = attributes.instance_state(item) @@ -93,26 +98,31 @@ def track_cascade_events(descriptor, prop): prop = state.manager.mapper._props[key] if newvalue is not None: newvalue_state = attributes.instance_state(newvalue) - if prop._cascade.save_update and \ - (prop.cascade_backrefs or key == initiator.key) and \ - not sess._contains_state(newvalue_state): + if ( + prop._cascade.save_update + and (prop.cascade_backrefs or key == initiator.key) + and not sess._contains_state(newvalue_state) + ): sess._save_or_update_state(newvalue_state) - if oldvalue is not None and \ - oldvalue is not attributes.NEVER_SET and \ - oldvalue is not attributes.PASSIVE_NO_RESULT and \ - prop._cascade.delete_orphan: + if ( + oldvalue is not None + and oldvalue is not attributes.NEVER_SET + and oldvalue is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): # possible to reach here with attributes.NEVER_SET ? oldvalue_state = attributes.instance_state(oldvalue) - if oldvalue_state in sess._new and \ - prop.mapper._is_orphan(oldvalue_state): + if oldvalue_state in sess._new and prop.mapper._is_orphan( + oldvalue_state + ): sess.expunge(oldvalue) return newvalue - event.listen(descriptor, 'append', append, raw=True, retval=True) - event.listen(descriptor, 'remove', remove, raw=True, retval=True) - event.listen(descriptor, 'set', set_, raw=True, retval=True) + event.listen(descriptor, "append", append, raw=True, retval=True) + event.listen(descriptor, "remove", remove, raw=True, retval=True) + event.listen(descriptor, "set", set_, raw=True, retval=True) class UOWTransaction(object): @@ -197,8 +207,9 @@ class UOWTransaction(object): self.states[state] = (isdelete, True) - def get_attribute_history(self, state, key, - passive=attributes.PASSIVE_NO_INITIALIZE): + def get_attribute_history( + self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE + ): """facade to attributes.get_state_history(), including caching of results.""" @@ -213,12 +224,16 @@ class UOWTransaction(object): # if the cached lookup was "passive" and now # we want non-passive, do a non-passive lookup and re-cache - if not cached_passive & attributes.SQL_OK \ - and passive & attributes.SQL_OK: + if ( + not cached_passive & attributes.SQL_OK + and passive & attributes.SQL_OK + ): impl = state.manager[key].impl - history = impl.get_history(state, state.dict, - attributes.PASSIVE_OFF | - attributes.LOAD_AGAINST_COMMITTED) + history = impl.get_history( + state, + state.dict, + attributes.PASSIVE_OFF | attributes.LOAD_AGAINST_COMMITTED, + ) if history and impl.uses_objects: state_history = history.as_state() else: @@ -228,14 +243,14 @@ class UOWTransaction(object): impl = state.manager[key].impl # TODO: store the history as (state, object) tuples # so we don't have to keep converting here - history = impl.get_history(state, state.dict, passive | - attributes.LOAD_AGAINST_COMMITTED) + history = impl.get_history( + state, state.dict, passive | attributes.LOAD_AGAINST_COMMITTED + ) if history and impl.uses_objects: state_history = history.as_state() else: state_history = history - self.attributes[hashkey] = (history, state_history, - passive) + self.attributes[hashkey] = (history, state_history, passive) return state_history @@ -247,17 +262,25 @@ class UOWTransaction(object): if key not in self.presort_actions: self.presort_actions[key] = Preprocess(processor, fromparent) - def register_object(self, state, isdelete=False, - listonly=False, cancel_delete=False, - operation=None, prop=None): + def register_object( + self, + state, + isdelete=False, + listonly=False, + cancel_delete=False, + operation=None, + prop=None, + ): if not self.session._contains_state(state): # this condition is normal when objects are registered # as part of a relationship cascade operation. it should # not occur for the top-level register from Session.flush(). if not state.deleted and operation is not None: - util.warn("Object of type %s not in session, %s operation " - "along '%s' will not proceed" % - (orm_util.state_class_str(state), operation, prop)) + util.warn( + "Object of type %s not in session, %s operation " + "along '%s' will not proceed" + % (orm_util.state_class_str(state), operation, prop) + ) return False if state not in self.states: @@ -340,24 +363,26 @@ class UOWTransaction(object): # see if the graph of mapper dependencies has cycles. self.cycles = cycles = topological.find_cycles( - self.dependencies, - list(self.postsort_actions.values())) + self.dependencies, list(self.postsort_actions.values()) + ) if cycles: # if yes, break the per-mapper actions into # per-state actions convert = dict( - (rec, set(rec.per_state_flush_actions(self))) - for rec in cycles + (rec, set(rec.per_state_flush_actions(self))) for rec in cycles ) # rewrite the existing dependencies to point to # the per-state actions for those per-mapper actions # that were broken up. for edge in list(self.dependencies): - if None in edge or \ - edge[0].disabled or edge[1].disabled or \ - cycles.issuperset(edge): + if ( + None in edge + or edge[0].disabled + or edge[1].disabled + or cycles.issuperset(edge) + ): self.dependencies.remove(edge) elif edge[0] in cycles: self.dependencies.remove(edge) @@ -368,10 +393,9 @@ class UOWTransaction(object): for dep in convert[edge[1]]: self.dependencies.add((edge[0], dep)) - return set([a for a in self.postsort_actions.values() - if not a.disabled - ] - ).difference(cycles) + return set( + [a for a in self.postsort_actions.values() if not a.disabled] + ).difference(cycles) def execute(self): postsort_actions = self._generate_actions() @@ -386,15 +410,13 @@ class UOWTransaction(object): # execute if self.cycles: for set_ in topological.sort_as_subsets( - self.dependencies, - postsort_actions): + self.dependencies, postsort_actions + ): while set_: n = set_.pop() n.execute_aggregate(self, set_) else: - for rec in topological.sort( - self.dependencies, - postsort_actions): + for rec in topological.sort(self.dependencies, postsort_actions): rec.execute(self) def finalize_flush_changes(self): @@ -410,8 +432,7 @@ class UOWTransaction(object): states = set(self.states) isdel = set( - s for (s, (isdelete, listonly)) in self.states.items() - if isdelete + s for (s, (isdelete, listonly)) in self.states.items() if isdelete ) other = states.difference(isdel) if isdel: @@ -424,8 +445,8 @@ class IterateMappersMixin(object): def _mappers(self, uow): if self.fromparent: return iter( - m for m in - self.dependency_processor.parent.self_and_descendants + m + for m in self.dependency_processor.parent.self_and_descendants if uow._mapper_for_dep[(m, self.dependency_processor)] ) else: @@ -434,8 +455,10 @@ class IterateMappersMixin(object): class Preprocess(IterateMappersMixin): __slots__ = ( - 'dependency_processor', 'fromparent', 'processed', - 'setup_flush_actions' + "dependency_processor", + "fromparent", + "processed", + "setup_flush_actions", ) def __init__(self, dependency_processor, fromparent): @@ -464,12 +487,14 @@ class Preprocess(IterateMappersMixin): self.dependency_processor.presort_saves(uow, save_states) self.processed.update(save_states) - if (delete_states or save_states): + if delete_states or save_states: if not self.setup_flush_actions and ( - self.dependency_processor. - prop_has_changes(uow, delete_states, True) or - self.dependency_processor. - prop_has_changes(uow, save_states, False) + self.dependency_processor.prop_has_changes( + uow, delete_states, True + ) + or self.dependency_processor.prop_has_changes( + uow, save_states, False + ) ): self.dependency_processor.per_property_flush_actions(uow) self.setup_flush_actions = True @@ -479,16 +504,14 @@ class Preprocess(IterateMappersMixin): class PostSortRec(object): - __slots__ = 'disabled', + __slots__ = ("disabled",) def __new__(cls, uow, *args): - key = (cls, ) + args + key = (cls,) + args if key in uow.postsort_actions: return uow.postsort_actions[key] else: - uow.postsort_actions[key] = \ - ret = \ - object.__new__(cls) + uow.postsort_actions[key] = ret = object.__new__(cls) ret.disabled = False return ret @@ -497,14 +520,15 @@ class PostSortRec(object): class ProcessAll(IterateMappersMixin, PostSortRec): - __slots__ = 'dependency_processor', 'isdelete', 'fromparent' + __slots__ = "dependency_processor", "isdelete", "fromparent" def __init__(self, uow, dependency_processor, isdelete, fromparent): self.dependency_processor = dependency_processor self.isdelete = isdelete self.fromparent = fromparent - uow.deps[dependency_processor.parent.base_mapper].\ - add(dependency_processor) + uow.deps[dependency_processor.parent.base_mapper].add( + dependency_processor + ) def execute(self, uow): states = self._elements(uow) @@ -524,7 +548,7 @@ class ProcessAll(IterateMappersMixin, PostSortRec): return "%s(%s, isdelete=%s)" % ( self.__class__.__name__, self.dependency_processor, - self.isdelete + self.isdelete, ) def _elements(self, uow): @@ -536,7 +560,7 @@ class ProcessAll(IterateMappersMixin, PostSortRec): class PostUpdateAll(PostSortRec): - __slots__ = 'mapper', 'isdelete' + __slots__ = "mapper", "isdelete" def __init__(self, uow, mapper, isdelete): self.mapper = mapper @@ -550,22 +574,23 @@ class PostUpdateAll(PostSortRec): class SaveUpdateAll(PostSortRec): - __slots__ = 'mapper', + __slots__ = ("mapper",) def __init__(self, uow, mapper): self.mapper = mapper assert mapper is mapper.base_mapper def execute(self, uow): - persistence.save_obj(self.mapper, - uow.states_for_mapper_hierarchy( - self.mapper, False, False), - uow - ) + persistence.save_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, False, False), + uow, + ) def per_state_flush_actions(self, uow): - states = list(uow.states_for_mapper_hierarchy( - self.mapper, False, False)) + states = list( + uow.states_for_mapper_hierarchy(self.mapper, False, False) + ) base_mapper = self.mapper.base_mapper delete_all = DeleteAll(uow, base_mapper) for state in states: @@ -580,29 +605,27 @@ class SaveUpdateAll(PostSortRec): dep.per_state_flush_actions(uow, states_for_prop, False) def __repr__(self): - return "%s(%s)" % ( - self.__class__.__name__, - self.mapper - ) + return "%s(%s)" % (self.__class__.__name__, self.mapper) class DeleteAll(PostSortRec): - __slots__ = 'mapper', + __slots__ = ("mapper",) def __init__(self, uow, mapper): self.mapper = mapper assert mapper is mapper.base_mapper def execute(self, uow): - persistence.delete_obj(self.mapper, - uow.states_for_mapper_hierarchy( - self.mapper, True, False), - uow - ) + persistence.delete_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, True, False), + uow, + ) def per_state_flush_actions(self, uow): - states = list(uow.states_for_mapper_hierarchy( - self.mapper, True, False)) + states = list( + uow.states_for_mapper_hierarchy(self.mapper, True, False) + ) base_mapper = self.mapper.base_mapper save_all = SaveUpdateAll(uow, base_mapper) for state in states: @@ -617,14 +640,11 @@ class DeleteAll(PostSortRec): dep.per_state_flush_actions(uow, states_for_prop, True) def __repr__(self): - return "%s(%s)" % ( - self.__class__.__name__, - self.mapper - ) + return "%s(%s)" % (self.__class__.__name__, self.mapper) class ProcessState(PostSortRec): - __slots__ = 'dependency_processor', 'isdelete', 'state' + __slots__ = "dependency_processor", "isdelete", "state" def __init__(self, uow, dependency_processor, isdelete, state): self.dependency_processor = dependency_processor @@ -635,10 +655,13 @@ class ProcessState(PostSortRec): cls_ = self.__class__ dependency_processor = self.dependency_processor isdelete = self.isdelete - our_recs = [r for r in recs - if r.__class__ is cls_ and - r.dependency_processor is dependency_processor and - r.isdelete is isdelete] + our_recs = [ + r + for r in recs + if r.__class__ is cls_ + and r.dependency_processor is dependency_processor + and r.isdelete is isdelete + ] recs.difference_update(our_recs) states = [self.state] + [r.state for r in our_recs] if isdelete: @@ -651,12 +674,12 @@ class ProcessState(PostSortRec): self.__class__.__name__, self.dependency_processor, orm_util.state_str(self.state), - self.isdelete + self.isdelete, ) class SaveUpdateState(PostSortRec): - __slots__ = 'state', 'mapper' + __slots__ = "state", "mapper" def __init__(self, uow, state): self.state = state @@ -665,24 +688,23 @@ class SaveUpdateState(PostSortRec): def execute_aggregate(self, uow, recs): cls_ = self.__class__ mapper = self.mapper - our_recs = [r for r in recs - if r.__class__ is cls_ and - r.mapper is mapper] + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] recs.difference_update(our_recs) - persistence.save_obj(mapper, - [self.state] + - [r.state for r in our_recs], - uow) + persistence.save_obj( + mapper, [self.state] + [r.state for r in our_recs], uow + ) def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, - orm_util.state_str(self.state) + orm_util.state_str(self.state), ) class DeleteState(PostSortRec): - __slots__ = 'state', 'mapper' + __slots__ = "state", "mapper" def __init__(self, uow, state): self.state = state @@ -691,17 +713,17 @@ class DeleteState(PostSortRec): def execute_aggregate(self, uow, recs): cls_ = self.__class__ mapper = self.mapper - our_recs = [r for r in recs - if r.__class__ is cls_ and - r.mapper is mapper] + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] recs.difference_update(our_recs) states = [self.state] + [r.state for r in our_recs] - persistence.delete_obj(mapper, - [s for s in states if uow.states[s][0]], - uow) + persistence.delete_obj( + mapper, [s for s in states if uow.states[s][0]], uow + ) def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, - orm_util.state_str(self.state) + orm_util.state_str(self.state), ) |
