summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--redis/_compat.py13
-rwxr-xr-xredis/connection.py45
-rw-r--r--redis/selector.py187
-rw-r--r--tests/test_connection_pool.py71
4 files changed, 295 insertions, 21 deletions
diff --git a/redis/_compat.py b/redis/_compat.py
index c9213a6..bde6fb6 100644
--- a/redis/_compat.py
+++ b/redis/_compat.py
@@ -9,17 +9,6 @@ if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and
import socket
import time
- from select import select as _select, error as select_error
-
- def select(rlist, wlist, xlist, timeout):
- while True:
- try:
- return _select(rlist, wlist, xlist, timeout)
- except select_error as e:
- if e.args[0] == errno.EINTR:
- continue
- raise
-
# Wrapper for handling interruptable system calls.
def _retryable_call(s, func, *args, **kwargs):
# Some modules (SSL) use the _fileobject wrapper directly and
@@ -65,8 +54,6 @@ if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and
return _retryable_call(sock, sock.recv_into, *args, **kwargs)
else: # Python 3.5 and above automatically retry EINTR
- from select import select
-
def recv(sock, *args, **kwargs):
return sock.recv(*args, **kwargs)
diff --git a/redis/connection.py b/redis/connection.py
index f631cab..beeba30 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -17,7 +17,7 @@ except ImportError:
from redis._compat import (xrange, imap, byte_to_chr, unicode, long,
nativestr, basestring, iteritems,
LifoQueue, Empty, Full, urlparse, parse_qs,
- recv, recv_into, select, unquote)
+ recv, recv_into, unquote)
from redis.exceptions import (
DataError,
RedisError,
@@ -31,6 +31,7 @@ from redis.exceptions import (
ExecAbortError,
ReadOnlyError
)
+from redis.selector import DefaultSelector
from redis.utils import HIREDIS_AVAILABLE
if HIREDIS_AVAILABLE:
import hiredis
@@ -496,6 +497,7 @@ class Connection(object):
raise ConnectionError(self._error_message(e))
self._sock = sock
+ self._selector = DefaultSelector(sock)
try:
self.on_connect()
except RedisError:
@@ -623,8 +625,11 @@ class Connection(object):
if not sock:
self.connect()
sock = self._sock
- return self._parser.can_read() or \
- bool(select([sock], [], [], timeout)[0])
+ return self._parser.can_read() or self._selector.can_read(timeout)
+
+ def is_ready_for_command(self):
+ "Check if the connection is ready for a command"
+ return self._selector.is_ready_for_command()
def read_response(self):
"Read the response from a previously sent command"
@@ -984,6 +989,23 @@ class ConnectionPool(object):
except IndexError:
connection = self.make_connection()
self._in_use_connections.add(connection)
+ try:
+ # ensure this connection is connected to Redis
+ connection.connect()
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
+ # pool before all data has been read or the socket has been
+ # closed. either way, reconnect and verify everything is good.
+ if not connection.is_ready_for_command():
+ connection.disconnect()
+ connection.connect()
+ if not connection.is_ready_for_command():
+ raise ConnectionError('Connection not ready')
+ except: # noqa: E722
+ # release the connection back to the pool so that we don't leak it
+ self.release(connection)
+ raise
+
return connection
def get_encoder(self):
@@ -1115,6 +1137,23 @@ class BlockingConnectionPool(ConnectionPool):
if connection is None:
connection = self.make_connection()
+ try:
+ # ensure this connection is connected to Redis
+ connection.connect()
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
+ # pool before all data has been read or the socket has been
+ # closed. either way, reconnect and verify everything is good.
+ if not connection.is_ready_for_command():
+ connection.disconnect()
+ connection.connect()
+ if not connection.is_ready_for_command():
+ raise ConnectionError('Connection not ready')
+ except: # noqa: E722
+ # release the connection back to the pool so that we don't leak it
+ self.release(connection)
+ raise
+
return connection
def release(self, connection):
diff --git a/redis/selector.py b/redis/selector.py
new file mode 100644
index 0000000..fd9b12b
--- /dev/null
+++ b/redis/selector.py
@@ -0,0 +1,187 @@
+import errno
+import select
+from redis.exceptions import RedisError
+
+
+_DEFAULT_SELECTOR = None
+
+
+class BaseSelector(object):
+ """
+ Base class for all Selectors
+ """
+ def __init__(self, sock):
+ self.sock = sock
+
+ def can_read(self, timeout=0):
+ """
+ Return True if data is ready to be read from the socket,
+ otherwise False.
+
+ This doesn't guarentee that the socket is still connected, just that
+ there is data to read.
+
+ Automatically retries EINTR errors based on PEP 475.
+ """
+ while True:
+ try:
+ return self.check_can_read(timeout)
+ except (select.error, IOError) as ex:
+ if self.errno_from_exception(ex) == errno.EINTR:
+ continue
+ return False
+
+ def is_ready_for_command(self, timeout=0):
+ """
+ Return True if the socket is ready to send a command,
+ otherwise False.
+
+ Automatically retries EINTR errors based on PEP 475.
+ """
+ while True:
+ try:
+ return self.check_is_ready_for_command(timeout)
+ except (select.error, IOError) as ex:
+ if self.errno_from_exception(ex) == errno.EINTR:
+ continue
+ return False
+
+ def check_can_read(self, timeout):
+ """
+ Perform the can_read check. Subclasses should implement this.
+ """
+ raise NotImplementedError
+
+ def check_is_ready_for_command(self, timeout):
+ """
+ Perform the is_ready_for_command check. Subclasses should
+ implement this.
+ """
+ raise NotImplementedError
+
+ def close(self):
+ """
+ Close the selector.
+ """
+ self.sock = None
+
+ def errno_from_exception(self, ex):
+ """
+ Get the error number from an exception
+ """
+ if hasattr(ex, 'errno'):
+ return ex.errno
+ elif ex.args:
+ return ex.args[0]
+ else:
+ return None
+
+
+if hasattr(select, 'select'):
+ class SelectSelector(BaseSelector):
+ """
+ A select-based selector that should work on most platforms.
+
+ This is the worst poll strategy and should only be used if no other
+ option is available.
+ """
+ def check_can_read(self, timeout):
+ """
+ Return True if data is ready to be read from the socket,
+ otherwise False.
+
+ This doesn't guarentee that the socket is still connected, just
+ that there is data to read.
+ """
+ return bool(select.select([self.sock], [], [], timeout)[0])
+
+ def check_is_ready_for_command(self, timeout):
+ """
+ Return True if the socket is ready to send a command,
+ otherwise False.
+ """
+ r, w, e = select.select([self.sock], [self.sock], [self.sock],
+ timeout)
+ return bool(w and not r and not e)
+
+
+if hasattr(select, 'poll'):
+ class PollSelector(BaseSelector):
+ """
+ A poll-based selector that should work on (almost?) all versions
+ of Unix
+ """
+ _EVENT_MASK = (select.POLLIN | select.POLLPRI | select.POLLOUT |
+ select.POLLERR | select.POLLHUP)
+ _READ_MASK = select.POLLIN | select.POLLPRI
+ _WRITE_MASK = select.POLLOUT
+
+ def __init__(self, sock):
+ super().__init__(sock)
+ self.poller = select.poll()
+ self.poller.register(sock, self._EVENT_MASK)
+
+ def close(self):
+ """
+ Close the selector.
+ """
+ try:
+ self.poller.unregister(self.sock)
+ except (KeyError, ValueError):
+ # KeyError is raised if somehow the socket was not registered
+ # ValueError is raised if the socket's file descriptor is
+ # negative.
+ # In either case, we can't do anything better than to remove
+ # the reference to the poller.
+ pass
+ self.poller = None
+ self.sock = None
+
+ def check_can_read(self, timeout=0):
+ """
+ Return True if data is ready to be read from the socket,
+ otherwise False.
+
+ This doesn't guarentee that the socket is still connected, just
+ that there is data to read.
+ """
+ events = self.poller.poll(0)
+ return bool(events and events[0][1] & self._READ_MASK)
+
+ def check_is_ready_for_command(self, timeout=0):
+ """
+ Return True if the socket is ready to send a command,
+ otherwise False
+ """
+ events = self.poller.poll(0)
+ return bool(events and events[0][1] == self._WRITE_MASK)
+
+
+def has_selector(selector):
+ "Determine if the current platform has the selector available"
+ try:
+ if selector == 'poll':
+ # the select module offers the poll selector even if the platform
+ # doesn't support it. Attempt to poll for nothing to make sure
+ # poll is available
+ p = select.poll()
+ p.poll(0)
+ else:
+ # the other selectors will fail when instantiated
+ getattr(select, selector)().close()
+ return True
+ except (OSError, AttributeError):
+ return False
+
+
+def DefaultSelector(sock):
+ "Return the best selector for the platform"
+ global _DEFAULT_SELECTOR
+ if _DEFAULT_SELECTOR is None:
+ if has_selector('poll'):
+ _DEFAULT_SELECTOR = PollSelector
+ elif hasattr(select, 'select'):
+ _DEFAULT_SELECTOR = SelectSelector
+ else:
+ raise RedisError('Platform does not support any selectors')
+ return _DEFAULT_SELECTOR(sock)
diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py
index ca56a76..0f5ad72 100644
--- a/tests/test_connection_pool.py
+++ b/tests/test_connection_pool.py
@@ -5,6 +5,7 @@ import time
import re
from threading import Thread
+from redis.client import parse_client_list
from redis.connection import ssl_available, to_bool
from .conftest import skip_if_server_version_lt
@@ -16,10 +17,16 @@ class DummyConnection(object):
self.kwargs = kwargs
self.pid = os.getpid()
+ def connect(self):
+ pass
+
+ def is_ready_for_command(self):
+ return True
+
class TestConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=None,
- connection_class=DummyConnection):
+ connection_class=redis.Connection):
connection_kwargs = connection_kwargs or {}
pool = redis.ConnectionPool(
connection_class=connection_class,
@@ -29,7 +36,8 @@ class TestConnectionPool(object):
def test_connection_creation(self):
connection_kwargs = {'foo': 'bar', 'biz': 'baz'}
- pool = self.get_pool(connection_kwargs=connection_kwargs)
+ pool = self.get_pool(connection_kwargs=connection_kwargs,
+ connection_class=DummyConnection)
connection = pool.get_connection('_')
assert isinstance(connection, DummyConnection)
assert connection.kwargs == connection_kwargs
@@ -68,6 +76,39 @@ class TestConnectionPool(object):
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
assert repr(pool) == expected
+ def test_pool_provides_healthy_connections(self):
+ pool = self.get_pool(connection_class=redis.Connection,
+ max_connections=2)
+ conn1 = pool.get_connection('_')
+ conn2 = pool.get_connection('_')
+
+ # set a unique name on the connection we'll be testing
+ conn1._same_connection_value = 'killed-client'
+ conn1.send_command('client', 'setname', 'redis-py-1')
+ assert conn1.read_response() == b'OK'
+ pool.release(conn1)
+
+ # find the well named client in the client list
+ conn2.send_command('client', 'list')
+ client_list = parse_client_list(conn2.read_response())
+ for client in client_list:
+ if client['name'] == 'redis-py-1':
+ break
+ else:
+ assert False, 'Client redis-py-1 not found in client list'
+
+ # kill the well named client
+ conn2.send_command('client', 'kill', client['addr'])
+ assert conn2.read_response() == b'OK'
+
+ # our connection should have been disconnected, but a quality
+ # connection pool would know this and only provide a healthy
+ # connection.
+ conn = pool.get_connection('_')
+ assert conn == conn1
+ conn.send_command('ping')
+ assert conn.read_response() == b'PONG'
+
class TestBlockingConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
@@ -399,14 +440,20 @@ class TestSSLConnectionURLParsing(object):
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
def test_cert_reqs_options(self):
import ssl
- pool = redis.ConnectionPool.from_url('rediss://?ssl_cert_reqs=none')
+
+ class DummyConnectionPool(redis.ConnectionPool):
+ def get_connection(self, *args, **kwargs):
+ return self.make_connection()
+
+ pool = DummyConnectionPool.from_url(
+ 'rediss://?ssl_cert_reqs=none')
assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE
- pool = redis.ConnectionPool.from_url(
+ pool = DummyConnectionPool.from_url(
'rediss://?ssl_cert_reqs=optional')
assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL
- pool = redis.ConnectionPool.from_url(
+ pool = DummyConnectionPool.from_url(
'rediss://?ssl_cert_reqs=required')
assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED
@@ -494,3 +541,17 @@ class TestConnection(object):
'UnixDomainSocketConnection',
'path=/path/to/socket,db=0',
)
+
+ def test_can_read(self, r):
+ connection = r.connection_pool.get_connection('ping')
+ assert not connection.can_read()
+ connection.send_command('ping')
+ # wait for the server to respond
+ wait_until = time.time() + 2
+ while time.time() < wait_until:
+ if connection.can_read():
+ break
+ time.sleep(0.01)
+ assert connection.can_read()
+ assert connection.read_response() == b'PONG'
+ assert not connection.can_read()