summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/loading.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/loading.py')
-rw-r--r--lib/sqlalchemy/orm/loading.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 7feec660d..48c0db851 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -19,6 +19,7 @@ from . import attributes, exc as orm_exc
from ..sql import util as sql_util
from . import strategy_options
from . import path_registry
+from .. import sql
from .util import _none_set, state_str
from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE
@@ -353,6 +354,27 @@ def _instance_processor(
session_id = context.session.hash_key
version_check = context.version_check
runid = context.runid
+
+ if not refresh_state and _polymorphic_from is not None:
+ key = ('loader', path.path)
+ if (
+ key in context.attributes and
+ context.attributes[key].strategy ==
+ (('selectinload_polymorphic', True), ) and
+ mapper in context.attributes[key].local_opts['mappers']
+ ) or mapper.polymorphic_load == 'selectin':
+
+ # only_load_props goes w/ refresh_state only, and in a refresh
+ # we are a single row query for the exact entity; polymorphic
+ # loading does not apply
+ assert only_load_props is None
+
+ callable_ = _load_subclass_via_in(context, path, mapper)
+
+ PostLoad.callable_for_path(
+ context, load_path, mapper,
+ callable_, mapper)
+
post_load = PostLoad.for_context(context, load_path, only_load_props)
if refresh_state:
@@ -501,6 +523,37 @@ def _instance_processor(
return _instance
+@util.dependencies("sqlalchemy.ext.baked")
+def _load_subclass_via_in(baked, context, path, mapper):
+
+ zero_idx = len(mapper.base_mapper.primary_key) == 1
+
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in
+
+ def do_load(context, path, states, load_only, effective_entity):
+ orig_query = context.query
+
+ q._add_lazyload_options(
+ (enable_opt, ) + orig_query._with_options + (disable_opt, ),
+ path.parent, cache_path=path
+ )
+
+ if orig_query._populate_existing:
+ q.add_criteria(
+ lambda q: q.populate_existing()
+ )
+
+ q(context.session).params(
+ primary_keys=[
+ state.key[1][0] if zero_idx else state.key[1]
+ for state, load_attrs in states
+ if state.mapper.isa(mapper)
+ ]
+ ).all()
+
+ return do_load
+
+
def _populate_full(
context, row, state, dict_, isnew, load_path,
loaded_instance, populate_existing, populators):