diff options
Diffstat (limited to 'test/orm/test_cache_key.py')
| -rw-r--r-- | test/orm/test_cache_key.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 3c6536195..f25a57fe5 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -1,12 +1,17 @@ import random +import sqlalchemy as sa +from sqlalchemy import Column from sqlalchemy import func from sqlalchemy import inspect +from sqlalchemy import Integer from sqlalchemy import null from sqlalchemy import select +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import true +from sqlalchemy import update from sqlalchemy.orm import aliased from sqlalchemy.orm import Bundle from sqlalchemy.orm import defaultload @@ -29,6 +34,7 @@ from sqlalchemy.sql.expression import case from sqlalchemy.sql.visitors import InternalTraversal from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -884,3 +890,74 @@ class RoundTripTest(QueryTest, AssertsCompiledSQL): go() eq_(len(cache), lc) + + +class CompositeTest(fixtures.MappedTest): + __dialect__ = "default" + + @classmethod + def define_tables(cls, metadata): + Table( + "edges", + metadata, + Column("id", Integer, primary_key=True), + Column("x1", Integer), + Column("y1", Integer), + Column("x2", Integer), + Column("y2", Integer), + ) + + @classmethod + def setup_mappers(cls): + edges = cls.tables.edges + + class Point(cls.Comparable): + def __init__(self, x, y): + self.x = x + self.y = y + + def __composite_values__(self): + return [self.x, self.y] + + __hash__ = None + + def __eq__(self, other): + return ( + isinstance(other, Point) + and other.x == self.x + and other.y == self.y + ) + + def __ne__(self, other): + return not isinstance(other, Point) or not self.__eq__(other) + + class Edge(cls.Comparable): + def __init__(self, *args): + if args: + self.start, self.end = args + + cls.mapper_registry.map_imperatively( + Edge, + edges, + properties={ + "start": sa.orm.composite(Point, edges.c.x1, edges.c.y1), + "end": sa.orm.composite(Point, edges.c.x2, edges.c.y2), + }, + ) + + def test_bulk_update_cache_key(self): + """test secondary issue located as part of #7209""" + Edge, Point = (self.classes.Edge, self.classes.Point) + + stmt = ( + update(Edge) + .filter(Edge.start == Point(14, 5)) + .values({Edge.end: Point(16, 10)}) + ) + stmt2 = ( + update(Edge) + .filter(Edge.start == Point(14, 5)) + .values({Edge.end: Point(17, 8)}) + ) + + eq_(stmt._generate_cache_key(), stmt2._generate_cache_key()) |
