diff options
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
-rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 61 |
1 files changed, 38 insertions, 23 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index f86e4fc93..7248e5b4d 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -20,7 +20,7 @@ from .. import util from ..orm.session import Session from ..orm.query import Query -__all__ = ['ShardedSession', 'ShardedQuery'] +__all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): @@ -43,12 +43,10 @@ class ShardedQuery(Query): def _execute_and_instances(self, context): def iter_for_shard(shard_id): - context.attributes['shard_id'] = context.identity_token = shard_id + context.attributes["shard_id"] = context.identity_token = shard_id result = self._connection_from_session( - mapper=self._bind_mapper(), - shard_id=shard_id).execute( - context.statement, - self._params) + mapper=self._bind_mapper(), shard_id=shard_id + ).execute(context.statement, self._params) return self.instances(result, context) if context.identity_token is not None: @@ -70,7 +68,8 @@ class ShardedQuery(Query): mapper=mapper, shard_id=shard_id, clause=stmt, - close_with_result=True) + close_with_result=True, + ) result = conn.execute(stmt, self._params) return result @@ -87,8 +86,13 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) def _identity_lookup( - self, mapper, primary_key_identity, identity_token=None, - lazy_loaded_from=None, **kw): + self, + mapper, + primary_key_identity, + identity_token=None, + lazy_loaded_from=None, + **kw + ): """override the default Query._identity_lookup method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -97,8 +101,10 @@ class ShardedQuery(Query): if identity_token is not None: return super(ShardedQuery, self)._identity_lookup( - mapper, primary_key_identity, - identity_token=identity_token, **kw + mapper, + primary_key_identity, + identity_token=identity_token, + **kw ) else: q = self.session.query(mapper) @@ -113,13 +119,13 @@ class ShardedQuery(Query): return None - def _get_impl( - self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): """Override the default Query._get_impl() method so that we emit a query to the DB for each possible identity token, if we don't have one already. """ + def _db_load_fn(query, primary_key_identity): # load from the database. The original db_load_fn will # use the given Query object to load from the DB, so our @@ -142,7 +148,8 @@ class ShardedQuery(Query): identity_token = self._shard_id return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token) + primary_key_identity, _db_load_fn, identity_token=identity_token + ) class ShardedResult(object): @@ -158,7 +165,7 @@ class ShardedResult(object): .. versionadded:: 1.3 """ - __slots__ = ('result_proxies', 'aggregate_rowcount',) + __slots__ = ("result_proxies", "aggregate_rowcount") def __init__(self, result_proxies, aggregate_rowcount): self.result_proxies = result_proxies @@ -168,9 +175,17 @@ class ShardedResult(object): def rowcount(self): return self.aggregate_rowcount + class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, - query_cls=ShardedQuery, **kwargs): + def __init__( + self, + shard_chooser, + id_chooser, + query_chooser, + shards=None, + query_cls=ShardedQuery, + **kwargs + ): """Construct a ShardedSession. :param shard_chooser: A callable which, passed a Mapper, a mapped @@ -225,16 +240,16 @@ class ShardedSession(Session): return self.transaction.connection(mapper, shard_id=shard_id) else: return self.get_bind( - mapper, - shard_id=shard_id, - instance=instance + mapper, shard_id=shard_id, instance=instance ).contextual_connect(**kwargs) - def get_bind(self, mapper, shard_id=None, - instance=None, clause=None, **kw): + def get_bind( + self, mapper, shard_id=None, instance=None, clause=None, **kw + ): if shard_id is None: shard_id = self._choose_shard_and_assign( - mapper, instance, clause=clause) + mapper, instance, clause=clause + ) return self.__binds[shard_id] def bind_shard(self, shard_id, bind): |