diff options
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 156 |
1 files changed, 81 insertions, 75 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 919f4409a..1375a24cd 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -15,10 +15,8 @@ the source distribution. """ -import copy - +from sqlalchemy import event from .. import inspect -from .. import util from ..orm.query import Query from ..orm.session import Session @@ -37,54 +35,32 @@ class ShardedQuery(Query): all subsequent operations with the returned query will be against the single shard regardless of other state. - """ - q = self._clone() - q._shard_id = shard_id - return q + The shard_id can be passed for a 2.0 style execution to the + bind_arguments dictionary of :meth:`.Session.execute`:: - def _execute_and_instances(self, context, params=None): - if params is None: - params = self.load_options._params - - def iter_for_shard(shard_id): - # shallow copy, so that each context may be used by - # ORM load events and similar. - copied_context = copy.copy(context) - copied_context.attributes = context.attributes.copy() - - copied_context.attributes[ - "shard_id" - ] = copied_context.identity_token = shard_id - result_ = self._connection_from_session( - mapper=context.compile_state._bind_mapper(), shard_id=shard_id - ).execute( - copied_context.compile_state.statement, - self.load_options._params, + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} ) - return self.instances(result_, copied_context) - if context.identity_token is not None: - return iter_for_shard(context.identity_token) - elif self._shard_id is not None: - return iter_for_shard(self._shard_id) - else: - partial = [] - for shard_id in self.query_chooser(self): - result_ = iter_for_shard(shard_id) - partial.append(result_) + """ - return partial[0].merge(*partial[1:]) + q = self._clone() + q._shard_id = shard_id + return q def _execute_crud(self, stmt, mapper): def exec_for_shard(shard_id): - conn = self._connection_from_session( + conn = self.session.connection( mapper=mapper, shard_id=shard_id, clause=stmt, close_with_result=True, ) - result = conn.execute(stmt, self.load_options._params) + result = conn._execute_20( + stmt, self.load_options._params, self._execution_options + ) return result if self._shard_id is not None: @@ -99,38 +75,6 @@ class ShardedQuery(Query): return ShardedResult(results, rowcount) - def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): - """Override the default Query._get_impl() method so that we emit - a query to the DB for each possible identity token, if we don't - have one already. - - """ - - def _db_load_fn(query, primary_key_identity): - # load from the database. The original db_load_fn will - # use the given Query object to load from the DB, so our - # shard_id is what will indicate the DB that we query from. - if self._shard_id is not None: - return db_load_fn(self, primary_key_identity) - else: - ident = util.to_list(primary_key_identity) - # build a ShardedQuery for each shard identifier and - # try to load from the DB - for shard_id in self.id_chooser(self, ident): - q = self.set_shard(shard_id) - o = db_load_fn(q, ident) - if o is not None: - return o - else: - return None - - if identity_token is None and self._shard_id is not None: - identity_token = self._shard_id - - return super(ShardedQuery, self)._get_impl( - primary_key_identity, _db_load_fn, identity_token=identity_token - ) - class ShardedResult(object): """A value object that represents multiple :class:`_engine.CursorResult` @@ -190,11 +134,14 @@ class ShardedSession(Session): """ super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + + event.listen( + self, "do_orm_execute", execute_and_instances, retval=True + ) self.shard_chooser = shard_chooser self.id_chooser = id_chooser self.query_chooser = query_chooser self.__binds = {} - self.connection_callable = self.connection if shards is not None: for k in shards: self.bind_shard(k, shards[k]) @@ -207,8 +154,8 @@ class ShardedSession(Session): lazy_loaded_from=None, **kw ): - """override the default :meth:`.Session._identity_lookup` method so that we - search for a given non-token primary key identity across all + """override the default :meth:`.Session._identity_lookup` method so + that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from @@ -255,7 +202,14 @@ class ShardedSession(Session): state.identity_token = shard_id return shard_id - def connection(self, mapper=None, instance=None, shard_id=None, **kwargs): + def connection_callable( + self, mapper=None, instance=None, shard_id=None, **kwargs + ): + """Provide a :class:`_engine.Connection` to use in the unit of work + flush process. + + """ + if shard_id is None: shard_id = self._choose_shard_and_assign(mapper, instance) @@ -267,7 +221,7 @@ class ShardedSession(Session): ).connect(**kwargs) def get_bind( - self, mapper, shard_id=None, instance=None, clause=None, **kw + self, mapper=None, shard_id=None, instance=None, clause=None, **kw ): if shard_id is None: shard_id = self._choose_shard_and_assign( @@ -277,3 +231,55 @@ class ShardedSession(Session): def bind_shard(self, shard_id, bind): self.__binds[shard_id] = bind + + +def execute_and_instances(orm_context): + if orm_context.bind_arguments.get("_horizontal_shard", False): + return None + + params = orm_context.parameters + + load_options = orm_context.load_options + session = orm_context.session + orm_query = orm_context.orm_query + + if params is None: + params = load_options._params + + def iter_for_shard(shard_id, load_options): + execution_options = dict(orm_context.execution_options) + + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["_horizontal_shard"] = True + bind_arguments["shard_id"] = shard_id + + load_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_load_options"] = load_options + + return session.execute( + orm_context.statement, + orm_context.parameters, + execution_options, + bind_arguments, + ) + + if load_options._refresh_identity_token is not None: + shard_id = load_options._refresh_identity_token + elif orm_query is not None and orm_query._shard_id is not None: + shard_id = orm_query._shard_id + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None + + if shard_id is not None: + return iter_for_shard(shard_id, load_options) + else: + partial = [] + for shard_id in session.query_chooser( + orm_query if orm_query is not None else orm_context.statement + ): + result_ = iter_for_shard(shard_id, load_options) + partial.append(result_) + + return partial[0].merge(*partial[1:]) |
