import datetime import os from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update from sqlalchemy import util from sqlalchemy.ext.horizontal_shard import set_shard_id from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import defer from sqlalchemy.orm import deferred from sqlalchemy.orm import lazyload from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import SingletonThreadPool from sqlalchemy.sql import operators from sqlalchemy.sql import Select from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import provision from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.engines import testing_reaper class ShardTest: __skip_if__ = (lambda: util.win32,) __requires__ = ("sqlite",) run_create_tables = None schema = None @classmethod def define_tables(cls, metadata): global db1, db2, db3, db4, weather_locations, weather_reports cls.tables.ids = ids = Table( "ids", metadata, Column("nextid", Integer, nullable=False) ) def id_generator(ctx): # in reality, might want to use a separate transaction for this. with db1.begin() as c: nextid = c.execute(ids.select().with_for_update()).scalar() c.execute( ids.update().values({ids.c.nextid: ids.c.nextid + 1}) ) return nextid cls.tables.weather_locations = weather_locations = Table( "weather_locations", metadata, Column("id", Integer, primary_key=True, default=id_generator), Column("continent", String(30), nullable=False), Column("city", String(50), nullable=False), schema=cls.schema, ) cls.tables.weather_reports = Table( "weather_reports", metadata, Column("id", Integer, primary_key=True), Column("location_id", Integer, ForeignKey(weather_locations.c.id)), Column("temperature", Float), Column("report_time", DateTime, default=datetime.datetime.now), schema=cls.schema, ) def setup_test(self): global db1, db2, db3, db4 db1, db2, db3, db4 = self._dbs = self.dbs = self._init_dbs() for db in (db1, db2, db3, db4): self.tables_test_metadata.create_all(db) ids = self.tables.ids with db1.begin() as conn: conn.execute(ids.insert(), dict(nextid=1)) self.setup_session() @classmethod def setup_session(cls): global sharded_session shard_lookup = { "North America": "north_america", "Asia": "asia", "Europe": "europe", "South America": "south_america", } def shard_chooser(mapper, instance, clause=None): if isinstance(instance, WeatherLocation): return shard_lookup[instance.continent] else: return shard_chooser(mapper, instance.location) def identity_chooser( mapper, primary_key, *, lazy_loaded_from, execution_options, bind_arguments, **kw, ): return ["north_america", "asia", "europe", "south_america"] def execute_chooser(orm_context): ids = [] query = orm_context.statement class FindContinent(sql.ClauseVisitor): def visit_binary(self, binary): if binary.left.shares_lineage( weather_locations.c.continent ): if binary.operator == operators.eq: ids.append(shard_lookup[binary.right.value]) elif binary.operator == operators.in_op: for value in binary.right.value: ids.append(shard_lookup[value]) if isinstance(query, Select) and query.whereclause is not None: FindContinent().traverse(query.whereclause) if len(ids) == 0: return ["north_america", "asia", "europe", "south_america"] else: return ids sharded_session = sessionmaker(class_=ShardedSession, autoflush=True) sharded_session.configure( shards={ "north_america": db1, "asia": db2, "europe": db3, "south_america": db4, }, shard_chooser=shard_chooser, identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) @classmethod def setup_mappers(cls): global WeatherLocation, Report class WeatherLocation: def __init__(self, continent, city): self.continent = continent self.city = city class Report: def __init__(self, temperature, id_=None): self.temperature = temperature if id_: self.id = id_ weather_locations = cls.tables.weather_locations cls.mapper_registry.map_imperatively( WeatherLocation, weather_locations, properties={ "reports": relationship(Report, backref="location"), "city": deferred(weather_locations.c.city), }, ) weather_reports = cls.tables.weather_reports cls.mapper_registry.map_imperatively(Report, weather_reports) def _fixture_data(self): tokyo = WeatherLocation("Asia", "Tokyo") newyork = WeatherLocation("North America", "New York") toronto = WeatherLocation("North America", "Toronto") london = WeatherLocation("Europe", "London") dublin = WeatherLocation("Europe", "Dublin") brasilia = WeatherLocation("South America", "Brasila") quito = WeatherLocation("South America", "Quito") tokyo.reports.append(Report(80.0, id_=1)) newyork.reports.append(Report(75, id_=1)) quito.reports.append(Report(85)) sess = sharded_session() for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.add(c) sess.flush() eq_(inspect(newyork).key[2], "north_america") eq_(inspect(newyork).identity_token, "north_america") eq_(inspect(dublin).key[2], "europe") eq_(inspect(dublin).identity_token, "europe") sess.commit() sess.close() return sess def test_get(self): sess = self._fixture_data() tokyo = sess.get(WeatherLocation, 1) eq_(tokyo.city, "Tokyo") newyork = sess.get(WeatherLocation, 2) eq_(newyork.city, "New York") t2 = sess.get(WeatherLocation, 1) is_(t2, tokyo) def test_get_explicit_shard(self): sess = self._fixture_data() tokyo = ( sess.query(WeatherLocation) .set_shard("europe") .where(WeatherLocation.id == 1) .first() ) is_(tokyo, None) newyork = ( sess.query(WeatherLocation) .set_shard("north_america") .where(WeatherLocation.id == 2) .first() ) eq_(newyork.city, "New York") # now it found it t2 = sess.get(WeatherLocation, 1) eq_(t2.city, "Tokyo") @testing.variation("option_type", ["none", "lazyload", "selectinload"]) @testing.variation( "limit_shard", ["none", "lead_only", "propagate_to_loaders", "bind_arg"], ) def test_set_shard_option_relationship(self, option_type, limit_shard): sess = self._fixture_data() stmt = select(WeatherLocation).filter( WeatherLocation.city == "New York" ) bind_arguments = {} if limit_shard.none: # right now selectinload / lazyload runs all the shards even if the # ids are limited to just one shard, since that information # is not transferred counts = [2, 2, 2, 2] elif limit_shard.lead_only: if option_type.selectinload: counts = [2, 0, 0, 0] else: counts = [2, 1, 1, 1] elif limit_shard.bind_arg: counts = [2, 1, 1, 1] elif limit_shard.propagate_to_loaders: counts = [2, 0, 0, 0] else: limit_shard.fail() if option_type.lazyload: stmt = stmt.options(lazyload(WeatherLocation.reports)) elif option_type.selectinload: stmt = stmt.options(selectinload(WeatherLocation.reports)) if limit_shard.lead_only: stmt = stmt.options( set_shard_id("north_america", propagate_to_loaders=False) ) elif limit_shard.propagate_to_loaders: stmt = stmt.options(set_shard_id("north_america")) elif limit_shard.bind_arg: bind_arguments["shard_id"] = "north_america" with self.assert_statement_count_multi_db(self.dbs, counts): w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first() w1.reports @testing.variation("option_type", ["none", "defer"]) @testing.variation( "limit_shard", ["none", "lead_only", "propagate_to_loaders", "bind_arg"], ) def test_set_shard_option_column(self, option_type, limit_shard): sess = self._fixture_data() stmt = select(WeatherLocation).filter( WeatherLocation.city == "New York" ) bind_arguments = {} if limit_shard.none: if option_type.defer: counts = [2, 1, 1, 1] else: counts = [1, 1, 1, 1] elif limit_shard.lead_only or limit_shard.propagate_to_loaders: if option_type.defer: counts = [2, 0, 0, 0] else: counts = [1, 0, 0, 0] elif limit_shard.bind_arg: if option_type.defer: counts = [2, 0, 0, 0] else: counts = [1, 0, 0, 0] else: limit_shard.fail() if option_type.defer: stmt = stmt.options(defer(WeatherLocation.continent)) if limit_shard.lead_only: stmt = stmt.options( set_shard_id("north_america", propagate_to_loaders=False) ) elif limit_shard.propagate_to_loaders: stmt = stmt.options(set_shard_id("north_america")) elif limit_shard.bind_arg: bind_arguments["shard_id"] = "north_america" with self.assert_statement_count_multi_db(self.dbs, counts): w1 = sess.scalars(stmt, bind_arguments=bind_arguments).first() w1.continent def test_query_explicit_shard_via_bind_opts(self): sess = self._fixture_data() stmt = select(WeatherLocation).filter(WeatherLocation.id == 1) tokyo = ( sess.execute(stmt, bind_arguments={"shard_id": "asia"}) .scalars() .first() ) eq_(tokyo.city, "Tokyo") def test_plain_db_lookup(self): self._fixture_data() # not sure what this is testing except the fixture data itself eq_( db2.connect().execute(weather_locations.select()).fetchall(), [(1, "Asia", "Tokyo")], ) eq_( db1.connect().execute(weather_locations.select()).fetchall(), [ (2, "North America", "New York"), (3, "North America", "Toronto"), ], ) def test_plain_core_lookup_w_shard(self): sess = self._fixture_data() eq_( sess.execute( weather_locations.select(), bind_arguments=dict(shard_id="asia"), ).fetchall(), [(1, "Asia", "Tokyo")], ) def test_roundtrip_future(self): sess = self._fixture_data() tokyo = ( sess.execute(select(WeatherLocation).filter_by(city="Tokyo")) .scalars() .one() ) eq_(tokyo.city, "Tokyo") asia_and_europe = sess.execute( select(WeatherLocation).filter( WeatherLocation.continent.in_(["Europe", "Asia"]) ) ).scalars() eq_( {c.city for c in asia_and_europe}, {"Tokyo", "London", "Dublin"}, ) def test_roundtrip(self): sess = self._fixture_data() tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() eq_(tokyo.city, "Tokyo") tokyo.city # reload 'city' attribute on tokyo sess.expire_all() t = sess.get(WeatherLocation, tokyo.id) eq_(t.city, tokyo.city) eq_(t.reports[0].temperature, 80.0) north_american_cities = sess.query(WeatherLocation).filter( WeatherLocation.continent == "North America" ) eq_( {c.city for c in north_american_cities}, {"New York", "Toronto"}, ) asia_and_europe = sess.query(WeatherLocation).filter( WeatherLocation.continent.in_(["Europe", "Asia"]) ) eq_( {c.city for c in asia_and_europe}, {"Tokyo", "London", "Dublin"}, ) # inspect the shard token stored with each instance eq_( {inspect(c).key[2] for c in asia_and_europe}, {"europe", "asia"}, ) eq_( {inspect(c).identity_token for c in asia_and_europe}, {"europe", "asia"}, ) newyork = sess.query(WeatherLocation).filter_by(city="New York").one() newyork_report = newyork.reports[0] tokyo_report = tokyo.reports[0] # same primary key, two identity keys eq_( inspect(newyork_report).identity_key, (Report, (1,), "north_america"), ) eq_(inspect(tokyo_report).identity_key, (Report, (1,), "asia")) # the token representing the originating shard is available eq_(inspect(newyork_report).identity_token, "north_america") eq_(inspect(tokyo_report).identity_token, "asia") def test_get_baked_query(self): sess = self._fixture_data() tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() tokyo.city sess.expunge_all() from sqlalchemy.ext.baked import BakedQuery bakery = BakedQuery.bakery() bq = bakery(lambda session: session.query(WeatherLocation)) t = bq(sess).get(tokyo.id) eq_(t.city, tokyo.city) eq_(inspect(t).key[2], "asia") def test_get_baked_query_shard_id(self): sess = self._fixture_data() tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() tokyo.city sess.expunge_all() from sqlalchemy.ext.baked import BakedQuery bakery = BakedQuery.bakery() bq = bakery(lambda session: session.query(WeatherLocation)) t = ( bq(sess) .with_post_criteria(lambda q: q.set_shard("asia")) .get(tokyo.id) ) eq_(t.city, tokyo.city) eq_(inspect(t).key[2], "asia") def test_filter_baked_query_shard_id(self): sess = self._fixture_data() tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() tokyo.city sess.expunge_all() from sqlalchemy.ext.baked import BakedQuery bakery = BakedQuery.bakery() bq = bakery( lambda session: session.query(WeatherLocation) ).with_criteria(lambda q: q.filter_by(id=tokyo.id)) t = bq(sess).with_post_criteria(lambda q: q.set_shard("asia")).one() eq_(t.city, tokyo.city) def test_shard_id_event(self): # this test is kind of important, it's testing that # when the load event is emitted for an ORM result, # the context is set up in the state that is expected. # prior to 1.4, we were changing a single context in place, # as we would join result sets by fully evaluating and concatenating. # in 1.4 onwards we return a Result that has not run for each # individual result yet, so each one has its own context that # is a shallow copy from the original. canary = [] def load(instance, ctx): canary.append(ctx.bind_arguments["shard_id"]) event.listen(WeatherLocation, "load", load) sess = self._fixture_data() tokyo = ( # noqa sess.query(WeatherLocation) .filter_by(city="Tokyo") .set_shard("asia") .one() ) sess.query(WeatherLocation).all() eq_( canary, [ "asia", "north_america", "north_america", "europe", "europe", "south_america", "south_america", ], ) def test_baked_mix(self): sess = self._fixture_data() tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one() tokyo.city sess.expunge_all() from sqlalchemy.ext.baked import BakedQuery bakery = BakedQuery.bakery() def get_tokyo(sess): bq = bakery(lambda session: session.query(WeatherLocation)) t = bq(sess).get(tokyo.id) return t Sess = sessionmaker(class_=Session, bind=db2, autoflush=True) sess2 = Sess() t = get_tokyo(sess) eq_(t.city, tokyo.city) t = get_tokyo(sess2) eq_(t.city, tokyo.city) @testing.combinations( "fetch", "evaluate", "auto", argnames="synchronize_session" ) @testing.combinations(True, False, argnames="legacy") def test_orm_update_synchronize(self, synchronize_session, legacy): sess = self._fixture_data() eq_( {row.temperature for row in sess.query(Report.temperature)}, {80.0, 75.0, 85.0}, ) temps = sess.query(Report).all() eq_({t.temperature for t in temps}, {80.0, 75.0, 85.0}) if legacy: sess.query(Report).filter(Report.temperature >= 80).update( {"temperature": Report.temperature + 6}, synchronize_session=synchronize_session, ) else: sess.execute( update(Report) .filter(Report.temperature >= 80) .values(temperature=Report.temperature + 6) .execution_options(synchronize_session=synchronize_session) ) with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]): eq_({t.temperature for t in temps}, {86.0, 75.0, 91.0}) eq_( {row.temperature for row in sess.query(Report.temperature)}, {86.0, 75.0, 91.0}, ) @testing.combinations( "fetch", "evaluate", "auto", argnames="synchronize_session" ) @testing.combinations(True, False, argnames="legacy") def test_orm_delete_synchronize(self, synchronize_session, legacy): sess = self._fixture_data() temps = sess.query(Report).all() eq_({t.temperature for t in temps}, {80.0, 75.0, 85.0}) if legacy: sess.query(Report).filter(Report.temperature >= 80).delete( synchronize_session=synchronize_session ) else: sess.execute( delete(Report) .filter(Report.temperature >= 80) .execution_options(synchronize_session=synchronize_session) ) with self.assert_statement_count_multi_db(self.dbs, [0, 0, 0, 0]): # test synchronize session for t in temps: assert inspect(t).deleted is (t.temperature >= 80) eq_( {row.temperature for row in sess.query(Report.temperature)}, {75.0}, ) class DistinctEngineShardTest(ShardTest, fixtures.MappedTest): def _init_dbs(self): db1 = testing_engine( "sqlite:///shard1_%s.db" % provision.FOLLOWER_IDENT, options=dict(poolclass=SingletonThreadPool), ) db2 = testing_engine( "sqlite:///shard2_%s.db" % provision.FOLLOWER_IDENT ) db3 = testing_engine( "sqlite:///shard3_%s.db" % provision.FOLLOWER_IDENT ) db4 = testing_engine( "sqlite:///shard4_%s.db" % provision.FOLLOWER_IDENT ) self.dbs = [db1, db2, db3, db4] return self.dbs def teardown_test(self): testing_reaper.checkin_all() for i in range(1, 5): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) def test_plain_core_textual_lookup_w_shard(self): sess = self._fixture_data() stmt = text("SELECT * FROM weather_locations") eq_( sess.execute( stmt, bind_arguments=dict(shard_id="asia") ).fetchall(), [(1, "Asia", "Tokyo")], ) def test_plain_core_textual_lookup(self): sess = self._fixture_data() stmt = text("SELECT * FROM weather_locations WHERE id=1") eq_( sess.execute(stmt).fetchall(), [(1, "Asia", "Tokyo")], ) class LegacyAPIShardTest(DistinctEngineShardTest): @classmethod def setup_session(cls): global sharded_session shard_lookup = { "North America": "north_america", "Asia": "asia", "Europe": "europe", "South America": "south_america", } def shard_chooser(mapper, instance, clause=None): if isinstance(instance, WeatherLocation): return shard_lookup[instance.continent] else: return shard_chooser(mapper, instance.location) def id_chooser(query, primary_key): return ["north_america", "asia", "europe", "south_america"] def query_chooser(query): ids = [] class FindContinent(sql.ClauseVisitor): def visit_binary(self, binary): if binary.left.shares_lineage( weather_locations.c.continent ): if binary.operator == operators.eq: ids.append(shard_lookup[binary.right.value]) elif binary.operator == operators.in_op: for value in binary.right.value: ids.append(shard_lookup[value]) if isinstance(query, Select) and query.whereclause is not None: FindContinent().traverse(query.whereclause) if len(ids) == 0: return ["north_america", "asia", "europe", "south_america"] else: return ids sm = sessionmaker(class_=ShardedSession, autoflush=True) sm.configure( shards={ "north_america": db1, "asia": db2, "europe": db3, "south_america": db4, }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser, ) def sharded_session(): with expect_deprecated( "The ``id_chooser`` parameter is deprecated", "The ``query_chooser`` parameter is deprecated", ): return sm() class AttachedFileShardTest(ShardTest, fixtures.MappedTest): """Use modern schema conventions along with SQLite ATTACH.""" schema = "changeme" def _init_dbs(self): e = testing_engine("sqlite://") with e.connect() as conn: for i in range(1, 5): conn.exec_driver_sql( 'ATTACH DATABASE "shard%s_%s.db" AS shard%s' % (i, provision.FOLLOWER_IDENT, i) ) db1 = e.execution_options(schema_translate_map={"changeme": "shard1"}) db2 = e.execution_options(schema_translate_map={"changeme": "shard2"}) db3 = e.execution_options(schema_translate_map={"changeme": "shard3"}) db4 = e.execution_options(schema_translate_map={"changeme": "shard4"}) self.engine = e return db1, db2, db3, db4 def teardown_test(self): testing_reaper.checkin_all() for i in range(1, 5): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) class TableNameConventionShardTest(ShardTest, fixtures.MappedTest): """This fixture uses a single SQLite database along with a table naming convention to achieve sharding. Event hooks are used to rewrite SQL statements. This used to be called "AttachedFileShardTest" but I didn't see any ATTACH going on. A more modern approach here would be to use the schema_translate_map option. """ schema = "changeme" def _init_dbs(self): dbmain = testing_engine("sqlite://") db1 = dbmain.execution_options(shard_id="shard1") db2 = dbmain.execution_options(shard_id="shard2") db3 = dbmain.execution_options(shard_id="shard3") db4 = dbmain.execution_options(shard_id="shard4") import re @event.listens_for(dbmain, "before_cursor_execute", retval=True) def _switch_shard(conn, cursor, stmt, params, context, executemany): shard_id = conn._execution_options["shard_id"] # because SQLite can't just give us a "use" statement, we have # to use the schema hack to locate table names if shard_id: stmt = re.sub(r"\"?changeme\"?\.", shard_id + "_", stmt) return stmt, params return db1, db2, db3, db4 class MultipleDialectShardTest(ShardTest, fixtures.MappedTest): __only_on__ = "postgresql" schema = "changeme" def _init_dbs(self): e1 = testing_engine("sqlite://") with e1.connect() as conn: for i in [1, 3]: conn.exec_driver_sql( 'ATTACH DATABASE "shard%s_%s.db" AS shard%s' % (i, provision.FOLLOWER_IDENT, i) ) e2 = testing_engine() with e2.begin() as conn: for i in [2, 4]: conn.exec_driver_sql( "CREATE SCHEMA IF NOT EXISTS shard%s" % (i,) ) db1 = e1.execution_options(schema_translate_map={"changeme": "shard1"}) db2 = e2.execution_options(schema_translate_map={"changeme": "shard2"}) db3 = e1.execution_options(schema_translate_map={"changeme": "shard3"}) db4 = e2.execution_options(schema_translate_map={"changeme": "shard4"}) self.sqlite_engine = e1 self.postgresql_engine = e2 return db1, db2, db3, db4 def teardown_test(self): clear_mappers() # the tests in this suite don't cleanly close out the Session # at the moment so use the reaper to close all connections testing_reaper.checkin_all() for i in [1, 3]: os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) with self.postgresql_engine.begin() as conn: self.tables_test_metadata.drop_all(conn) for i in [2, 4]: conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,)) self.postgresql_engine.dispose() class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): """test #4175""" @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Book(Base): __tablename__ = "book" id = Column(Integer, primary_key=True) pages = relationship("Page") class Page(Base): __tablename__ = "page" id = Column(Integer, primary_key=True) book_id = Column(ForeignKey("book.id")) def test_selectinload_query(self): session = ShardedSession( shards={"test": testing.db}, shard_chooser=lambda *args: "test", identity_chooser=lambda *args: None, execute_chooser=lambda *args: ["test"], ) Book, Page = self.classes("Book", "Page") book = Book() book.pages.append(Page()) session.add(book) session.commit() result = session.query(Book).options(selectinload(Book.pages)).all() eq_(result, [book]) class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest): @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class A(Base): __tablename__ = "a" id = Column(Integer, primary_key=True) data = Column(String(30)) deferred_data = deferred(Column(String(30))) @classmethod def insert_data(cls, connection): A = cls.classes.A s = Session(connection) s.add(A(data="d1", deferred_data="d2")) s.commit() def _session_fixture(self, **kw): # the "fake" key here is to ensure that neither id_chooser # nor query_chooser are actually used, only shard_chooser # should be used. return ShardedSession( shards={"main": testing.db}, shard_chooser=lambda *args: "main", identity_chooser=lambda *args: ["fake", "main"], execute_chooser=lambda *args: ["fake", "main"], **kw, ) def test_refresh(self): A = self.classes.A session = self._session_fixture() a1 = session.query(A).set_shard("main").first() session.refresh(a1) def test_deferred(self): A = self.classes.A session = self._session_fixture() a1 = session.query(A).set_shard("main").first() eq_(a1.deferred_data, "d2") def test_unexpire(self): A = self.classes.A session = self._session_fixture() a1 = session.query(A).set_shard("main").first() session.expire(a1) eq_(a1.data, "d1") class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): def _init_dbs(self): self.db1 = db1 = testing_engine( "sqlite:///shard1_%s.db" % provision.FOLLOWER_IDENT ) self.db2 = db2 = testing_engine( "sqlite:///shard2_%s.db" % provision.FOLLOWER_IDENT ) for db in (db1, db2): self.tables_test_metadata.create_all(db) self.dbs = [db1, db2] return self.dbs def teardown_test(self): for db in self.dbs: db.connect().invalidate() testing_reaper.checkin_all() for i in range(1, 3): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) @classmethod def setup_classes(cls): Base = cls.DeclarativeBasic class Book(Base): __tablename__ = "book" id = Column(Integer, primary_key=True) title = Column(String(50), nullable=False) pages = relationship("Page", backref="book") class Page(Base): __tablename__ = "page" id = Column(Integer, primary_key=True) book_id = Column(ForeignKey("book.id")) title = Column(String(50)) def _fixture(self, lazy_load_book=False, lazy_load_pages=False): Book, Page = self.classes("Book", "Page") def shard_for_book(book): if book.title == "title 1": return "test" elif book.title == "title 2": return "test2" else: assert False def identity_chooser( mapper, primary_key, *, lazy_loaded_from, execution_options, bind_arguments, **kw, ): assert lazy_loaded_from if isinstance(lazy_loaded_from.obj(), Book): token = shard_for_book(lazy_loaded_from.obj()) assert lazy_loaded_from.identity_token == token return [lazy_loaded_from.identity_token] def execute_chooser(orm_context): if ( orm_context.statement.column_descriptions[0]["type"] is Book and lazy_load_book ): assert isinstance(orm_context.lazy_loaded_from.obj(), Page) elif ( orm_context.statement.column_descriptions[0]["type"] is Page and lazy_load_pages ): assert isinstance(orm_context.lazy_loaded_from.obj(), Book) if orm_context.lazy_loaded_from is None: return ["test", "test2"] else: return [orm_context.lazy_loaded_from.identity_token] def shard_chooser(mapper, instance, **kw): if isinstance(instance, Page): return shard_for_book(instance.book) else: return shard_for_book(instance) db1, db2 = self._init_dbs() session = ShardedSession( shards={"test": db1, "test2": db2}, shard_chooser=shard_chooser, identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) return session def test_lazy_load_from_identity_map(self): session = self._fixture() Book, Page = self.classes("Book", "Page") book = Book(title="title 1") book.pages.append(Page()) session.add(book) session.flush() page = session.query(Page).first() session.expire(page, ["book"]) with self.assert_statement_count_multi_db(self.dbs, [0, 0]): # doesn't emit SQL eq_(page.book, book) def test_lazy_load_from_db(self): session = self._fixture(lazy_load_book=True) Book, Page = self.classes("Book", "Page") book1 = Book(title="title 1") book1.pages.append(Page(title="book 1 page 1")) session.add(book1) session.flush() book1_id = inspect(book1).identity_key session.expunge(book1) book1_page = session.query(Page).first() session.expire(book1_page, ["book"]) with self.assert_statement_count_multi_db(self.dbs, [1, 0]): # emits one query eq_(inspect(book1_page.book).identity_key, book1_id) def test_lazy_load_no_baked_conflict(self): session = self._fixture(lazy_load_pages=True) Book, Page = self.classes("Book", "Page") book1 = Book(title="title 1") book1.pages.append(Page(title="book 1 page 1")) book2 = Book(title="title 2") book2.pages.append(Page(title="book 2 page 1")) session.add(book1) session.add(book2) session.flush() session.expire(book1, ["pages"]) session.expire(book2, ["pages"]) eq_(book1.pages[0].title, "book 1 page 1") # second lazy load uses correct state eq_(book2.pages[0].title, "book 2 page 1")