diff options
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 53 |
1 files changed, 38 insertions, 15 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 1375a24cd..c3ac71c10 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -15,8 +15,10 @@ the source distribution. """ -from sqlalchemy import event +from .. import event +from .. import exc from .. import inspect +from .. import util from ..orm.query import Query from ..orm.session import Session @@ -28,6 +30,7 @@ class ShardedQuery(Query): super(ShardedQuery, self).__init__(*args, **kwargs) self.id_chooser = self.session.id_chooser self.query_chooser = self.session.query_chooser + self.execute_chooser = self.session.execute_chooser self._shard_id = None def set_shard(self, shard_id): @@ -45,10 +48,7 @@ class ShardedQuery(Query): ) """ - - q = self._clone() - q._shard_id = shard_id - return q + return self.execution_options(_sa_shard_id=shard_id) def _execute_crud(self, stmt, mapper): def exec_for_shard(shard_id): @@ -68,7 +68,8 @@ class ShardedQuery(Query): else: rowcount = 0 results = [] - for shard_id in self.query_chooser(self): + # TODO: this will have to be the new object + for shard_id in self.execute_chooser(self): result = exec_for_shard(shard_id) rowcount += result.rowcount results.append(result) @@ -107,7 +108,7 @@ class ShardedSession(Session): self, shard_chooser, id_chooser, - query_chooser, + execute_chooser=None, shards=None, query_cls=ShardedQuery, **kwargs @@ -125,14 +126,19 @@ class ShardedSession(Session): values, which should return a list of shard ids where the ID might reside. The databases will be queried in the order of this listing. - :param query_chooser: For a given Query, returns the list of shard_ids + :param execute_chooser: For a given :class:`.ORMExecuteState`, + returns the list of shard_ids where the query should be issued. Results from all shards returned will be combined together into a single listing. + .. versionchanged:: 1.4 The ``execute_chooser`` paramter + supersedes the ``query_chooser`` parameter. + :param shards: A dictionary of string shard names to :class:`~sqlalchemy.engine.Engine` objects. """ + query_chooser = kwargs.pop("query_chooser", None) super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) event.listen( @@ -140,6 +146,25 @@ class ShardedSession(Session): ) self.shard_chooser = shard_chooser self.id_chooser = id_chooser + + if query_chooser: + util.warn_deprecated( + "The ``query_choser`` parameter is deprecated; " + "please use ``execute_chooser``.", + "1.4", + ) + if execute_chooser: + raise exc.ArgumentError( + "Can't pass query_chooser and execute_chooser " + "at the same time." + ) + + def execute_chooser(orm_context): + return query_chooser(orm_context.statement) + + self.execute_chooser = execute_chooser + else: + self.execute_chooser = execute_chooser self.query_chooser = query_chooser self.__binds = {} if shards is not None: @@ -241,13 +266,13 @@ def execute_and_instances(orm_context): load_options = orm_context.load_options session = orm_context.session - orm_query = orm_context.orm_query + # orm_query = orm_context.orm_query if params is None: params = load_options._params def iter_for_shard(shard_id, load_options): - execution_options = dict(orm_context.execution_options) + execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) bind_arguments["_horizontal_shard"] = True @@ -265,8 +290,8 @@ def execute_and_instances(orm_context): if load_options._refresh_identity_token is not None: shard_id = load_options._refresh_identity_token - elif orm_query is not None and orm_query._shard_id is not None: - shard_id = orm_query._shard_id + elif "_sa_shard_id" in orm_context.merged_execution_options: + shard_id = orm_context.merged_execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: shard_id = orm_context.bind_arguments["shard_id"] else: @@ -276,9 +301,7 @@ def execute_and_instances(orm_context): return iter_for_shard(shard_id, load_options) else: partial = [] - for shard_id in session.query_chooser( - orm_query if orm_query is not None else orm_context.statement - ): + for shard_id in session.execute_chooser(orm_context): result_ = iter_for_shard(shard_id, load_options) partial.append(result_) |
