summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/horizontal_shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py61
1 files changed, 38 insertions, 23 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index f86e4fc93..7248e5b4d 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -20,7 +20,7 @@ from .. import util
from ..orm.session import Session
from ..orm.query import Query
-__all__ = ['ShardedSession', 'ShardedQuery']
+__all__ = ["ShardedSession", "ShardedQuery"]
class ShardedQuery(Query):
@@ -43,12 +43,10 @@ class ShardedQuery(Query):
def _execute_and_instances(self, context):
def iter_for_shard(shard_id):
- context.attributes['shard_id'] = context.identity_token = shard_id
+ context.attributes["shard_id"] = context.identity_token = shard_id
result = self._connection_from_session(
- mapper=self._bind_mapper(),
- shard_id=shard_id).execute(
- context.statement,
- self._params)
+ mapper=self._bind_mapper(), shard_id=shard_id
+ ).execute(context.statement, self._params)
return self.instances(result, context)
if context.identity_token is not None:
@@ -70,7 +68,8 @@ class ShardedQuery(Query):
mapper=mapper,
shard_id=shard_id,
clause=stmt,
- close_with_result=True)
+ close_with_result=True,
+ )
result = conn.execute(stmt, self._params)
return result
@@ -87,8 +86,13 @@ class ShardedQuery(Query):
return ShardedResult(results, rowcount)
def _identity_lookup(
- self, mapper, primary_key_identity, identity_token=None,
- lazy_loaded_from=None, **kw):
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ lazy_loaded_from=None,
+ **kw
+ ):
"""override the default Query._identity_lookup method so that we
search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
@@ -97,8 +101,10 @@ class ShardedQuery(Query):
if identity_token is not None:
return super(ShardedQuery, self)._identity_lookup(
- mapper, primary_key_identity,
- identity_token=identity_token, **kw
+ mapper,
+ primary_key_identity,
+ identity_token=identity_token,
+ **kw
)
else:
q = self.session.query(mapper)
@@ -113,13 +119,13 @@ class ShardedQuery(Query):
return None
- def _get_impl(
- self, primary_key_identity, db_load_fn, identity_token=None):
+ def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
"""Override the default Query._get_impl() method so that we emit
a query to the DB for each possible identity token, if we don't
have one already.
"""
+
def _db_load_fn(query, primary_key_identity):
# load from the database. The original db_load_fn will
# use the given Query object to load from the DB, so our
@@ -142,7 +148,8 @@ class ShardedQuery(Query):
identity_token = self._shard_id
return super(ShardedQuery, self)._get_impl(
- primary_key_identity, _db_load_fn, identity_token=identity_token)
+ primary_key_identity, _db_load_fn, identity_token=identity_token
+ )
class ShardedResult(object):
@@ -158,7 +165,7 @@ class ShardedResult(object):
.. versionadded:: 1.3
"""
- __slots__ = ('result_proxies', 'aggregate_rowcount',)
+ __slots__ = ("result_proxies", "aggregate_rowcount")
def __init__(self, result_proxies, aggregate_rowcount):
self.result_proxies = result_proxies
@@ -168,9 +175,17 @@ class ShardedResult(object):
def rowcount(self):
return self.aggregate_rowcount
+
class ShardedSession(Session):
- def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
- query_cls=ShardedQuery, **kwargs):
+ def __init__(
+ self,
+ shard_chooser,
+ id_chooser,
+ query_chooser,
+ shards=None,
+ query_cls=ShardedQuery,
+ **kwargs
+ ):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
@@ -225,16 +240,16 @@ class ShardedSession(Session):
return self.transaction.connection(mapper, shard_id=shard_id)
else:
return self.get_bind(
- mapper,
- shard_id=shard_id,
- instance=instance
+ mapper, shard_id=shard_id, instance=instance
).contextual_connect(**kwargs)
- def get_bind(self, mapper, shard_id=None,
- instance=None, clause=None, **kw):
+ def get_bind(
+ self, mapper, shard_id=None, instance=None, clause=None, **kw
+ ):
if shard_id is None:
shard_id = self._choose_shard_and_assign(
- mapper, instance, clause=clause)
+ mapper, instance, clause=clause
+ )
return self.__binds[shard_id]
def bind_shard(self, shard_id, bind):