diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/ext/asyncio/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/asyncio/base.py | 44 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 56 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 112 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/state.py | 35 |
5 files changed, 212 insertions, 37 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index 2fda2d777..349bc1b75 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -14,5 +14,7 @@ from .events import AsyncSessionEvents from .result import AsyncMappingResult from .result import AsyncResult from .result import AsyncScalarResult +from .session import async_object_session +from .session import async_session from .session import AsyncSession from .session import AsyncSessionTransaction diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 76a2fbbde..3f2c084f4 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -1,8 +1,50 @@ import abc +import functools +import weakref from . import exc as async_exc +class ReversibleProxy: + # weakref.ref(async proxy object) -> weakref.ref(sync proxied object) + _proxy_objects = {} + + def _assign_proxied(self, target): + if target is not None: + target_ref = weakref.ref(target, ReversibleProxy._target_gced) + proxy_ref = weakref.ref( + self, + functools.partial(ReversibleProxy._target_gced, target_ref), + ) + ReversibleProxy._proxy_objects[target_ref] = proxy_ref + + return target + + @classmethod + def _target_gced(cls, ref, proxy_ref=None): + cls._proxy_objects.pop(ref, None) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + raise NotImplementedError() + + @classmethod + def _retrieve_proxy_for_target(cls, target, regenerate=True): + try: + proxy_ref = cls._proxy_objects[weakref.ref(target)] + except KeyError: + pass + else: + proxy = proxy_ref() + if proxy is not None: + return proxy + + if regenerate: + return cls._regenerate_proxy_for_target(target) + else: + return None + + class StartableContext(abc.ABC): @abc.abstractmethod async def start(self, is_ctxmanager=False): @@ -25,7 +67,7 @@ class StartableContext(abc.ABC): ) -class ProxyComparable: +class ProxyComparable(ReversibleProxy): def __hash__(self): return id(self) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 9cd3cb2f8..8e5c01919 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -11,6 +11,7 @@ from .result import AsyncResult from ... import exc from ... import util from ...engine import create_engine as _create_engine +from ...engine.base import NestedTransaction from ...future import Connection from ...future import Engine from ...util.concurrency import greenlet_spawn @@ -86,7 +87,13 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): def __init__(self, async_engine, sync_connection=None): self.engine = async_engine self.sync_engine = async_engine.sync_engine - self.sync_connection = sync_connection + self.sync_connection = self._assign_proxied(sync_connection) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncConnection( + AsyncEngine._retrieve_proxy_for_target(target.engine), target + ) async def start(self, is_ctxmanager=False): """Start this :class:`_asyncio.AsyncConnection` object's context @@ -95,7 +102,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ if self.sync_connection: raise exc.InvalidRequestError("connection is already started") - self.sync_connection = await (greenlet_spawn(self.sync_engine.connect)) + self.sync_connection = self._assign_proxied( + await (greenlet_spawn(self.sync_engine.connect)) + ) return self @property @@ -216,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): trans = conn.get_transaction() if trans is not None: - return AsyncTransaction._from_existing_transaction(self, trans) + return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None @@ -236,9 +245,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): trans = conn.get_nested_transaction() if trans is not None: - return AsyncTransaction._from_existing_transaction( - self, trans, True - ) + return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None @@ -522,7 +529,11 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): "The asyncio extension requires an async driver to be used. " f"The loaded {sync_engine.dialect.driver!r} is not async." ) - self.sync_engine = self._proxied = sync_engine + self.sync_engine = self._proxied = self._assign_proxied(sync_engine) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncEngine(target) def begin(self): """Return a context manager which when entered will deliver an @@ -605,17 +616,24 @@ class AsyncTransaction(ProxyComparable, StartableContext): __slots__ = ("connection", "sync_transaction", "nested") def __init__(self, connection, nested=False): - self.connection = connection - self.sync_transaction = None + self.connection = connection # AsyncConnection + self.sync_transaction = None # sqlalchemy.engine.Transaction self.nested = nested @classmethod - def _from_existing_transaction( - cls, connection, sync_transaction, nested=False - ): + def _regenerate_proxy_for_target(cls, target): + sync_connection = target.connection + sync_transaction = target + nested = isinstance(target, NestedTransaction) + + async_connection = AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + assert async_connection is not None + obj = cls.__new__(cls) - obj.connection = connection - obj.sync_transaction = sync_transaction + obj.connection = async_connection + obj.sync_transaction = obj._assign_proxied(sync_transaction) obj.nested = nested return obj @@ -664,10 +682,12 @@ class AsyncTransaction(ProxyComparable, StartableContext): """ - self.sync_transaction = await greenlet_spawn( - self.connection._sync_connection().begin_nested - if self.nested - else self.connection._sync_connection().begin + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.connection._sync_connection().begin_nested + if self.nested + else self.connection._sync_connection().begin + ) ) if is_ctxmanager: self.sync_transaction.__enter__() diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 343465f37..16e15c873 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -6,9 +6,12 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from . import engine from . import result as _result +from .base import ReversibleProxy from .base import StartableContext from ... import util +from ...orm import object_session from ...orm import Session +from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn @@ -29,6 +32,7 @@ from ...util.concurrency import greenlet_spawn "get_bind", "is_modified", "in_transaction", + "in_nested_transaction", ], attributes=[ "dirty", @@ -41,7 +45,7 @@ from ...util.concurrency import greenlet_spawn "info", ], ) -class AsyncSession: +class AsyncSession(ReversibleProxy): """Asyncio version of :class:`_orm.Session`. @@ -72,8 +76,8 @@ class AsyncSession: for key, b in binds.items() } - self.sync_session = self._proxied = Session( - bind=bind, binds=binds, **kw + self.sync_session = self._proxied = self._assign_proxied( + Session(bind=bind, binds=binds, **kw) ) async def refresh( @@ -242,21 +246,46 @@ class AsyncSession: """ await greenlet_spawn(self.sync_session.flush, objects=objects) + def get_transaction(self): + """Return the current root transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + trans = self.sync_session.get_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self): + """Return the current nested transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + + trans = self.sync_session.get_nested_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + async def connection(self): - r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this - :class:`.Session` object's transactional state. + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. """ - # POSSIBLY TODO: here, we see that the sync engine / connection - # that are generated from AsyncEngine / AsyncConnection don't - # provide any backlink from those sync objects back out to the - # async ones. it's not *too* big a deal since AsyncEngine/Connection - # are just proxies and all the state is actually in the sync - # version of things. However! it has to stay that way :) sync_connection = await greenlet_spawn(self.sync_session.connection) - return engine.AsyncConnection( - engine.AsyncEngine(sync_connection.engine), sync_connection + return engine.AsyncConnection._retrieve_proxy_for_target( + sync_connection ) def begin(self, **kw): @@ -363,7 +392,7 @@ class _AsyncSessionContextManager: await self.async_session.__aexit__(type_, value, traceback) -class AsyncSessionTransaction(StartableContext): +class AsyncSessionTransaction(ReversibleProxy, StartableContext): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. This object is provided so that a transaction-holding object @@ -408,10 +437,12 @@ class AsyncSessionTransaction(StartableContext): await greenlet_spawn(self._sync_transaction().commit) async def start(self, is_ctxmanager=False): - self.sync_transaction = await greenlet_spawn( - self.session.sync_session.begin_nested - if self.nested - else self.session.sync_session.begin + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.session.sync_session.begin_nested + if self.nested + else self.session.sync_session.begin + ) ) if is_ctxmanager: self.sync_transaction.__enter__() @@ -421,3 +452,48 @@ class AsyncSessionTransaction(StartableContext): await greenlet_spawn( self._sync_transaction().__exit__, type_, value, traceback ) + + +def async_object_session(instance): + """Return the :class:`_asyncio.AsyncSession` to which the given instance + belongs. + + This function makes use of the sync-API function + :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which + refers to the given instance, and from there links it to the original + :class:`_asyncio.AsyncSession`. + + If the :class:`_asyncio.AsyncSession` has been garbage collected, the + return value is ``None``. + + This functionality is also available from the + :attr:`_orm.InstanceState.async_session` accessor. + + :param instance: an ORM mapped instance + :return: an :class:`_asyncio.AsyncSession` object, or ``None``. + + .. versionadded:: 1.4.18 + + """ + + session = object_session(instance) + if session is not None: + return async_session(session) + else: + return None + + +def async_session(session): + """Return the :class:`_asyncio.AsyncSession` which is proxying the given + :class:`_orm.Session` object, if any. + + :param session: a :class:`_orm.Session` instance. + :return: a :class:`_asyncio.AsyncSession` instance, or ``None``. + + .. versionadded:: 1.4.18 + + """ + return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) + + +_instance_state._async_provider = async_session diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 08390328e..884e364c6 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -34,6 +34,9 @@ from .. import util # late-populated by session.py _sessions = None +# optionally late-provided by sqlalchemy.ext.asyncio.session +_async_provider = None + @inspection._self_inspects class InstanceState(interfaces.InspectionAttrInfo): @@ -262,6 +265,10 @@ class InstanceState(interfaces.InspectionAttrInfo): Only when the transaction is completed does the object become fully detached under normal circumstances. + .. seealso:: + + :attr:`_orm.InstanceState.async_session` + """ if self.session_id: try: @@ -271,6 +278,34 @@ class InstanceState(interfaces.InspectionAttrInfo): return None @property + def async_session(self): + """Return the owning :class:`_asyncio.AsyncSession` for this instance, + or ``None`` if none available. + + This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio` + API is in use for this ORM object. The returned + :class:`_asyncio.AsyncSession` object will be a proxy for the + :class:`_orm.Session` object that would be returned from the + :attr:`_orm.InstanceState.session` attribute for this + :class:`_orm.InstanceState`. + + .. versionadded:: 1.4.18 + + .. seealso:: + + :ref:`asyncio_toplevel` + + """ + if _async_provider is None: + return None + + sess = self.session + if sess is not None: + return _async_provider(sess) + else: + return None + + @property def object(self): """Return the mapped object represented by this :class:`.InstanceState`.""" |
