summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_12/4156.rst8
-rw-r--r--lib/sqlalchemy/orm/loading.py24
-rw-r--r--lib/sqlalchemy/orm/strategies.py2
-rw-r--r--test/orm/test_selectin_relations.py145
4 files changed, 170 insertions, 9 deletions
diff --git a/doc/build/changelog/unreleased_12/4156.rst b/doc/build/changelog/unreleased_12/4156.rst
new file mode 100644
index 000000000..4511302e3
--- /dev/null
+++ b/doc/build/changelog/unreleased_12/4156.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 4156
+
+ Fixed bug in new "selectin" relationship loader where the loader could try
+ to load a non-existent relationship when loading a collection of
+ polymorphic objects, where only some of the mappers include that
+ relationship, typically when :meth:`.PropComparator.of_type` is being used.
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index a23cafac2..8a20bf0dd 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -394,7 +394,8 @@ def _instance_processor(
callable_ = _load_subclass_via_in(context, path, selectin_load_via)
PostLoad.callable_for_path(
- context, load_path, selectin_load_via,
+ context, load_path, selectin_load_via.mapper,
+ selectin_load_via,
callable_, selectin_load_via)
post_load = PostLoad.for_context(context, load_path, only_load_props)
@@ -574,7 +575,6 @@ def _load_subclass_via_in(context, path, entity):
primary_keys=[
state.key[1][0] if zero_idx else state.key[1]
for state, load_attrs in states
- if state.mapper.isa(mapper)
]
).all()
@@ -738,16 +738,25 @@ class PostLoad(object):
self.load_keys = None
def add_state(self, state, overwrite):
+ # the states for a polymorphic load here are all shared
+ # within a single PostLoad object among multiple subtypes.
+ # Filtering of callables on a per-subclass basis needs to be done at
+ # the invocation level
self.states[state] = overwrite
def invoke(self, context, path):
if not self.states:
return
path = path_registry.PathRegistry.coerce(path)
- for key, loader, arg, kw in self.loaders.values():
+ for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
+ states = [
+ (state, overwrite)
+ for state, overwrite
+ in self.states.items()
+ if state.manager.mapper.isa(limit_to_mapper)
+ ]
loader(
- context, path, self.states.items(),
- self.load_keys, *arg, **kw)
+ context, path, states, self.load_keys, *arg, **kw)
self.states.clear()
@classmethod
@@ -764,12 +773,13 @@ class PostLoad(object):
@classmethod
def callable_for_path(
- cls, context, path, attr_key, loader_callable, *arg, **kw):
+ cls, context, path, limit_to_mapper, token,
+ loader_callable, *arg, **kw):
if path.path in context.post_load_paths:
pl = context.post_load_paths[path.path]
else:
pl = context.post_load_paths[path.path] = PostLoad()
- pl.loaders[attr_key] = (attr_key, loader_callable, arg, kw)
+ pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw)
def load_scalar_attributes(mapper, state, attribute_names):
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index a57b66045..c3eae1e91 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -1883,7 +1883,7 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
return
loading.PostLoad.callable_for_path(
- context, selectin_path, self.key,
+ context, selectin_path, self.parent, self.key,
self._load_for_path, effective_entity)
@util.dependencies("sqlalchemy.ext.baked")
diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py
index 6f10260cc..ff1d0d40f 100644
--- a/test/orm/test_selectin_relations.py
+++ b/test/orm/test_selectin_relations.py
@@ -5,7 +5,7 @@ from sqlalchemy import Integer, String, ForeignKey, bindparam
from sqlalchemy.orm import selectinload, selectinload_all, \
mapper, relationship, clear_mappers, create_session, \
aliased, joinedload, deferred, undefer,\
- Session, subqueryload
+ Session, subqueryload, defaultload
from sqlalchemy.testing import assert_raises, \
assert_raises_message
from sqlalchemy.testing.assertsql import CompiledSQL
@@ -1334,6 +1334,149 @@ class BaseRelationFromJoinedSubclassTest(_Polymorphic):
)
+class HeterogeneousSubtypesTest(fixtures.DeclarativeMappedTest):
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class Company(Base):
+ __tablename__ = 'company'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ employees = relationship('Employee', order_by="Employee.id")
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+ name = Column(String(50))
+ company_id = Column(ForeignKey('company.id'))
+
+ __mapper_args__ = {
+ 'polymorphic_on': 'type',
+ 'with_polymorphic': '*',
+ }
+
+ class Programmer(Employee):
+ __tablename__ = 'programmer'
+ id = Column(ForeignKey('employee.id'), primary_key=True)
+ languages = relationship('Language')
+
+ __mapper_args__ = {
+ 'polymorphic_identity': 'programmer',
+ }
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+ id = Column(ForeignKey('employee.id'), primary_key=True)
+ golf_swing_id = Column(ForeignKey("golf_swing.id"))
+ golf_swing = relationship("GolfSwing")
+
+ __mapper_args__ = {
+ 'polymorphic_identity': 'manager',
+ }
+
+ class Language(Base):
+ __tablename__ = 'language'
+ id = Column(Integer, primary_key=True)
+ programmer_id = Column(
+ Integer,
+ ForeignKey('programmer.id'),
+ nullable=False,
+ )
+ name = Column(String(50))
+
+ class GolfSwing(Base):
+ __tablename__ = 'golf_swing'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+
+ @classmethod
+ def insert_data(cls):
+ Company, Programmer, Manager, GolfSwing, Language = cls.classes(
+ "Company", "Programmer", "Manager", "GolfSwing", "Language")
+ c1 = Company(
+ id=1,
+ name='Foobar Corp',
+ employees=[Programmer(
+ id=1,
+ name='p1',
+ languages=[Language(id=1, name='Python')],
+ ), Manager(
+ id=2,
+ name='m1',
+ golf_swing=GolfSwing(name="fore")
+ )],
+ )
+ c2 = Company(
+ id=2,
+ name='bat Corp',
+ employees=[
+ Manager(
+ id=3,
+ name='m2',
+ golf_swing=GolfSwing(name="clubs"),
+ ), Programmer(
+ id=4,
+ name='p2',
+ languages=[Language(id=2, name="Java")]
+ )],
+ )
+ sess = Session()
+ sess.add_all([c1, c2])
+ sess.commit()
+
+ def test_one_to_many(self):
+
+ Company, Programmer, Manager, GolfSwing, Language = self.classes(
+ "Company", "Programmer", "Manager", "GolfSwing", "Language")
+ sess = Session()
+ company = sess.query(Company).filter(
+ Company.id == 1,
+ ).options(
+ selectinload(Company.employees.of_type(Programmer)).
+ selectinload(Programmer.languages),
+ ).one()
+
+ def go():
+ eq_(company.employees[0].languages[0].name, "Python")
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_many_to_one(self):
+ Company, Programmer, Manager, GolfSwing, Language = self.classes(
+ "Company", "Programmer", "Manager", "GolfSwing", "Language")
+ sess = Session()
+ company = sess.query(Company).filter(
+ Company.id == 2,
+ ).options(
+ selectinload(Company.employees.of_type(Manager)).
+ selectinload(Manager.golf_swing),
+ ).one()
+
+ def go():
+ eq_(company.employees[0].golf_swing.name, "clubs")
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_both(self):
+ Company, Programmer, Manager, GolfSwing, Language = self.classes(
+ "Company", "Programmer", "Manager", "GolfSwing", "Language")
+ sess = Session()
+ rows = sess.query(Company).options(
+ selectinload(Company.employees.of_type(Manager)).
+ selectinload(Manager.golf_swing),
+ defaultload(Company.employees.of_type(Programmer)).
+ selectinload(Programmer.languages),
+ ).order_by(Company.id).all()
+
+ def go():
+ eq_(rows[0].employees[0].languages[0].name, "Python")
+ eq_(rows[1].employees[0].golf_swing.name, "clubs")
+
+ self.assert_sql_count(testing.db, go, 0)
+
+
class ChunkingTest(fixtures.DeclarativeMappedTest):
"""test IN chunking.