summaryrefslogtreecommitdiff
path: root/test/orm/test_cache_key.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/orm/test_cache_key.py')
-rw-r--r--test/orm/test_cache_key.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py
index 3c6536195..f25a57fe5 100644
--- a/test/orm/test_cache_key.py
+++ b/test/orm/test_cache_key.py
@@ -1,12 +1,17 @@
import random
+import sqlalchemy as sa
+from sqlalchemy import Column
from sqlalchemy import func
from sqlalchemy import inspect
+from sqlalchemy import Integer
from sqlalchemy import null
from sqlalchemy import select
+from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy import true
+from sqlalchemy import update
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Bundle
from sqlalchemy.orm import defaultload
@@ -29,6 +34,7 @@ from sqlalchemy.sql.expression import case
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
from sqlalchemy.testing import ne_
from sqlalchemy.testing.fixtures import fixture_session
from test.orm import _fixtures
@@ -884,3 +890,74 @@ class RoundTripTest(QueryTest, AssertsCompiledSQL):
go()
eq_(len(cache), lc)
+
+
+class CompositeTest(fixtures.MappedTest):
+ __dialect__ = "default"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "edges",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x1", Integer),
+ Column("y1", Integer),
+ Column("x2", Integer),
+ Column("y2", Integer),
+ )
+
+ @classmethod
+ def setup_mappers(cls):
+ edges = cls.tables.edges
+
+ class Point(cls.Comparable):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __composite_values__(self):
+ return [self.x, self.y]
+
+ __hash__ = None
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, Point)
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+ def __ne__(self, other):
+ return not isinstance(other, Point) or not self.__eq__(other)
+
+ class Edge(cls.Comparable):
+ def __init__(self, *args):
+ if args:
+ self.start, self.end = args
+
+ cls.mapper_registry.map_imperatively(
+ Edge,
+ edges,
+ properties={
+ "start": sa.orm.composite(Point, edges.c.x1, edges.c.y1),
+ "end": sa.orm.composite(Point, edges.c.x2, edges.c.y2),
+ },
+ )
+
+ def test_bulk_update_cache_key(self):
+ """test secondary issue located as part of #7209"""
+ Edge, Point = (self.classes.Edge, self.classes.Point)
+
+ stmt = (
+ update(Edge)
+ .filter(Edge.start == Point(14, 5))
+ .values({Edge.end: Point(16, 10)})
+ )
+ stmt2 = (
+ update(Edge)
+ .filter(Edge.start == Point(14, 5))
+ .values({Edge.end: Point(17, 8)})
+ )
+
+ eq_(stmt._generate_cache_key(), stmt2._generate_cache_key())