summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/scoping.py
blob: 4e7f15c1fdad49ec39f2e421f3c9107319220c7d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# ext/asyncio/scoping.py
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

from .session import AsyncSession
from ...orm.scoping import ScopedSessionMixin
from ...util import create_proxy_methods
from ...util import ScopedRegistry


@create_proxy_methods(
    AsyncSession,
    ":class:`_asyncio.AsyncSession`",
    ":class:`_asyncio.scoping.async_scoped_session`",
    classmethods=["close_all", "object_session", "identity_key"],
    methods=[
        "__contains__",
        "__iter__",
        "add",
        "add_all",
        "begin",
        "begin_nested",
        "close",
        "commit",
        "connection",
        "delete",
        "execute",
        "expire",
        "expire_all",
        "expunge",
        "expunge_all",
        "flush",
        "get",
        "get_bind",
        "is_modified",
        "merge",
        "refresh",
        "rollback",
        "scalar",
        "scalars",
        "stream",
        "stream_scalars",
    ],
    attributes=[
        "bind",
        "dirty",
        "deleted",
        "new",
        "identity_map",
        "is_active",
        "autoflush",
        "no_autoflush",
        "info",
    ],
)
class async_scoped_session(ScopedSessionMixin):
    """Provides scoped management of :class:`.AsyncSession` objects.

    See the section :ref:`asyncio_scoped_session` for usage details.

    .. versionadded:: 1.4.19


    """

    _support_async = True

    def __init__(self, session_factory, scopefunc):
        """Construct a new :class:`_asyncio.async_scoped_session`.

        :param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
         instances. This is usually, but not necessarily, an instance
         of :class:`_orm.sessionmaker` which itself was passed the
         :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
         parameter::

            async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
            AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)

        :param scopefunc: function which defines
         the current scope.   A function such as ``asyncio.current_task``
         may be useful here.

        """  # noqa E501

        self.session_factory = session_factory
        self.registry = ScopedRegistry(session_factory, scopefunc)

    @property
    def _proxied(self):
        return self.registry()

    async def remove(self):
        """Dispose of the current :class:`.AsyncSession`, if present.

        Different from scoped_session's remove method, this method would use
        await to wait for the close method of AsyncSession.

        """

        if self.registry.has():
            await self.registry().close()
        self.registry.clear()