diff options
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/state.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 4 | ||||
| -rw-r--r-- | test/orm/test_attributes.py | 14 | ||||
| -rw-r--r-- | test/orm/test_extendedattr.py | 8 | ||||
| -rw-r--r-- | test/perf/objselectspeed.py | 37 |
9 files changed, 88 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5a6e24dfa..a8c525657 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1484,7 +1484,7 @@ class Mapper(object): ) if readonly: - _expire_state(state, readonly) + _expire_state(state, state.dict, readonly) # if specified, eagerly refresh whatever has # been expired. @@ -1524,7 +1524,7 @@ class Mapper(object): deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] if deferred_props: - _expire_state(state, deferred_props) + _expire_state(state, state.dict, deferred_props) # synchronize newly inserted ids from one table to the next # TODO: this still goes a little too often. would be nice to diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 8dbc6b3db..bb92f39e4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -115,7 +115,7 @@ class ColumnProperty(StrategizedProperty): impl = dest_state.get_impl(self.key) impl.set(dest_state, dest_dict, value, None) else: - dest_state.expire_attributes([self.key]) + dest_state.expire_attributes(dest_dict, [self.key]) def get_col_value(self, column, value): return value @@ -636,7 +636,7 @@ class RelationProperty(StrategizedProperty): return if not "merge" in self.cascade: - dest_state.expire_attributes([self.key]) + dest_state.expire_attribute(dest_dict, [self.key]) return if self.key not in source_dict: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ed5535151..456f1f19c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1873,8 +1873,9 @@ class Query(object): state.commit(dict_, list(to_evaluate)) - # expire attributes with pending changes (there was no autoflush, so they are overwritten) - state.expire_attributes(set(evaluated_keys).difference(to_evaluate)) + # expire attributes with pending changes + # (there was no autoflush, so they are overwritten) + state.expire_attributes(dict_, set(evaluated_keys).difference(to_evaluate)) elif synchronize_session == 'fetch': target_mapper = self._mapper_zero() diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 0e5e939b1..d5246bee0 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -289,14 +289,14 @@ class SessionTransaction(object): assert not self.session._deleted for s in self.session.identity_map.all_states(): - _expire_state(s, None, instance_dict=self.session.identity_map) + _expire_state(s, s.dict, None, instance_dict=self.session.identity_map) def _remove_snapshot(self): assert self._is_transaction_boundary if not self.nested and self.session.expire_on_commit: for s in self.session.identity_map.all_states(): - _expire_state(s, None, instance_dict=self.session.identity_map) + _expire_state(s, s.dict, None, instance_dict=self.session.identity_map) def _connection_for_bind(self, bind): self._assert_is_active() @@ -915,7 +915,7 @@ class Session(object): """Expires all persistent instances within this Session.""" for state in self.identity_map.all_states(): - _expire_state(state, None, instance_dict=self.identity_map) + _expire_state(state, state.dict, None, instance_dict=self.identity_map) def expire(self, instance, attribute_names=None): """Expire the attributes on an instance. @@ -936,14 +936,15 @@ class Session(object): raise exc.UnmappedInstanceError(instance) self._validate_persistent(state) if attribute_names: - _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map) + _expire_state(state, state.dict, + attribute_names=attribute_names, instance_dict=self.identity_map) else: # pre-fetch the full cascade since the expire is going to # remove associations cascaded = list(_cascade_state_iterator('refresh-expire', state)) - _expire_state(state, None, instance_dict=self.identity_map) + _expire_state(state, state.dict, None, instance_dict=self.identity_map) for (state, m, o) in cascaded: - _expire_state(state, None, instance_dict=self.identity_map) + _expire_state(state, state.dict, None, instance_dict=self.identity_map) def prune(self): """Remove unreferenced instances cached in the identity map. diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 4bb9219f4..a9494a50e 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -239,9 +239,30 @@ class InstanceState(object): return set( key for key in self.manager.iterkeys() if key not in self.committed_state and key not in self.dict) - - def expire_attributes(self, attribute_names, instance_dict=None): - self.expired_attributes = set(self.expired_attributes) + + def expire_attribute_pre_commit(self, dict_, key): + """a fast expire that can be called by column loaders during a load. + + The additional bookkeeping is finished up in commit_all(). + + This method is actually called a lot with joined-table + loading, when the second table isn't present in the result. + + """ + # TODO: yes, this is still a little too busy. + # need to more cleanly separate out handling + # for the various AttributeImpls and the contracts + # they wish to maintain with their strategies + if not self.expired_attributes: + self.expired_attributes = set(self.expired_attributes) + + dict_.pop(key, None) + self.callables[key] = self + self.expired_attributes.add(key) + + def expire_attributes(self, dict_, attribute_names, instance_dict=None): + if not self.expired_attributes: + self.expired_attributes = set(self.expired_attributes) if attribute_names is None: attribute_names = self.manager.keys() @@ -258,7 +279,6 @@ class InstanceState(object): filter_deferred = True else: filter_deferred = False - dict_ = self.dict for key in attribute_names: impl = self.manager[key].impl @@ -354,8 +374,7 @@ class InstanceState(object): self.committed_state = {} self.pending = {} - - # unexpire attributes which have loaded + if self.expired_attributes: for key in self.expired_attributes.intersection(dict_): self.callables.pop(key, None) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 5e81d33ca..4d5ec3da4 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -121,7 +121,7 @@ class ColumnLoader(LoaderStrategy): else: def new_execute(state, dict_, row, isnew): if isnew: - state.expire_attributes([key]) + state.expire_attribute_pre_commit(dict_, key) return new_execute, None log.class_logger(ColumnLoader) @@ -168,7 +168,7 @@ class CompositeColumnLoader(ColumnLoader): if c not in row: def new_execute(state, dict_, row, isnew): if isnew: - state.expire_attributes([key]) + state.expire_attribute_pre_commit(dict_, key) break else: def new_execute(state, dict_, row, isnew): diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index c69021aa3..e6041d566 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -142,21 +142,21 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Foo, 'b', uselist=False, useobject=False) f = Foo() - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) eq_(f.a, "this is a") eq_(f.b, 12) f.a = "this is some new a" - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) eq_(f.a, "this is a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) f.a = "this is another new a" eq_(f.a, "this is another new a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) eq_(f.a, "this is a") eq_(f.b, 12) @@ -182,7 +182,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False) m = MyTest() - attributes.instance_state(m).expire_attributes(None) + attributes.instance_state(m).expire_attributes(attributes.instance_dict(m), None) assert 'a' not in m.__dict__ m2 = pickle.loads(pickle.dumps(m)) assert 'a' not in m2.__dict__ @@ -355,7 +355,7 @@ class AttributesTest(_base.ORMTest): x.bars b = Bar(id=4) b.foos.append(x) - attributes.instance_state(x).expire_attributes(['bars']) + attributes.instance_state(x).expire_attributes(attributes.instance_dict(x), ['bars']) assert_raises(AssertionError, b.foos.remove, x) @@ -1294,7 +1294,7 @@ class HistoryTest(_base.ORMTest): eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), ([bar4], [], [])) lazy_load = [bar1, bar2, bar3] - attributes.instance_state(f).expire_attributes(['bars']) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), ['bars']) eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), ((), [bar1, bar2, bar3], ())) def test_collections_via_lazyload(self): diff --git a/test/orm/test_extendedattr.py b/test/orm/test_extendedattr.py index 685be3a5f..4374b9ecb 100644 --- a/test/orm/test_extendedattr.py +++ b/test/orm/test_extendedattr.py @@ -161,21 +161,21 @@ class UserDefinedExtensionTest(_base.ORMTest): assert Foo in attributes.instrumentation_registry._state_finders f = Foo() - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) eq_(f.a, "this is a") eq_(f.b, 12) f.a = "this is some new a" - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) eq_(f.a, "this is a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) f.a = "this is another new a" eq_(f.a, "this is another new a") eq_(f.b, 12) - attributes.instance_state(f).expire_attributes(None) + attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None) eq_(f.a, "this is a") eq_(f.b, 12) diff --git a/test/perf/objselectspeed.py b/test/perf/objselectspeed.py index e04ef4efb..d3fd34046 100644 --- a/test/perf/objselectspeed.py +++ b/test/perf/objselectspeed.py @@ -8,21 +8,44 @@ db = create_engine('sqlite://') metadata = MetaData(db) Person_table = Table('Person', metadata, Column('id', Integer, primary_key=True), + Column('type', String(10)), Column('name', String(40)), Column('sex', Integer), Column('age', Integer)) + +Employee_table = Table('Employee', metadata, + Column('id', Integer, ForeignKey('Person.id'), primary_key=True), + Column('foo', String(40)), + Column('bar', Integer), + Column('bat', Integer)) + class RawPerson(object): pass class Person(object): pass mapper(Person, Person_table) + +class JoinedPerson(object):pass +class Employee(JoinedPerson):pass +mapper(JoinedPerson, Person_table, \ + polymorphic_on=Person_table.c.type, polymorphic_identity='person') +mapper(Employee, Employee_table, \ + inherits=JoinedPerson, polymorphic_identity='employee') compile_mappers() def setup(): metadata.create_all() i = Person_table.insert() - data = [{'name':'John Doe','sex':1,'age':35}] * 100 + data = [{'name':'John Doe','sex':1,'age':35, 'type':'employee'}] * 100 for j in xrange(500): i.execute(data) + + # note we arent fetching from employee_table, + # so we can leave it empty even though its "incorrect" + #i = Employee_table.insert() + #data = [{'foo':'foo', 'bar':'bar':'bat':'bat'}] * 100 + #for j in xrange(500): + # i.execute(data) + print "Inserted 50,000 rows" def sqlite_select(entity_cls): @@ -55,6 +78,11 @@ def orm_select(): session = create_session() people = session.query(Person).all() +#@profiling.profiled(report=True, always=True) +def joined_orm_select(): + session = create_session() + people = session.query(JoinedPerson).all() + def all(): setup() try: @@ -103,6 +131,13 @@ def all(): orm_select() t2 = time.clock() usage('sqlalchemy.orm fetch') + + gc_collect() + usage.snap() + t = time.clock() + joined_orm_select() + t2 = time.clock() + usage('sqlalchemy.orm "joined" fetch') finally: metadata.drop_all() |
