diff options
| -rw-r--r-- | examples/sharding/attribute_shard.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/shard.py | 5 | ||||
| -rw-r--r-- | test/orm/sharding/shard.py | 23 |
4 files changed, 35 insertions, 19 deletions
diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py index e95b978ae..6e4732989 100644 --- a/examples/sharding/attribute_shard.py +++ b/examples/sharding/attribute_shard.py @@ -34,13 +34,15 @@ db4 = create_engine('sqlite:///shard4.db', echo=echo) # step 3. create session function. this binds the shard ids # to databases within a ShardedSession and returns it. -def create_session(): - s = ShardedSession(shard_chooser, id_chooser, query_chooser) - s.bind_shard('north_america', db1) - s.bind_shard('asia', db2) - s.bind_shard('europe', db3) - s.bind_shard('south_america', db4) - return s +create_session = sessionmaker(class_=ShardedSession) + +create_session.configure(shards={ + 'north_america':db1, + 'asia':db2, + 'europe':db3, + 'south_america':db4 +}) + # step 4. table setup. meta = MetaData() @@ -143,6 +145,9 @@ def query_chooser(query): else: return ids +# further configure create_session to use these functions +create_session.configure(shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) + # step 6. mapped classes. class WeatherLocation(object): def __init__(self, continent, city): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index f982da536..80c1a5b0d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -14,14 +14,17 @@ from sqlalchemy.orm.mapper import global_extensions __all__ = ['Session', 'SessionTransaction'] -def sessionmaker(autoflush=True, transactional=True, bind=None, **kwargs): +def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs): """Generate a Session configuration.""" kwargs['bind'] = bind kwargs['autoflush'] = autoflush kwargs['transactional'] = transactional - class Sess(Session): + if class_ is None: + class_ = Session + + class Sess(class_): def __init__(self, **local_kwargs): for k in kwargs: local_kwargs.setdefault(k, kwargs[k]) diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py index 26d03372f..9d4396d2b 100644 --- a/lib/sqlalchemy/orm/shard.py +++ b/lib/sqlalchemy/orm/shard.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Query __all__ = ['ShardedSession', 'ShardedQuery'] class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs): + def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs): """construct a ShardedSession. shard_chooser @@ -32,6 +32,9 @@ class ShardedSession(Session): self.__binds = {} self._mapper_flush_opts = {'connection_callable':self.connection} self._query_cls = ShardedQuery + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): if shard_id is None: diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py index faa980cc2..c1dd63d65 100644 --- a/test/orm/sharding/shard.py +++ b/test/orm/sharding/shard.py @@ -90,14 +90,16 @@ class ShardTest(PersistTest): return ['north_america', 'asia', 'europe', 'south_america'] else: return ids - - def create_session(): - s = ShardedSession(shard_chooser, id_chooser, query_chooser) - s.bind_shard('north_america', db1) - s.bind_shard('asia', db2) - s.bind_shard('europe', db3) - s.bind_shard('south_america', db4) - return s + + create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True) + + create_session.configure(shards={ + 'north_america':db1, + 'asia':db2, + 'europe':db3, + 'south_america':db4 + }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) + def setup_mappers(self): global WeatherLocation, Report @@ -133,10 +135,13 @@ class ShardTest(PersistTest): sess = create_session() for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.save(c) - sess.flush() + sess.commit() sess.clear() + assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')] + assert db1.execute(weather_locations.select()).fetchall() == [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')] + t = sess.query(WeatherLocation).get(tokyo.id) assert t.city == tokyo.city assert t.reports[0].temperature == 80.0 |
