diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/ext/horizontal_shard.py | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 61 |
1 files changed, 38 insertions, 23 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index f86e4fc93..7248e5b4d 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -20,7 +20,7 @@ from .. import util from ..orm.session import Session from ..orm.query import Query -__all__ = ['ShardedSession', 'ShardedQuery'] +__all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): @@ -43,12 +43,10 @@ class ShardedQuery(Query): def _execute_and_instances(self, context): def iter_for_shard(shard_id): - context.attributes['shard_id'] = context.identity_token = shard_id + context.attributes["shard_id"] = context.identity_token = shard_id result = self._connection_from_session( - mapper=self._bind_mapper(), - shard_id=shard_id).execute( - context.statement, - self._params) + mapper=self._bind_mapper(), shard_id=shard_id + ).execute(context.statement, self._params) return self.instances(result, context) if context.identity_token is not None: @@ -70,7 +68,8 @@ class ShardedQuery(Query): mapper=mapper, shard_id=shard_id, clause=stmt, - close_with_result=True) + close_with_result=True, + ) result = conn.execute(stmt, self._params) return result @@ -87,8 +86,13 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) def _identity_lookup( - self, mapper, primary_key_identity, identity_token=None, - lazy_loaded_from=None, **kw): + self, + mapper, + primary_key_identity, + identity_token=None, + lazy_loaded_from=None, + **kw + ): """override the default Query._identity_lookup method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -97,8 +101,10 @@ class ShardedQuery(Query): if identity_token is not None: return super(ShardedQuery, self)._identity_lookup( - mapper, primary_key_identity, - identity_token=identity_token, **kw + mapper, + primary_key_identity, + identity_token=identity_token, + **kw ) else: q = self.session.query(mapper) @@ -113,13 +119,13 @@ class ShardedQuery(Query): return None - def _get_impl( - self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): """Override the default Query._get_impl() method so that we emit a query to the DB for each possible identity token, if we don't have one already. """ + def _db_load_fn(query, primary_key_identity): # load from the database. The original db_load_fn will # use the given Query object to load from the DB, so our @@ -142,7 +148,8 @@ class ShardedQuery(Query): identity_token = self._shard_id return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token) + primary_key_identity, _db_load_fn, identity_token=identity_token + ) class ShardedResult(object): @@ -158,7 +165,7 @@ class ShardedResult(object): .. versionadded:: 1.3 """ - __slots__ = ('result_proxies', 'aggregate_rowcount',) + __slots__ = ("result_proxies", "aggregate_rowcount") def __init__(self, result_proxies, aggregate_rowcount): self.result_proxies = result_proxies @@ -168,9 +175,17 @@ class ShardedResult(object): def rowcount(self): return self.aggregate_rowcount + class ShardedSession(Session): - def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, - query_cls=ShardedQuery, **kwargs): + def __init__( + self, + shard_chooser, + id_chooser, + query_chooser, + shards=None, + query_cls=ShardedQuery, + **kwargs + ): """Construct a ShardedSession. :param shard_chooser: A callable which, passed a Mapper, a mapped @@ -225,16 +240,16 @@ class ShardedSession(Session): return self.transaction.connection(mapper, shard_id=shard_id) else: return self.get_bind( - mapper, - shard_id=shard_id, - instance=instance + mapper, shard_id=shard_id, instance=instance ).contextual_connect(**kwargs) - def get_bind(self, mapper, shard_id=None, - instance=None, clause=None, **kw): + def get_bind( + self, mapper, shard_id=None, instance=None, clause=None, **kw + ): if shard_id is None: shard_id = self._choose_shard_and_assign( - mapper, instance, clause=clause) + mapper, instance, clause=clause + ) return self.__binds[shard_id] def bind_shard(self, shard_id, bind): |
