summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/asyncio/base.py
blob: 3f77f55007e6a6402552badfddb06803a23be14d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import abc
import functools
import weakref

from . import exc as async_exc


class ReversibleProxy:
    # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
    _proxy_objects = {}
    __slots__ = ("__weakref__",)

    def _assign_proxied(self, target):
        if target is not None:
            target_ref = weakref.ref(target, ReversibleProxy._target_gced)
            proxy_ref = weakref.ref(
                self,
                functools.partial(ReversibleProxy._target_gced, target_ref),
            )
            ReversibleProxy._proxy_objects[target_ref] = proxy_ref

        return target

    @classmethod
    def _target_gced(cls, ref, proxy_ref=None):
        cls._proxy_objects.pop(ref, None)

    @classmethod
    def _regenerate_proxy_for_target(cls, target):
        raise NotImplementedError()

    @classmethod
    def _retrieve_proxy_for_target(cls, target, regenerate=True):
        try:
            proxy_ref = cls._proxy_objects[weakref.ref(target)]
        except KeyError:
            pass
        else:
            proxy = proxy_ref()
            if proxy is not None:
                return proxy

        if regenerate:
            return cls._regenerate_proxy_for_target(target)
        else:
            return None


class StartableContext(abc.ABC):
    __slots__ = ()

    @abc.abstractmethod
    async def start(self, is_ctxmanager=False):
        pass

    def __await__(self):
        return self.start().__await__()

    async def __aenter__(self):
        return await self.start(is_ctxmanager=True)

    @abc.abstractmethod
    async def __aexit__(self, type_, value, traceback):
        pass

    def _raise_for_not_started(self):
        raise async_exc.AsyncContextNotStarted(
            "%s context has not been started and object has not been awaited."
            % (self.__class__.__name__)
        )


class ProxyComparable(ReversibleProxy):
    __slots__ = ()

    def __hash__(self):
        return id(self)

    def __eq__(self, other):
        return (
            isinstance(other, self.__class__)
            and self._proxied == other._proxied
        )

    def __ne__(self, other):
        return (
            not isinstance(other, self.__class__)
            or self._proxied != other._proxied
        )