summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/loading.py24
-rw-r--r--lib/sqlalchemy/orm/strategies.py2
2 files changed, 18 insertions, 8 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index a23cafac2..8a20bf0dd 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -394,7 +394,8 @@ def _instance_processor(
callable_ = _load_subclass_via_in(context, path, selectin_load_via)
PostLoad.callable_for_path(
- context, load_path, selectin_load_via,
+ context, load_path, selectin_load_via.mapper,
+ selectin_load_via,
callable_, selectin_load_via)
post_load = PostLoad.for_context(context, load_path, only_load_props)
@@ -574,7 +575,6 @@ def _load_subclass_via_in(context, path, entity):
primary_keys=[
state.key[1][0] if zero_idx else state.key[1]
for state, load_attrs in states
- if state.mapper.isa(mapper)
]
).all()
@@ -738,16 +738,25 @@ class PostLoad(object):
self.load_keys = None
def add_state(self, state, overwrite):
+ # the states for a polymorphic load here are all shared
+ # within a single PostLoad object among multiple subtypes.
+ # Filtering of callables on a per-subclass basis needs to be done at
+ # the invocation level
self.states[state] = overwrite
def invoke(self, context, path):
if not self.states:
return
path = path_registry.PathRegistry.coerce(path)
- for key, loader, arg, kw in self.loaders.values():
+ for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
+ states = [
+ (state, overwrite)
+ for state, overwrite
+ in self.states.items()
+ if state.manager.mapper.isa(limit_to_mapper)
+ ]
loader(
- context, path, self.states.items(),
- self.load_keys, *arg, **kw)
+ context, path, states, self.load_keys, *arg, **kw)
self.states.clear()
@classmethod
@@ -764,12 +773,13 @@ class PostLoad(object):
@classmethod
def callable_for_path(
- cls, context, path, attr_key, loader_callable, *arg, **kw):
+ cls, context, path, limit_to_mapper, token,
+ loader_callable, *arg, **kw):
if path.path in context.post_load_paths:
pl = context.post_load_paths[path.path]
else:
pl = context.post_load_paths[path.path] = PostLoad()
- pl.loaders[attr_key] = (attr_key, loader_callable, arg, kw)
+ pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw)
def load_scalar_attributes(mapper, state, attribute_names):
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index a57b66045..c3eae1e91 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -1883,7 +1883,7 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots):
return
loading.PostLoad.callable_for_path(
- context, selectin_path, self.key,
+ context, selectin_path, self.parent, self.key,
self._load_for_path, effective_entity)
@util.dependencies("sqlalchemy.ext.baked")