summaryrefslogtreecommitdiff
path: root/Lib/asyncio/test_utils.py
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2016-10-05 17:48:59 -0400
committerYury Selivanov <yury@magic.io>2016-10-05 17:48:59 -0400
commit1e31580b4d009f7412b120fc71b22cd2854d2ece (patch)
tree7c31a82045e5cabf02948b984d89eb5f0c5a368a /Lib/asyncio/test_utils.py
parentc9acdb06126d56565fabab072374c098e3ad5089 (diff)
downloadcpython-1e31580b4d009f7412b120fc71b22cd2854d2ece.tar.gz
Issue #28369: Raise an error when transport's FD is used with add_reader
Diffstat (limited to 'Lib/asyncio/test_utils.py')
-rw-r--r--Lib/asyncio/test_utils.py42
1 files changed, 38 insertions, 4 deletions
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py
index 396e6aed56..307fffccc6 100644
--- a/Lib/asyncio/test_utils.py
+++ b/Lib/asyncio/test_utils.py
@@ -13,6 +13,8 @@ import tempfile
import threading
import time
import unittest
+import weakref
+
from unittest import mock
from http.server import HTTPServer
@@ -300,6 +302,8 @@ class TestLoop(base_events.BaseEventLoop):
self.writers = {}
self.reset_counters()
+ self._transports = weakref.WeakValueDictionary()
+
def time(self):
return self._time
@@ -318,10 +322,10 @@ class TestLoop(base_events.BaseEventLoop):
else: # pragma: no cover
raise AssertionError("Time generator is not finished")
- def add_reader(self, fd, callback, *args):
+ def _add_reader(self, fd, callback, *args):
self.readers[fd] = events.Handle(callback, args, self)
- def remove_reader(self, fd):
+ def _remove_reader(self, fd):
self.remove_reader_count[fd] += 1
if fd in self.readers:
del self.readers[fd]
@@ -337,10 +341,10 @@ class TestLoop(base_events.BaseEventLoop):
assert handle._args == args, '{!r} != {!r}'.format(
handle._args, args)
- def add_writer(self, fd, callback, *args):
+ def _add_writer(self, fd, callback, *args):
self.writers[fd] = events.Handle(callback, args, self)
- def remove_writer(self, fd):
+ def _remove_writer(self, fd):
self.remove_writer_count[fd] += 1
if fd in self.writers:
del self.writers[fd]
@@ -356,6 +360,36 @@ class TestLoop(base_events.BaseEventLoop):
assert handle._args == args, '{!r} != {!r}'.format(
handle._args, args)
+ def _ensure_fd_no_transport(self, fd):
+ try:
+ transport = self._transports[fd]
+ except KeyError:
+ pass
+ else:
+ raise RuntimeError(
+ 'File descriptor {!r} is used by transport {!r}'.format(
+ fd, transport))
+
+ def add_reader(self, fd, callback, *args):
+ """Add a reader callback."""
+ self._ensure_fd_no_transport(fd)
+ return self._add_reader(fd, callback, *args)
+
+ def remove_reader(self, fd):
+ """Remove a reader callback."""
+ self._ensure_fd_no_transport(fd)
+ return self._remove_reader(fd)
+
+ def add_writer(self, fd, callback, *args):
+ """Add a writer callback.."""
+ self._ensure_fd_no_transport(fd)
+ return self._add_writer(fd, callback, *args)
+
+ def remove_writer(self, fd):
+ """Remove a writer callback."""
+ self._ensure_fd_no_transport(fd)
+ return self._remove_writer(fd)
+
def reset_counters(self):
self.remove_reader_count = collections.defaultdict(int)
self.remove_writer_count = collections.defaultdict(int)