diff options
| author | twisteroid ambassador <twisteroidambassador@users.noreply.github.com> | 2019-05-05 19:14:35 +0800 | 
|---|---|---|
| committer | Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> | 2019-05-05 04:14:35 -0700 | 
| commit | 88f07a804a0adc0b6ee87687b59d8416113c7331 (patch) | |
| tree | e7fcadefb5269eb9b03c5c5946a31387ab7839e7 | |
| parent | c4d92c8ada7ecfc479ebb1dd4a819c9202155970 (diff) | |
| download | cpython-git-88f07a804a0adc0b6ee87687b59d8416113c7331.tar.gz | |
bpo-33530: Implement Happy Eyeballs in asyncio, v2 (GH-7237)
Added two keyword arguments, `delay` and `interleave`, to
`BaseEventLoop.create_connection`. Happy eyeballs is activated if
`delay` is specified.
We now have documentation for the new arguments. `staggered_race()` is in its own module, but not exported to the main asyncio package.
https://bugs.python.org/issue33530
| -rw-r--r-- | Doc/library/asyncio-eventloop.rst | 24 | ||||
| -rw-r--r-- | Lib/asyncio/base_events.py | 125 | ||||
| -rw-r--r-- | Lib/asyncio/events.py | 3 | ||||
| -rw-r--r-- | Lib/asyncio/staggered.py | 147 | ||||
| -rw-r--r-- | Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst | 3 | 
5 files changed, 264 insertions, 38 deletions
diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index e2b3124539..06f673be79 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -397,9 +397,27 @@ Opening network connections       If given, these should all be integers from the corresponding       :mod:`socket` module constants. +   * *happy_eyeballs_delay*, if given, enables Happy Eyeballs for this +     connection. It should +     be a floating-point number representing the amount of time in seconds +     to wait for a connection attempt to complete, before starting the next +     attempt in parallel. This is the "Connection Attempt Delay" as defined +     in :rfc:`8305`. A sensible default value recommended by the RFC is ``0.25`` +     (250 milliseconds). + +   * *interleave* controls address reordering when a host name resolves to +     multiple IP addresses. +     If ``0`` or unspecified, no reordering is done, and addresses are +     tried in the order returned by :meth:`getaddrinfo`. If a positive integer +     is specified, the addresses are interleaved by address family, and the +     given integer is interpreted as "First Address Family Count" as defined +     in :rfc:`8305`. The default is ``0`` if *happy_eyeballs_delay* is not +     specified, and ``1`` if it is. +     * *sock*, if given, should be an existing, already connected       :class:`socket.socket` object to be used by the transport. -     If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags* +     If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags*, +     *happy_eyeballs_delay*, *interleave*       and *local_addr* should be specified.     * *local_addr*, if given, is a ``(local_host, local_port)`` tuple used @@ -410,6 +428,10 @@ Opening network connections       to wait for the TLS handshake to complete before aborting the connection.       ``60.0`` seconds if ``None`` (default). +   .. versionadded:: 3.8 + +      The *happy_eyeballs_delay* and *interleave* parameters. +     .. versionadded:: 3.7        The *ssl_handshake_timeout* parameter. diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 9b4b846131..c58906f8b4 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,6 +16,7 @@ to modify the meaning of the API call itself.  import collections  import collections.abc  import concurrent.futures +import functools  import heapq  import itertools  import os @@ -41,6 +42,7 @@ from . import exceptions  from . import futures  from . import protocols  from . import sslproto +from . import staggered  from . import tasks  from . import transports  from .log import logger @@ -159,6 +161,28 @@ def _ipaddr_info(host, port, family, type, proto):      return None +def _interleave_addrinfos(addrinfos, first_address_family_count=1): +    """Interleave list of addrinfo tuples by family.""" +    # Group addresses by family +    addrinfos_by_family = collections.OrderedDict() +    for addr in addrinfos: +        family = addr[0] +        if family not in addrinfos_by_family: +            addrinfos_by_family[family] = [] +        addrinfos_by_family[family].append(addr) +    addrinfos_lists = list(addrinfos_by_family.values()) + +    reordered = [] +    if first_address_family_count > 1: +        reordered.extend(addrinfos_lists[0][:first_address_family_count - 1]) +        del addrinfos_lists[0][:first_address_family_count - 1] +    reordered.extend( +        a for a in itertools.chain.from_iterable( +            itertools.zip_longest(*addrinfos_lists) +        ) if a is not None) +    return reordered + +  def _run_until_complete_cb(fut):      if not fut.cancelled():          exc = fut.exception() @@ -871,12 +895,49 @@ class BaseEventLoop(events.AbstractEventLoop):                  "offset must be a non-negative integer (got {!r})".format(                      offset)) +    async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None): +        """Create, bind and connect one socket.""" +        my_exceptions = [] +        exceptions.append(my_exceptions) +        family, type_, proto, _, address = addr_info +        sock = None +        try: +            sock = socket.socket(family=family, type=type_, proto=proto) +            sock.setblocking(False) +            if local_addr_infos is not None: +                for _, _, _, _, laddr in local_addr_infos: +                    try: +                        sock.bind(laddr) +                        break +                    except OSError as exc: +                        msg = ( +                            f'error while attempting to bind on ' +                            f'address {laddr!r}: ' +                            f'{exc.strerror.lower()}' +                        ) +                        exc = OSError(exc.errno, msg) +                        my_exceptions.append(exc) +                else:  # all bind attempts failed +                    raise my_exceptions.pop() +            await self.sock_connect(sock, address) +            return sock +        except OSError as exc: +            my_exceptions.append(exc) +            if sock is not None: +                sock.close() +            raise +        except: +            if sock is not None: +                sock.close() +            raise +      async def create_connection(              self, protocol_factory, host=None, port=None,              *, ssl=None, family=0,              proto=0, flags=0, sock=None,              local_addr=None, server_hostname=None, -            ssl_handshake_timeout=None): +            ssl_handshake_timeout=None, +            happy_eyeballs_delay=None, interleave=None):          """Connect to a TCP server.          Create a streaming transport connection to a given Internet host and @@ -911,6 +972,10 @@ class BaseEventLoop(events.AbstractEventLoop):              raise ValueError(                  'ssl_handshake_timeout is only meaningful with ssl') +        if happy_eyeballs_delay is not None and interleave is None: +            # If using happy eyeballs, default to interleave addresses by family +            interleave = 1 +          if host is not None or port is not None:              if sock is not None:                  raise ValueError( @@ -929,43 +994,31 @@ class BaseEventLoop(events.AbstractEventLoop):                      flags=flags, loop=self)                  if not laddr_infos:                      raise OSError('getaddrinfo() returned empty list') +            else: +                laddr_infos = None + +            if interleave: +                infos = _interleave_addrinfos(infos, interleave)              exceptions = [] -            for family, type, proto, cname, address in infos: -                try: -                    sock = socket.socket(family=family, type=type, proto=proto) -                    sock.setblocking(False) -                    if local_addr is not None: -                        for _, _, _, _, laddr in laddr_infos: -                            try: -                                sock.bind(laddr) -                                break -                            except OSError as exc: -                                msg = ( -                                    f'error while attempting to bind on ' -                                    f'address {laddr!r}: ' -                                    f'{exc.strerror.lower()}' -                                ) -                                exc = OSError(exc.errno, msg) -                                exceptions.append(exc) -                        else: -                            sock.close() -                            sock = None -                            continue -                    if self._debug: -                        logger.debug("connect %r to %r", sock, address) -                    await self.sock_connect(sock, address) -                except OSError as exc: -                    if sock is not None: -                        sock.close() -                    exceptions.append(exc) -                except: -                    if sock is not None: -                        sock.close() -                    raise -                else: -                    break -            else: +            if happy_eyeballs_delay is None: +                # not using happy eyeballs +                for addrinfo in infos: +                    try: +                        sock = await self._connect_sock( +                            exceptions, addrinfo, laddr_infos) +                        break +                    except OSError: +                        continue +            else:  # using happy eyeballs +                sock, _, _ = await staggered.staggered_race( +                    (functools.partial(self._connect_sock, +                                       exceptions, addrinfo, laddr_infos) +                     for addrinfo in infos), +                    happy_eyeballs_delay, loop=self) + +            if sock is None: +                exceptions = [exc for sub in exceptions for exc in sub]                  if len(exceptions) == 1:                      raise exceptions[0]                  else: diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 163b868afe..9a923514db 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -298,7 +298,8 @@ class AbstractEventLoop:              *, ssl=None, family=0, proto=0,              flags=0, sock=None, local_addr=None,              server_hostname=None, -            ssl_handshake_timeout=None): +            ssl_handshake_timeout=None, +            happy_eyeballs_delay=None, interleave=None):          raise NotImplementedError      async def create_server( diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py new file mode 100644 index 0000000000..feec681b43 --- /dev/null +++ b/Lib/asyncio/staggered.py @@ -0,0 +1,147 @@ +"""Support for running coroutines in parallel with staggered start times.""" + +__all__ = 'staggered_race', + +import contextlib +import typing + +from . import events +from . import futures +from . import locks +from . import tasks + + +async def staggered_race( +        coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], +        delay: typing.Optional[float], +        *, +        loop: events.AbstractEventLoop = None, +) -> typing.Tuple[ +    typing.Any, +    typing.Optional[int], +    typing.List[typing.Optional[Exception]] +]: +    """Run coroutines with staggered start times and take the first to finish. + +    This method takes an iterable of coroutine functions. The first one is +    started immediately. From then on, whenever the immediately preceding one +    fails (raises an exception), or when *delay* seconds has passed, the next +    coroutine is started. This continues until one of the coroutines complete +    successfully, in which case all others are cancelled, or until all +    coroutines fail. + +    The coroutines provided should be well-behaved in the following way: + +    * They should only ``return`` if completed successfully. + +    * They should always raise an exception if they did not complete +      successfully. In particular, if they handle cancellation, they should +      probably reraise, like this:: + +        try: +            # do work +        except asyncio.CancelledError: +            # undo partially completed work +            raise + +    Args: +        coro_fns: an iterable of coroutine functions, i.e. callables that +            return a coroutine object when called. Use ``functools.partial`` or +            lambdas to pass arguments. + +        delay: amount of time, in seconds, between starting coroutines. If +            ``None``, the coroutines will run sequentially. + +        loop: the event loop to use. + +    Returns: +        tuple *(winner_result, winner_index, exceptions)* where + +        - *winner_result*: the result of the winning coroutine, or ``None`` +          if no coroutines won. + +        - *winner_index*: the index of the winning coroutine in +          ``coro_fns``, or ``None`` if no coroutines won. If the winning +          coroutine may return None on success, *winner_index* can be used +          to definitively determine whether any coroutine won. + +        - *exceptions*: list of exceptions returned by the coroutines. +          ``len(exceptions)`` is equal to the number of coroutines actually +          started, and the order is the same as in ``coro_fns``. The winning +          coroutine's entry is ``None``. + +    """ +    # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. +    loop = loop or events.get_running_loop() +    enum_coro_fns = enumerate(coro_fns) +    winner_result = None +    winner_index = None +    exceptions = [] +    running_tasks = [] + +    async def run_one_coro( +            previous_failed: typing.Optional[locks.Event]) -> None: +        # Wait for the previous task to finish, or for delay seconds +        if previous_failed is not None: +            with contextlib.suppress(futures.TimeoutError): +                # Use asyncio.wait_for() instead of asyncio.wait() here, so +                # that if we get cancelled at this point, Event.wait() is also +                # cancelled, otherwise there will be a "Task destroyed but it is +                # pending" later. +                await tasks.wait_for(previous_failed.wait(), delay) +        # Get the next coroutine to run +        try: +            this_index, coro_fn = next(enum_coro_fns) +        except StopIteration: +            return +        # Start task that will run the next coroutine +        this_failed = locks.Event() +        next_task = loop.create_task(run_one_coro(this_failed)) +        running_tasks.append(next_task) +        assert len(running_tasks) == this_index + 2 +        # Prepare place to put this coroutine's exceptions if not won +        exceptions.append(None) +        assert len(exceptions) == this_index + 1 + +        try: +            result = await coro_fn() +        except Exception as e: +            exceptions[this_index] = e +            this_failed.set()  # Kickstart the next coroutine +        else: +            # Store winner's results +            nonlocal winner_index, winner_result +            assert winner_index is None +            winner_index = this_index +            winner_result = result +            # Cancel all other tasks. We take care to not cancel the current +            # task as well. If we do so, then since there is no `await` after +            # here and CancelledError are usually thrown at one, we will +            # encounter a curious corner case where the current task will end +            # up as done() == True, cancelled() == False, exception() == +            # asyncio.CancelledError. This behavior is specified in +            # https://bugs.python.org/issue30048 +            for i, t in enumerate(running_tasks): +                if i != this_index: +                    t.cancel() + +    first_task = loop.create_task(run_one_coro(None)) +    running_tasks.append(first_task) +    try: +        # Wait for a growing list of tasks to all finish: poor man's version of +        # curio's TaskGroup or trio's nursery +        done_count = 0 +        while done_count != len(running_tasks): +            done, _ = await tasks.wait(running_tasks) +            done_count = len(done) +            # If run_one_coro raises an unhandled exception, it's probably a +            # programming error, and I want to see it. +            if __debug__: +                for d in done: +                    if d.done() and not d.cancelled() and d.exception(): +                        raise d.exception() +        return winner_result, winner_index, exceptions +    finally: +        # Make sure no tasks are left running if we leave this function +        for t in running_tasks: +            t.cancel() diff --git a/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst b/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst new file mode 100644 index 0000000000..747219b1bf --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst @@ -0,0 +1,3 @@ +Implemented Happy Eyeballs in `asyncio.create_connection()`. Added two new +arguments, *happy_eyeballs_delay* and *interleave*, +to specify Happy Eyeballs behavior.  | 
