diff options
-rw-r--r-- | doc/build/changelog/unreleased_20/9779.rst | 9 | ||||
-rw-r--r-- | doc/build/changelog/unreleased_20/9789.rst | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 39 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/util/_concurrency_py3k.py | 4 | ||||
-rw-r--r-- | test/orm/inheritance/test_assorted_poly.py | 168 | ||||
-rw-r--r-- | test/orm/test_cache_key.py | 4 | ||||
-rw-r--r-- | tox.ini | 3 |
9 files changed, 241 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_20/9779.rst b/doc/build/changelog/unreleased_20/9779.rst new file mode 100644 index 000000000..ab417b2dc --- /dev/null +++ b/doc/build/changelog/unreleased_20/9779.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 9779 + + Fixed issue where using additional relationship criteria with the + :func:`_orm.joinedload` loader option, where the additional criteria itself + contained correlated subqueries that referred to the joined entities and + therefore also required "adaption" to aliased entities, would be excluded + from this adaption, producing the wrong ON clause for the joinedload. diff --git a/doc/build/changelog/unreleased_20/9789.rst b/doc/build/changelog/unreleased_20/9789.rst new file mode 100644 index 000000000..bae1537cd --- /dev/null +++ b/doc/build/changelog/unreleased_20/9789.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, tests, pypy + :tickets: 9789 + + Fixed test that relied on the ``sys.getsizeof()`` function to not run on + pypy, where this function appears to have different behavior than it does + on cpython. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 8d9f3c644..e5a6b9afa 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -80,6 +80,7 @@ from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _HasClauseElement from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement +from ..sql.util import _deep_annotate from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -115,6 +116,7 @@ if typing.TYPE_CHECKING: from ..sql._typing import _EquivalentColumnMap from ..sql._typing import _InfoType from ..sql.annotation import _AnnotationDict + from ..sql.annotation import SupportsAnnotations from ..sql.elements import BinaryExpression from ..sql.elements import BindParameter from ..sql.elements import ClauseElement @@ -3284,6 +3286,38 @@ class JoinCondition: primaryjoin = primaryjoin & single_crit if extra_criteria: + + def mark_unrelated_columns_as_ok_to_adapt( + elem: SupportsAnnotations, annotations: _AnnotationDict + ) -> SupportsAnnotations: + """note unrelated columns in the "extra criteria" as OK + to adapt, even though they are not part of our "local" + or "remote" side. + + see #9779 for this case + + """ + + parentmapper_for_element = elem._annotations.get( + "parentmapper", None + ) + if ( + parentmapper_for_element is not self.prop.parent + and parentmapper_for_element is not self.prop.mapper + ): + return elem._annotate(annotations) + else: + return elem + + extra_criteria = tuple( + _deep_annotate( + elem, + {"ok_to_adapt_in_join_condition": True}, + annotate_callable=mark_unrelated_columns_as_ok_to_adapt, + ) + for elem in extra_criteria + ) + if secondaryjoin is not None: secondaryjoin = secondaryjoin & sql.and_(*extra_criteria) else: @@ -3409,7 +3443,10 @@ class _ColInAnnotations: self.name = name def __call__(self, c: ClauseElement) -> bool: - return self.name in c._annotations + return ( + self.name in c._annotations + or "ok_to_adapt_in_join_condition" in c._annotations + ) class Relationship( # type: ignore diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 7487e074c..e6dee7d17 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -406,8 +406,12 @@ def _deep_annotate( element: _SA, annotations: _AnnotationDict, exclude: Optional[Sequence[SupportsAnnotations]] = None, + *, detect_subquery_cols: bool = False, ind_cols_on_fromclause: bool = False, + annotate_callable: Optional[ + Callable[[SupportsAnnotations, _AnnotationDict], SupportsAnnotations] + ] = None, ) -> _SA: """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. @@ -446,9 +450,13 @@ def _deep_annotate( newelem = elem._clone(clone=clone, **kw) elif annotations != elem._annotations: if detect_subquery_cols and elem._is_immutable: - newelem = elem._clone(clone=clone, **kw)._annotate(annotations) + to_annotate = elem._clone(clone=clone, **kw) else: - newelem = elem._annotate(annotations) + to_annotate = elem + if annotate_callable: + newelem = annotate_callable(to_annotate, annotations) + else: + newelem = to_annotate._annotate(annotations) else: newelem = elem diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 0a50197a0..18caf5de4 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1343,6 +1343,7 @@ class ColumnAdapter(ClauseAdapter): def traverse( self, obj: Optional[ExternallyTraversible] ) -> Optional[ExternallyTraversible]: + return self.columns[obj] def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 0e26425b2..3544a0fd5 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -258,4 +258,6 @@ def get_event_loop() -> asyncio.AbstractEventLoop: try: return asyncio.get_running_loop() except RuntimeError: - return asyncio.get_event_loop_policy().get_event_loop() + # avoid "During handling of the above exception, another exception..." + pass + return asyncio.get_event_loop_policy().get_event_loop() diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index a40a9ae74..28480c89e 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Optional +from sqlalchemy import and_ from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func @@ -37,6 +38,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing.fixtures import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session @@ -2765,3 +2767,169 @@ class PolyIntoSelfReferentialTest( assert False self._run_load(opt) + + +class AdaptExistsSubqTest(fixtures.DeclarativeMappedTest): + """test for #9777""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Discriminator(Base): + __tablename__ = "discriminator" + id = Column(Integer, primary_key=True, autoincrement=False) + value = Column(String(50)) + + class Entity(Base): + __tablename__ = "entity" + __mapper_args__ = {"polymorphic_on": "type"} + + id = Column(Integer, primary_key=True, autoincrement=False) + type = Column(String(50)) + + discriminator_id = Column( + ForeignKey("discriminator.id"), nullable=False + ) + discriminator = relationship( + "Discriminator", foreign_keys=discriminator_id + ) + + class Parent(Entity): + __tablename__ = "parent" + __mapper_args__ = {"polymorphic_identity": "parent"} + + id = Column(Integer, ForeignKey("entity.id"), primary_key=True) + some_data = Column(String(30)) + + class Child(Entity): + __tablename__ = "child" + __mapper_args__ = {"polymorphic_identity": "child"} + + id = Column(Integer, ForeignKey("entity.id"), primary_key=True) + + some_data = Column(String(30)) + parent_id = Column(ForeignKey("parent.id"), nullable=False) + parent = relationship( + "Parent", + foreign_keys=parent_id, + backref="children", + ) + + @classmethod + def insert_data(cls, connection): + Parent, Child, Discriminator = cls.classes( + "Parent", "Child", "Discriminator" + ) + + with Session(connection) as sess: + discriminator_zero = Discriminator(id=1, value="zero") + discriminator_one = Discriminator(id=2, value="one") + discriminator_two = Discriminator(id=3, value="two") + + parent = Parent(id=1, discriminator=discriminator_zero) + child_1 = Child( + id=2, + discriminator=discriminator_one, + parent=parent, + some_data="c1data", + ) + child_2 = Child( + id=3, + discriminator=discriminator_two, + parent=parent, + some_data="c2data", + ) + sess.add_all([parent, child_1, child_2]) + sess.commit() + + def test_explicit_aliasing(self): + Parent, Child, Discriminator = self.classes( + "Parent", "Child", "Discriminator" + ) + + parent_id = 1 + discriminator_one_id = 2 + + session = fixture_session() + c_alias = aliased(Child, flat=True) + retrieved = ( + session.query(Parent) + .filter_by(id=parent_id) + .outerjoin( + Parent.children.of_type(c_alias).and_( + c_alias.discriminator.has( + and_( + Discriminator.id == discriminator_one_id, + c_alias.some_data == "c1data", + ) + ) + ) + ) + .options(contains_eager(Parent.children.of_type(c_alias))) + .populate_existing() + .one() + ) + eq_(len(retrieved.children), 1) + + def test_implicit_aliasing(self): + Parent, Child, Discriminator = self.classes( + "Parent", "Child", "Discriminator" + ) + + parent_id = 1 + discriminator_one_id = 2 + + session = fixture_session() + q = ( + session.query(Parent) + .filter_by(id=parent_id) + .outerjoin( + Parent.children.and_( + Child.discriminator.has( + and_( + Discriminator.id == discriminator_one_id, + Child.some_data == "c1data", + ) + ) + ) + ) + .options(contains_eager(Parent.children)) + .populate_existing() + ) + + with expect_warnings("An alias is being generated automatically"): + retrieved = q.one() + + eq_(len(retrieved.children), 1) + + @testing.combinations(joinedload, selectinload, argnames="loader") + def test_eager_loaders(self, loader): + Parent, Child, Discriminator = self.classes( + "Parent", "Child", "Discriminator" + ) + + parent_id = 1 + discriminator_one_id = 2 + + session = fixture_session() + retrieved = ( + session.query(Parent) + .filter_by(id=parent_id) + .options( + loader( + Parent.children.and_( + Child.discriminator.has( + and_( + Discriminator.id == discriminator_one_id, + Child.some_data == "c1data", + ) + ) + ) + ) + ) + .populate_existing() + .one() + ) + + eq_(len(retrieved.children), 1) diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 884baed62..5e5a85761 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -1143,7 +1143,9 @@ class EmbeddedSubqTest( Base.registry.configure() @testing.combinations( - "tuples", ("memory", testing.requires.is64bit), argnames="assert_on" + "tuples", + ("memory", testing.requires.is64bit + testing.requires.cpython), + argnames="assert_on", ) def test_cache_key_gen(self, assert_on): Employee = self.classes.Employee @@ -35,7 +35,8 @@ extras= deps= pytest>=7.0.0rc1,<8 - pytest-xdist + # tracked by https://github.com/pytest-dev/pytest-xdist/issues/907 + pytest-xdist!=3.3.0 dbapimain-sqlite: git+https://github.com/omnilib/aiosqlite.git#egg=aiosqlite dbapimain-sqlite: git+https://github.com/coleifer/sqlcipher3.git#egg=sqlcipher3 |