diff options
-rw-r--r-- | kafka/client_async.py | 4 | ||||
-rw-r--r-- | kafka/conn.py | 26 | ||||
-rw-r--r-- | test/test_conn.py | 81 |
3 files changed, 87 insertions, 24 deletions
diff --git a/kafka/client_async.py b/kafka/client_async.py index 6f5d1fe..7719426 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -142,6 +142,7 @@ class KafkaClient(object): # Exponential backoff if bootstrap fails backoff_ms = self.config['reconnect_backoff_ms'] * 2 ** self._bootstrap_fails next_at = self._last_bootstrap + backoff_ms / 1000.0 + self._refresh_on_disconnects = False now = time.time() if next_at > now: log.debug("Sleeping %0.4f before bootstrapping again", next_at - now) @@ -180,6 +181,7 @@ class KafkaClient(object): log.error('Unable to bootstrap from %s', hosts) # Max exponential backoff is 2^12, x4000 (50ms -> 200s) self._bootstrap_fails = min(self._bootstrap_fails + 1, 12) + self._refresh_on_disconnects = True def _can_connect(self, node_id): if node_id not in self._conns: @@ -223,7 +225,7 @@ class KafkaClient(object): except KeyError: pass if self._refresh_on_disconnects: - log.warning("Node %s connect failed -- refreshing metadata", node_id) + log.warning("Node %s connection failed -- refreshing metadata", node_id) self.cluster.request_update() def _maybe_connect(self, node_id): diff --git a/kafka/conn.py b/kafka/conn.py index 3571e90..b5c7ba0 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -381,9 +381,17 @@ class BrokerConnection(object): # Not receiving is the state of reading the payload header if not self._receiving: try: - # An extremely small, but non-zero, probability that there are - # more than 0 but not yet 4 bytes available to read - self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell())) + bytes_to_read = 4 - self._rbuffer.tell() + data = self._sock.recv(bytes_to_read) + # We expect socket.recv to raise an exception if there is not + # enough data to read the full bytes_to_read + # but if the socket is disconnected, we will get empty data + # without an exception raised + if not data: + log.error('%s: socket disconnected', self) + self.close(error=Errors.ConnectionError('socket disconnected')) + return None + self._rbuffer.write(data) except ssl.SSLWantReadError: return None except ConnectionError as e: @@ -411,7 +419,17 @@ class BrokerConnection(object): if self._receiving: staged_bytes = self._rbuffer.tell() try: - self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes)) + bytes_to_read = self._next_payload_bytes - staged_bytes + data = self._sock.recv(bytes_to_read) + # We expect socket.recv to raise an exception if there is not + # enough data to read the full bytes_to_read + # but if the socket is disconnected, we will get empty data + # without an exception raised + if not data: + log.error('%s: socket disconnected', self) + self.close(error=Errors.ConnectionError('socket disconnected')) + return None + self._rbuffer.write(data) except ssl.SSLWantReadError: return None except ConnectionError as e: diff --git a/test/test_conn.py b/test/test_conn.py index f0ca2cf..6a3b154 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET +import socket import time import pytest @@ -14,7 +15,7 @@ import kafka.common as Errors @pytest.fixture -def socket(mocker): +def _socket(mocker): socket = mocker.MagicMock() socket.connect_ex.return_value = 0 mocker.patch('socket.socket', return_value=socket) @@ -22,9 +23,8 @@ def socket(mocker): @pytest.fixture -def conn(socket): - from socket import AF_INET - conn = BrokerConnection('localhost', 9092, AF_INET) +def conn(_socket): + conn = BrokerConnection('localhost', 9092, socket.AF_INET) return conn @@ -38,23 +38,23 @@ def conn(socket): ([EALREADY], ConnectionStates.CONNECTING), ([EISCONN], ConnectionStates.CONNECTED)), ]) -def test_connect(socket, conn, states): +def test_connect(_socket, conn, states): assert conn.state is ConnectionStates.DISCONNECTED for errno, state in states: - socket.connect_ex.side_effect = errno + _socket.connect_ex.side_effect = errno conn.connect() assert conn.state is state -def test_connect_timeout(socket, conn): +def test_connect_timeout(_socket, conn): assert conn.state is ConnectionStates.DISCONNECTED # Initial connect returns EINPROGRESS # immediate inline connect returns EALREADY # second explicit connect returns EALREADY # third explicit connect returns EALREADY and times out via last_attempt - socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY] + _socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY] conn.connect() assert conn.state is ConnectionStates.CONNECTING conn.connect() @@ -108,7 +108,7 @@ def test_send_max_ifr(conn): assert isinstance(f.exception, Errors.TooManyInFlightRequests) -def test_send_no_response(socket, conn): +def test_send_no_response(_socket, conn): conn.connect() assert conn.state is ConnectionStates.CONNECTED req = MetadataRequest[0]([]) @@ -116,7 +116,7 @@ def test_send_no_response(socket, conn): payload_bytes = len(header.encode()) + len(req.encode()) third = payload_bytes // 3 remainder = payload_bytes % 3 - socket.send.side_effect = [4, third, third, third, remainder] + _socket.send.side_effect = [4, third, third, third, remainder] assert len(conn.in_flight_requests) == 0 f = conn.send(req, expect_response=False) @@ -125,7 +125,7 @@ def test_send_no_response(socket, conn): assert len(conn.in_flight_requests) == 0 -def test_send_response(socket, conn): +def test_send_response(_socket, conn): conn.connect() assert conn.state is ConnectionStates.CONNECTED req = MetadataRequest[0]([]) @@ -133,7 +133,7 @@ def test_send_response(socket, conn): payload_bytes = len(header.encode()) + len(req.encode()) third = payload_bytes // 3 remainder = payload_bytes % 3 - socket.send.side_effect = [4, third, third, third, remainder] + _socket.send.side_effect = [4, third, third, third, remainder] assert len(conn.in_flight_requests) == 0 f = conn.send(req) @@ -141,20 +141,18 @@ def test_send_response(socket, conn): assert len(conn.in_flight_requests) == 1 -def test_send_error(socket, conn): +def test_send_error(_socket, conn): conn.connect() assert conn.state is ConnectionStates.CONNECTED req = MetadataRequest[0]([]) - header = RequestHeader(req, client_id=conn.config['client_id']) try: - error = ConnectionError + _socket.send.side_effect = ConnectionError except NameError: - from socket import error - socket.send.side_effect = error + _socket.send.side_effect = socket.error f = conn.send(req) assert f.failed() is True assert isinstance(f.exception, Errors.ConnectionError) - assert socket.close.call_count == 1 + assert _socket.close.call_count == 1 assert conn.state is ConnectionStates.DISCONNECTED @@ -167,7 +165,52 @@ def test_can_send_more(conn): assert conn.can_send_more() is False -def test_recv(socket, conn): +def test_recv_disconnected(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('127.0.0.1', 0)) + port = sock.getsockname()[1] + sock.listen(5) + + conn = BrokerConnection('127.0.0.1', port, socket.AF_INET) + timeout = time.time() + 1 + while time.time() < timeout: + conn.connect() + if conn.connected(): + break + else: + assert False, 'Connection attempt to local socket timed-out ?' + + conn.send(MetadataRequest[0]([])) + + # Disconnect server socket + sock.close() + + # Attempt to receive should mark connection as disconnected + assert conn.connected() + conn.recv() + assert conn.disconnected() + + +def test_recv_disconnected_too(_socket, conn): + conn.connect() + assert conn.connected() + + req = MetadataRequest[0]([]) + header = RequestHeader(req, client_id=conn.config['client_id']) + payload_bytes = len(header.encode()) + len(req.encode()) + _socket.send.side_effect = [4, payload_bytes] + conn.send(req) + + # Empty data on recv means the socket is disconnected + _socket.recv.return_value = b'' + + # Attempt to receive should mark connection as disconnected + assert conn.connected() + conn.recv() + assert conn.disconnected() + + +def test_recv(_socket, conn): pass # TODO |