diff options
-rw-r--r-- | redis/_compat.py | 13 | ||||
-rwxr-xr-x | redis/connection.py | 45 | ||||
-rw-r--r-- | redis/selector.py | 187 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 71 |
4 files changed, 295 insertions, 21 deletions
diff --git a/redis/_compat.py b/redis/_compat.py index c9213a6..bde6fb6 100644 --- a/redis/_compat.py +++ b/redis/_compat.py @@ -9,17 +9,6 @@ if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and import socket import time - from select import select as _select, error as select_error - - def select(rlist, wlist, xlist, timeout): - while True: - try: - return _select(rlist, wlist, xlist, timeout) - except select_error as e: - if e.args[0] == errno.EINTR: - continue - raise - # Wrapper for handling interruptable system calls. def _retryable_call(s, func, *args, **kwargs): # Some modules (SSL) use the _fileobject wrapper directly and @@ -65,8 +54,6 @@ if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and return _retryable_call(sock, sock.recv_into, *args, **kwargs) else: # Python 3.5 and above automatically retry EINTR - from select import select - def recv(sock, *args, **kwargs): return sock.recv(*args, **kwargs) diff --git a/redis/connection.py b/redis/connection.py index f631cab..beeba30 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -17,7 +17,7 @@ except ImportError: from redis._compat import (xrange, imap, byte_to_chr, unicode, long, nativestr, basestring, iteritems, LifoQueue, Empty, Full, urlparse, parse_qs, - recv, recv_into, select, unquote) + recv, recv_into, unquote) from redis.exceptions import ( DataError, RedisError, @@ -31,6 +31,7 @@ from redis.exceptions import ( ExecAbortError, ReadOnlyError ) +from redis.selector import DefaultSelector from redis.utils import HIREDIS_AVAILABLE if HIREDIS_AVAILABLE: import hiredis @@ -496,6 +497,7 @@ class Connection(object): raise ConnectionError(self._error_message(e)) self._sock = sock + self._selector = DefaultSelector(sock) try: self.on_connect() except RedisError: @@ -623,8 +625,11 @@ class Connection(object): if not sock: self.connect() sock = self._sock - return self._parser.can_read() or \ - bool(select([sock], [], [], timeout)[0]) + return self._parser.can_read() or self._selector.can_read(timeout) + + def is_ready_for_command(self): + "Check if the connection is ready for a command" + return self._selector.is_ready_for_command() def read_response(self): "Read the response from a previously sent command" @@ -984,6 +989,23 @@ class ConnectionPool(object): except IndexError: connection = self.make_connection() self._in_use_connections.add(connection) + try: + # ensure this connection is connected to Redis + connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + if not connection.is_ready_for_command(): + connection.disconnect() + connection.connect() + if not connection.is_ready_for_command(): + raise ConnectionError('Connection not ready') + except: # noqa: E722 + # release the connection back to the pool so that we don't leak it + self.release(connection) + raise + return connection def get_encoder(self): @@ -1115,6 +1137,23 @@ class BlockingConnectionPool(ConnectionPool): if connection is None: connection = self.make_connection() + try: + # ensure this connection is connected to Redis + connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + if not connection.is_ready_for_command(): + connection.disconnect() + connection.connect() + if not connection.is_ready_for_command(): + raise ConnectionError('Connection not ready') + except: # noqa: E722 + # release the connection back to the pool so that we don't leak it + self.release(connection) + raise + return connection def release(self, connection): diff --git a/redis/selector.py b/redis/selector.py new file mode 100644 index 0000000..fd9b12b --- /dev/null +++ b/redis/selector.py @@ -0,0 +1,187 @@ +import errno +import select +from redis.exceptions import RedisError + + +_DEFAULT_SELECTOR = None + + +class BaseSelector(object): + """ + Base class for all Selectors + """ + def __init__(self, sock): + self.sock = sock + + def can_read(self, timeout=0): + """ + Return True if data is ready to be read from the socket, + otherwise False. + + This doesn't guarentee that the socket is still connected, just that + there is data to read. + + Automatically retries EINTR errors based on PEP 475. + """ + while True: + try: + return self.check_can_read(timeout) + except (select.error, IOError) as ex: + if self.errno_from_exception(ex) == errno.EINTR: + continue + return False + + def is_ready_for_command(self, timeout=0): + """ + Return True if the socket is ready to send a command, + otherwise False. + + Automatically retries EINTR errors based on PEP 475. + """ + while True: + try: + return self.check_is_ready_for_command(timeout) + except (select.error, IOError) as ex: + if self.errno_from_exception(ex) == errno.EINTR: + continue + return False + + def check_can_read(self, timeout): + """ + Perform the can_read check. Subclasses should implement this. + """ + raise NotImplementedError + + def check_is_ready_for_command(self, timeout): + """ + Perform the is_ready_for_command check. Subclasses should + implement this. + """ + raise NotImplementedError + + def close(self): + """ + Close the selector. + """ + self.sock = None + + def errno_from_exception(self, ex): + """ + Get the error number from an exception + """ + if hasattr(ex, 'errno'): + return ex.errno + elif ex.args: + return ex.args[0] + else: + return None + + +if hasattr(select, 'select'): + class SelectSelector(BaseSelector): + """ + A select-based selector that should work on most platforms. + + This is the worst poll strategy and should only be used if no other + option is available. + """ + def check_can_read(self, timeout): + """ + Return True if data is ready to be read from the socket, + otherwise False. + + This doesn't guarentee that the socket is still connected, just + that there is data to read. + """ + return bool(select.select([self.sock], [], [], timeout)[0]) + + def check_is_ready_for_command(self, timeout): + """ + Return True if the socket is ready to send a command, + otherwise False. + """ + r, w, e = select.select([self.sock], [self.sock], [self.sock], + timeout) + return bool(w and not r and not e) + + +if hasattr(select, 'poll'): + class PollSelector(BaseSelector): + """ + A poll-based selector that should work on (almost?) all versions + of Unix + """ + _EVENT_MASK = (select.POLLIN | select.POLLPRI | select.POLLOUT | + select.POLLERR | select.POLLHUP) + _READ_MASK = select.POLLIN | select.POLLPRI + _WRITE_MASK = select.POLLOUT + + def __init__(self, sock): + super().__init__(sock) + self.poller = select.poll() + self.poller.register(sock, self._EVENT_MASK) + + def close(self): + """ + Close the selector. + """ + try: + self.poller.unregister(self.sock) + except (KeyError, ValueError): + # KeyError is raised if somehow the socket was not registered + # ValueError is raised if the socket's file descriptor is + # negative. + # In either case, we can't do anything better than to remove + # the reference to the poller. + pass + self.poller = None + self.sock = None + + def check_can_read(self, timeout=0): + """ + Return True if data is ready to be read from the socket, + otherwise False. + + This doesn't guarentee that the socket is still connected, just + that there is data to read. + """ + events = self.poller.poll(0) + return bool(events and events[0][1] & self._READ_MASK) + + def check_is_ready_for_command(self, timeout=0): + """ + Return True if the socket is ready to send a command, + otherwise False + """ + events = self.poller.poll(0) + return bool(events and events[0][1] == self._WRITE_MASK) + + +def has_selector(selector): + "Determine if the current platform has the selector available" + try: + if selector == 'poll': + # the select module offers the poll selector even if the platform + # doesn't support it. Attempt to poll for nothing to make sure + # poll is available + p = select.poll() + p.poll(0) + else: + # the other selectors will fail when instantiated + getattr(select, selector)().close() + return True + except (OSError, AttributeError): + return False + + +def DefaultSelector(sock): + "Return the best selector for the platform" + global _DEFAULT_SELECTOR + if _DEFAULT_SELECTOR is None: + if has_selector('poll'): + _DEFAULT_SELECTOR = PollSelector + elif hasattr(select, 'select'): + _DEFAULT_SELECTOR = SelectSelector + else: + raise RedisError('Platform does not support any selectors') + return _DEFAULT_SELECTOR(sock) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index ca56a76..0f5ad72 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -5,6 +5,7 @@ import time import re from threading import Thread +from redis.client import parse_client_list from redis.connection import ssl_available, to_bool from .conftest import skip_if_server_version_lt @@ -16,10 +17,16 @@ class DummyConnection(object): self.kwargs = kwargs self.pid = os.getpid() + def connect(self): + pass + + def is_ready_for_command(self): + return True + class TestConnectionPool(object): def get_pool(self, connection_kwargs=None, max_connections=None, - connection_class=DummyConnection): + connection_class=redis.Connection): connection_kwargs = connection_kwargs or {} pool = redis.ConnectionPool( connection_class=connection_class, @@ -29,7 +36,8 @@ class TestConnectionPool(object): def test_connection_creation(self): connection_kwargs = {'foo': 'bar', 'biz': 'baz'} - pool = self.get_pool(connection_kwargs=connection_kwargs) + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=DummyConnection) connection = pool.get_connection('_') assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -68,6 +76,39 @@ class TestConnectionPool(object): expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>' assert repr(pool) == expected + def test_pool_provides_healthy_connections(self): + pool = self.get_pool(connection_class=redis.Connection, + max_connections=2) + conn1 = pool.get_connection('_') + conn2 = pool.get_connection('_') + + # set a unique name on the connection we'll be testing + conn1._same_connection_value = 'killed-client' + conn1.send_command('client', 'setname', 'redis-py-1') + assert conn1.read_response() == b'OK' + pool.release(conn1) + + # find the well named client in the client list + conn2.send_command('client', 'list') + client_list = parse_client_list(conn2.read_response()) + for client in client_list: + if client['name'] == 'redis-py-1': + break + else: + assert False, 'Client redis-py-1 not found in client list' + + # kill the well named client + conn2.send_command('client', 'kill', client['addr']) + assert conn2.read_response() == b'OK' + + # our connection should have been disconnected, but a quality + # connection pool would know this and only provide a healthy + # connection. + conn = pool.get_connection('_') + assert conn == conn1 + conn.send_command('ping') + assert conn.read_response() == b'PONG' + class TestBlockingConnectionPool(object): def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): @@ -399,14 +440,20 @@ class TestSSLConnectionURLParsing(object): @pytest.mark.skipif(not ssl_available, reason="SSL not installed") def test_cert_reqs_options(self): import ssl - pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none') + + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self, *args, **kwargs): + return self.make_connection() + + pool = DummyConnectionPool.from_url( + 'rediss://?ssl_cert_reqs=none') assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE - pool = redis.ConnectionPool.from_url( + pool = DummyConnectionPool.from_url( 'rediss://?ssl_cert_reqs=optional') assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL - pool = redis.ConnectionPool.from_url( + pool = DummyConnectionPool.from_url( 'rediss://?ssl_cert_reqs=required') assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED @@ -494,3 +541,17 @@ class TestConnection(object): 'UnixDomainSocketConnection', 'path=/path/to/socket,db=0', ) + + def test_can_read(self, r): + connection = r.connection_pool.get_connection('ping') + assert not connection.can_read() + connection.send_command('ping') + # wait for the server to respond + wait_until = time.time() + 2 + while time.time() < wait_until: + if connection.can_read(): + break + time.sleep(0.01) + assert connection.can_read() + assert connection.read_response() == b'PONG' + assert not connection.can_read() |