summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/interfaces.py2
-rw-r--r--lib/sqlalchemy/orm/strategies.py42
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py97
-rw-r--r--lib/sqlalchemy/testing/assertsql.py24
4 files changed, 154 insertions, 11 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index f17cc6159..4826354dc 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -740,6 +740,8 @@ class ORMOption(ExecutableOption):
_is_criteria_option = False
+ _is_strategy_option = False
+
class LoaderOption(ORMOption):
"""Describe a loader modification to an ORM statement at compilation time.
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 955cd6dd2..bf60c803d 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -1511,11 +1511,26 @@ class SubqueryLoader(PostLoader):
effective_entity,
loadopt,
):
- opts = orig_query._with_options
+
+ if orig_query is context.query:
+ options = new_options = orig_query._with_options
+ else:
+ # There's currently no test that exercises the necessity of
+ # this step for subqueryload. Added in #6881, it is necessary for
+ # selectinload, but its necessity for subqueryload is still
+ # theoretical.
+ options = orig_query._with_options
+
+ new_options = [
+ orig_opt._adjust_for_extra_criteria(context)
+ if orig_opt._is_strategy_option
+ else orig_opt
+ for orig_opt in options
+ ]
if loadopt and loadopt._extra_criteria:
- opts += (
+ new_options += (
orm_util.LoaderCriteriaOption(
self.entity,
loadopt._generate_extra_criteria(context),
@@ -1525,7 +1540,7 @@ class SubqueryLoader(PostLoader):
# propagate loader options etc. to the new query.
# these will fire relative to subq_path.
q = q._with_current_path(rewritten_path)
- q = q.options(*opts)
+ q = q.options(*new_options)
return q
@@ -2916,17 +2931,32 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
effective_path = path[self.parent_property]
- options = orig_query._with_options
+ if orig_query is context.query:
+ options = new_options = orig_query._with_options
+ else:
+ options = orig_query._with_options
+
+ # note this will create a different cache key than
+ # "orig" options if extra_criteria is present, because the copy
+ # of extra_criteria will have different boundparam than that of
+ # the QueryableAttribute in the path
+
+ new_options = [
+ orig_opt._adjust_for_extra_criteria(context)
+ if orig_opt._is_strategy_option
+ else orig_opt
+ for orig_opt in options
+ ]
if loadopt and loadopt._extra_criteria:
- options += (
+ new_options += (
orm_util.LoaderCriteriaOption(
effective_entity,
loadopt._generate_extra_criteria(context),
),
)
- q = q.options(*options)._update_compile_options(
+ q = q.options(*new_options)._update_compile_options(
{"_current_path": effective_path}
)
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index ea1e5ea2a..b7ed4e89b 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -77,6 +77,8 @@ class Load(Generative, LoaderOption):
"""
+ _is_strategy_option = True
+
_cache_key_traversal = [
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
@@ -115,7 +117,7 @@ class Load(Generative, LoaderOption):
def _generate_extra_criteria(self, context):
"""Apply the current bound parameters in a QueryContext to the
- "extra_criteria" stored with this Load object.
+ immediate "extra_criteria" stored with this Load object.
Load objects are typically pulled from the cached version of
the statement from a QueryContext. The statement currently being
@@ -150,6 +152,69 @@ class Load(Generative, LoaderOption):
return k2._apply_params_to_element(k1, and_(*self._extra_criteria))
+ def _adjust_for_extra_criteria(self, context):
+ """Apply the current bound parameters in a QueryContext to all
+ occurrences "extra_criteria" stored within al this Load object;
+ copying in place.
+
+ """
+ orig_query = context.compile_state.select_statement
+
+ applied = {}
+
+ ck = [None, None]
+
+ def process(opt):
+ if not opt._extra_criteria:
+ return
+
+ if ck[0] is None:
+ ck[:] = (
+ orig_query._generate_cache_key(),
+ context.query._generate_cache_key(),
+ )
+ k1, k2 = ck
+
+ opt._extra_criteria = tuple(
+ k2._apply_params_to_element(k1, crit)
+ for crit in opt._extra_criteria
+ )
+
+ return self._deep_clone(applied, process)
+
+ def _deep_clone(self, applied, process):
+ if self in applied:
+ return applied[self]
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ if self.context:
+ cloned.context = util.OrderedDict(
+ [
+ (
+ key,
+ value._deep_clone(applied, process)
+ if isinstance(value, Load)
+ else value,
+ )
+ for key, value in self.context.items()
+ ]
+ )
+
+ cloned.local_opts.update(self.local_opts)
+
+ process(cloned)
+
+ return cloned
+
@property
def _context_cache_key(self):
serialized = []
@@ -345,7 +410,10 @@ class Load(Generative, LoaderOption):
else:
return None
- if attr._extra_criteria:
+ if attr._extra_criteria and not self._extra_criteria:
+ # in most cases, the process that brings us here will have
+ # already established _extra_criteria. however if not,
+ # and it's present on the attribute, then use that.
self._extra_criteria = attr._extra_criteria
if getattr(attr, "_of_type", None):
@@ -708,6 +776,30 @@ class _UnboundLoad(Load):
# anonymous clone of the Load / UnboundLoad object since #5056
self._to_bind = None
+ def _deep_clone(self, applied, process):
+ if self in applied:
+ return applied[self]
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ cloned._to_bind = [
+ elem._deep_clone(applied, process) for elem in self._to_bind or ()
+ ]
+
+ cloned.local_opts.update(self.local_opts)
+
+ process(cloned)
+
+ return cloned
+
def _apply_to_parent(self, parent, applied, bound, to_bind=None):
if self in applied:
return applied[self]
@@ -984,6 +1076,7 @@ class _UnboundLoad(Load):
loader.strategy = self.strategy
loader.is_opts_only = self.is_opts_only
loader.is_class_strategy = self.is_class_strategy
+ loader._extra_criteria = self._extra_criteria
path = loader.path
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 2d41d6dab..4ee4c5844 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -96,6 +96,15 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
+ # received_statement runs a full compile(). we should not need to
+ # consider extracted_parameters; if we do this indicates some state
+ # is being sent from a previous cached query, which some misbehaviors
+ # in the ORM can cause, see #6881
+ cache_key = None # execute_observed.context.compiled.cache_key
+ extracted_parameters = (
+ None # execute_observed.context.extracted_parameters
+ )
+
if "schema_translate_map" in context.execution_options:
map_ = context.execution_options["schema_translate_map"]
else:
@@ -104,10 +113,12 @@ class CompiledSQL(SQLMatchRule):
if isinstance(execute_observed.clauseelement, _DDLCompiles):
compiled = execute_observed.clauseelement.compile(
- dialect=compare_dialect, schema_translate_map=map_
+ dialect=compare_dialect,
+ schema_translate_map=map_,
)
else:
compiled = execute_observed.clauseelement.compile(
+ cache_key=cache_key,
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
for_executemany=context.compiled.for_executemany,
@@ -117,10 +128,17 @@ class CompiledSQL(SQLMatchRule):
parameters = execute_observed.parameters
if not parameters:
- _received_parameters = [compiled.construct_params()]
+ _received_parameters = [
+ compiled.construct_params(
+ extracted_parameters=extracted_parameters
+ )
+ ]
else:
_received_parameters = [
- compiled.construct_params(m) for m in parameters
+ compiled.construct_params(
+ m, extracted_parameters=extracted_parameters
+ )
+ for m in parameters
]
return _received_statement, _received_parameters