diff options
| -rw-r--r-- | examples/sharding/attribute_shard.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/shard.py | 5 | ||||
| -rw-r--r-- | test/orm/sharding/shard.py | 23 | 
4 files changed, 35 insertions, 19 deletions
diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py index e95b978ae..6e4732989 100644 --- a/examples/sharding/attribute_shard.py +++ b/examples/sharding/attribute_shard.py @@ -34,13 +34,15 @@ db4 = create_engine('sqlite:///shard4.db', echo=echo)  # step 3. create session function.  this binds the shard ids  # to databases within a ShardedSession and returns it. -def create_session(): -    s = ShardedSession(shard_chooser, id_chooser, query_chooser) -    s.bind_shard('north_america', db1) -    s.bind_shard('asia', db2) -    s.bind_shard('europe', db3) -    s.bind_shard('south_america', db4) -    return s +create_session = sessionmaker(class_=ShardedSession) + +create_session.configure(shards={ +    'north_america':db1, +    'asia':db2, +    'europe':db3, +    'south_america':db4 +}) +  # step 4.  table setup.  meta = MetaData() @@ -143,6 +145,9 @@ def query_chooser(query):      else:          return ids +# further configure create_session to use these functions +create_session.configure(shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) +  # step 6.  mapped classes.      class WeatherLocation(object):      def __init__(self, continent, city): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index f982da536..80c1a5b0d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -14,14 +14,17 @@ from sqlalchemy.orm.mapper import global_extensions  __all__ = ['Session', 'SessionTransaction'] -def sessionmaker(autoflush=True, transactional=True, bind=None, **kwargs): +def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs):      """Generate a Session configuration."""      kwargs['bind'] = bind      kwargs['autoflush'] = autoflush      kwargs['transactional'] = transactional -    class Sess(Session): +    if class_ is None: +        class_ = Session +         +    class Sess(class_):          def __init__(self, **local_kwargs):              for k in kwargs:                  local_kwargs.setdefault(k, kwargs[k]) diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py index 26d03372f..9d4396d2b 100644 --- a/lib/sqlalchemy/orm/shard.py +++ b/lib/sqlalchemy/orm/shard.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Query  __all__ = ['ShardedSession', 'ShardedQuery']  class ShardedSession(Session): -    def __init__(self, shard_chooser, id_chooser, query_chooser, **kwargs): +    def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):          """construct a ShardedSession.              shard_chooser @@ -32,6 +32,9 @@ class ShardedSession(Session):          self.__binds = {}          self._mapper_flush_opts = {'connection_callable':self.connection}          self._query_cls = ShardedQuery +        if shards is not None: +            for k in shards: +                self.bind_shard(k, shards[k])      def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):          if shard_id is None: diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/shard.py index faa980cc2..c1dd63d65 100644 --- a/test/orm/sharding/shard.py +++ b/test/orm/sharding/shard.py @@ -90,14 +90,16 @@ class ShardTest(PersistTest):                  return ['north_america', 'asia', 'europe', 'south_america']              else:                  return ids - -        def create_session(): -            s = ShardedSession(shard_chooser, id_chooser, query_chooser) -            s.bind_shard('north_america', db1) -            s.bind_shard('asia', db2) -            s.bind_shard('europe', db3) -            s.bind_shard('south_america', db4) -            return s +         +        create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True) + +        create_session.configure(shards={ +            'north_america':db1, +            'asia':db2, +            'europe':db3, +            'south_america':db4 +        }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) +              def setup_mappers(self):          global WeatherLocation, Report @@ -133,10 +135,13 @@ class ShardTest(PersistTest):          sess = create_session()          for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:              sess.save(c) -        sess.flush() +        sess.commit()          sess.clear() +        assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')] +        assert db1.execute(weather_locations.select()).fetchall() == [(2, 'North America', 'New York'), (3, 'North America', 'Toronto')] +                  t = sess.query(WeatherLocation).get(tokyo.id)          assert t.city == tokyo.city          assert t.reports[0].temperature == 80.0  | 
