diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-05-13 12:32:31 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-05-14 09:50:45 -0400 |
commit | 9f7c97b40b6b6c738bd4cac7ae8bdc8803a88973 (patch) | |
tree | a7c75bcc21059b4f6389b373d039f4435b6411fa | |
parent | eb286c15f096771dbb128acbe8fe03e94aa72f6a (diff) | |
download | sqlalchemy-9f7c97b40b6b6c738bd4cac7ae8bdc8803a88973.tar.gz |
limit joinedload exclusion rules to immediate mapped columns
Fixed issue where using additional relationship criteria with the
:func:`_orm.joinedload` loader option, where the additional criteria itself
contained correlated subqueries that referred to the joined entities and
therefore also required "adaption" to aliased entities, would be excluded
from this adaption, producing the wrong ON clause for the joinedload.
Fixes: #9779
Change-Id: Idcfec3e760057fbf6a09c10ad67a0bb4bf70f03a
-rw-r--r-- | doc/build/changelog/unreleased_20/9779.rst | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 39 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 1 | ||||
-rw-r--r-- | test/orm/inheritance/test_assorted_poly.py | 168 |
5 files changed, 226 insertions, 3 deletions
diff --git a/doc/build/changelog/unreleased_20/9779.rst b/doc/build/changelog/unreleased_20/9779.rst new file mode 100644 index 000000000..ab417b2dc --- /dev/null +++ b/doc/build/changelog/unreleased_20/9779.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 9779 + + Fixed issue where using additional relationship criteria with the + :func:`_orm.joinedload` loader option, where the additional criteria itself + contained correlated subqueries that referred to the joined entities and + therefore also required "adaption" to aliased entities, would be excluded + from this adaption, producing the wrong ON clause for the joinedload. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 8d9f3c644..e5a6b9afa 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -80,6 +80,7 @@ from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _HasClauseElement from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement +from ..sql.util import _deep_annotate from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -115,6 +116,7 @@ if typing.TYPE_CHECKING: from ..sql._typing import _EquivalentColumnMap from ..sql._typing import _InfoType from ..sql.annotation import _AnnotationDict + from ..sql.annotation import SupportsAnnotations from ..sql.elements import BinaryExpression from ..sql.elements import BindParameter from ..sql.elements import ClauseElement @@ -3284,6 +3286,38 @@ class JoinCondition: primaryjoin = primaryjoin & single_crit if extra_criteria: + + def mark_unrelated_columns_as_ok_to_adapt( + elem: SupportsAnnotations, annotations: _AnnotationDict + ) -> SupportsAnnotations: + """note unrelated columns in the "extra criteria" as OK + to adapt, even though they are not part of our "local" + or "remote" side. + + see #9779 for this case + + """ + + parentmapper_for_element = elem._annotations.get( + "parentmapper", None + ) + if ( + parentmapper_for_element is not self.prop.parent + and parentmapper_for_element is not self.prop.mapper + ): + return elem._annotate(annotations) + else: + return elem + + extra_criteria = tuple( + _deep_annotate( + elem, + {"ok_to_adapt_in_join_condition": True}, + annotate_callable=mark_unrelated_columns_as_ok_to_adapt, + ) + for elem in extra_criteria + ) + if secondaryjoin is not None: secondaryjoin = secondaryjoin & sql.and_(*extra_criteria) else: @@ -3409,7 +3443,10 @@ class _ColInAnnotations: self.name = name def __call__(self, c: ClauseElement) -> bool: - return self.name in c._annotations + return ( + self.name in c._annotations + or "ok_to_adapt_in_join_condition" in c._annotations + ) class Relationship( # type: ignore diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 7487e074c..e6dee7d17 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -406,8 +406,12 @@ def _deep_annotate( element: _SA, annotations: _AnnotationDict, exclude: Optional[Sequence[SupportsAnnotations]] = None, + *, detect_subquery_cols: bool = False, ind_cols_on_fromclause: bool = False, + annotate_callable: Optional[ + Callable[[SupportsAnnotations, _AnnotationDict], SupportsAnnotations] + ] = None, ) -> _SA: """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. @@ -446,9 +450,13 @@ def _deep_annotate( newelem = elem._clone(clone=clone, **kw) elif annotations != elem._annotations: if detect_subquery_cols and elem._is_immutable: - newelem = elem._clone(clone=clone, **kw)._annotate(annotations) + to_annotate = elem._clone(clone=clone, **kw) else: - newelem = elem._annotate(annotations) + to_annotate = elem + if annotate_callable: + newelem = annotate_callable(to_annotate, annotations) + else: + newelem = to_annotate._annotate(annotations) else: newelem = elem diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 0a50197a0..18caf5de4 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1343,6 +1343,7 @@ class ColumnAdapter(ClauseAdapter): def traverse( self, obj: Optional[ExternallyTraversible] ) -> Optional[ExternallyTraversible]: + return self.columns[obj] def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index a40a9ae74..28480c89e 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Optional +from sqlalchemy import and_ from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func @@ -37,6 +38,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing.fixtures import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session @@ -2765,3 +2767,169 @@ class PolyIntoSelfReferentialTest( assert False self._run_load(opt) + + +class AdaptExistsSubqTest(fixtures.DeclarativeMappedTest): + """test for #9777""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Discriminator(Base): + __tablename__ = "discriminator" + id = Column(Integer, primary_key=True, autoincrement=False) + value = Column(String(50)) + + class Entity(Base): + __tablename__ = "entity" + __mapper_args__ = {"polymorphic_on": "type"} + + id = Column(Integer, primary_key=True, autoincrement=False) + type = Column(String(50)) + + discriminator_id = Column( + ForeignKey("discriminator.id"), nullable=False + ) + discriminator = relationship( + "Discriminator", foreign_keys=discriminator_id + ) + + class Parent(Entity): + __tablename__ = "parent" + __mapper_args__ = {"polymorphic_identity": "parent"} + + id = Column(Integer, ForeignKey("entity.id"), primary_key=True) + some_data = Column(String(30)) + + class Child(Entity): + __tablename__ = "child" + __mapper_args__ = {"polymorphic_identity": "child"} + + id = Column(Integer, ForeignKey("entity.id"), primary_key=True) + + some_data = Column(String(30)) + parent_id = Column(ForeignKey("parent.id"), nullable=False) + parent = relationship( + "Parent", + foreign_keys=parent_id, + backref="children", + ) + + @classmethod + def insert_data(cls, connection): + Parent, Child, Discriminator = cls.classes( + "Parent", "Child", "Discriminator" + ) + + with Session(connection) as sess: + discriminator_zero = Discriminator(id=1, value="zero") + discriminator_one = Discriminator(id=2, value="one") + discriminator_two = Discriminator(id=3, value="two") + + parent = Parent(id=1, discriminator=discriminator_zero) + child_1 = Child( + id=2, + discriminator=discriminator_one, + parent=parent, + some_data="c1data", + ) + child_2 = Child( + id=3, + discriminator=discriminator_two, + parent=parent, + some_data="c2data", + ) + sess.add_all([parent, child_1, child_2]) + sess.commit() + + def test_explicit_aliasing(self): + Parent, Child, Discriminator = self.classes( + "Parent", "Child", "Discriminator" + ) + + parent_id = 1 + discriminator_one_id = 2 + + session = fixture_session() + c_alias = aliased(Child, flat=True) + retrieved = ( + session.query(Parent) + .filter_by(id=parent_id) + .outerjoin( + Parent.children.of_type(c_alias).and_( + c_alias.discriminator.has( + and_( + Discriminator.id == discriminator_one_id, + c_alias.some_data == "c1data", + ) + ) + ) + ) + .options(contains_eager(Parent.children.of_type(c_alias))) + .populate_existing() + .one() + ) + eq_(len(retrieved.children), 1) + + def test_implicit_aliasing(self): + Parent, Child, Discriminator = self.classes( + "Parent", "Child", "Discriminator" + ) + + parent_id = 1 + discriminator_one_id = 2 + + session = fixture_session() + q = ( + session.query(Parent) + .filter_by(id=parent_id) + .outerjoin( + Parent.children.and_( + Child.discriminator.has( + and_( + Discriminator.id == discriminator_one_id, + Child.some_data == "c1data", + ) + ) + ) + ) + .options(contains_eager(Parent.children)) + .populate_existing() + ) + + with expect_warnings("An alias is being generated automatically"): + retrieved = q.one() + + eq_(len(retrieved.children), 1) + + @testing.combinations(joinedload, selectinload, argnames="loader") + def test_eager_loaders(self, loader): + Parent, Child, Discriminator = self.classes( + "Parent", "Child", "Discriminator" + ) + + parent_id = 1 + discriminator_one_id = 2 + + session = fixture_session() + retrieved = ( + session.query(Parent) + .filter_by(id=parent_id) + .options( + loader( + Parent.children.and_( + Child.discriminator.has( + and_( + Discriminator.id == discriminator_one_id, + Child.some_data == "c1data", + ) + ) + ) + ) + ) + .populate_existing() + .one() + ) + + eq_(len(retrieved.children), 1) |