summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/sharding/attribute_shard.py19
-rw-r--r--lib/sqlalchemy/orm/session.py7
-rw-r--r--lib/sqlalchemy/orm/shard.py5
-rw-r--r--test/orm/sharding/shard.py23
4 files changed, 35 insertions, 19 deletions
diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py
index e95b978ae..6e4732989 100644
--- a/examples/sharding/attribute_shard.py
+++ b/examples/sharding/attribute_shard.py
@@ -34,13 +34,15 @@ db4 = create_engine('sqlite:///shard4.db', echo=echo)
# step 3. create session function. this binds the shard ids
# to databases within a ShardedSession and returns it.
-def create_session():
- s = ShardedSession(shard_chooser, id_chooser, query_chooser)
- s.bind_shard('north_america', db1)
- s.bind_shard('asia', db2)
- s.bind_shard('europe', db3)
- s.bind_shard('south_america', db4)
- return s
+create_session = sessionmaker(class_=ShardedSession)
+
+create_session.configure(shards={
+ 'north_america':db1,
+ 'asia':db2,
+ 'europe':db3,
+ 'south_america':db4
+})
+
# step 4. table setup.
meta = MetaData()
@@ -143,6 +145,9 @@ def query_chooser(query):
else:
return ids
+# further configure create_session to use these functions
+create_session.configure(shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
+
# step 6. mapped classes.
class WeatherLocation(object):
def __init__(self, continent, city):
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index f982da536..80c1a5b0d 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -14,14 +14,17 @@ from sqlalchemy.orm.mapper import global_extensions
__all__ = ['Session', 'SessionTransaction']
-def sessionmaker(autoflush=True, transactional=True, bind=None, **kwargs):
+def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs):
"""Generate a Session configuration."""
kwargs['bind'] = bind
kwargs['autoflush'] = autoflush
kwargs['transactional'] = transactional
- class Sess(Session):
+ if class_ is None:
+ class_ = Session
+
+ class Sess(class_):
def __init__(self, **local_kwargs):
for k in kwargs:
local_kwargs.setdefault(k, kwargs[k])
diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py
index 26d03372f..9d4396d2b 100644
--- a/lib/sqlalchemy/orm/shard.py
+++ b/lib/sqlalchemy/orm/shard.py
@@ -4,7 +4,7 @@ from sqlalchemy.orm import Query
__all__ = ['ShardedSession', 'ShardedQuery']
class ShardedSession(Session):
- def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs):
+ def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
"""construct a ShardedSession.
shard_chooser
@@ -32,6 +32,9 @@ class ShardedSession(Session):
self.__binds = {}
self._mapper_flush_opts = {'connection_callable':self.connection}
self._query_cls = ShardedQuery
+ if shards is not None:
+ for k in shards:
+ self.bind_shard(k, shards[k])
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py
index faa980cc2..c1dd63d65 100644
--- a/test/orm/sharding/shard.py
+++ b/test/orm/sharding/shard.py
@@ -90,14 +90,16 @@ class ShardTest(PersistTest):
return ['north_america', 'asia', 'europe', 'south_america']
else:
return ids
-
- def create_session():
- s = ShardedSession(shard_chooser, id_chooser, query_chooser)
- s.bind_shard('north_america', db1)
- s.bind_shard('asia', db2)
- s.bind_shard('europe', db3)
- s.bind_shard('south_america', db4)
- return s
+
+ create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True)
+
+ create_session.configure(shards={
+ 'north_america':db1,
+ 'asia':db2,
+ 'europe':db3,
+ 'south_america':db4
+ }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
+
def setup_mappers(self):
global WeatherLocation, Report
@@ -133,10 +135,13 @@ class ShardTest(PersistTest):
sess = create_session()
for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
sess.save(c)
- sess.flush()
+ sess.commit()
sess.clear()
+ assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')]
+ assert db1.execute(weather_locations.select()).fetchall() == [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')]
+
t = sess.query(WeatherLocation).get(tokyo.id)
assert t.city == tokyo.city
assert t.reports[0].temperature == 80.0