summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2021-08-30 19:38:40 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2021-08-30 19:38:40 +0000
commit184e2da5992c55266b37bab5ce3a07e9dfb8caa1 (patch)
treef91fdaef4d135309b47b3da0e6e5b0f9bd2792f2
parentfb874f97fab798eea76f9732a31bb9332877d00e (diff)
parentaf0824fd790bad28beb01c11f262ac1ffe8c53be (diff)
downloadsqlalchemy-184e2da5992c55266b37bab5ce3a07e9dfb8caa1.tar.gz
Merge "Allow custom sync session class in ``AsyncSession``."
-rw-r--r--doc/build/changelog/unreleased_14/6689.rst9
-rw-r--r--doc/build/orm/extensions/asyncio.rst3
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py52
-rw-r--r--test/ext/asyncio/test_session_py3k.py50
4 files changed, 107 insertions, 7 deletions
diff --git a/doc/build/changelog/unreleased_14/6689.rst b/doc/build/changelog/unreleased_14/6689.rst
new file mode 100644
index 000000000..6abebc5f3
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/6689.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: asyncio, usecase
+ :tickets: 6746
+
+ The :class:`_asyncio.AsyncSession` now supports overriding which
+ :class:`_orm.Session` it uses as the proxied instance. A custom ``Session``
+ class can be passed using the :paramref:`.AsyncSession.sync_session_class`
+ parameter or by subclassing the ``AsyncSession`` and specifying a custom
+ :attr:`.AsyncSession.sync_session_class`.
diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst
index 281d9805b..f3e89c647 100644
--- a/doc/build/orm/extensions/asyncio.rst
+++ b/doc/build/orm/extensions/asyncio.rst
@@ -583,6 +583,9 @@ ORM Session API Documentation
.. autoclass:: AsyncSession
:members:
+ :exclude-members: sync_session_class
+
+ .. autoattribute:: sync_session_class
.. autoclass:: AsyncSessionTransaction
:members:
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 5c6e7f5a7..5c5426d72 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -51,9 +51,16 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
class AsyncSession(ReversibleProxy):
"""Asyncio version of :class:`_orm.Session`.
+ The :class:`_asyncio.AsyncSession` is a proxy for a traditional
+ :class:`_orm.Session` instance.
.. versionadded:: 1.4
+ To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
+ implementations, see the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
+
+
"""
_is_asyncio = True
@@ -68,7 +75,25 @@ class AsyncSession(ReversibleProxy):
dispatch = None
- def __init__(self, bind=None, binds=None, **kw):
+ def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+ r"""Construct a new :class:`_asyncio.AsyncSession`.
+
+ All parameters other than ``sync_session_class`` are passed to the
+ ``sync_session_class`` callable directly to instantiate a new
+ :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
+ parameter documentation.
+
+ :param sync_session_class:
+ A :class:`_orm.Session` subclass or other callable which will be used
+ to construct the :class:`_orm.Session` which will be proxied. This
+ parameter may be used to provide custom :class:`_orm.Session`
+ subclasses. Defaults to the
+ :attr:`_asyncio.AsyncSession.sync_session_class` class-level
+ attribute.
+
+ .. versionadded:: 1.4.24
+
+ """
kw["future"] = True
if bind:
self.bind = bind
@@ -81,10 +106,30 @@ class AsyncSession(ReversibleProxy):
for key, b in binds.items()
}
+ if sync_session_class:
+ self.sync_session_class = sync_session_class
+
self.sync_session = self._proxied = self._assign_proxied(
- Session(bind=bind, binds=binds, **kw)
+ self.sync_session_class(bind=bind, binds=binds, **kw)
)
+ sync_session_class = Session
+ """The class or callable that provides the
+ underlying :class:`_orm.Session` instance for a particular
+ :class:`_asyncio.AsyncSession`.
+
+ At the class level, this attribute is the default value for the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
+ subclasses of :class:`_asyncio.AsyncSession` can override this.
+
+ At the instance level, this attribute indicates the current class or
+ callable that was used to provide the :class:`_orm.Session` instance for
+ this :class:`_asyncio.AsyncSession` instance.
+
+ .. versionadded:: 1.4.24
+
+ """
+
async def refresh(
self, instance, attribute_names=None, with_for_update=None
):
@@ -141,7 +186,8 @@ class AsyncSession(ReversibleProxy):
**kw
):
"""Execute a statement and return a buffered
- :class:`_engine.Result` object."""
+ :class:`_engine.Result` object.
+ """
if execution_options:
execution_options = util.immutabledict(execution_options).union(
diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py
index a0aaf7ee0..459d95ea6 100644
--- a/test/ext/asyncio/test_session_py3k.py
+++ b/test/ext/asyncio/test_session_py3k.py
@@ -14,11 +14,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.base import ReversibleProxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
+from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.testing import async_test
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from .test_engine_py3k import AsyncFixture as _AsyncFixture
from ...orm import _fixtures
@@ -722,8 +724,6 @@ class AsyncProxyTest(AsyncFixture):
is_(inspect(u3).async_session, None)
def test_inspect_session_no_asyncio_used(self):
- from sqlalchemy.orm import Session
-
User = self.classes.User
s1 = Session(testing.db)
@@ -732,8 +732,6 @@ class AsyncProxyTest(AsyncFixture):
is_(inspect(u1).async_session, None)
def test_inspect_session_no_asyncio_imported(self):
- from sqlalchemy.orm import Session
-
with mock.patch("sqlalchemy.orm.state._async_provider", None):
User = self.classes.User
@@ -756,3 +754,47 @@ class AsyncProxyTest(AsyncFixture):
del async_session
eq_(len(ReversibleProxy._proxy_objects), 0)
+
+
+class _MySession(Session):
+ pass
+
+
+class _MyAS(AsyncSession):
+ sync_session_class = _MySession
+
+
+class OverrideSyncSession(AsyncFixture):
+ def test_default(self, async_engine):
+ ass = AsyncSession(async_engine)
+
+ is_true(isinstance(ass.sync_session, Session))
+ is_(ass.sync_session.__class__, Session)
+ is_(ass.sync_session_class, Session)
+
+ def test_init_class(self, async_engine):
+ ass = AsyncSession(async_engine, sync_session_class=_MySession)
+
+ is_true(isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, _MySession)
+
+ def test_init_sessionmaker(self, async_engine):
+ sm = sessionmaker(
+ async_engine, class_=AsyncSession, sync_session_class=_MySession
+ )
+ ass = sm()
+
+ is_true(isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, _MySession)
+
+ def test_subclass(self, async_engine):
+ ass = _MyAS(async_engine)
+
+ is_true(isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, _MySession)
+
+ def test_subclass_override(self, async_engine):
+ ass = _MyAS(async_engine, sync_session_class=Session)
+
+ is_true(not isinstance(ass.sync_session, _MySession))
+ is_(ass.sync_session_class, Session)