summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ext/asyncio/__init__.py2
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py44
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py56
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py112
-rw-r--r--lib/sqlalchemy/orm/state.py35
5 files changed, 212 insertions, 37 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
index 2fda2d777..349bc1b75 100644
--- a/lib/sqlalchemy/ext/asyncio/__init__.py
+++ b/lib/sqlalchemy/ext/asyncio/__init__.py
@@ -14,5 +14,7 @@ from .events import AsyncSessionEvents
from .result import AsyncMappingResult
from .result import AsyncResult
from .result import AsyncScalarResult
+from .session import async_object_session
+from .session import async_session
from .session import AsyncSession
from .session import AsyncSessionTransaction
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
index 76a2fbbde..3f2c084f4 100644
--- a/lib/sqlalchemy/ext/asyncio/base.py
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -1,8 +1,50 @@
import abc
+import functools
+import weakref
from . import exc as async_exc
+class ReversibleProxy:
+ # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
+ _proxy_objects = {}
+
+ def _assign_proxied(self, target):
+ if target is not None:
+ target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ proxy_ref = weakref.ref(
+ self,
+ functools.partial(ReversibleProxy._target_gced, target_ref),
+ )
+ ReversibleProxy._proxy_objects[target_ref] = proxy_ref
+
+ return target
+
+ @classmethod
+ def _target_gced(cls, ref, proxy_ref=None):
+ cls._proxy_objects.pop(ref, None)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ raise NotImplementedError()
+
+ @classmethod
+ def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ try:
+ proxy_ref = cls._proxy_objects[weakref.ref(target)]
+ except KeyError:
+ pass
+ else:
+ proxy = proxy_ref()
+ if proxy is not None:
+ return proxy
+
+ if regenerate:
+ return cls._regenerate_proxy_for_target(target)
+ else:
+ return None
+
+
class StartableContext(abc.ABC):
@abc.abstractmethod
async def start(self, is_ctxmanager=False):
@@ -25,7 +67,7 @@ class StartableContext(abc.ABC):
)
-class ProxyComparable:
+class ProxyComparable(ReversibleProxy):
def __hash__(self):
return id(self)
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 9cd3cb2f8..8e5c01919 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -11,6 +11,7 @@ from .result import AsyncResult
from ... import exc
from ... import util
from ...engine import create_engine as _create_engine
+from ...engine.base import NestedTransaction
from ...future import Connection
from ...future import Engine
from ...util.concurrency import greenlet_spawn
@@ -86,7 +87,13 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
def __init__(self, async_engine, sync_connection=None):
self.engine = async_engine
self.sync_engine = async_engine.sync_engine
- self.sync_connection = sync_connection
+ self.sync_connection = self._assign_proxied(sync_connection)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncConnection(
+ AsyncEngine._retrieve_proxy_for_target(target.engine), target
+ )
async def start(self, is_ctxmanager=False):
"""Start this :class:`_asyncio.AsyncConnection` object's context
@@ -95,7 +102,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
if self.sync_connection:
raise exc.InvalidRequestError("connection is already started")
- self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
+ self.sync_connection = self._assign_proxied(
+ await (greenlet_spawn(self.sync_engine.connect))
+ )
return self
@property
@@ -216,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
trans = conn.get_transaction()
if trans is not None:
- return AsyncTransaction._from_existing_transaction(self, trans)
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
@@ -236,9 +245,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
trans = conn.get_nested_transaction()
if trans is not None:
- return AsyncTransaction._from_existing_transaction(
- self, trans, True
- )
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
@@ -522,7 +529,11 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
"The asyncio extension requires an async driver to be used. "
f"The loaded {sync_engine.dialect.driver!r} is not async."
)
- self.sync_engine = self._proxied = sync_engine
+ self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncEngine(target)
def begin(self):
"""Return a context manager which when entered will deliver an
@@ -605,17 +616,24 @@ class AsyncTransaction(ProxyComparable, StartableContext):
__slots__ = ("connection", "sync_transaction", "nested")
def __init__(self, connection, nested=False):
- self.connection = connection
- self.sync_transaction = None
+ self.connection = connection # AsyncConnection
+ self.sync_transaction = None # sqlalchemy.engine.Transaction
self.nested = nested
@classmethod
- def _from_existing_transaction(
- cls, connection, sync_transaction, nested=False
- ):
+ def _regenerate_proxy_for_target(cls, target):
+ sync_connection = target.connection
+ sync_transaction = target
+ nested = isinstance(target, NestedTransaction)
+
+ async_connection = AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+ assert async_connection is not None
+
obj = cls.__new__(cls)
- obj.connection = connection
- obj.sync_transaction = sync_transaction
+ obj.connection = async_connection
+ obj.sync_transaction = obj._assign_proxied(sync_transaction)
obj.nested = nested
return obj
@@ -664,10 +682,12 @@ class AsyncTransaction(ProxyComparable, StartableContext):
"""
- self.sync_transaction = await greenlet_spawn(
- self.connection._sync_connection().begin_nested
- if self.nested
- else self.connection._sync_connection().begin
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.connection._sync_connection().begin_nested
+ if self.nested
+ else self.connection._sync_connection().begin
+ )
)
if is_ctxmanager:
self.sync_transaction.__enter__()
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 343465f37..16e15c873 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -6,9 +6,12 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import engine
from . import result as _result
+from .base import ReversibleProxy
from .base import StartableContext
from ... import util
+from ...orm import object_session
from ...orm import Session
+from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
@@ -29,6 +32,7 @@ from ...util.concurrency import greenlet_spawn
"get_bind",
"is_modified",
"in_transaction",
+ "in_nested_transaction",
],
attributes=[
"dirty",
@@ -41,7 +45,7 @@ from ...util.concurrency import greenlet_spawn
"info",
],
)
-class AsyncSession:
+class AsyncSession(ReversibleProxy):
"""Asyncio version of :class:`_orm.Session`.
@@ -72,8 +76,8 @@ class AsyncSession:
for key, b in binds.items()
}
- self.sync_session = self._proxied = Session(
- bind=bind, binds=binds, **kw
+ self.sync_session = self._proxied = self._assign_proxied(
+ Session(bind=bind, binds=binds, **kw)
)
async def refresh(
@@ -242,21 +246,46 @@ class AsyncSession:
"""
await greenlet_spawn(self.sync_session.flush, objects=objects)
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ trans = self.sync_session.get_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ trans = self.sync_session.get_nested_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
async def connection(self):
- r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this
- :class:`.Session` object's transactional state.
+ r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
+ this :class:`.Session` object's transactional state.
"""
- # POSSIBLY TODO: here, we see that the sync engine / connection
- # that are generated from AsyncEngine / AsyncConnection don't
- # provide any backlink from those sync objects back out to the
- # async ones. it's not *too* big a deal since AsyncEngine/Connection
- # are just proxies and all the state is actually in the sync
- # version of things. However! it has to stay that way :)
sync_connection = await greenlet_spawn(self.sync_session.connection)
- return engine.AsyncConnection(
- engine.AsyncEngine(sync_connection.engine), sync_connection
+ return engine.AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
)
def begin(self, **kw):
@@ -363,7 +392,7 @@ class _AsyncSessionContextManager:
await self.async_session.__aexit__(type_, value, traceback)
-class AsyncSessionTransaction(StartableContext):
+class AsyncSessionTransaction(ReversibleProxy, StartableContext):
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.
This object is provided so that a transaction-holding object
@@ -408,10 +437,12 @@ class AsyncSessionTransaction(StartableContext):
await greenlet_spawn(self._sync_transaction().commit)
async def start(self, is_ctxmanager=False):
- self.sync_transaction = await greenlet_spawn(
- self.session.sync_session.begin_nested
- if self.nested
- else self.session.sync_session.begin
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.session.sync_session.begin_nested
+ if self.nested
+ else self.session.sync_session.begin
+ )
)
if is_ctxmanager:
self.sync_transaction.__enter__()
@@ -421,3 +452,48 @@ class AsyncSessionTransaction(StartableContext):
await greenlet_spawn(
self._sync_transaction().__exit__, type_, value, traceback
)
+
+
+def async_object_session(instance):
+ """Return the :class:`_asyncio.AsyncSession` to which the given instance
+ belongs.
+
+ This function makes use of the sync-API function
+ :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
+ refers to the given instance, and from there links it to the original
+ :class:`_asyncio.AsyncSession`.
+
+ If the :class:`_asyncio.AsyncSession` has been garbage collected, the
+ return value is ``None``.
+
+ This functionality is also available from the
+ :attr:`_orm.InstanceState.async_session` accessor.
+
+ :param instance: an ORM mapped instance
+ :return: an :class:`_asyncio.AsyncSession` object, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ session = object_session(instance)
+ if session is not None:
+ return async_session(session)
+ else:
+ return None
+
+
+def async_session(session):
+ """Return the :class:`_asyncio.AsyncSession` which is proxying the given
+ :class:`_orm.Session` object, if any.
+
+ :param session: a :class:`_orm.Session` instance.
+ :return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
+
+
+_instance_state._async_provider = async_session
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 08390328e..884e364c6 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -34,6 +34,9 @@ from .. import util
# late-populated by session.py
_sessions = None
+# optionally late-provided by sqlalchemy.ext.asyncio.session
+_async_provider = None
+
@inspection._self_inspects
class InstanceState(interfaces.InspectionAttrInfo):
@@ -262,6 +265,10 @@ class InstanceState(interfaces.InspectionAttrInfo):
Only when the transaction is completed does the object become
fully detached under normal circumstances.
+ .. seealso::
+
+ :attr:`_orm.InstanceState.async_session`
+
"""
if self.session_id:
try:
@@ -271,6 +278,34 @@ class InstanceState(interfaces.InspectionAttrInfo):
return None
@property
+ def async_session(self):
+ """Return the owning :class:`_asyncio.AsyncSession` for this instance,
+ or ``None`` if none available.
+
+ This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio`
+ API is in use for this ORM object. The returned
+ :class:`_asyncio.AsyncSession` object will be a proxy for the
+ :class:`_orm.Session` object that would be returned from the
+ :attr:`_orm.InstanceState.session` attribute for this
+ :class:`_orm.InstanceState`.
+
+ .. versionadded:: 1.4.18
+
+ .. seealso::
+
+ :ref:`asyncio_toplevel`
+
+ """
+ if _async_provider is None:
+ return None
+
+ sess = self.session
+ if sess is not None:
+ return _async_provider(sess)
+ else:
+ return None
+
+ @property
def object(self):
"""Return the mapped object represented by this
:class:`.InstanceState`."""