import dataclasses import operator import random import sqlalchemy as sa from sqlalchemy import event from sqlalchemy import ForeignKey from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import aliased from sqlalchemy.orm import Composite from sqlalchemy.orm import composite from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import LoaderCallableStatus from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): __dialect__ = "default" @classmethod def define_tables(cls, metadata): Table( "graphs", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("name", String(30)), ) Table( "edges", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("graph_id", Integer, ForeignKey("graphs.id")), Column("x1", Integer), Column("y1", Integer), Column("x2", Integer), Column("y2", Integer), ) @classmethod def setup_mappers(cls): graphs, edges = cls.tables.graphs, cls.tables.edges class Point(cls.Comparable): def __init__(self, x, y): self.x = x self.y = y def __composite_values__(self): return [self.x, self.y] __hash__ = None def __eq__(self, other): return ( isinstance(other, Point) and other.x == self.x and other.y == self.y ) def __ne__(self, other): return not isinstance(other, Point) or not self.__eq__(other) class Graph(cls.Comparable): pass class Edge(cls.Comparable): def __init__(self, *args): if args: self.start, self.end = args cls.mapper_registry.map_imperatively( Graph, graphs, properties={"edges": relationship(Edge)} ) cls.mapper_registry.map_imperatively( Edge, edges, properties={ "start": sa.orm.composite(Point, edges.c.x1, edges.c.y1), "end": sa.orm.composite(Point, edges.c.x2, edges.c.y2), }, ) def _fixture(self): Graph, Edge, Point = ( self.classes.Graph, self.classes.Edge, self.classes.Point, ) sess = Session(testing.db) g = Graph( id=1, edges=[ Edge(Point(3, 4), Point(5, 6)), Edge(Point(14, 5), Point(2, 7)), ], ) sess.add(g) sess.commit() return sess def test_early_configure(self): # test [ticket:2935], that we can call a composite # expression before configure_mappers() Edge = self.classes.Edge Edge.start.__clause_element__() def test_round_trip(self): Graph, Point = self.classes.Graph, self.classes.Point sess = self._fixture() g1 = sess.query(Graph).first() sess.close() g = sess.get(Graph, g1.id) eq_( [(e.start, e.end) for e in g.edges], [(Point(3, 4), Point(5, 6)), (Point(14, 5), Point(2, 7))], ) def test_detect_change(self): Graph, Edge, Point = ( self.classes.Graph, self.classes.Edge, self.classes.Point, ) sess = self._fixture() g = sess.query(Graph).first() g.edges[1].end = Point(18, 4) sess.commit() e = sess.get(Edge, g.edges[1].id) eq_(e.end, Point(18, 4)) def test_not_none(self): Edge = self.classes.Edge # current contract. the composite is None # when hasn't been populated etc. on a # pending/transient object. e1 = Edge() assert e1.end is None sess = fixture_session() sess.add(e1) # however, once it's persistent, the code as of 0.7.3 # would unconditionally populate it, even though it's # all None. I think this usage contract is inconsistent, # and it would be better that the composite is just # created unconditionally in all cases. # but as we are just trying to fix [ticket:2308] and # [ticket:2309] without changing behavior we maintain # that only "persistent" gets the composite with the # Nones sess.flush() assert e1.end is not None def test_eager_load(self): Graph, Point = self.classes.Graph, self.classes.Point sess = self._fixture() g = sess.query(Graph).first() sess.close() def go(): g2 = sess.get( Graph, g.id, options=[sa.orm.joinedload(Graph.edges)] ) eq_( [(e.start, e.end) for e in g2.edges], [(Point(3, 4), Point(5, 6)), (Point(14, 5), Point(2, 7))], ) self.assert_sql_count(testing.db, go, 1) def test_comparator(self): Graph, Edge, Point = ( self.classes.Graph, self.classes.Edge, self.classes.Point, ) sess = self._fixture() g = sess.query(Graph).first() assert ( sess.query(Edge).filter(Edge.start == Point(3, 4)).one() is g.edges[0] ) assert ( sess.query(Edge).filter(Edge.start != Point(3, 4)).first() is g.edges[1] ) eq_(sess.query(Edge).filter(Edge.start == None).all(), []) # noqa def test_comparator_aliased(self): Graph, Edge, Point = ( self.classes.Graph, self.classes.Edge, self.classes.Point, ) sess = self._fixture() g = sess.query(Graph).first() ea = aliased(Edge) assert ( sess.query(ea).filter(ea.start != Point(3, 4)).first() is g.edges[1] ) def test_update_crit_sql(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() e1 = sess.execute( select(Edge).filter(Edge.start == Point(14, 5)) ).scalar_one() eq_(e1.end, Point(2, 7)) stmt = ( update(Edge) .filter(Edge.start == Point(14, 5)) .values({Edge.end: Point(16, 10)}) ) self.assert_compile( stmt, "UPDATE edges SET x2=:x2, y2=:y2 WHERE edges.x1 = :x1_1 " "AND edges.y1 = :y1_1", params={"x2": 16, "x1_1": 14, "y2": 10, "y1_1": 5}, dialect="default", ) def test_update_crit_evaluate(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() e1 = sess.execute( select(Edge).filter(Edge.start == Point(14, 5)) ).scalar_one() eq_(e1.end, Point(2, 7)) stmt = ( update(Edge) .filter(Edge.start == Point(14, 5)) .values({Edge.end: Point(16, 10)}) ) sess.execute(stmt) eq_(e1.end, Point(16, 10)) stmt = ( update(Edge) .filter(Edge.start == Point(14, 5)) .values({Edge.end: Point(17, 8)}) ) sess.execute(stmt) eq_(e1.end, Point(17, 8)) def test_update_crit_fetch(self): Edge, Point = (self.classes.Edge, self.classes.Point) sess = self._fixture() e1 = sess.query(Edge).filter(Edge.start == Point(14, 5)).one() eq_(e1.end, Point(2, 7)) q = sess.query(Edge).filter(Edge.start == Point(14, 5)) q.update({Edge.end: Point(16, 10)}, synchronize_session="fetch") eq_(e1.end, Point(16, 10)) q.update({Edge.end: Point(17, 8)}, synchronize_session="fetch") eq_(e1.end, Point(17, 8)) @testing.combinations( ("legacy",), ("statement",), ("values",), ("stmt_returning", testing.requires.insertmanyvalues), ("values_returning", testing.requires.insert_returning), ) def test_bulk_insert(self, type_): Edge, Point = (self.classes.Edge, self.classes.Point) Graph = self.classes.Graph sess = self._fixture() graph = Graph(id=2) sess.add(graph) sess.flush() graph_id = 2 data = [ { "start": Point(random.randint(1, 50), random.randint(1, 50)), "end": Point(random.randint(1, 50), random.randint(1, 50)), "graph_id": graph_id, } for i in range(25) ] returning = False if type_ == "statement": sess.execute(insert(Edge), data) elif type_ == "stmt_returning": result = sess.scalars(insert(Edge).returning(Edge), data) returning = True elif type_ == "values": sess.execute(insert(Edge).values(data)) elif type_ == "values_returning": result = sess.scalars(insert(Edge).values(data).returning(Edge)) returning = True elif type_ == "legacy": sess.bulk_insert_mappings(Edge, data) else: assert False if returning: eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data]) edges = self.tables.edges eq_( sess.execute( select(edges.c["x1", "y1", "x2", "y2"]) .where(edges.c.graph_id == graph_id) .order_by(edges.c.id) ).all(), [ (e["start"].x, e["start"].y, e["end"].x, e["end"].y) for e in data ], ) @testing.combinations("legacy", "statement") def test_bulk_insert_heterogeneous(self, type_): Edge, Point = (self.classes.Edge, self.classes.Point) Graph = self.classes.Graph sess = self._fixture() graph = Graph(id=2) sess.add(graph) sess.flush() graph_id = 2 d1 = [ { "start": Point(random.randint(1, 50), random.randint(1, 50)), "end": Point(random.randint(1, 50), random.randint(1, 50)), "graph_id": graph_id, } for i in range(3) ] d2 = [ { "start": Point(random.randint(1, 50), random.randint(1, 50)), "graph_id": graph_id, } for i in range(2) ] d3 = [ { "x2": random.randint(1, 50), "y2": random.randint(1, 50), "graph_id": graph_id, } for i in range(2) ] data = d1 + d2 + d3 random.shuffle(data) assert_data = [ { "start": d["start"] if "start" in d else None, "end": d["end"] if "end" in d else Point(d["x2"], d["y2"]) if "x2" in d else None, "graph_id": d["graph_id"], } for d in data ] if type_ == "statement": sess.execute(insert(Edge), data) elif type_ == "legacy": sess.bulk_insert_mappings(Edge, data) else: assert False edges = self.tables.edges eq_( sess.execute( select(edges.c["x1", "y1", "x2", "y2"]) .where(edges.c.graph_id == graph_id) .order_by(edges.c.id) ).all(), [ ( e["start"].x if e["start"] else None, e["start"].y if e["start"] else None, e["end"].x if e["end"] else None, e["end"].y if e["end"] else None, ) for e in assert_data ], ) @testing.combinations("legacy", "statement") def test_bulk_update(self, type_): Edge, Point = (self.classes.Edge, self.classes.Point) Graph = self.classes.Graph sess = self._fixture() graph = Graph(id=2) sess.add(graph) sess.flush() graph_id = 2 data = [ { "start": Point(random.randint(1, 50), random.randint(1, 50)), "end": Point(random.randint(1, 50), random.randint(1, 50)), "graph_id": graph_id, } for i in range(25) ] sess.execute(insert(Edge), data) inserted_data = [ dict(row._mapping) for row in sess.execute( select(Edge.id, Edge.start, Edge.end, Edge.graph_id) .where(Edge.graph_id == graph_id) .order_by(Edge.id) ) ] to_update = [] updated_pks = {} for rec in random.choices(inserted_data, k=7): rec_copy = dict(rec) updated_pks[rec_copy["id"]] = rec_copy rec_copy["start"] = Point( random.randint(1, 50), random.randint(1, 50) ) rec_copy["end"] = Point( random.randint(1, 50), random.randint(1, 50) ) to_update.append(rec_copy) expected_dataset = [ updated_pks[row["id"]] if row["id"] in updated_pks else row for row in inserted_data ] if type_ == "statement": sess.execute(update(Edge), to_update) elif type_ == "legacy": sess.bulk_update_mappings(Edge, to_update) else: assert False edges = self.tables.edges eq_( sess.execute( select(edges.c["x1", "y1", "x2", "y2"]) .where(edges.c.graph_id == graph_id) .order_by(edges.c.id) ).all(), [ (e["start"].x, e["start"].y, e["end"].x, e["end"].y) for e in expected_dataset ], ) def test_get_history(self): Edge = self.classes.Edge Point = self.classes.Point from sqlalchemy.orm.attributes import get_history e1 = Edge() e1.start = Point(1, 2) eq_( get_history(e1, "start"), ([Point(x=1, y=2)], (), [Point(x=None, y=None)]), ) eq_(get_history(e1, "end"), ((), [Point(x=None, y=None)], ())) def test_query_cols_legacy(self): Edge = self.classes.Edge sess = self._fixture() eq_( sess.query(Edge.start.clauses, Edge.end.clauses).all(), [(3, 4, 5, 6), (14, 5, 2, 7)], ) def test_query_cols(self): Edge = self.classes.Edge Point = self.classes.Point sess = self._fixture() start, end = Edge.start, Edge.end eq_( sess.query(start, end).filter(start == Point(3, 4)).all(), [(Point(3, 4), Point(5, 6))], ) def test_cols_as_core_clauseelement(self): Edge = self.classes.Edge Point = self.classes.Point start, end = Edge.start, Edge.end stmt = select(start, end).where(start == Point(3, 4)) self.assert_compile( stmt, "SELECT edges.x1, edges.y1, edges.x2, edges.y2 " "FROM edges WHERE edges.x1 = :x1_1 AND edges.y1 = :y1_1", checkparams={"x1_1": 3, "y1_1": 4}, ) def test_query_cols_labeled(self): Edge = self.classes.Edge Point = self.classes.Point sess = self._fixture() start, end = Edge.start, Edge.end row = ( sess.query(start.label("s1"), end) .filter(start == Point(3, 4)) .first() ) eq_(row.s1.x, 3) eq_(row.s1.y, 4) eq_(row.end.x, 5) eq_(row.end.y, 6) def test_delete(self): Point = self.classes.Point Graph, Edge = self.classes.Graph, self.classes.Edge sess = self._fixture() g = sess.query(Graph).first() e = g.edges[1] del e.end sess.flush() eq_( sess.query(Edge.start, Edge.end).all(), [ (Point(x=3, y=4), Point(x=5, y=6)), (Point(x=14, y=5), Point(x=None, y=None)), ], ) def test_save_null(self): """test saving a null composite value See google groups thread for more context: https://groups.google.com/group/sqlalchemy/browse_thread/thread/0c6580a1761b2c29 """ Graph, Edge = self.classes.Graph, self.classes.Edge sess = fixture_session() g = Graph(id=1) e = Edge(None, None) g.edges.append(e) sess.add(g) sess.commit() g2 = sess.get(Graph, 1) assert g2.edges[-1].start.x is None assert g2.edges[-1].start.y is None def test_expire(self): Graph, Point = self.classes.Graph, self.classes.Point sess = self._fixture() g = sess.query(Graph).first() e = g.edges[0] sess.expire(e) assert "start" not in e.__dict__ assert e.start == Point(3, 4) def test_default_value(self): Edge = self.classes.Edge e = Edge() eq_(e.start, None) def test_no_name_declarative(self, decl_base, connection): """test #7751""" class Point: def __init__(self, x, y): self.x = x self.y = y def __composite_values__(self): return self.x, self.y def __repr__(self): return "Point(x=%r, y=%r)" % (self.x, self.y) def __eq__(self, other): return ( isinstance(other, Point) and other.x == self.x and other.y == self.y ) def __ne__(self, other): return not self.__eq__(other) class Vertex(decl_base): __tablename__ = "vertices" id = Column(Integer, primary_key=True) x1 = Column(Integer) y1 = Column(Integer) x2 = Column(Integer) y2 = Column(Integer) start = composite(Point, x1, y1) end = composite(Point, x2, y2) self.assert_compile( select(Vertex), "SELECT vertices.id, vertices.x1, vertices.y1, vertices.x2, " "vertices.y2 FROM vertices", ) decl_base.metadata.create_all(connection) s = Session(connection) hv = Vertex(start=Point(1, 2), end=Point(3, 4)) s.add(hv) s.commit() is_( hv, s.scalars( select(Vertex).where(Vertex.start == Point(1, 2)) ).first(), ) def test_no_name_declarative_two(self, decl_base, connection): """test #7752""" class Point: def __init__(self, x, y): self.x = x self.y = y def __composite_values__(self): return self.x, self.y def __repr__(self): return "Point(x=%r, y=%r)" % (self.x, self.y) def __eq__(self, other): return ( isinstance(other, Point) and other.x == self.x and other.y == self.y ) def __ne__(self, other): return not self.__eq__(other) class Vertex: def __init__(self, start, end): self.start = start self.end = end @classmethod def _generate(self, x1, y1, x2, y2): """generate a Vertex from a row""" return Vertex(Point(x1, y1), Point(x2, y2)) def __composite_values__(self): return ( self.start.__composite_values__() + self.end.__composite_values__() ) class HasVertex(decl_base): __tablename__ = "has_vertex" id = Column(Integer, primary_key=True) x1 = Column(Integer) y1 = Column(Integer) x2 = Column(Integer) y2 = Column(Integer) vertex = composite(Vertex._generate, x1, y1, x2, y2) self.assert_compile( select(HasVertex), "SELECT has_vertex.id, has_vertex.x1, has_vertex.y1, " "has_vertex.x2, has_vertex.y2 FROM has_vertex", ) decl_base.metadata.create_all(connection) s = Session(connection) hv = HasVertex(vertex=Vertex(Point(1, 2), Point(3, 4))) s.add(hv) s.commit() is_( hv, s.scalars( select(HasVertex).where( HasVertex.vertex == Vertex(Point(1, 2), Point(3, 4)) ) ).first(), ) class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): @classmethod def define_tables(cls, metadata): Table( "stuff", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("a", String(30)), Column("b", String(30)), Column("c", String(30)), Column("d", String(30)), ) def _fixture(self): class AB: def __init__(self, a, b, cd): self.a = a self.b = b self.cd = cd @classmethod def generate(cls, a, b, c, d): return AB(a, b, CD(c, d)) def __composite_values__(self): return (self.a, self.b) + self.cd.__composite_values__() def __eq__(self, other): return ( isinstance(other, AB) and self.a == other.a and self.b == other.b and self.cd == other.cd ) def __ne__(self, other): return not self.__eq__(other) class CD: def __init__(self, c, d): self.c = c self.d = d def __composite_values__(self): return (self.c, self.d) def __eq__(self, other): return ( isinstance(other, CD) and self.c == other.c and self.d == other.d ) def __ne__(self, other): return not self.__eq__(other) class Thing: def __init__(self, ab): self.ab = ab stuff = self.tables.stuff self.mapper_registry.map_imperatively( Thing, stuff, properties={ "ab": composite( AB.generate, stuff.c.a, stuff.c.b, stuff.c.c, stuff.c.d ) }, ) return Thing, AB, CD def test_round_trip(self): Thing, AB, CD = self._fixture() s = fixture_session() s.add(Thing(AB("a", "b", CD("c", "d")))) s.commit() s.close() t1 = ( s.query(Thing).filter(Thing.ab == AB("a", "b", CD("c", "d"))).one() ) eq_(t1.ab, AB("a", "b", CD("c", "d"))) class EventsEtcTest(fixtures.MappedTest): @testing.fixture def point_fixture(self, decl_base): def go(active_history): @dataclasses.dataclass class Point: x: int y: int class Edge(decl_base): __tablename__ = "edge" id = mapped_column(Integer, primary_key=True) start = composite( Point, mapped_column("x1", Integer), mapped_column("y1", Integer), active_history=active_history, ) end = composite( Point, mapped_column("x2", Integer, nullable=True), mapped_column("y2", Integer, nullable=True), active_history=active_history, ) decl_base.metadata.create_all(testing.db) return Point, Edge return go @testing.variation("active_history", [True, False]) @testing.variation("hist_on_mapping", [True, False]) def test_event_listener_no_value_to_set( self, point_fixture, active_history, hist_on_mapping ): if hist_on_mapping: config_active_history = bool(active_history) else: config_active_history = False Point, Edge = point_fixture(config_active_history) if not hist_on_mapping and active_history: Edge.start.impl.active_history = True m1 = mock.Mock() event.listen(Edge.start, "set", m1) e1 = Edge() e1.start = Point(5, 6) eq_( m1.mock_calls, [ mock.call( e1, Point(5, 6), LoaderCallableStatus.NO_VALUE if not active_history else None, Edge.start.impl, ) ], ) eq_( inspect(e1).attrs.start.history, ([Point(5, 6)], (), [Point(None, None)]), ) @testing.variation("active_history", [True, False]) @testing.variation("hist_on_mapping", [True, False]) def test_event_listener_set_to_new( self, point_fixture, active_history, hist_on_mapping ): if hist_on_mapping: config_active_history = bool(active_history) else: config_active_history = False Point, Edge = point_fixture(config_active_history) if not hist_on_mapping and active_history: Edge.start.impl.active_history = True e1 = Edge() e1.start = Point(5, 6) sess = fixture_session() sess.add(e1) sess.commit() assert "start" not in e1.__dict__ m1 = mock.Mock() event.listen(Edge.start, "set", m1) e1.start = Point(7, 8) eq_( m1.mock_calls, [ mock.call( e1, Point(7, 8), LoaderCallableStatus.NO_VALUE if not active_history else Point(5, 6), Edge.start.impl, ) ], ) if active_history: eq_( inspect(e1).attrs.start.history, ([Point(7, 8)], (), [Point(5, 6)]), ) else: eq_( inspect(e1).attrs.start.history, ([Point(7, 8)], (), [Point(None, None)]), ) @testing.variation("active_history", [True, False]) @testing.variation("hist_on_mapping", [True, False]) def test_event_listener_set_to_deleted( self, point_fixture, active_history, hist_on_mapping ): if hist_on_mapping: config_active_history = bool(active_history) else: config_active_history = False Point, Edge = point_fixture(config_active_history) if not hist_on_mapping and active_history: Edge.start.impl.active_history = True e1 = Edge() e1.start = Point(5, 6) sess = fixture_session() sess.add(e1) sess.commit() assert "start" not in e1.__dict__ m1 = mock.Mock() event.listen(Edge.start, "remove", m1) del e1.start eq_( m1.mock_calls, [ mock.call( e1, LoaderCallableStatus.NO_VALUE if not active_history else Point(5, 6), Edge.start.impl, ) ], ) if active_history: eq_( inspect(e1).attrs.start.history, ([Point(None, None)], (), [Point(5, 6)]), ) else: eq_( inspect(e1).attrs.start.history, ([Point(None, None)], (), [Point(None, None)]), ) class PrimaryKeyTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( "graphs", metadata, Column("id", Integer, primary_key=True), Column("version_id", Integer, primary_key=True, nullable=True), Column("name", String(30)), ) @classmethod def setup_mappers(cls): graphs = cls.tables.graphs class Version(cls.Comparable): def __init__(self, id_, version): self.id = id_ self.version = version def __composite_values__(self): return (self.id, self.version) __hash__ = None def __eq__(self, other): return ( isinstance(other, Version) and other.id == self.id and other.version == self.version ) def __ne__(self, other): return not self.__eq__(other) class Graph(cls.Comparable): def __init__(self, version): self.version = version cls.mapper_registry.map_imperatively( Graph, graphs, properties={ "version": sa.orm.composite( Version, graphs.c.id, graphs.c.version_id ) }, ) def _fixture(self): Graph, Version = self.classes.Graph, self.classes.Version sess = fixture_session() g = Graph(Version(1, 1)) sess.add(g) sess.commit() return sess def test_get_by_col(self): Graph = self.classes.Graph sess = self._fixture() g = sess.query(Graph).first() g2 = sess.get(Graph, [g.id, g.version_id]) eq_(g.version, g2.version) def test_get_by_composite(self): Graph, Version = self.classes.Graph, self.classes.Version sess = self._fixture() g = sess.query(Graph).first() g2 = sess.get(Graph, Version(g.id, g.version_id)) eq_(g.version, g2.version) def test_pk_mutation(self): Graph, Version = self.classes.Graph, self.classes.Version sess = self._fixture() g = sess.query(Graph).first() g.version = Version(2, 1) sess.commit() g2 = sess.get(Graph, Version(2, 1)) eq_(g.version, g2.version) @testing.fails_on_everything_except("sqlite") def test_null_pk(self): Graph, Version = self.classes.Graph, self.classes.Version sess = fixture_session() # test pk with one column NULL # only sqlite can really handle this g = Graph(Version(2, None)) sess.add(g) sess.commit() g2 = sess.query(Graph).filter_by(version=Version(2, None)).one() eq_(g.version, g2.version) class PrimaryKeyTestDataclasses(PrimaryKeyTest): @classmethod def setup_mappers(cls): graphs = cls.tables.graphs @dataclasses.dataclass class Version: id: int version: int cls.classes.Version = Version class Graph(cls.Comparable): def __init__(self, version): self.version = version cls.mapper_registry.map_imperatively( Graph, graphs, properties={ "version": sa.orm.composite( Version, graphs.c.id, graphs.c.version_id ) }, ) class DefaultsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( "foobars", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("x1", Integer, default=2), Column("x2", Integer), Column("x3", Integer, server_default="15"), Column("x4", Integer), ) @classmethod def setup_mappers(cls): foobars = cls.tables.foobars class Foobar(cls.Comparable): pass class FBComposite(cls.Comparable): def __init__(self, x1, x2, x3, x4): self.goofy_x1 = x1 self.x2 = x2 self.x3 = x3 self.x4 = x4 def __composite_values__(self): return self.goofy_x1, self.x2, self.x3, self.x4 __hash__ = None def __eq__(self, other): return ( other.goofy_x1 == self.goofy_x1 and other.x2 == self.x2 and other.x3 == self.x3 and other.x4 == self.x4 ) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "FBComposite(%r, %r, %r, %r)" % ( self.goofy_x1, self.x2, self.x3, self.x4, ) cls.mapper_registry.map_imperatively( Foobar, foobars, properties=dict( foob=sa.orm.composite( FBComposite, foobars.c.x1, foobars.c.x2, foobars.c.x3, foobars.c.x4, ) ), ) def test_attributes_with_defaults(self): Foobar, FBComposite = self.classes.Foobar, self.classes.FBComposite sess = fixture_session() f1 = Foobar() f1.foob = FBComposite(None, 5, None, None) sess.add(f1) sess.flush() eq_(f1.foob, FBComposite(2, 5, 15, None)) f2 = Foobar() sess.add(f2) sess.flush() eq_(f2.foob, FBComposite(2, None, 15, None)) def test_set_composite_values(self): Foobar, FBComposite = self.classes.Foobar, self.classes.FBComposite sess = fixture_session() f1 = Foobar() f1.foob = FBComposite(None, 5, None, None) sess.add(f1) sess.flush() eq_(f1.foob, FBComposite(2, 5, 15, None)) class MappedSelectTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( "descriptions", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("d1", String(20)), Column("d2", String(20)), ) Table( "values", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column( "description_id", Integer, ForeignKey("descriptions.id"), nullable=False, ), Column("v1", String(20)), Column("v2", String(20)), ) @classmethod def setup_mappers(cls): values, descriptions = cls.tables.values, cls.tables.descriptions class Descriptions(cls.Comparable): pass class Values(cls.Comparable): pass class CustomValues(cls.Comparable, list): def __init__(self, *args): self.extend(args) def __composite_values__(self): return self desc_values = ( select(values, descriptions.c.d1, descriptions.c.d2) .where( descriptions.c.id == values.c.description_id, ) .alias("descriptions_values") ) cls.mapper_registry.map_imperatively( Descriptions, descriptions, properties={ "values": relationship(Values, lazy="dynamic"), "custom_descriptions": composite( CustomValues, descriptions.c.d1, descriptions.c.d2 ), }, ) cls.mapper_registry.map_imperatively( Values, desc_values, properties={ "custom_values": composite( CustomValues, desc_values.c.v1, desc_values.c.v2 ) }, ) def test_set_composite_attrs_via_selectable(self, connection): Values, CustomValues, values, Descriptions, descriptions = ( self.classes.Values, self.classes.CustomValues, self.tables.values, self.classes.Descriptions, self.tables.descriptions, ) session = fixture_session() d = Descriptions( custom_descriptions=CustomValues("Color", "Number"), values=[ Values(custom_values=CustomValues("Red", "5")), Values(custom_values=CustomValues("Blue", "1")), ], ) session.add(d) session.commit() eq_( connection.execute(descriptions.select()).fetchall(), [(1, "Color", "Number")], ) eq_( connection.execute(values.select()).fetchall(), [(1, 1, "Red", "5"), (2, 1, "Blue", "1")], ) class ManyToOneTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( "a", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("b1", String(20)), Column("b2_id", Integer, ForeignKey("b.id")), ) Table( "b", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("data", String(20)), ) @classmethod def setup_mappers(cls): a, b = cls.tables.a, cls.tables.b class A(cls.Comparable): pass class B(cls.Comparable): pass class C(cls.Comparable): def __init__(self, b1, b2): self.b1, self.b2 = b1, b2 def __composite_values__(self): return self.b1, self.b2 def __eq__(self, other): return ( isinstance(other, C) and other.b1 == self.b1 and other.b2 == self.b2 ) cls.mapper_registry.map_imperatively( A, a, properties={"b2": relationship(B), "c": composite(C, "b1", "b2")}, ) cls.mapper_registry.map_imperatively(B, b) def test_early_configure(self): # test [ticket:2935], that we can call a composite # expression before configure_mappers() A = self.classes.A A.c.__clause_element__() def test_persist(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) sess = fixture_session() sess.add(A(c=C("b1", B(data="b2")))) sess.commit() a1 = sess.query(A).one() eq_(a1.c, C("b1", B(data="b2"))) def test_query(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) sess = fixture_session() b1, b2 = B(data="b1"), B(data="b2") a1 = A(c=C("a1b1", b1)) a2 = A(c=C("a2b1", b2)) sess.add_all([a1, a2]) sess.commit() eq_(sess.query(A).filter(A.c == C("a2b1", b2)).one(), a2) def test_query_aliased(self): A, C, B = (self.classes.A, self.classes.C, self.classes.B) sess = fixture_session() b1, b2 = B(data="b1"), B(data="b2") a1 = A(c=C("a1b1", b1)) a2 = A(c=C("a2b1", b2)) sess.add_all([a1, a2]) sess.commit() ae = aliased(A) eq_(sess.query(ae).filter(ae.c == C("a2b1", b2)).one(), a2) class ConfigurationTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( "edge", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("x1", Integer), Column("y1", Integer), Column("x2", Integer), Column("y2", Integer), ) @classmethod def setup_mappers(cls): class Point(cls.Comparable): def __init__(self, x, y): self.x = x self.y = y def __composite_values__(self): return [self.x, self.y] def __eq__(self, other): return ( isinstance(other, Point) and other.x == self.x and other.y == self.y ) def __ne__(self, other): return not isinstance(other, Point) or not self.__eq__(other) class Edge(cls.Comparable): pass def _test_roundtrip(self): Edge, Point = self.classes.Edge, self.classes.Point e1 = Edge(start=Point(3, 4), end=Point(5, 6)) sess = fixture_session() sess.add(e1) sess.commit() eq_(sess.query(Edge).one(), Edge(start=Point(3, 4), end=Point(5, 6))) def test_columns(self): edge, Edge, Point = ( self.tables.edge, self.classes.Edge, self.classes.Point, ) self.mapper_registry.map_imperatively( Edge, edge, properties={ "start": sa.orm.composite(Point, edge.c.x1, edge.c.y1), "end": sa.orm.composite(Point, edge.c.x2, edge.c.y2), }, ) self._test_roundtrip() def test_attributes(self): edge, Edge, Point = ( self.tables.edge, self.classes.Edge, self.classes.Point, ) m = self.mapper_registry.map_imperatively(Edge, edge) m.add_property("start", sa.orm.composite(Point, Edge.x1, Edge.y1)) m.add_property("end", sa.orm.composite(Point, Edge.x2, Edge.y2)) self._test_roundtrip() def test_strings(self): edge, Edge, Point = ( self.tables.edge, self.classes.Edge, self.classes.Point, ) m = self.mapper_registry.map_imperatively(Edge, edge) m.add_property("start", sa.orm.composite(Point, "x1", "y1")) m.add_property("end", sa.orm.composite(Point, "x2", "y2")) self._test_roundtrip() def test_deferred(self): edge, Edge, Point = ( self.tables.edge, self.classes.Edge, self.classes.Point, ) self.mapper_registry.map_imperatively( Edge, edge, properties={ "start": sa.orm.composite( Point, edge.c.x1, edge.c.y1, deferred=True, group="s" ), "end": sa.orm.composite( Point, edge.c.x2, edge.c.y2, deferred=True ), }, ) self._test_roundtrip() def test_check_prop_type(self): edge, Edge, Point = ( self.tables.edge, self.classes.Edge, self.classes.Point, ) self.mapper_registry.map_imperatively( Edge, edge, properties={ "start": sa.orm.composite(Point, (edge.c.x1,), edge.c.y1) }, ) assert_raises_message( sa.exc.ArgumentError, # note that we also are checking that the tuple # renders here, so the "%" operator in the string needs to # apply the tuple also r"Composite expects Column objects or mapped " r"attributes/attribute names as " r"arguments, got: \(Column", configure_mappers, ) class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): __dialect__ = "default" @classmethod def define_tables(cls, metadata): Table( "edge", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("x1", Integer), Column("y1", Integer), Column("x2", Integer), Column("y2", Integer), ) @classmethod def setup_mappers(cls): class Point(cls.Comparable): def __init__(self, x, y): self.x = x self.y = y def __composite_values__(self): return [self.x, self.y] def __eq__(self, other): return ( isinstance(other, Point) and other.x == self.x and other.y == self.y ) def __ne__(self, other): return not isinstance(other, Point) or not self.__eq__(other) class Edge(cls.Comparable): def __init__(self, start, end): self.start = start self.end = end def __eq__(self, other): return isinstance(other, Edge) and other.id == self.id def _fixture(self, custom): edge, Edge, Point = ( self.tables.edge, self.classes.Edge, self.classes.Point, ) if custom: class CustomComparator(sa.orm.Composite.Comparator): def near(self, other, d): clauses = self.__clause_element__().clauses diff_x = clauses[0] - other.x diff_y = clauses[1] - other.y return diff_x * diff_x + diff_y * diff_y <= d * d self.mapper_registry.map_imperatively( Edge, edge, properties={ "start": sa.orm.composite( Point, edge.c.x1, edge.c.y1, comparator_factory=CustomComparator, ), "end": sa.orm.composite(Point, edge.c.x2, edge.c.y2), }, ) else: self.mapper_registry.map_imperatively( Edge, edge, properties={ "start": sa.orm.composite(Point, edge.c.x1, edge.c.y1), "end": sa.orm.composite(Point, edge.c.x2, edge.c.y2), }, ) @testing.combinations(True, False, argnames="custom") @testing.combinations( (operator.lt, "<", ">"), (operator.gt, ">", "<"), (operator.eq, "=", "="), (operator.ne, "!=", "!="), (operator.le, "<=", ">="), (operator.ge, ">=", "<="), argnames="operator, fwd_op, rev_op", ) def test_comparator_behavior(self, custom, operator, fwd_op, rev_op): self._fixture(custom) Edge, Point = self.classes("Edge", "Point") self.assert_compile( select(Edge).filter(operator(Edge.start, Point(3, 4))), "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " f"WHERE edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1", checkparams={"x1_1": 3, "y1_1": 4}, ) self.assert_compile( select(Edge).filter(~operator(Edge.start, Point(3, 4))), "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " f"WHERE NOT (edge.x1 {fwd_op} :x1_1 AND edge.y1 {fwd_op} :y1_1)", checkparams={"x1_1": 3, "y1_1": 4}, ) @testing.combinations(True, False, argnames="custom") @testing.combinations( (operator.lt, "<", ">"), (operator.gt, ">", "<"), (operator.eq, "=", "="), (operator.ne, "!=", "!="), (operator.le, "<=", ">="), (operator.ge, ">=", "<="), argnames="op, fwd_op, rev_op", ) def test_comparator_null(self, custom, op, fwd_op, rev_op): self._fixture(custom) Edge, Point = self.classes("Edge", "Point") if op is operator.eq: self.assert_compile( select(Edge).filter(op(Edge.start, None)), "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " "WHERE edge.x1 IS NULL AND edge.y1 IS NULL", checkparams={}, ) elif op is operator.ne: self.assert_compile( select(Edge).filter(op(Edge.start, None)), "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " "WHERE edge.x1 IS NOT NULL AND edge.y1 IS NOT NULL", checkparams={}, ) else: with expect_raises_message( sa.exc.ArgumentError, r"Only '=', '!=', .* operators can be used " r"with None/True/False", ): select(Edge).filter(op(Edge.start, None)) def test_default_comparator_factory(self): self._fixture(False) Edge = self.classes.Edge start_prop = Edge.start.property assert start_prop.comparator_factory is Composite.Comparator def test_custom_comparator_factory(self): self._fixture(True) Edge, Point = (self.classes.Edge, self.classes.Point) edge_1, edge_2 = ( Edge(Point(0, 0), Point(3, 5)), Edge(Point(0, 1), Point(3, 5)), ) sess = fixture_session() sess.add_all([edge_1, edge_2]) sess.commit() near_edges = ( sess.query(Edge).filter(Edge.start.near(Point(1, 1), 1)).all() ) assert edge_1 not in near_edges assert edge_2 in near_edges near_edges = ( sess.query(Edge).filter(Edge.start.near(Point(0, 1), 1)).all() ) assert edge_1 in near_edges and edge_2 in near_edges def test_order_by(self): self._fixture(False) Edge = self.classes.Edge s = fixture_session() self.assert_compile( s.query(Edge).order_by(Edge.start, Edge.end), "SELECT edge.id AS edge_id, edge.x1 AS edge_x1, " "edge.y1 AS edge_y1, edge.x2 AS edge_x2, edge.y2 AS edge_y2 " "FROM edge ORDER BY edge.x1, edge.y1, edge.x2, edge.y2", ) def test_order_by_aliased(self): self._fixture(False) Edge = self.classes.Edge s = fixture_session() ea = aliased(Edge) self.assert_compile( s.query(ea).order_by(ea.start, ea.end), "SELECT edge_1.id AS edge_1_id, edge_1.x1 AS edge_1_x1, " "edge_1.y1 AS edge_1_y1, edge_1.x2 AS edge_1_x2, " "edge_1.y2 AS edge_1_y2 " "FROM edge AS edge_1 ORDER BY edge_1.x1, edge_1.y1, " "edge_1.x2, edge_1.y2", ) def test_clause_expansion(self): self._fixture(False) Edge = self.classes.Edge from sqlalchemy.orm import configure_mappers configure_mappers() self.assert_compile( select(Edge).order_by(Edge.start), "SELECT edge.id, edge.x1, edge.y1, edge.x2, edge.y2 FROM edge " "ORDER BY edge.x1, edge.y1", )