summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-05-13 12:32:31 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2023-05-14 09:50:45 -0400
commit9f7c97b40b6b6c738bd4cac7ae8bdc8803a88973 (patch)
treea7c75bcc21059b4f6389b373d039f4435b6411fa
parenteb286c15f096771dbb128acbe8fe03e94aa72f6a (diff)
downloadsqlalchemy-9f7c97b40b6b6c738bd4cac7ae8bdc8803a88973.tar.gz
limit joinedload exclusion rules to immediate mapped columns
Fixed issue where using additional relationship criteria with the :func:`_orm.joinedload` loader option, where the additional criteria itself contained correlated subqueries that referred to the joined entities and therefore also required "adaption" to aliased entities, would be excluded from this adaption, producing the wrong ON clause for the joinedload. Fixes: #9779 Change-Id: Idcfec3e760057fbf6a09c10ad67a0bb4bf70f03a
-rw-r--r--doc/build/changelog/unreleased_20/9779.rst9
-rw-r--r--lib/sqlalchemy/orm/relationships.py39
-rw-r--r--lib/sqlalchemy/sql/annotation.py12
-rw-r--r--lib/sqlalchemy/sql/util.py1
-rw-r--r--test/orm/inheritance/test_assorted_poly.py168
5 files changed, 226 insertions, 3 deletions
diff --git a/doc/build/changelog/unreleased_20/9779.rst b/doc/build/changelog/unreleased_20/9779.rst
new file mode 100644
index 000000000..ab417b2dc
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9779.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 9779
+
+ Fixed issue where using additional relationship criteria with the
+ :func:`_orm.joinedload` loader option, where the additional criteria itself
+ contained correlated subqueries that referred to the joined entities and
+ therefore also required "adaption" to aliased entities, would be excluded
+ from this adaption, producing the wrong ON clause for the joinedload.
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 8d9f3c644..e5a6b9afa 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -80,6 +80,7 @@ from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _HasClauseElement
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
+from ..sql.util import _deep_annotate
from ..sql.util import _deep_deannotate
from ..sql.util import _shallow_annotate
from ..sql.util import adapt_criterion_to_null
@@ -115,6 +116,7 @@ if typing.TYPE_CHECKING:
from ..sql._typing import _EquivalentColumnMap
from ..sql._typing import _InfoType
from ..sql.annotation import _AnnotationDict
+ from ..sql.annotation import SupportsAnnotations
from ..sql.elements import BinaryExpression
from ..sql.elements import BindParameter
from ..sql.elements import ClauseElement
@@ -3284,6 +3286,38 @@ class JoinCondition:
primaryjoin = primaryjoin & single_crit
if extra_criteria:
+
+ def mark_unrelated_columns_as_ok_to_adapt(
+ elem: SupportsAnnotations, annotations: _AnnotationDict
+ ) -> SupportsAnnotations:
+ """note unrelated columns in the "extra criteria" as OK
+ to adapt, even though they are not part of our "local"
+ or "remote" side.
+
+ see #9779 for this case
+
+ """
+
+ parentmapper_for_element = elem._annotations.get(
+ "parentmapper", None
+ )
+ if (
+ parentmapper_for_element is not self.prop.parent
+ and parentmapper_for_element is not self.prop.mapper
+ ):
+ return elem._annotate(annotations)
+ else:
+ return elem
+
+ extra_criteria = tuple(
+ _deep_annotate(
+ elem,
+ {"ok_to_adapt_in_join_condition": True},
+ annotate_callable=mark_unrelated_columns_as_ok_to_adapt,
+ )
+ for elem in extra_criteria
+ )
+
if secondaryjoin is not None:
secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
else:
@@ -3409,7 +3443,10 @@ class _ColInAnnotations:
self.name = name
def __call__(self, c: ClauseElement) -> bool:
- return self.name in c._annotations
+ return (
+ self.name in c._annotations
+ or "ok_to_adapt_in_join_condition" in c._annotations
+ )
class Relationship( # type: ignore
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 7487e074c..e6dee7d17 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -406,8 +406,12 @@ def _deep_annotate(
element: _SA,
annotations: _AnnotationDict,
exclude: Optional[Sequence[SupportsAnnotations]] = None,
+ *,
detect_subquery_cols: bool = False,
ind_cols_on_fromclause: bool = False,
+ annotate_callable: Optional[
+ Callable[[SupportsAnnotations, _AnnotationDict], SupportsAnnotations]
+ ] = None,
) -> _SA:
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
@@ -446,9 +450,13 @@ def _deep_annotate(
newelem = elem._clone(clone=clone, **kw)
elif annotations != elem._annotations:
if detect_subquery_cols and elem._is_immutable:
- newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+ to_annotate = elem._clone(clone=clone, **kw)
else:
- newelem = elem._annotate(annotations)
+ to_annotate = elem
+ if annotate_callable:
+ newelem = annotate_callable(to_annotate, annotations)
+ else:
+ newelem = to_annotate._annotate(annotations)
else:
newelem = elem
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 0a50197a0..18caf5de4 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -1343,6 +1343,7 @@ class ColumnAdapter(ClauseAdapter):
def traverse(
self, obj: Optional[ExternallyTraversible]
) -> Optional[ExternallyTraversible]:
+
return self.columns[obj]
def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py
index a40a9ae74..28480c89e 100644
--- a/test/orm/inheritance/test_assorted_poly.py
+++ b/test/orm/inheritance/test_assorted_poly.py
@@ -7,6 +7,7 @@ from __future__ import annotations
from typing import Optional
+from sqlalchemy import and_
from sqlalchemy import exists
from sqlalchemy import ForeignKey
from sqlalchemy import func
@@ -37,6 +38,7 @@ from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.fixtures import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
@@ -2765,3 +2767,169 @@ class PolyIntoSelfReferentialTest(
assert False
self._run_load(opt)
+
+
+class AdaptExistsSubqTest(fixtures.DeclarativeMappedTest):
+ """test for #9777"""
+
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class Discriminator(Base):
+ __tablename__ = "discriminator"
+ id = Column(Integer, primary_key=True, autoincrement=False)
+ value = Column(String(50))
+
+ class Entity(Base):
+ __tablename__ = "entity"
+ __mapper_args__ = {"polymorphic_on": "type"}
+
+ id = Column(Integer, primary_key=True, autoincrement=False)
+ type = Column(String(50))
+
+ discriminator_id = Column(
+ ForeignKey("discriminator.id"), nullable=False
+ )
+ discriminator = relationship(
+ "Discriminator", foreign_keys=discriminator_id
+ )
+
+ class Parent(Entity):
+ __tablename__ = "parent"
+ __mapper_args__ = {"polymorphic_identity": "parent"}
+
+ id = Column(Integer, ForeignKey("entity.id"), primary_key=True)
+ some_data = Column(String(30))
+
+ class Child(Entity):
+ __tablename__ = "child"
+ __mapper_args__ = {"polymorphic_identity": "child"}
+
+ id = Column(Integer, ForeignKey("entity.id"), primary_key=True)
+
+ some_data = Column(String(30))
+ parent_id = Column(ForeignKey("parent.id"), nullable=False)
+ parent = relationship(
+ "Parent",
+ foreign_keys=parent_id,
+ backref="children",
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ Parent, Child, Discriminator = cls.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ with Session(connection) as sess:
+ discriminator_zero = Discriminator(id=1, value="zero")
+ discriminator_one = Discriminator(id=2, value="one")
+ discriminator_two = Discriminator(id=3, value="two")
+
+ parent = Parent(id=1, discriminator=discriminator_zero)
+ child_1 = Child(
+ id=2,
+ discriminator=discriminator_one,
+ parent=parent,
+ some_data="c1data",
+ )
+ child_2 = Child(
+ id=3,
+ discriminator=discriminator_two,
+ parent=parent,
+ some_data="c2data",
+ )
+ sess.add_all([parent, child_1, child_2])
+ sess.commit()
+
+ def test_explicit_aliasing(self):
+ Parent, Child, Discriminator = self.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ parent_id = 1
+ discriminator_one_id = 2
+
+ session = fixture_session()
+ c_alias = aliased(Child, flat=True)
+ retrieved = (
+ session.query(Parent)
+ .filter_by(id=parent_id)
+ .outerjoin(
+ Parent.children.of_type(c_alias).and_(
+ c_alias.discriminator.has(
+ and_(
+ Discriminator.id == discriminator_one_id,
+ c_alias.some_data == "c1data",
+ )
+ )
+ )
+ )
+ .options(contains_eager(Parent.children.of_type(c_alias)))
+ .populate_existing()
+ .one()
+ )
+ eq_(len(retrieved.children), 1)
+
+ def test_implicit_aliasing(self):
+ Parent, Child, Discriminator = self.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ parent_id = 1
+ discriminator_one_id = 2
+
+ session = fixture_session()
+ q = (
+ session.query(Parent)
+ .filter_by(id=parent_id)
+ .outerjoin(
+ Parent.children.and_(
+ Child.discriminator.has(
+ and_(
+ Discriminator.id == discriminator_one_id,
+ Child.some_data == "c1data",
+ )
+ )
+ )
+ )
+ .options(contains_eager(Parent.children))
+ .populate_existing()
+ )
+
+ with expect_warnings("An alias is being generated automatically"):
+ retrieved = q.one()
+
+ eq_(len(retrieved.children), 1)
+
+ @testing.combinations(joinedload, selectinload, argnames="loader")
+ def test_eager_loaders(self, loader):
+ Parent, Child, Discriminator = self.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ parent_id = 1
+ discriminator_one_id = 2
+
+ session = fixture_session()
+ retrieved = (
+ session.query(Parent)
+ .filter_by(id=parent_id)
+ .options(
+ loader(
+ Parent.children.and_(
+ Child.discriminator.has(
+ and_(
+ Discriminator.id == discriminator_one_id,
+ Child.some_data == "c1data",
+ )
+ )
+ )
+ )
+ )
+ .populate_existing()
+ .one()
+ )
+
+ eq_(len(retrieved.children), 1)