summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/horizontal_shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py156
1 files changed, 81 insertions, 75 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index 919f4409a..1375a24cd 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -15,10 +15,8 @@ the source distribution.
"""
-import copy
-
+from sqlalchemy import event
from .. import inspect
-from .. import util
from ..orm.query import Query
from ..orm.session import Session
@@ -37,54 +35,32 @@ class ShardedQuery(Query):
all subsequent operations with the returned query will
be against the single shard regardless of other state.
- """
- q = self._clone()
- q._shard_id = shard_id
- return q
+ The shard_id can be passed for a 2.0 style execution to the
+ bind_arguments dictionary of :meth:`.Session.execute`::
- def _execute_and_instances(self, context, params=None):
- if params is None:
- params = self.load_options._params
-
- def iter_for_shard(shard_id):
- # shallow copy, so that each context may be used by
- # ORM load events and similar.
- copied_context = copy.copy(context)
- copied_context.attributes = context.attributes.copy()
-
- copied_context.attributes[
- "shard_id"
- ] = copied_context.identity_token = shard_id
- result_ = self._connection_from_session(
- mapper=context.compile_state._bind_mapper(), shard_id=shard_id
- ).execute(
- copied_context.compile_state.statement,
- self.load_options._params,
+ results = session.execute(
+ stmt,
+ bind_arguments={"shard_id": "my_shard"}
)
- return self.instances(result_, copied_context)
- if context.identity_token is not None:
- return iter_for_shard(context.identity_token)
- elif self._shard_id is not None:
- return iter_for_shard(self._shard_id)
- else:
- partial = []
- for shard_id in self.query_chooser(self):
- result_ = iter_for_shard(shard_id)
- partial.append(result_)
+ """
- return partial[0].merge(*partial[1:])
+ q = self._clone()
+ q._shard_id = shard_id
+ return q
def _execute_crud(self, stmt, mapper):
def exec_for_shard(shard_id):
- conn = self._connection_from_session(
+ conn = self.session.connection(
mapper=mapper,
shard_id=shard_id,
clause=stmt,
close_with_result=True,
)
- result = conn.execute(stmt, self.load_options._params)
+ result = conn._execute_20(
+ stmt, self.load_options._params, self._execution_options
+ )
return result
if self._shard_id is not None:
@@ -99,38 +75,6 @@ class ShardedQuery(Query):
return ShardedResult(results, rowcount)
- def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
- """Override the default Query._get_impl() method so that we emit
- a query to the DB for each possible identity token, if we don't
- have one already.
-
- """
-
- def _db_load_fn(query, primary_key_identity):
- # load from the database. The original db_load_fn will
- # use the given Query object to load from the DB, so our
- # shard_id is what will indicate the DB that we query from.
- if self._shard_id is not None:
- return db_load_fn(self, primary_key_identity)
- else:
- ident = util.to_list(primary_key_identity)
- # build a ShardedQuery for each shard identifier and
- # try to load from the DB
- for shard_id in self.id_chooser(self, ident):
- q = self.set_shard(shard_id)
- o = db_load_fn(q, ident)
- if o is not None:
- return o
- else:
- return None
-
- if identity_token is None and self._shard_id is not None:
- identity_token = self._shard_id
-
- return super(ShardedQuery, self)._get_impl(
- primary_key_identity, _db_load_fn, identity_token=identity_token
- )
-
class ShardedResult(object):
"""A value object that represents multiple :class:`_engine.CursorResult`
@@ -190,11 +134,14 @@ class ShardedSession(Session):
"""
super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
+
+ event.listen(
+ self, "do_orm_execute", execute_and_instances, retval=True
+ )
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
- self.connection_callable = self.connection
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
@@ -207,8 +154,8 @@ class ShardedSession(Session):
lazy_loaded_from=None,
**kw
):
- """override the default :meth:`.Session._identity_lookup` method so that we
- search for a given non-token primary key identity across all
+ """override the default :meth:`.Session._identity_lookup` method so
+ that we search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
.. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
@@ -255,7 +202,14 @@ class ShardedSession(Session):
state.identity_token = shard_id
return shard_id
- def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
+ def connection_callable(
+ self, mapper=None, instance=None, shard_id=None, **kwargs
+ ):
+ """Provide a :class:`_engine.Connection` to use in the unit of work
+ flush process.
+
+ """
+
if shard_id is None:
shard_id = self._choose_shard_and_assign(mapper, instance)
@@ -267,7 +221,7 @@ class ShardedSession(Session):
).connect(**kwargs)
def get_bind(
- self, mapper, shard_id=None, instance=None, clause=None, **kw
+ self, mapper=None, shard_id=None, instance=None, clause=None, **kw
):
if shard_id is None:
shard_id = self._choose_shard_and_assign(
@@ -277,3 +231,55 @@ class ShardedSession(Session):
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind
+
+
+def execute_and_instances(orm_context):
+ if orm_context.bind_arguments.get("_horizontal_shard", False):
+ return None
+
+ params = orm_context.parameters
+
+ load_options = orm_context.load_options
+ session = orm_context.session
+ orm_query = orm_context.orm_query
+
+ if params is None:
+ params = load_options._params
+
+ def iter_for_shard(shard_id, load_options):
+ execution_options = dict(orm_context.execution_options)
+
+ bind_arguments = dict(orm_context.bind_arguments)
+ bind_arguments["_horizontal_shard"] = True
+ bind_arguments["shard_id"] = shard_id
+
+ load_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_load_options"] = load_options
+
+ return session.execute(
+ orm_context.statement,
+ orm_context.parameters,
+ execution_options,
+ bind_arguments,
+ )
+
+ if load_options._refresh_identity_token is not None:
+ shard_id = load_options._refresh_identity_token
+ elif orm_query is not None and orm_query._shard_id is not None:
+ shard_id = orm_query._shard_id
+ elif "shard_id" in orm_context.bind_arguments:
+ shard_id = orm_context.bind_arguments["shard_id"]
+ else:
+ shard_id = None
+
+ if shard_id is not None:
+ return iter_for_shard(shard_id, load_options)
+ else:
+ partial = []
+ for shard_id in session.query_chooser(
+ orm_query if orm_query is not None else orm_context.statement
+ ):
+ result_ = iter_for_shard(shard_id, load_options)
+ partial.append(result_)
+
+ return partial[0].merge(*partial[1:])