summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/mapper.py4
-rw-r--r--lib/sqlalchemy/orm/properties.py4
-rw-r--r--lib/sqlalchemy/orm/query.py5
-rw-r--r--lib/sqlalchemy/orm/session.py13
-rw-r--r--lib/sqlalchemy/orm/state.py31
-rw-r--r--lib/sqlalchemy/orm/strategies.py4
-rw-r--r--test/orm/test_attributes.py14
-rw-r--r--test/orm/test_extendedattr.py8
-rw-r--r--test/perf/objselectspeed.py37
9 files changed, 88 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 5a6e24dfa..a8c525657 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1484,7 +1484,7 @@ class Mapper(object):
)
if readonly:
- _expire_state(state, readonly)
+ _expire_state(state, state.dict, readonly)
# if specified, eagerly refresh whatever has
# been expired.
@@ -1524,7 +1524,7 @@ class Mapper(object):
deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]]
if deferred_props:
- _expire_state(state, deferred_props)
+ _expire_state(state, state.dict, deferred_props)
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 8dbc6b3db..bb92f39e4 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -115,7 +115,7 @@ class ColumnProperty(StrategizedProperty):
impl = dest_state.get_impl(self.key)
impl.set(dest_state, dest_dict, value, None)
else:
- dest_state.expire_attributes([self.key])
+ dest_state.expire_attributes(dest_dict, [self.key])
def get_col_value(self, column, value):
return value
@@ -636,7 +636,7 @@ class RelationProperty(StrategizedProperty):
return
if not "merge" in self.cascade:
- dest_state.expire_attributes([self.key])
+ dest_state.expire_attribute(dest_dict, [self.key])
return
if self.key not in source_dict:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index ed5535151..456f1f19c 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1873,8 +1873,9 @@ class Query(object):
state.commit(dict_, list(to_evaluate))
- # expire attributes with pending changes (there was no autoflush, so they are overwritten)
- state.expire_attributes(set(evaluated_keys).difference(to_evaluate))
+ # expire attributes with pending changes
+ # (there was no autoflush, so they are overwritten)
+ state.expire_attributes(dict_, set(evaluated_keys).difference(to_evaluate))
elif synchronize_session == 'fetch':
target_mapper = self._mapper_zero()
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 0e5e939b1..d5246bee0 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -289,14 +289,14 @@ class SessionTransaction(object):
assert not self.session._deleted
for s in self.session.identity_map.all_states():
- _expire_state(s, None, instance_dict=self.session.identity_map)
+ _expire_state(s, s.dict, None, instance_dict=self.session.identity_map)
def _remove_snapshot(self):
assert self._is_transaction_boundary
if not self.nested and self.session.expire_on_commit:
for s in self.session.identity_map.all_states():
- _expire_state(s, None, instance_dict=self.session.identity_map)
+ _expire_state(s, s.dict, None, instance_dict=self.session.identity_map)
def _connection_for_bind(self, bind):
self._assert_is_active()
@@ -915,7 +915,7 @@ class Session(object):
"""Expires all persistent instances within this Session."""
for state in self.identity_map.all_states():
- _expire_state(state, None, instance_dict=self.identity_map)
+ _expire_state(state, state.dict, None, instance_dict=self.identity_map)
def expire(self, instance, attribute_names=None):
"""Expire the attributes on an instance.
@@ -936,14 +936,15 @@ class Session(object):
raise exc.UnmappedInstanceError(instance)
self._validate_persistent(state)
if attribute_names:
- _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map)
+ _expire_state(state, state.dict,
+ attribute_names=attribute_names, instance_dict=self.identity_map)
else:
# pre-fetch the full cascade since the expire is going to
# remove associations
cascaded = list(_cascade_state_iterator('refresh-expire', state))
- _expire_state(state, None, instance_dict=self.identity_map)
+ _expire_state(state, state.dict, None, instance_dict=self.identity_map)
for (state, m, o) in cascaded:
- _expire_state(state, None, instance_dict=self.identity_map)
+ _expire_state(state, state.dict, None, instance_dict=self.identity_map)
def prune(self):
"""Remove unreferenced instances cached in the identity map.
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 4bb9219f4..a9494a50e 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -239,9 +239,30 @@ class InstanceState(object):
return set(
key for key in self.manager.iterkeys()
if key not in self.committed_state and key not in self.dict)
-
- def expire_attributes(self, attribute_names, instance_dict=None):
- self.expired_attributes = set(self.expired_attributes)
+
+ def expire_attribute_pre_commit(self, dict_, key):
+ """a fast expire that can be called by column loaders during a load.
+
+ The additional bookkeeping is finished up in commit_all().
+
+ This method is actually called a lot with joined-table
+ loading, when the second table isn't present in the result.
+
+ """
+ # TODO: yes, this is still a little too busy.
+ # need to more cleanly separate out handling
+ # for the various AttributeImpls and the contracts
+ # they wish to maintain with their strategies
+ if not self.expired_attributes:
+ self.expired_attributes = set(self.expired_attributes)
+
+ dict_.pop(key, None)
+ self.callables[key] = self
+ self.expired_attributes.add(key)
+
+ def expire_attributes(self, dict_, attribute_names, instance_dict=None):
+ if not self.expired_attributes:
+ self.expired_attributes = set(self.expired_attributes)
if attribute_names is None:
attribute_names = self.manager.keys()
@@ -258,7 +279,6 @@ class InstanceState(object):
filter_deferred = True
else:
filter_deferred = False
- dict_ = self.dict
for key in attribute_names:
impl = self.manager[key].impl
@@ -354,8 +374,7 @@ class InstanceState(object):
self.committed_state = {}
self.pending = {}
-
- # unexpire attributes which have loaded
+
if self.expired_attributes:
for key in self.expired_attributes.intersection(dict_):
self.callables.pop(key, None)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 5e81d33ca..4d5ec3da4 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -121,7 +121,7 @@ class ColumnLoader(LoaderStrategy):
else:
def new_execute(state, dict_, row, isnew):
if isnew:
- state.expire_attributes([key])
+ state.expire_attribute_pre_commit(dict_, key)
return new_execute, None
log.class_logger(ColumnLoader)
@@ -168,7 +168,7 @@ class CompositeColumnLoader(ColumnLoader):
if c not in row:
def new_execute(state, dict_, row, isnew):
if isnew:
- state.expire_attributes([key])
+ state.expire_attribute_pre_commit(dict_, key)
break
else:
def new_execute(state, dict_, row, isnew):
diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py
index c69021aa3..e6041d566 100644
--- a/test/orm/test_attributes.py
+++ b/test/orm/test_attributes.py
@@ -142,21 +142,21 @@ class AttributesTest(_base.ORMTest):
attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
f = Foo()
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
eq_(f.a, "this is a")
eq_(f.b, 12)
f.a = "this is some new a"
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
eq_(f.a, "this is a")
eq_(f.b, 12)
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
f.a = "this is another new a"
eq_(f.a, "this is another new a")
eq_(f.b, 12)
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
eq_(f.a, "this is a")
eq_(f.b, 12)
@@ -182,7 +182,7 @@ class AttributesTest(_base.ORMTest):
attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
m = MyTest()
- attributes.instance_state(m).expire_attributes(None)
+ attributes.instance_state(m).expire_attributes(attributes.instance_dict(m), None)
assert 'a' not in m.__dict__
m2 = pickle.loads(pickle.dumps(m))
assert 'a' not in m2.__dict__
@@ -355,7 +355,7 @@ class AttributesTest(_base.ORMTest):
x.bars
b = Bar(id=4)
b.foos.append(x)
- attributes.instance_state(x).expire_attributes(['bars'])
+ attributes.instance_state(x).expire_attributes(attributes.instance_dict(x), ['bars'])
assert_raises(AssertionError, b.foos.remove, x)
@@ -1294,7 +1294,7 @@ class HistoryTest(_base.ORMTest):
eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), ([bar4], [], []))
lazy_load = [bar1, bar2, bar3]
- attributes.instance_state(f).expire_attributes(['bars'])
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), ['bars'])
eq_(attributes.get_state_history(attributes.instance_state(f), 'bars'), ((), [bar1, bar2, bar3], ()))
def test_collections_via_lazyload(self):
diff --git a/test/orm/test_extendedattr.py b/test/orm/test_extendedattr.py
index 685be3a5f..4374b9ecb 100644
--- a/test/orm/test_extendedattr.py
+++ b/test/orm/test_extendedattr.py
@@ -161,21 +161,21 @@ class UserDefinedExtensionTest(_base.ORMTest):
assert Foo in attributes.instrumentation_registry._state_finders
f = Foo()
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
eq_(f.a, "this is a")
eq_(f.b, 12)
f.a = "this is some new a"
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
eq_(f.a, "this is a")
eq_(f.b, 12)
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
f.a = "this is another new a"
eq_(f.a, "this is another new a")
eq_(f.b, 12)
- attributes.instance_state(f).expire_attributes(None)
+ attributes.instance_state(f).expire_attributes(attributes.instance_dict(f), None)
eq_(f.a, "this is a")
eq_(f.b, 12)
diff --git a/test/perf/objselectspeed.py b/test/perf/objselectspeed.py
index e04ef4efb..d3fd34046 100644
--- a/test/perf/objselectspeed.py
+++ b/test/perf/objselectspeed.py
@@ -8,21 +8,44 @@ db = create_engine('sqlite://')
metadata = MetaData(db)
Person_table = Table('Person', metadata,
Column('id', Integer, primary_key=True),
+ Column('type', String(10)),
Column('name', String(40)),
Column('sex', Integer),
Column('age', Integer))
+
+Employee_table = Table('Employee', metadata,
+ Column('id', Integer, ForeignKey('Person.id'), primary_key=True),
+ Column('foo', String(40)),
+ Column('bar', Integer),
+ Column('bat', Integer))
+
class RawPerson(object): pass
class Person(object): pass
mapper(Person, Person_table)
+
+class JoinedPerson(object):pass
+class Employee(JoinedPerson):pass
+mapper(JoinedPerson, Person_table, \
+ polymorphic_on=Person_table.c.type, polymorphic_identity='person')
+mapper(Employee, Employee_table, \
+ inherits=JoinedPerson, polymorphic_identity='employee')
compile_mappers()
def setup():
metadata.create_all()
i = Person_table.insert()
- data = [{'name':'John Doe','sex':1,'age':35}] * 100
+ data = [{'name':'John Doe','sex':1,'age':35, 'type':'employee'}] * 100
for j in xrange(500):
i.execute(data)
+
+ # note we arent fetching from employee_table,
+ # so we can leave it empty even though its "incorrect"
+ #i = Employee_table.insert()
+ #data = [{'foo':'foo', 'bar':'bar':'bat':'bat'}] * 100
+ #for j in xrange(500):
+ # i.execute(data)
+
print "Inserted 50,000 rows"
def sqlite_select(entity_cls):
@@ -55,6 +78,11 @@ def orm_select():
session = create_session()
people = session.query(Person).all()
+#@profiling.profiled(report=True, always=True)
+def joined_orm_select():
+ session = create_session()
+ people = session.query(JoinedPerson).all()
+
def all():
setup()
try:
@@ -103,6 +131,13 @@ def all():
orm_select()
t2 = time.clock()
usage('sqlalchemy.orm fetch')
+
+ gc_collect()
+ usage.snap()
+ t = time.clock()
+ joined_orm_select()
+ t2 = time.clock()
+ usage('sqlalchemy.orm "joined" fetch')
finally:
metadata.drop_all()