diff options
author | Yury Selivanov <yury@magic.io> | 2016-10-05 17:48:59 -0400 |
---|---|---|
committer | Yury Selivanov <yury@magic.io> | 2016-10-05 17:48:59 -0400 |
commit | 1e31580b4d009f7412b120fc71b22cd2854d2ece (patch) | |
tree | 7c31a82045e5cabf02948b984d89eb5f0c5a368a /Lib/asyncio/test_utils.py | |
parent | c9acdb06126d56565fabab072374c098e3ad5089 (diff) | |
download | cpython-1e31580b4d009f7412b120fc71b22cd2854d2ece.tar.gz |
Issue #28369: Raise an error when transport's FD is used with add_reader
Diffstat (limited to 'Lib/asyncio/test_utils.py')
-rw-r--r-- | Lib/asyncio/test_utils.py | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py index 396e6aed56..307fffccc6 100644 --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -13,6 +13,8 @@ import tempfile import threading import time import unittest +import weakref + from unittest import mock from http.server import HTTPServer @@ -300,6 +302,8 @@ class TestLoop(base_events.BaseEventLoop): self.writers = {} self.reset_counters() + self._transports = weakref.WeakValueDictionary() + def time(self): return self._time @@ -318,10 +322,10 @@ class TestLoop(base_events.BaseEventLoop): else: # pragma: no cover raise AssertionError("Time generator is not finished") - def add_reader(self, fd, callback, *args): + def _add_reader(self, fd, callback, *args): self.readers[fd] = events.Handle(callback, args, self) - def remove_reader(self, fd): + def _remove_reader(self, fd): self.remove_reader_count[fd] += 1 if fd in self.readers: del self.readers[fd] @@ -337,10 +341,10 @@ class TestLoop(base_events.BaseEventLoop): assert handle._args == args, '{!r} != {!r}'.format( handle._args, args) - def add_writer(self, fd, callback, *args): + def _add_writer(self, fd, callback, *args): self.writers[fd] = events.Handle(callback, args, self) - def remove_writer(self, fd): + def _remove_writer(self, fd): self.remove_writer_count[fd] += 1 if fd in self.writers: del self.writers[fd] @@ -356,6 +360,36 @@ class TestLoop(base_events.BaseEventLoop): assert handle._args == args, '{!r} != {!r}'.format( handle._args, args) + def _ensure_fd_no_transport(self, fd): + try: + transport = self._transports[fd] + except KeyError: + pass + else: + raise RuntimeError( + 'File descriptor {!r} is used by transport {!r}'.format( + fd, transport)) + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + self._ensure_fd_no_transport(fd) + return self._add_reader(fd, callback, *args) + + def remove_reader(self, fd): + """Remove a reader callback.""" + self._ensure_fd_no_transport(fd) + return self._remove_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + self._ensure_fd_no_transport(fd) + return self._add_writer(fd, callback, *args) + + def remove_writer(self, fd): + """Remove a writer callback.""" + self._ensure_fd_no_transport(fd) + return self._remove_writer(fd) + def reset_counters(self): self.remove_reader_count = collections.defaultdict(int) self.remove_writer_count = collections.defaultdict(int) |