diff options
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 672 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/collections.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/dependency.py | 39 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/dynamic.py | 46 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/identity.py | 89 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 144 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 33 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/state.py | 429 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 34 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 20 |
13 files changed, 887 insertions, 684 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 68aa0d93a..4fa41ff3b 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -20,14 +20,13 @@ import types import weakref from sqlalchemy import util -from sqlalchemy.util import EMPTY_SET from sqlalchemy.orm import interfaces, collections, exc import sqlalchemy.exceptions as sa_exc # lazy imports _entity_info = None identity_equal = None - +state = None PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT') ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') @@ -105,7 +104,7 @@ class QueryableAttribute(interfaces.PropComparator): self.parententity = parententity def get_history(self, instance, **kwargs): - return self.impl.get_history(instance_state(instance), **kwargs) + return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs) def __selectable__(self): # TODO: conditionally attach this method based on clause_element ? @@ -148,15 +147,15 @@ class InstrumentedAttribute(QueryableAttribute): """Public-facing descriptor, placed in the mapped class dictionary.""" def __set__(self, instance, value): - self.impl.set(instance_state(instance), value, None) + self.impl.set(instance_state(instance), instance_dict(instance), value, None) def __delete__(self, instance): - self.impl.delete(instance_state(instance)) + self.impl.delete(instance_state(instance), instance_dict(instance)) def __get__(self, instance, owner): if instance is None: return self - return self.impl.get(instance_state(instance)) + return self.impl.get(instance_state(instance), instance_dict(instance)) class _ProxyImpl(object): accepts_scalar_loader = False @@ -335,7 +334,7 @@ class AttributeImpl(object): else: state.callables[self.key] = callable_ - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): raise NotImplementedError() def _get_callable(self, state): @@ -346,13 +345,13 @@ class AttributeImpl(object): else: return None - def initialize(self, state): + def initialize(self, state, dict_): """Initialize this attribute on the given object instance with an empty value.""" - state.dict[self.key] = None + dict_[self.key] = None return None - def get(self, state, passive=PASSIVE_OFF): + def get(self, state, dict_, passive=PASSIVE_OFF): """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and @@ -361,7 +360,7 @@ class AttributeImpl(object): """ try: - return state.dict[self.key] + return dict_[self.key] except KeyError: # if no history, check for lazy callables, etc. if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET: @@ -374,25 +373,25 @@ class AttributeImpl(object): return PASSIVE_NORESULT value = callable_() if value is not ATTR_WAS_SET: - return self.set_committed_value(state, value) + return self.set_committed_value(state, dict_, value) else: - if self.key not in state.dict: + if self.key not in dict_: return self.get(state, passive=passive) - return state.dict[self.key] + return dict_[self.key] # Return a new, empty value - return self.initialize(state) + return self.initialize(state, dict_) - def append(self, state, value, initiator, passive=PASSIVE_OFF): - self.set(state, value, initiator) + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, value, initiator) - def remove(self, state, value, initiator, passive=PASSIVE_OFF): - self.set(state, None, initiator) + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, None, initiator) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): raise NotImplementedError() - def get_committed_value(self, state, passive=PASSIVE_OFF): + def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): """return the unchanged value of this attribute""" if self.key in state.committed_state: @@ -401,12 +400,12 @@ class AttributeImpl(object): else: return state.committed_state.get(self.key) else: - return self.get(state, passive=passive) + return self.get(state, dict_, passive=passive) - def set_committed_value(self, state, value): + def set_committed_value(self, state, dict_, value): """set an attribute value on the given instance and 'commit' it.""" - state.commit([self.key]) + state.commit(dict_, [self.key]) state.callables.pop(self.key, None) state.dict[self.key] = value @@ -419,45 +418,45 @@ class ScalarAttributeImpl(AttributeImpl): accepts_scalar_loader = True uses_objects = False - def delete(self, state): + def delete(self, state, dict_): # TODO: catch key errors, convert to attributeerror? if self.active_history or self.extensions: - old = self.get(state) + old = self.get(state, dict_) else: - old = state.dict.get(self.key, NO_VALUE) + old = dict_.get(self.key, NO_VALUE) - state.modified_event(self, False, old) + state.modified_event(dict_, self, False, old) if self.extensions: - self.fire_remove_event(state, old, None) - del state.dict[self.key] + self.fire_remove_event(state, dict_, old, None) + del dict_[self.key] - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): return History.from_attribute( - self, state, state.dict.get(self.key, NO_VALUE)) + self, state, dict_.get(self.key, NO_VALUE)) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): if initiator is self: return if self.active_history or self.extensions: - old = self.get(state) + old = self.get(state, dict_) else: - old = state.dict.get(self.key, NO_VALUE) + old = dict_.get(self.key, NO_VALUE) - state.modified_event(self, False, old) + state.modified_event(dict_, self, False, old) if self.extensions: - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value - def fire_replace_event(self, state, value, previous, initiator): + def fire_replace_event(self, state, dict_, value, previous, initiator): for ext in self.extensions: value = ext.set(state, value, previous, initiator or self) return value - def fire_remove_event(self, state, value, initiator): + def fire_remove_event(self, state, dict_, value, initiator): for ext in self.extensions: ext.remove(state, value, initiator or self) @@ -483,29 +482,48 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl): raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function") self.copy = copy_function - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if not dict_: + v = state.committed_state.get(self.key, NO_VALUE) + else: + v = dict_.get(self.key, NO_VALUE) + return History.from_attribute( - self, state, state.dict.get(self.key, NO_VALUE)) + self, state, v) - def commit_to_state(self, state, dest): - dest[self.key] = self.copy(state.dict[self.key]) + def commit_to_state(self, state, dict_, dest): + dest[self.key] = self.copy(dict_[self.key]) - def check_mutable_modified(self, state): - (added, unchanged, deleted) = self.get_history(state, passive=PASSIVE_NO_INITIALIZE) + def check_mutable_modified(self, state, dict_): + (added, unchanged, deleted) = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE) return bool(added or deleted) - def set(self, state, value, initiator): + def get(self, state, dict_, passive=PASSIVE_OFF): + if self.key not in state.mutable_dict: + ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive) + if ret is not PASSIVE_NORESULT: + state.mutable_dict[self.key] = ret + return ret + else: + return state.mutable_dict[self.key] + + def delete(self, state, dict_): + ScalarAttributeImpl.delete(self, state, dict_) + state.mutable_dict.pop(self.key) + + def set(self, state, dict_, value, initiator): if initiator is self: return - state.modified_event(self, True, NEVER_SET) - + state.modified_event(dict_, self, True, NEVER_SET) + if self.extensions: - old = self.get(state) - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + old = self.get(state, dict_) + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value else: - state.dict[self.key] = value + dict_[self.key] = value + state.mutable_dict[self.key] = value class ScalarObjectAttributeImpl(ScalarAttributeImpl): @@ -526,22 +544,22 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): if compare_function is None: self.is_equal = identity_equal - def delete(self, state): - old = self.get(state) - self.fire_remove_event(state, old, self) - del state.dict[self.key] + def delete(self, state, dict_): + old = self.get(state, dict_) + self.fire_remove_event(state, dict_, old, self) + del dict_[self.key] - def get_history(self, state, passive=PASSIVE_OFF): - if self.key in state.dict: - return History.from_attribute(self, state, state.dict[self.key]) + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return History.from_attribute(self, state, dict_[self.key]) else: - current = self.get(state, passive=passive) + current = self.get(state, dict_, passive=passive) if current is PASSIVE_NORESULT: return HISTORY_BLANK else: return History.from_attribute(self, state, current) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): """Set a value on the given InstanceState. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -553,12 +571,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): return # may want to add options to allow the get() here to be passive - old = self.get(state) - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + old = self.get(state, dict_) + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value - def fire_remove_event(self, state, value, initiator): - state.modified_event(self, False, value) + def fire_remove_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, False, value) if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) @@ -566,8 +584,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def fire_replace_event(self, state, value, previous, initiator): - state.modified_event(self, False, previous) + def fire_replace_event(self, state, dict_, value, previous, initiator): + state.modified_event(dict_, self, False, previous) if self.trackparent: if previous is not value and previous is not None: @@ -615,15 +633,15 @@ class CollectionAttributeImpl(AttributeImpl): def __copy(self, item): return [y for y in list(collections.collection_adapter(item))] - def get_history(self, state, passive=PASSIVE_OFF): - current = self.get(state, passive=passive) + def get_history(self, state, dict_, passive=PASSIVE_OFF): + current = self.get(state, dict_, passive=passive) if current is PASSIVE_NORESULT: return HISTORY_BLANK else: return History.from_attribute(self, state, current) - def fire_append_event(self, state, value, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_append_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) for ext in self.extensions: value = ext.append(state, value, initiator or self) @@ -633,11 +651,11 @@ class CollectionAttributeImpl(AttributeImpl): return value - def fire_pre_remove_event(self, state, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_pre_remove_event(self, state, dict_, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) - def fire_remove_event(self, state, value, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_remove_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) @@ -645,51 +663,51 @@ class CollectionAttributeImpl(AttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def delete(self, state): - if self.key not in state.dict: + def delete(self, state, dict_): + if self.key not in dict_: return - state.modified_event(self, True, NEVER_SET) + state.modified_event(dict_, self, True, NEVER_SET) - collection = self.get_collection(state) + collection = self.get_collection(state, state.dict) collection.clear_with_event() # TODO: catch key errors, convert to attributeerror? - del state.dict[self.key] + del dict_[self.key] - def initialize(self, state): + def initialize(self, state, dict_): """Initialize this attribute with an empty collection.""" _, user_data = self._initialize_collection(state) - state.dict[self.key] = user_data + dict_[self.key] = user_data return user_data def _initialize_collection(self, state): return state.manager.initialize_collection( self.key, state, self.collection_factory) - def append(self, state, value, initiator, passive=PASSIVE_OFF): + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return - collection = self.get_collection(state, passive=passive) + collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NORESULT: - value = self.fire_append_event(state, value, initiator) + value = self.fire_append_event(state, dict_, value, initiator) state.get_pending(self.key).append(value) else: collection.append_with_event(value, initiator) - def remove(self, state, value, initiator, passive=PASSIVE_OFF): + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return - collection = self.get_collection(state, passive=passive) + collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NORESULT: - self.fire_remove_event(state, value, initiator) + self.fire_remove_event(state, dict_, value, initiator) state.get_pending(self.key).remove(value) else: collection.remove_with_event(value, initiator) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): """Set a value on the given object. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -701,10 +719,10 @@ class CollectionAttributeImpl(AttributeImpl): return self._set_iterable( - state, value, + state, dict_, value, lambda adapter, i: adapter.adapt_like_to_iterable(i)) - def _set_iterable(self, state, iterable, adapter=None): + def _set_iterable(self, state, dict_, iterable, adapter=None): """Set a collection value from an iterable of state-bearers. ``adapter`` is an optional callable invoked with a CollectionAdapter @@ -722,24 +740,24 @@ class CollectionAttributeImpl(AttributeImpl): else: new_values = list(iterable) - old = self.get(state) + old = self.get(state, dict_) # ignore re-assignment of the current collection, as happens # implicitly with in-place operators (foo.collection |= other) if old is iterable: return - state.modified_event(self, True, old) + state.modified_event(dict_, self, True, old) - old_collection = self.get_collection(state, old) + old_collection = self.get_collection(state, dict_, old) - state.dict[self.key] = user_data + dict_[self.key] = user_data collections.bulk_replace(new_values, old_collection, new_collection) old_collection.unlink(old) - def set_committed_value(self, state, value): + def set_committed_value(self, state, dict_, value): """Set an attribute value on the given instance and 'commit' it.""" collection, user_data = self._initialize_collection(state) @@ -751,13 +769,13 @@ class CollectionAttributeImpl(AttributeImpl): state.callables.pop(self.key, None) state.dict[self.key] = user_data - state.commit([self.key]) + state.commit(dict_, [self.key]) if self.key in state.pending: # pending items exist. issue a modified event, # add/remove new items. - state.modified_event(self, True, user_data) + state.modified_event(dict_, self, True, user_data) pending = state.pending.pop(self.key) added = pending.added_items @@ -769,14 +787,14 @@ class CollectionAttributeImpl(AttributeImpl): return user_data - def get_collection(self, state, user_data=None, passive=PASSIVE_OFF): + def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF): """Retrieve the CollectionAdapter associated with the given state. Creates a new CollectionAdapter if one does not exist. """ if user_data is None: - user_data = self.get(state, passive=passive) + user_data = self.get(state, dict_, passive=passive) if user_data is PASSIVE_NORESULT: return user_data @@ -799,320 +817,26 @@ class GenericBackrefExtension(interfaces.AttributeExtension): if oldchild is not None: # With lazy=None, there's no guarantee that the full collection is # present when updating via a backref. - old_state = instance_state(oldchild) + old_state, old_dict = instance_state(oldchild), instance_dict(oldchild) impl = old_state.get_impl(self.key) try: - impl.remove(old_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + impl.remove(old_state, old_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) except (ValueError, KeyError, IndexError): pass if child is not None: - new_state = instance_state(child) - new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + new_state, new_dict = instance_state(child), instance_dict(child) + new_state.get_impl(self.key).append(new_state, new_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) return child def append(self, state, child, initiator): - child_state = instance_state(child) - child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).append(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) return child def remove(self, state, child, initiator): if child is not None: - child_state = instance_state(child) - child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) - - -class InstanceState(object): - """tracks state information at the instance level.""" - - session_id = None - key = None - runid = None - expired_attributes = EMPTY_SET - load_options = EMPTY_SET - load_path = () - insert_order = None - - def __init__(self, obj, manager): - self.class_ = obj.__class__ - self.manager = manager - self.obj = weakref.ref(obj, self._cleanup) - self.dict = obj.__dict__ - self.modified = False - self.callables = {} - self.expired = False - self.committed_state = {} - self.pending = {} - self.parents = {} - - def detach(self): - if self.session_id: - del self.session_id - - def dispose(self): - if self.session_id: - del self.session_id - del self.obj - del self.dict - - def _cleanup(self, ref): - self.dispose() - - def obj(self): - return None - - @util.memoized_property - def dict(self): - # return a blank dict - # if none is available, so that asynchronous gc - # doesn't blow up expiration operations in progress - # (usually expire_attributes) - return {} - - @property - def sort_key(self): - return self.key and self.key[1] or (self.insert_order, ) - - def check_modified(self): - if self.modified: - return True - else: - for key in self.manager.mutable_attributes: - if self.manager[key].impl.check_mutable_modified(self): - return True - else: - return False - - def initialize_instance(*mixed, **kwargs): - self, instance, args = mixed[0], mixed[1], mixed[2:] - manager = self.manager - - for fn in manager.events.on_init: - fn(self, instance, args, kwargs) - try: - return manager.events.original_init(*mixed[1:], **kwargs) - except: - for fn in manager.events.on_init_failure: - fn(self, instance, args, kwargs) - raise - - def get_history(self, key, **kwargs): - return self.manager.get_impl(key).get_history(self, **kwargs) - - def get_impl(self, key): - return self.manager.get_impl(key) - - def get_pending(self, key): - if key not in self.pending: - self.pending[key] = PendingCollection() - return self.pending[key] - - def value_as_iterable(self, key, passive=PASSIVE_OFF): - """return an InstanceState attribute as a list, - regardless of it being a scalar or collection-based - attribute. - - returns None if passive is not PASSIVE_OFF and the getter returns - PASSIVE_NORESULT. - """ - - impl = self.get_impl(key) - x = impl.get(self, passive=passive) - if x is PASSIVE_NORESULT: - - return None - elif hasattr(impl, 'get_collection'): - return impl.get_collection(self, x, passive=passive) - elif isinstance(x, list): - return x - else: - return [x] - - def _run_on_load(self, instance=None): - if instance is None: - instance = self.obj() - self.manager.events.run('on_load', instance) - - def __getstate__(self): - return {'key': self.key, - 'committed_state': self.committed_state, - 'pending': self.pending, - 'parents': self.parents, - 'modified': self.modified, - 'expired':self.expired, - 'load_options':self.load_options, - 'load_path':interfaces.serialize_path(self.load_path), - 'instance': self.obj(), - 'expired_attributes':self.expired_attributes, - 'callables': self.callables} - - def __setstate__(self, state): - self.committed_state = state['committed_state'] - self.parents = state['parents'] - self.key = state['key'] - self.session_id = None - self.pending = state['pending'] - self.modified = state['modified'] - self.obj = weakref.ref(state['instance']) - self.load_options = state['load_options'] or EMPTY_SET - self.load_path = interfaces.deserialize_path(state['load_path']) - self.class_ = self.obj().__class__ - self.manager = manager_of_class(self.class_) - self.dict = self.obj().__dict__ - self.callables = state['callables'] - self.runid = None - self.expired = state['expired'] - self.expired_attributes = state['expired_attributes'] - - def initialize(self, key): - self.manager.get_impl(key).initialize(self) - - def set_callable(self, key, callable_): - self.dict.pop(key, None) - self.callables[key] = callable_ - - def __call__(self): - """__call__ allows the InstanceState to act as a deferred - callable for loading expired attributes, which is also - serializable (picklable). - - """ - unmodified = self.unmodified - class_manager = self.manager - class_manager.deferred_scalar_loader(self, [ - attr.impl.key for attr in class_manager.attributes if - attr.impl.accepts_scalar_loader and - attr.impl.key in self.expired_attributes and - attr.impl.key in unmodified - ]) - for k in self.expired_attributes: - self.callables.pop(k, None) - del self.expired_attributes - return ATTR_WAS_SET - - @property - def unmodified(self): - """a set of keys which have no uncommitted changes""" - - return set( - key for key in self.manager.iterkeys() - if (key not in self.committed_state or - (key in self.manager.mutable_attributes and - not self.manager[key].impl.check_mutable_modified(self)))) - - @property - def unloaded(self): - """a set of keys which do not have a loaded value. - - This includes expired attributes and any other attribute that - was never populated or modified. - - """ - return set( - key for key in self.manager.iterkeys() - if key not in self.committed_state and key not in self.dict) - - def expire_attributes(self, attribute_names): - self.expired_attributes = set(self.expired_attributes) - - if attribute_names is None: - attribute_names = self.manager.keys() - self.expired = True - self.modified = False - filter_deferred = True - else: - filter_deferred = False - for key in attribute_names: - impl = self.manager[key].impl - if not filter_deferred or \ - not impl.dont_expire_missing or \ - key in self.dict: - self.expired_attributes.add(key) - if impl.accepts_scalar_loader: - self.callables[key] = self - self.dict.pop(key, None) - self.pending.pop(key, None) - self.committed_state.pop(key, None) - - def reset(self, key): - """remove the given attribute and any callables associated with it.""" - - self.dict.pop(key, None) - self.callables.pop(key, None) - - def modified_event(self, attr, should_copy, previous, passive=PASSIVE_OFF): - needs_committed = attr.key not in self.committed_state - - if needs_committed: - if previous is NEVER_SET: - if passive: - if attr.key in self.dict: - previous = self.dict[attr.key] - else: - previous = attr.get(self) - - if should_copy and previous not in (None, NO_VALUE, NEVER_SET): - previous = attr.copy(previous) - - if needs_committed: - self.committed_state[attr.key] = previous - - self.modified = True - - def commit(self, keys): - """Commit attributes. - - This is used by a partial-attribute load operation to mark committed - those attributes which were refreshed from the database. - - Attributes marked as "expired" can potentially remain "expired" after - this step if a value was not populated in state.dict. - - """ - class_manager = self.manager - for key in keys: - if key in self.dict and key in class_manager.mutable_attributes: - class_manager[key].impl.commit_to_state(self, self.committed_state) - else: - self.committed_state.pop(key, None) - - self.expired = False - # unexpire attributes which have loaded - for key in self.expired_attributes.intersection(keys): - if key in self.dict: - self.expired_attributes.remove(key) - self.callables.pop(key, None) - - def commit_all(self): - """commit all attributes unconditionally. - - This is used after a flush() or a full load/refresh - to remove all pending state from the instance. - - - all attributes are marked as "committed" - - the "strong dirty reference" is removed - - the "modified" flag is set to False - - any "expired" markers/callables are removed. - - Attributes marked as "expired" can potentially remain "expired" after this step - if a value was not populated in state.dict. - - """ - - self.committed_state = {} - self.pending = {} - - # unexpire attributes which have loaded - if self.expired_attributes: - for key in self.expired_attributes.intersection(self.dict): - self.callables.pop(key, None) - self.expired_attributes.difference_update(self.dict) - - for key in self.manager.mutable_attributes: - if key in self.dict: - self.manager[key].impl.commit_to_state(self, self.committed_state) - - self.modified = self.expired = False - self._strong_obj = None + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).remove(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) class Events(object): @@ -1121,6 +845,7 @@ class Events(object): self.on_init = () self.on_init_failure = () self.on_load = () + self.on_resurrect = () def run(self, event, *args, **kwargs): for fn in getattr(self, event): @@ -1146,7 +871,6 @@ class ClassManager(dict): STATE_ATTR = '_sa_instance_state' event_registry_factory = Events - instance_state_factory = InstanceState deferred_scalar_loader = None def __init__(self, class_): @@ -1170,7 +894,6 @@ class ClassManager(dict): def _configure_create_arguments(self, _source=None, - instance_state_factory=None, deferred_scalar_loader=None): """Accept extra **kw arguments passed to create_manager_for_cls. @@ -1185,11 +908,8 @@ class ClassManager(dict): """ if _source: - instance_state_factory = _source.instance_state_factory deferred_scalar_loader = _source.deferred_scalar_loader - if instance_state_factory: - self.instance_state_factory = instance_state_factory if deferred_scalar_loader: self.deferred_scalar_loader = deferred_scalar_loader @@ -1222,7 +942,16 @@ class ClassManager(dict): if self.new_init: self.uninstall_member('__init__') self.new_init = None - + + def _create_instance_state(self, instance): + global state + if state is None: + from sqlalchemy.orm import state + if self.mutable_attributes: + return state.MutableAttrInstanceState(instance, self) + else: + return state.InstanceState(instance, self) + def manage(self): """Mark this instance as the manager for its class.""" @@ -1330,11 +1059,11 @@ class ClassManager(dict): def new_instance(self, state=None): instance = self.class_.__new__(self.class_) - setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self)) + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) return instance def setup_instance(self, instance, state=None): - setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self)) + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) def teardown_instance(self, instance): delattr(instance, self.STATE_ATTR) @@ -1348,13 +1077,10 @@ class ClassManager(dict): if hasattr(instance, self.STATE_ATTR): return False else: - state = self.instance_state_factory(instance, self) + state = self._create_instance_state(instance) setattr(instance, self.STATE_ATTR, state) return state - def state_of(self, instance): - return getattr(instance, self.STATE_ATTR) - def state_getter(self): """Return a (instance) -> InstanceState callable. @@ -1365,6 +1091,9 @@ class ClassManager(dict): return attrgetter(self.STATE_ATTR) + def dict_getter(self): + return attrgetter('__dict__') + def has_state(self, instance): return hasattr(instance, self.STATE_ATTR) @@ -1385,6 +1114,9 @@ class _ClassInstrumentationAdapter(ClassManager): def __init__(self, class_, override, **kw): self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + ClassManager.__init__(self, class_, **kw) def manage(self): @@ -1446,36 +1178,27 @@ class _ClassInstrumentationAdapter(ClassManager): self._adapted.initialize_instance_dict(self.class_, instance) if state is None: - state = self.instance_state_factory(instance, self) + state = self._create_instance_state(instance) # the given instance is assumed to have no state self._adapted.install_state(self.class_, instance, state) - state.dict = self._adapted.get_instance_dict(self.class_, instance) return state def teardown_instance(self, instance): self._adapted.remove_state(self.class_, instance) - def state_of(self, instance): - if hasattr(self._adapted, 'state_of'): - return self._adapted.state_of(self.class_, instance) - else: - getter = self._adapted.state_getter(self.class_) - return getter(instance) - def has_state(self, instance): - if hasattr(self._adapted, 'has_state'): - return self._adapted.has_state(self.class_, instance) - else: - try: - state = self.state_of(instance) - return True - except exc.NO_STATE: - return False + try: + state = self._get_state(instance) + return True + except exc.NO_STATE: + return False def state_getter(self): - return self._adapted.state_getter(self.class_) + return self._get_state + def dict_getter(self): + return self._get_dict class History(tuple): """A 3-tuple of added, unchanged and deleted values. @@ -1520,7 +1243,7 @@ class History(tuple): original = state.committed_state.get(attribute.key, NEVER_SET) if hasattr(attribute, 'get_collection'): - current = attribute.get_collection(state, current) + current = attribute.get_collection(state, state.dict, current) if original is NO_VALUE: return cls(list(current), (), ()) elif original is NEVER_SET: @@ -1557,30 +1280,8 @@ class History(tuple): HISTORY_BLANK = History(None, None, None) -class PendingCollection(object): - """A writable placeholder for an unloaded collection. - - Stores items appended to and removed from a collection that has not yet - been loaded. When the collection is loaded, the changes stored in - PendingCollection are applied to it to produce the final result. - - """ - def __init__(self): - self.deleted_items = util.IdentitySet() - self.added_items = util.OrderedIdentitySet() - - def append(self, value): - if value in self.deleted_items: - self.deleted_items.remove(value) - self.added_items.add(value) - - def remove(self, value): - if value in self.added_items: - self.added_items.remove(value) - self.deleted_items.add(value) - def _conditional_instance_state(obj): - if not isinstance(obj, InstanceState): + if not isinstance(obj, state.InstanceState): obj = instance_state(obj) return obj @@ -1690,15 +1391,16 @@ def init_collection(obj, key): this usage is deprecated. """ - - return init_state_collection(_conditional_instance_state(obj), key) + state = _conditional_instance_state(obj) + dict_ = state.dict + return init_state_collection(state, dict_, key) -def init_state_collection(state, key): +def init_state_collection(state, dict_, key): """Initialize a collection attribute and return the collection adapter.""" attr = state.get_impl(key) - user_data = attr.initialize(state) - return attr.get_collection(state, user_data) + user_data = attr.initialize(state, dict_) + return attr.get_collection(state, dict_, user_data) def set_committed_value(instance, key, value): """Set the value of an attribute with no history events. @@ -1715,8 +1417,8 @@ def set_committed_value(instance, key, value): as though it were part of its original loaded state. """ - state = instance_state(instance) - state.get_impl(key).set_committed_value(instance, key, value) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set_committed_value(state, dict_, key, value) def set_attribute(instance, key, value): """Set the value of an attribute, firing history events. @@ -1728,8 +1430,8 @@ def set_attribute(instance, key, value): by SQLAlchemy. """ - state = instance_state(instance) - state.get_impl(key).set(state, value, None) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set(state, dict_, value, None) def get_attribute(instance, key): """Get the value of an attribute, firing any callables required. @@ -1741,8 +1443,8 @@ def get_attribute(instance, key): by SQLAlchemy. """ - state = instance_state(instance) - return state.get_impl(key).get(state) + state, dict_ = instance_state(instance), instance_dict(instance) + return state.get_impl(key).get(state, dict_) def del_attribute(instance, key): """Delete the value of an attribute, firing history events. @@ -1754,8 +1456,8 @@ def del_attribute(instance, key): by SQLAlchemy. """ - state = instance_state(instance) - state.get_impl(key).delete(state) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).delete(state, dict_) def is_instrumented(instance, key): """Return True if the given attribute on the given instance is instrumented @@ -1772,6 +1474,7 @@ class InstrumentationRegistry(object): _manager_finders = weakref.WeakKeyDictionary() _state_finders = util.WeakIdentityMapping() + _dict_finders = util.WeakIdentityMapping() _extended = False def create_manager_for_cls(self, class_, **kw): @@ -1806,6 +1509,7 @@ class InstrumentationRegistry(object): manager.factory = factory self._manager_finders[class_] = manager.manager_getter() self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() return manager def _collect_management_factories_for(self, cls): @@ -1845,6 +1549,7 @@ class InstrumentationRegistry(object): return finder(cls) def state_of(self, instance): + # this is only called when alternate instrumentation has been established if instance is None: raise AttributeError("None has no persistent state.") try: @@ -1852,21 +1557,15 @@ class InstrumentationRegistry(object): except KeyError: raise AttributeError("%r is not instrumented" % instance.__class__) - def state_or_default(self, instance, default=None): + def dict_of(self, instance): + # this is only called when alternate instrumentation has been established if instance is None: - return default + raise AttributeError("None has no persistent state.") try: - finder = self._state_finders[instance.__class__] + return self._dict_finders[instance.__class__](instance) except KeyError: - return default - else: - try: - return finder(instance) - except exc.NO_STATE: - return default - except: - raise - + raise AttributeError("%r is not instrumented" % instance.__class__) + def unregister(self, class_): if class_ in self._manager_finders: manager = self.manager_of_class(class_) @@ -1874,6 +1573,7 @@ class InstrumentationRegistry(object): manager.dispose() del self._manager_finders[class_] del self._state_finders[class_] + del self._dict_finders[class_] instrumentation_registry = InstrumentationRegistry() @@ -1887,12 +1587,14 @@ def _install_lookup_strategy(implementation): and unit tests specific to this behavior. """ - global instance_state + global instance_state, instance_dict if implementation is util.symbol('native'): instance_state = attrgetter(ClassManager.STATE_ATTR) + instance_dict = attrgetter("__dict__") else: instance_state = instrumentation_registry.state_of - + instance_dict = instrumentation_registry.dict_of + manager_of_class = instrumentation_registry.manager_of_class _create_manager_for_cls = instrumentation_registry.create_manager_for_cls _install_lookup_strategy(util.symbol('native')) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 5638a7e4a..4ca4c5719 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -472,6 +472,7 @@ class CollectionAdapter(object): """ def __init__(self, attr, owner_state, data): self.attr = attr + # TODO: figure out what this being a weakref buys us self._data = weakref.ref(data) self.owner_state = owner_state self.link_to_self(data) @@ -578,7 +579,7 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - return self.attr.fire_append_event(self.owner_state, item, initiator) + return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator) else: return item @@ -591,7 +592,7 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - self.attr.fire_remove_event(self.owner_state, item, initiator) + self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator) def fire_pre_remove_event(self, initiator=None): """Notify that an entity is about to be removed from the collection. @@ -600,7 +601,7 @@ class CollectionAdapter(object): fire_remove_event(). """ - self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator) + self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator) def __getstate__(self): return {'key': self.attr.key, diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index a80727b7f..151c557d7 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -64,17 +64,21 @@ class DependencyProcessor(object): def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on which, with regards to the two or three mappers handled by - this ``PropertyLoader``. + this ``DependencyProcessor``. - Also register itself as a *processor* for one of its mappers, - which will be executed after that mapper's objects have been - saved or before they've been deleted. The process operation - manages attributes and dependent operations upon the objects - of one of the involved mappers. """ raise NotImplementedError() + def register_processors(self, uowcommit): + """Tell a ``UOWTransaction`` about this object as a processor, + which will be executed after that mapper's objects have been + saved or before they've been deleted. The process operation + manages attributes and dependent operations between two mappers. + + """ + raise NotImplementedError() + def whose_dependent_on_who(self, state1, state2): """Given an object pair assuming `obj2` is a child of `obj1`, return a tuple with the dependent object second, or None if @@ -181,9 +185,13 @@ class OneToManyDP(DependencyProcessor): if self.post_update: uowcommit.register_dependency(self.mapper, self.dependency_marker) uowcommit.register_dependency(self.parent, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) else: uowcommit.register_dependency(self.parent, self.mapper) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: uowcommit.register_processor(self.parent, self, self.parent) def process_dependencies(self, task, deplist, uowcommit, delete = False): @@ -285,6 +293,9 @@ class DetectKeySwitch(DependencyProcessor): no_dependencies = True def register_dependencies(self, uowcommit): + pass + + def register_processors(self, uowcommit): uowcommit.register_processor(self.parent, self, self.mapper) def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): @@ -330,12 +341,15 @@ class ManyToOneDP(DependencyProcessor): if self.post_update: uowcommit.register_dependency(self.mapper, self.dependency_marker) uowcommit.register_dependency(self.parent, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) else: uowcommit.register_dependency(self.mapper, self.parent) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: uowcommit.register_processor(self.mapper, self, self.parent) - def process_dependencies(self, task, deplist, uowcommit, delete=False): if delete: if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all': @@ -408,8 +422,10 @@ class ManyToManyDP(DependencyProcessor): uowcommit.register_dependency(self.parent, self.dependency_marker) uowcommit.register_dependency(self.mapper, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) + def register_processors(self, uowcommit): + uowcommit.register_processor(self.dependency_marker, self, self.parent) + def process_dependencies(self, task, deplist, uowcommit, delete = False): connection = uowcommit.transaction.connection(self.mapper) secondary_delete = [] @@ -527,6 +543,9 @@ class MapperStub(object): def _register_dependencies(self, uowcommit): pass + def _register_procesors(self, uowcommit): + pass + def _save_obj(self, *args, **kwargs): pass diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 3d31a686a..70243291d 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -55,21 +55,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: self.query_class = mixin_user_query(query_class) - def get(self, state, passive=False): + def get(self, state, dict_, passive=False): if passive: return self._get_collection_history(state, passive=True).added_items else: return self.query_class(self, state) - def get_collection(self, state, user_data=None, passive=True): + def get_collection(self, state, dict_, user_data=None, passive=True): if passive: return self._get_collection_history(state, passive=passive).added_items else: history = self._get_collection_history(state, passive=passive) return history.added_items + history.unchanged_items - def fire_append_event(self, state, value, initiator): - collection_history = self._modified_event(state) + def fire_append_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) collection_history.added_items.append(value) for ext in self.extensions: @@ -78,8 +78,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl): if self.trackparent and value is not None: self.sethasparent(attributes.instance_state(value), True) - def fire_remove_event(self, state, value, initiator): - collection_history = self._modified_event(state) + def fire_remove_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) collection_history.deleted_items.append(value) if self.trackparent and value is not None: @@ -88,31 +88,31 @@ class DynamicAttributeImpl(attributes.AttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def _modified_event(self, state): + def _modified_event(self, state, dict_): if self.key not in state.committed_state: state.committed_state[self.key] = CollectionHistory(self, state) - state.modified_event(self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE) + state.modified_event(dict_, self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE) # this is a hack to allow the _base.ComparableEntity fixture # to work - state.dict[self.key] = True + dict_[self.key] = True return state.committed_state[self.key] - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): if initiator is self: return - self._set_iterable(state, value) + self._set_iterable(state, dict_, value) - def _set_iterable(self, state, iterable, adapter=None): + def _set_iterable(self, state, dict_, iterable, adapter=None): - collection_history = self._modified_event(state) + collection_history = self._modified_event(state, dict_) new_values = list(iterable) if _state_has_identity(state): - old_collection = list(self.get(state)) + old_collection = list(self.get(state, dict_)) else: old_collection = [] @@ -121,7 +121,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def delete(self, *args, **kwargs): raise NotImplementedError() - def get_history(self, state, passive=False): + def get_history(self, state, dict_, passive=False): c = self._get_collection_history(state, passive) return attributes.History(c.added_items, c.unchanged_items, c.deleted_items) @@ -136,13 +136,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: return c - def append(self, state, value, initiator, passive=False): + def append(self, state, dict_, value, initiator, passive=False): if initiator is not self: - self.fire_append_event(state, value, initiator) + self.fire_append_event(state, dict_, value, initiator) - def remove(self, state, value, initiator, passive=False): + def remove(self, state, dict_, value, initiator, passive=False): if initiator is not self: - self.fire_remove_event(state, value, initiator) + self.fire_remove_event(state, dict_, value, initiator) class DynCollectionAdapter(object): """the dynamic analogue to orm.collections.CollectionAdapter""" @@ -156,10 +156,10 @@ class DynCollectionAdapter(object): return iter(self.data) def append_with_event(self, item, initiator=None): - self.attr.append(self.state, item, initiator) + self.attr.append(self.state, self.state.dict, item, initiator) def remove_with_event(self, item, initiator=None): - self.attr.remove(self.state, item, initiator) + self.attr.remove(self.state, self.state.dict, item, initiator) def append_without_event(self, item): pass @@ -240,10 +240,10 @@ class AppenderMixin(object): return query def append(self, item): - self.attr.append(attributes.instance_state(self.instance), item, None) + self.attr.append(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) def remove(self, item): - self.attr.remove(attributes.instance_state(self.instance), item, None) + self.attr.remove(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) class AppenderQuery(AppenderMixin, Query): diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 0753ea991..aa041a585 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -15,6 +15,9 @@ class IdentityMap(dict): self._mutable_attrs = {} self.modified = False self._wr = weakref.ref(self) + + def replace(self, state): + raise NotImplementedError() def add(self, state): raise NotImplementedError() @@ -102,6 +105,17 @@ class WeakInstanceDict(IdentityMap): def contains_state(self, state): return dict.get(self, state.key) is state + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state) + self._manage_incoming_state(state) + def add(self, state): if state.key in self: if dict.__getitem__(self, state.key) is not state: @@ -161,12 +175,24 @@ class StrongInstanceDict(IdentityMap): def contains_state(self, state): return state.key in self and attributes.instance_state(self[state.key]) is state + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + existing = attributes.instance_state(existing) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + def add(self, state): dict.__setitem__(self, state.key, state.obj()) self._manage_incoming_state(state) def remove(self, state): - if dict.pop(self, state.key) is not state: + if attributes.instance_state(dict.pop(self, state.key)) is not state: raise AssertionError("State %s is not present in this identity map" % state) self._manage_removed_state(state) @@ -176,7 +202,7 @@ class StrongInstanceDict(IdentityMap): self._manage_removed_state(state) def remove_key(self, key): - state = dict.__getitem__(self, key) + state = attributes.instance_state(dict.__getitem__(self, key)) self.remove(state) def prune(self): @@ -190,62 +216,3 @@ class StrongInstanceDict(IdentityMap): self.modified = bool(dirty) return ref_count - len(self) -class IdentityManagedState(attributes.InstanceState): - def _instance_dict(self): - return None - - def modified_event(self, attr, should_copy, previous, passive=False): - attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive) - - instance_dict = self._instance_dict() - if instance_dict: - instance_dict.modified = True - - def _is_really_none(self): - """do a check modified/resurrect. - - This would be called in the extremely rare - race condition that the weakref returned None but - the cleanup handler had not yet established the - __resurrect callable as its replacement. - - """ - if self.check_modified(): - self.obj = self.__resurrect - return self.obj() - else: - return None - - def _cleanup(self, ref): - """weakref callback. - - This method may be called by an asynchronous - gc. - - If the state shows pending changes, the weakref - is replaced by the __resurrect callable which will - re-establish an object reference on next access, - else removes this InstanceState from the owning - identity map, if any. - - """ - if self.check_modified(): - self.obj = self.__resurrect - else: - instance_dict = self._instance_dict() - if instance_dict: - instance_dict.remove(self) - self.dispose() - - def __resurrect(self): - """A substitute for the obj() weakref function which resurrects.""" - - # store strong ref'ed version of the object; will revert - # to weakref when changes are persisted - obj = self.manager.new_instance(state=self) - self.obj = weakref.ref(obj, self._cleanup) - self._strong_obj = obj - obj.__dict__.update(self.dict) - self.dict = obj.__dict__ - self._run_on_load(obj) - return obj diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d36f51194..0ac771305 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -359,7 +359,7 @@ class MapperProperty(object): Callables are of the following form:: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # process incoming instance state and given row. the instance is # "new" and was just created upon receipt of this row. # flags is a dictionary containing at least the following @@ -368,7 +368,7 @@ class MapperProperty(object): # result of reading this row # instancekey - identity key of the instance - def existing_execute(state, row, **flags): + def existing_execute(state, dict_, row, **flags): # process incoming instance state and given row. the instance is # "existing" and was created based on a previous row. @@ -427,13 +427,23 @@ class MapperProperty(object): def register_dependencies(self, *args, **kwargs): """Called by the ``Mapper`` in response to the UnitOfWork calling the ``Mapper``'s register_dependencies operation. - Should register with the UnitOfWork all inter-mapper - dependencies as well as dependency processors (see UOW docs - for more details). + Establishes a topological dependency between two mappers + which will affect the order in which mappers persist data. + """ pass + def register_processors(self, *args, **kwargs): + """Called by the ``Mapper`` in response to the UnitOfWork + calling the ``Mapper``'s register_processors operation. + Establishes a processor object between two mappers which + will link data and state between parent/child objects. + + """ + + pass + def is_primary(self): """Return True if this ``MapperProperty``'s mapper is the primary mapper for its class. @@ -939,3 +949,7 @@ class InstrumentationManager(object): def state_getter(self, class_): return lambda instance: getattr(instance, '_default_state') + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) +
\ No newline at end of file diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8af6153d6..87c4c8100 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -23,7 +23,6 @@ deque = __import__('collections').deque from sqlalchemy import sql, util, log, exc as sa_exc from sqlalchemy.sql import expression, visitors, operators, util as sqlutil from sqlalchemy.orm import attributes, exc, sync -from sqlalchemy.orm.identity import IdentityManagedState from sqlalchemy.orm.interfaces import ( MapperProperty, EXT_CONTINUE, PropComparator ) @@ -255,7 +254,8 @@ class Mapper(object): for mapper in self.iterate_to_root(): util.reset_memoized(mapper, '_equivalent_columns') - + util.reset_memoized(mapper, '_sorted_tables') + if self.order_by is False and not self.concrete and self.inherits.order_by is not False: self.order_by = self.inherits.order_by @@ -357,7 +357,6 @@ class Mapper(object): if manager is None: manager = attributes.register_class(self.class_, - instance_state_factory = IdentityManagedState, deferred_scalar_loader = _load_scalar_attributes ) @@ -372,6 +371,8 @@ class Mapper(object): event_registry = manager.events event_registry.add_listener('on_init', _event_on_init) event_registry.add_listener('on_init_failure', _event_on_init_failure) + event_registry.add_listener('on_resurrect', _event_on_resurrect) + for key, method in util.iterate_attributes(self.class_): if isinstance(method, types.FunctionType): if hasattr(method, '__sa_reconstructor__'): @@ -1173,6 +1174,19 @@ class Mapper(object): # persistence + @util.memoized_property + def _sorted_tables(self): + table_to_mapper = {} + for mapper in self.base_mapper.polymorphic_iterator(): + for t in mapper.tables: + table_to_mapper[t] = mapper + + sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys()) + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1198,16 +1212,37 @@ class Mapper(object): # if session has a connection callable, # organize individual states with the connection to use for insert/update + tups = [] if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection_callable(self, state.obj()), + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection, + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) if not postupdate: # call before_XXX extensions - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if not has_identity: if 'before_insert' in mapper.extension: mapper.extension.before_insert(mapper, connection, state.obj()) @@ -1215,39 +1250,44 @@ class Mapper(object): if 'before_update' in mapper.extension: mapper.extension.before_update(mapper, connection, state.obj()) - for state, mapper, connection, has_identity in tups: - # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. convert to an - # UPDATE if so. - instance_key = mapper._identity_key_from_state(state) - if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map: - instance = uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) - if self._should_log_debug: - self._log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) - uowtransaction.set_row_switch(existing) - - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + row_switches = set() + if not postupdate: + for state, mapper, connection, has_identity, instance_key in tups: + # detect if we have a "pending" instance (i.e. has no instance_key attached to it), + # and another instance with the same identity key already exists as persistent. convert to an + # UPDATE if so. + if not has_identity and instance_key in uowtransaction.session.identity_map: + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise exc.FlushError( + "New instance %s with identity key %s conflicts with persistent instance %s" % + (state_str(state), instance_key, state_str(existing))) + if self._should_log_debug: + self._log_debug( + "detected row switch for identity %s. will update %s, remove %s from transaction", + instance_key, state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.set_row_switch(existing) + row_switches.add(state) + + table_to_mapper = self._sorted_tables - for table in sqlutil.sort_tables(table_to_mapper.iterkeys()): + for table in table_to_mapper.iterkeys(): insert = [] update = [] - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if table not in mapper._pks_by_table: continue + pks = mapper._pks_by_table[table] - instance_key = mapper._identity_key_from_state(state) - + if self._should_log_debug: self._log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity + isinsert = not has_identity and not postupdate and state not in row_switches params = {} value_params = {} @@ -1364,7 +1404,7 @@ class Mapper(object): sync.populate(state, m, state, m, m._inherits_equated_pairs) if not postupdate: - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: # expire readonly attributes readonly = state.unmodified.intersection( @@ -1434,12 +1474,9 @@ class Mapper(object): if 'before_delete' in mapper.extension: mapper.extension.before_delete(mapper, connection, state.obj()) - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + table_to_mapper = self._sorted_tables - for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())): + for table in reversed(table_to_mapper.keys()): delete = {} for state, mapper, connection in tups: if table not in mapper._pks_by_table: @@ -1485,6 +1522,10 @@ class Mapper(object): for dep in self._props.values() + self._dependency_processors: dep.register_dependencies(uowcommit) + def _register_processors(self, uowcommit): + for dep in self._props.values() + self._dependency_processors: + dep.register_processors(uowcommit) + # result set conversion def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None): @@ -1514,7 +1555,7 @@ class Mapper(object): new_populators = [] existing_populators = [] - def populate_state(state, row, isnew, only_load_props, **flags): + def populate_state(state, dict_, row, isnew, only_load_props, **flags): if isnew: if context.options: state.load_options = context.options @@ -1533,7 +1574,7 @@ class Mapper(object): populators = [p for p in populators if p[0] in only_load_props] for key, populator in populators: - populator(state, row, isnew=isnew, **flags) + populator(state, dict_, row, isnew=isnew, **flags) session_identity_map = context.session.identity_map @@ -1573,9 +1614,11 @@ class Mapper(object): if identitykey in session_identity_map: instance = session_identity_map[identitykey] state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) if self._should_log_debug: - self._log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), identitykey)) + self._log_debug("_instance(): using existing instance %s identity %s", + instance_str(instance), identitykey) isnew = state.runid != context.runid currentload = not isnew @@ -1592,12 +1635,13 @@ class Mapper(object): # when eager_defaults is True. state = refresh_state instance = state.obj() + dict_ = attributes.instance_dict(instance) isnew = state.runid != context.runid currentload = True loaded_instance = False else: if self._should_log_debug: - self._log_debug("_instance(): identity key %s not in session" % str(identitykey)) + self._log_debug("_instance(): identity key %s not in session", identitykey) if self.allow_null_pks: for x in identitykey[1]: @@ -1625,8 +1669,10 @@ class Mapper(object): instance = self.class_manager.new_instance() if self._should_log_debug: - self._log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) + self._log_debug("_instance(): created new instance %s identity %s", + instance_str(instance), identitykey) + dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) state.key = identitykey @@ -1638,12 +1684,12 @@ class Mapper(object): if currentload or populate_existing: if isnew: state.runid = context.runid - context.progress.add(state) + context.progress[state] = dict_ if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, only_load_props) + populate_state(state, dict_, row, isnew, only_load_props) else: # populate attributes on non-loading instances which have been expired @@ -1652,16 +1698,16 @@ class Mapper(object): if state in context.partials: isnew = False - attrs = context.partials[state] + (d_, attrs) = context.partials[state] else: isnew = True attrs = state.unloaded - context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs + context.partials[state] = (dict_, attrs) #<-- allow query.instances to commit the subset of attrs if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, attrs, instancekey=identitykey) + populate_state(state, dict_, row, isnew, attrs, instancekey=identitykey) if loaded_instance: state._run_on_load(instance) @@ -1759,6 +1805,14 @@ def _event_on_init_failure(state, instance, args, kwargs): instrumenting_mapper, instrumenting_mapper.class_, state.manager.events.original_init, instance, args, kwargs) +def _event_on_resurrect(state, instance): + # re-populate the primary key elements + # of the dict based on the mapping. + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): + instrumenting_mapper._set_state_attr_by_column(state, col, val) + + def _sort_states(states): return sorted(states, key=operator.attrgetter('sort_key')) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d0cca2dc1..5605cdcd1 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -96,13 +96,13 @@ class ColumnProperty(StrategizedProperty): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) def getattr(self, state, column): - return state.get_impl(self.key).get(state) + return state.get_impl(self.key).get(state, state.dict) def getcommitted(self, state, column, passive=False): - return state.get_impl(self.key).get_committed_value(state, passive=passive) + return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) def setattr(self, state, value, column): - state.get_impl(self.key).set(state, value, None) + state.get_impl(self.key).set(state, state.dict, value, None) def merge(self, session, source, dest, dont_load, _recursive): value = attributes.instance_state(source).value_as_iterable( @@ -159,7 +159,7 @@ class CompositeProperty(ColumnProperty): super(ColumnProperty, self).do_init() def getattr(self, state, column): - obj = state.get_impl(self.key).get(state) + obj = state.get_impl(self.key).get(state, state.dict) return self.get_col_value(column, obj) def getcommitted(self, state, column, passive=False): @@ -168,7 +168,7 @@ class CompositeProperty(ColumnProperty): def setattr(self, state, value, column): - obj = state.get_impl(self.key).get(state) + obj = state.get_impl(self.key).get(state, state.dict) if obj is None: obj = self.composite_class(*[None for c in self.columns]) state.get_impl(self.key).set(state, obj, None) @@ -635,7 +635,7 @@ class RelationProperty(StrategizedProperty): return source_state = attributes.instance_state(source) - dest_state = attributes.instance_state(dest) + dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest) if not "merge" in self.cascade: dest_state.expire_attributes([self.key]) @@ -658,7 +658,7 @@ class RelationProperty(StrategizedProperty): for c in dest_list: coll.append_without_event(c) else: - getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list) + getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list) else: current = instances[0] if current is not None: @@ -1119,6 +1119,10 @@ class RelationProperty(StrategizedProperty): if not self.viewonly: self._dependency_processor.register_dependencies(uowcommit) + def register_processors(self, uowcommit): + if not self.viewonly: + self._dependency_processor.register_processors(uowcommit) + PropertyLoader = RelationProperty log.class_logger(RelationProperty) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 28ddcc5ea..e3cc3c756 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1330,7 +1330,7 @@ class Query(object): rowtuple.keys = labels.keys while True: - context.progress = set() + context.progress = {} context.partials = {} if self._yield_per: @@ -1354,13 +1354,13 @@ class Query(object): rows = filter(rows) if context.refresh_state and self._only_load_props and context.refresh_state in context.progress: - context.refresh_state.commit(self._only_load_props) - context.progress.remove(context.refresh_state) + context.refresh_state.commit(context.refresh_state.dict, self._only_load_props) + context.progress.pop(context.refresh_state) session._finalize_loaded(context.progress) - for ii, attrs in context.partials.items(): - ii.commit(attrs) + for ii, (dict_, attrs) in context.partials.items(): + ii.commit(dict_, attrs) for row in rows: yield row @@ -1683,14 +1683,14 @@ class Query(object): evaluated_keys = value_evaluators.keys() if issubclass(cls, target_cls) and eval_condition(obj): - state = attributes.instance_state(obj) + state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj) # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: - state.dict[key] = value_evaluators[key](obj) + dict_[key] = value_evaluators[key](obj) - state.commit(list(to_evaluate)) + state.commit(dict_, list(to_evaluate)) # expire attributes with pending changes (there was no autoflush, so they are overwritten) state.expire_attributes(set(evaluated_keys).difference(to_evaluate)) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 1e3a750d9..00a7d55e5 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc from sqlalchemy import util, sql, engine, log from sqlalchemy.sql import util as sql_util, expression from sqlalchemy.orm import ( - SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, + SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state ) from sqlalchemy.orm.util import object_mapper as _object_mapper from sqlalchemy.orm.util import class_mapper as _class_mapper @@ -899,8 +899,8 @@ class Session(object): self.flush() def _finalize_loaded(self, states): - for state in states: - state.commit_all() + for state, dict_ in states.items(): + state.commit_all(dict_) def refresh(self, instance, attribute_names=None): """Refresh the attributes on the given instance. @@ -1020,11 +1020,9 @@ class Session(object): # primary key switch self.identity_map.remove(state) state.key = instance_key - - if state.key in self.identity_map and not self.identity_map.contains_state(state): - self.identity_map.remove_key(state.key) - self.identity_map.add(state) - state.commit_all() + + self.identity_map.replace(state) + state.commit_all(state.dict) # remove from new last, might be the last strong ref if state in self._new: @@ -1213,7 +1211,7 @@ class Session(object): prop.merge(self, instance, merged, dont_load, _recursive) if dont_load: - attributes.instance_state(merged).commit_all() # remove any history + attributes.instance_state(merged).commit_all(attributes.instance_dict(merged)) # remove any history if new_instance: merged_state._run_on_load(merged) @@ -1368,7 +1366,7 @@ class Session(object): self.identity_map.modified = False return - flush_context = UOWTransaction(self) + flush_context = UOWTransaction(self) if self.extensions: for ext in self.extensions: @@ -1489,7 +1487,7 @@ class Session(object): return util.IdentitySet( [state for state in self.identity_map.all_states() - if state.check_modified()]) + if state.modified]) @property def dirty(self): @@ -1528,7 +1526,7 @@ class Session(object): return util.IdentitySet(self._new.values()) -_expire_state = attributes.InstanceState.expire_attributes +_expire_state = state.InstanceState.expire_attributes UOWEventHandler = unitofwork.UOWEventHandler @@ -1548,16 +1546,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs): yield _state_for_unknown_persistence_instance(o), m def _state_for_unsaved_instance(instance, create=False): - manager = attributes.manager_of_class(instance.__class__) - if manager is None: + try: + state = attributes.instance_state(instance) + except AttributeError: raise exc.UnmappedInstanceError(instance) - if manager.has_state(instance): - state = manager.state_of(instance) + if state: if state.key is not None: raise sa_exc.InvalidRequestError( "Instance '%s' is already persistent" % mapperutil.state_str(state)) elif create: + manager = attributes.manager_of_class(instance.__class__) + if manager is None: + raise exc.UnmappedInstanceError(instance) state = manager.setup_instance(instance) else: raise exc.UnmappedInstanceError(instance) diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py new file mode 100644 index 000000000..c99dfe73c --- /dev/null +++ b/lib/sqlalchemy/orm/state.py @@ -0,0 +1,429 @@ +from sqlalchemy.util import EMPTY_SET +import weakref +from sqlalchemy import util +from sqlalchemy.orm.attributes import PASSIVE_NORESULT, PASSIVE_OFF, NEVER_SET, NO_VALUE, manager_of_class, ATTR_WAS_SET +from sqlalchemy.orm import attributes +from sqlalchemy.orm import interfaces + +class InstanceState(object): + """tracks state information at the instance level.""" + + session_id = None + key = None + runid = None + expired_attributes = EMPTY_SET + load_options = EMPTY_SET + load_path = () + insert_order = None + mutable_dict = None + + def __init__(self, obj, manager): + self.class_ = obj.__class__ + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) + self.modified = False + self.callables = {} + self.expired = False + self.committed_state = {} + self.pending = {} + self.parents = {} + + def detach(self): + if self.session_id: + del self.session_id + + def dispose(self): + if self.session_id: + del self.session_id + del self.obj + + def _cleanup(self, ref): + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.remove(self) + self.dispose() + + def obj(self): + return None + + @property + def dict(self): + o = self.obj() + if o is not None: + return attributes.instance_dict(o) + else: + return {} + + @property + def sort_key(self): + return self.key and self.key[1] or (self.insert_order, ) + + def check_modified(self): + # TODO: deprecate + return self.modified + + def initialize_instance(*mixed, **kwargs): + self, instance, args = mixed[0], mixed[1], mixed[2:] + manager = self.manager + + for fn in manager.events.on_init: + fn(self, instance, args, kwargs) + + # LESSTHANIDEAL: + # adjust for the case where the InstanceState was created before + # mapper compilation, and this actually needs to be a MutableAttrInstanceState + if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState: + self.__class__ = MutableAttrInstanceState + self.obj = weakref.ref(self.obj(), self._cleanup) + self.mutable_dict = {} + + try: + return manager.events.original_init(*mixed[1:], **kwargs) + except: + for fn in manager.events.on_init_failure: + fn(self, instance, args, kwargs) + raise + + def get_history(self, key, **kwargs): + return self.manager.get_impl(key).get_history(self, self.dict, **kwargs) + + def get_impl(self, key): + return self.manager.get_impl(key) + + def get_pending(self, key): + if key not in self.pending: + self.pending[key] = PendingCollection() + return self.pending[key] + + def value_as_iterable(self, key, passive=PASSIVE_OFF): + """return an InstanceState attribute as a list, + regardless of it being a scalar or collection-based + attribute. + + returns None if passive is not PASSIVE_OFF and the getter returns + PASSIVE_NORESULT. + """ + + impl = self.get_impl(key) + dict_ = self.dict + x = impl.get(self, dict_, passive=passive) + if x is PASSIVE_NORESULT: + return None + elif hasattr(impl, 'get_collection'): + return impl.get_collection(self, dict_, x, passive=passive) + elif isinstance(x, list): + return x + else: + return [x] + + def _run_on_load(self, instance): + self.manager.events.run('on_load', instance) + + def __getstate__(self): + return {'key': self.key, + 'committed_state': self.committed_state, + 'pending': self.pending, + 'parents': self.parents, + 'modified': self.modified, + 'expired':self.expired, + 'load_options':self.load_options, + 'load_path':interfaces.serialize_path(self.load_path), + 'instance': self.obj(), + 'expired_attributes':self.expired_attributes, + 'callables': self.callables} + + def __setstate__(self, state): + self.committed_state = state['committed_state'] + self.parents = state['parents'] + self.key = state['key'] + self.session_id = None + self.pending = state['pending'] + self.modified = state['modified'] + self.obj = weakref.ref(state['instance']) + self.load_options = state['load_options'] or EMPTY_SET + self.load_path = interfaces.deserialize_path(state['load_path']) + self.class_ = self.obj().__class__ + self.manager = manager_of_class(self.class_) + self.callables = state['callables'] + self.runid = None + self.expired = state['expired'] + self.expired_attributes = state['expired_attributes'] + + def initialize(self, key): + self.manager.get_impl(key).initialize(self, self.dict) + + def set_callable(self, key, callable_): + self.dict.pop(key, None) + self.callables[key] = callable_ + + def __call__(self): + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable (picklable). + + """ + unmodified = self.unmodified + class_manager = self.manager + class_manager.deferred_scalar_loader(self, [ + attr.impl.key for attr in class_manager.attributes if + attr.impl.accepts_scalar_loader and + attr.impl.key in self.expired_attributes and + attr.impl.key in unmodified + ]) + for k in self.expired_attributes: + self.callables.pop(k, None) + del self.expired_attributes + return ATTR_WAS_SET + + @property + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + return set(self.manager).difference(self.committed_state) + + @property + def unloaded(self): + """a set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that + was never populated or modified. + + """ + return set( + key for key in self.manager.iterkeys() + if key not in self.committed_state and key not in self.dict) + + def expire_attributes(self, attribute_names): + self.expired_attributes = set(self.expired_attributes) + + if attribute_names is None: + attribute_names = self.manager.keys() + self.expired = True + self.modified = False + filter_deferred = True + else: + filter_deferred = False + dict_ = self.dict + + for key in attribute_names: + impl = self.manager[key].impl + if not filter_deferred or \ + not impl.dont_expire_missing or \ + key in dict_: + self.expired_attributes.add(key) + if impl.accepts_scalar_loader: + self.callables[key] = self + dict_.pop(key, None) + self.pending.pop(key, None) + self.committed_state.pop(key, None) + if self.mutable_dict: + self.mutable_dict.pop(key, None) + + def reset(self, key, dict_): + """remove the given attribute and any callables associated with it.""" + + dict_.pop(key, None) + self.callables.pop(key, None) + + def _instance_dict(self): + return None + + def _is_really_none(self): + return self.obj() + + def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF): + needs_committed = attr.key not in self.committed_state + + if needs_committed: + if previous is NEVER_SET: + if passive: + if attr.key in dict_: + previous = dict_[attr.key] + else: + previous = attr.get(self, dict_) + + if should_copy and previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + + if needs_committed: + self.committed_state[attr.key] = previous + + self.modified = True + self._strong_obj = self.obj() + + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.modified = True + + def commit(self, dict_, keys): + """Commit attributes. + + This is used by a partial-attribute load operation to mark committed + those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after + this step if a value was not populated in state.dict. + + """ + class_manager = self.manager + for key in keys: + if key in dict_ and key in class_manager.mutable_attributes: + class_manager[key].impl.commit_to_state(self, dict_, self.committed_state) + else: + self.committed_state.pop(key, None) + + self.expired = False + # unexpire attributes which have loaded + for key in self.expired_attributes.intersection(keys): + if key in dict_: + self.expired_attributes.remove(key) + self.callables.pop(key, None) + + def commit_all(self, dict_): + """commit all attributes unconditionally. + + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers/callables are removed. + + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. + + """ + + self.committed_state = {} + self.pending = {} + + # unexpire attributes which have loaded + if self.expired_attributes: + for key in self.expired_attributes.intersection(dict_): + self.callables.pop(key, None) + self.expired_attributes.difference_update(dict_) + + for key in self.manager.mutable_attributes: + if key in dict_: + self.manager[key].impl.commit_to_state(self, dict_, self.committed_state) + + self.modified = self.expired = False + self._strong_obj = None + +class MutableAttrInstanceState(InstanceState): + def __init__(self, obj, manager): + self.mutable_dict = {} + InstanceState.__init__(self, obj, manager) + + def _get_modified(self, dict_=None): + if self.__dict__.get('modified', False): + return True + else: + if dict_ is None: + dict_ = self.dict + for key in self.manager.mutable_attributes: + if self.manager[key].impl.check_mutable_modified(self, dict_): + return True + else: + return False + + def _set_modified(self, value): + self.__dict__['modified'] = value + + modified = property(_get_modified, _set_modified) + + @property + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + dict_ = self.dict + return set( + key for key in self.manager.iterkeys() + if (key not in self.committed_state or + (key in self.manager.mutable_attributes and + not self.manager[key].impl.check_mutable_modified(self, dict_)))) + + def _is_really_none(self): + """do a check modified/resurrect. + + This would be called in the extremely rare + race condition that the weakref returned None but + the cleanup handler had not yet established the + __resurrect callable as its replacement. + + """ + if self.modified: + self.obj = self.__resurrect + return self.obj() + else: + return None + + def reset(self, key, dict_): + self.mutable_dict.pop(key, None) + InstanceState.reset(self, key, dict_) + + def _cleanup(self, ref): + """weakref callback. + + This method may be called by an asynchronous + gc. + + If the state shows pending changes, the weakref + is replaced by the __resurrect callable which will + re-establish an object reference on next access, + else removes this InstanceState from the owning + identity map, if any. + + """ + if self._get_modified(self.mutable_dict): + self.obj = self.__resurrect + else: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.remove(self) + self.dispose() + + def __resurrect(self): + """A substitute for the obj() weakref function which resurrects.""" + + # store strong ref'ed version of the object; will revert + # to weakref when changes are persisted + + obj = self.manager.new_instance(state=self) + self.obj = weakref.ref(obj, self._cleanup) + self._strong_obj = obj + obj.__dict__.update(self.mutable_dict) + + # re-establishes identity attributes from the key + self.manager.events.run('on_resurrect', self, obj) + + # TODO: don't really think we should run this here. + # resurrect is only meant to preserve the minimal state needed to + # do an UPDATE, not to produce a fully usable object + self._run_on_load(obj) + + return obj + +class PendingCollection(object): + """A writable placeholder for an unloaded collection. + + Stores items appended to and removed from a collection that has not yet + been loaded. When the collection is loaded, the changes stored in + PendingCollection are applied to it to produce the final result. + + """ + def __init__(self): + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def append(self, value): + if value in self.deleted_items: + self.deleted_items.remove(value) + self.added_items.add(value) + + def remove(self, value): + if value in self.added_items: + self.added_items.remove(value) + self.deleted_items.add(value) + diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 1aeb311e1..20cbb8f4d 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -115,8 +115,8 @@ class ColumnLoader(LoaderStrategy): if adapter: col = adapter.columns[col] if col in row: - def new_execute(state, row, **flags): - state.dict[key] = row[col] + def new_execute(state, dict_, row, **flags): + dict_[key] = row[col] if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, @@ -125,7 +125,7 @@ class ColumnLoader(LoaderStrategy): ) return (new_execute, None) else: - def new_execute(state, row, isnew, **flags): + def new_execute(state, dict_, row, isnew, **flags): if isnew: state.expire_attributes([key]) if self._should_log_debug: @@ -171,15 +171,15 @@ class CompositeColumnLoader(ColumnLoader): columns = [adapter.columns[c] for c in columns] for c in columns: if c not in row: - def new_execute(state, row, isnew, **flags): + def new_execute(state, dict_, row, isnew, **flags): if isnew: state.expire_attributes([key]) if self._should_log_debug: self.logger.debug("%s deferring load" % self) return (new_execute, None) else: - def new_execute(state, row, **flags): - state.dict[key] = composite_class(*[row[c] for c in columns]) + def new_execute(state, dict_, row, **flags): + dict_[key] = composite_class(*[row[c] for c in columns]) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, @@ -202,13 +202,13 @@ class DeferredColumnLoader(LoaderStrategy): return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter) elif not self.is_class_level: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): state.set_callable(self.key, LoadDeferredColumns(state, self.key)) else: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # reset state on the key so that deferred callables # fire off on next access. - state.reset(self.key) + state.reset(self.key, dict_) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, None, @@ -340,7 +340,7 @@ class NoLoader(AbstractRelationLoader): ) def create_row_processor(self, selectcontext, path, mapper, row, adapter): - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): self._init_instance_attribute(state) if self._should_log_debug: @@ -437,7 +437,7 @@ class LazyLoader(AbstractRelationLoader): def create_row_processor(self, selectcontext, path, mapper, row, adapter): if not self.is_class_level: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, # which will override the class-level behavior. # this currently only happens when using a "lazyload" option on a "no load" attribute - @@ -451,11 +451,11 @@ class LazyLoader(AbstractRelationLoader): return (new_execute, None) else: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # we are the primary manager for this attribute on this class - reset its per-instance attribute state, # so that the class-level lazy loader is executed when next referenced on this instance. # this is needed in populate_existing() types of scenarios to reset any existing state. - state.reset(self.key) + state.reset(self.key, dict_) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, None, @@ -735,24 +735,24 @@ class EagerLoader(AbstractRelationLoader): _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter) if not self.uselist: - def execute(state, row, isnew, **flags): + def execute(state, dict_, row, isnew, **flags): if isnew: # set a scalar object instance directly on the # parent object, bypassing InstrumentedAttribute # event handlers. - state.dict[key] = _instance(row, None) + dict_[key] = _instance(row, None) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties _instance(row, None) else: - def execute(state, row, isnew, **flags): + def execute(state, dict_, row, isnew, **flags): if isnew or (state, key) not in context.attributes: # appender_key can be absent from context.attributes with isnew=False # when self-referential eager loading is used; the same instance may be present # in two distinct sets of result columns - collection = attributes.init_state_collection(state, key) + collection = attributes.init_state_collection(state, dict_, key) appender = util.UniqueAppender(collection, 'append_without_event') context.attributes[(state, key)] = appender diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 4ac9c765e..407b702a8 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -96,6 +96,8 @@ class UOWTransaction(object): # information. self.attributes = {} + self.processors = set() + def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) @@ -136,6 +138,16 @@ class UOWTransaction(object): else: task.append(state, listonly=listonly, isdelete=isdelete) + # ensure the mapper for this object has had its + # DependencyProcessors added. + if mapper not in self.processors: + mapper._register_processors(self) + self.processors.add(mapper) + + if mapper.base_mapper not in self.processors: + mapper.base_mapper._register_processors(self) + self.processors.add(mapper.base_mapper) + def set_row_switch(self, state): """mark a deleted object as a 'row switch'. @@ -147,7 +159,7 @@ class UOWTransaction(object): task = self.get_task_by_mapper(mapper) taskelement = task._objects[state] taskelement.isdelete = "rowswitch" - + def is_deleted(self, state): """return true if the given state is marked as deleted within this UOWTransaction.""" @@ -201,9 +213,9 @@ class UOWTransaction(object): self.dependencies.add((mapper, dependency)) def register_processor(self, mapper, processor, mapperfrom): - """register a dependency processor, corresponding to dependencies between - the two given mappers. - + """register a dependency processor, corresponding to + operations which occur between two mappers. + """ # correct for primary mapper mapper = mapper.primary_mapper() |
