diff options
-rw-r--r-- | kafka/client_async.py | 45 | ||||
-rw-r--r-- | kafka/conn.py | 50 | ||||
-rw-r--r-- | kafka/consumer/fetcher.py | 3 | ||||
-rw-r--r-- | kafka/coordinator/base.py | 22 | ||||
-rw-r--r-- | test/test_client_async.py | 84 | ||||
-rw-r--r-- | test/test_conn.py | 318 | ||||
-rw-r--r-- | test/test_conn_legacy.py | 242 |
7 files changed, 420 insertions, 344 deletions
diff --git a/kafka/client_async.py b/kafka/client_async.py index 5a1d624..d70e4f2 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -152,8 +152,8 @@ class KafkaClient(object): conn = self._conns[node_id] return conn.state is ConnectionStates.DISCONNECTED and not conn.blacked_out() - def _initiate_connect(self, node_id): - """Initiate a connection to the given node (must be in metadata)""" + def _maybe_connect(self, node_id): + """Idempotent non-blocking connection attempt to the given node id.""" if node_id not in self._conns: broker = self.cluster.broker_metadata(node_id) assert broker, 'Broker id %s not in current metadata' % node_id @@ -164,22 +164,21 @@ class KafkaClient(object): host, port, afi = get_ip_port_afi(broker.host) self._conns[node_id] = BrokerConnection(host, broker.port, afi, **self.config) - return self._finish_connect(node_id) - - def _finish_connect(self, node_id): - assert node_id in self._conns, '%s is not in current conns' % node_id state = self._conns[node_id].connect() if state is ConnectionStates.CONNECTING: self._connecting.add(node_id) + + # Whether CONNECTED or DISCONNECTED, we need to remove from connecting elif node_id in self._connecting: log.debug("Node %s connection state is %s", node_id, state) self._connecting.remove(node_id) + # Connection failures imply that our metadata is stale, so let's refresh if state is ConnectionStates.DISCONNECTED: log.warning("Node %s connect failed -- refreshing metadata", node_id) self.cluster.request_update() - return state + return self._conns[node_id].connected() def ready(self, node_id): """Check whether a node is connected and ok to send more requests. @@ -190,19 +189,15 @@ class KafkaClient(object): Returns: bool: True if we are ready to send to the given node """ - if self.is_ready(node_id): - return True - - if self._can_connect(node_id): - # if we are interested in sending to a node - # and we don't have a connection to it, initiate one - self._initiate_connect(node_id) - - if node_id in self._connecting: - self._finish_connect(node_id) - + self._maybe_connect(node_id) return self.is_ready(node_id) + def connected(self, node_id): + """Return True iff the node_id is connected.""" + if node_id not in self._conns: + return False + return self._conns[node_id].connected() + def close(self, node_id=None): """Closes the connection to a particular node (if there is one). @@ -295,15 +290,13 @@ class KafkaClient(object): request (Struct): request object (not-encoded) Raises: - NodeNotReadyError: if node_id is not ready + AssertionError: if node_id is not in current cluster metadata Returns: - Future: resolves to Response struct + Future: resolves to Response struct or Error """ - if not self._can_send_request(node_id): - raise Errors.NodeNotReadyError("Attempt to send a request to node" - " which is not ready (node id %s)." - % node_id) + if not self._maybe_connect(node_id): + return Future().failure(Errors.NodeNotReadyError(node_id)) # Every request gets a response, except one special case: expect_response = True @@ -341,7 +334,7 @@ class KafkaClient(object): # Attempt to complete pending connections for node_id in list(self._connecting): - self._finish_connect(node_id) + self._maybe_connect(node_id) # Send a metadata request if needed metadata_timeout_ms = self._maybe_refresh_metadata() @@ -557,7 +550,7 @@ class KafkaClient(object): elif self._can_connect(node_id): log.debug("Initializing connection to node %s for metadata request", node_id) - self._initiate_connect(node_id) + self._maybe_connect(node_id) return 0 diff --git a/kafka/conn.py b/kafka/conn.py index 0ce469d..2b82b6d 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -106,24 +106,22 @@ class BrokerConnection(object): # in non-blocking mode, use repeated calls to socket.connect_ex # to check connection status request_timeout = self.config['request_timeout_ms'] / 1000.0 - if time.time() > request_timeout + self.last_attempt: + try: + ret = self._sock.connect_ex((self.host, self.port)) + except socket.error as ret: + pass + if not ret or ret == errno.EISCONN: + self.state = ConnectionStates.CONNECTED + elif ret not in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022): + log.error('Connect attempt to %s returned error %s.' + ' Disconnecting.', self, ret) + self.close() + self.last_failure = time.time() + elif time.time() > request_timeout + self.last_attempt: log.error('Connection attempt to %s timed out', self) self.close() # error=TimeoutError ? self.last_failure = time.time() - else: - try: - ret = self._sock.connect_ex((self.host, self.port)) - except socket.error as ret: - pass - if not ret or ret == errno.EISCONN: - self.state = ConnectionStates.CONNECTED - # WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems - elif ret not in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022): - log.error('Connect attempt to %s returned error %s.' - ' Disconnecting.', self, ret) - self.close() - self.last_failure = time.time() return self.state def blacked_out(self): @@ -141,6 +139,10 @@ class BrokerConnection(object): """Return True iff socket is connected.""" return self.state is ConnectionStates.CONNECTED + def connecting(self): + """Return True iff socket is in intermediate connecting state.""" + return self.state is ConnectionStates.CONNECTING + def close(self, error=None): """Close socket and fail all in-flight-requests. @@ -158,7 +160,7 @@ class BrokerConnection(object): self._rbuffer.seek(0) self._rbuffer.truncate() if error is None: - error = Errors.ConnectionError() + error = Errors.ConnectionError(str(self)) while self.in_flight_requests: ifr = self.in_flight_requests.popleft() ifr.future.failure(error) @@ -169,10 +171,12 @@ class BrokerConnection(object): Can block on network if request is larger than send_buffer_bytes """ future = Future() - if not self.connected(): - return future.failure(Errors.ConnectionError()) - if not self.can_send_more(): - return future.failure(Errors.TooManyInFlightRequests()) + if self.connecting(): + return future.failure(Errors.NodeNotReadyError(str(self))) + elif not self.connected(): + return future.failure(Errors.ConnectionError(str(self))) + elif not self.can_send_more(): + return future.failure(Errors.TooManyInFlightRequests(str(self))) correlation_id = self._next_correlation_id() header = RequestHeader(request, correlation_id=correlation_id, @@ -191,7 +195,7 @@ class BrokerConnection(object): self._sock.setblocking(False) except (AssertionError, ConnectionError) as e: log.exception("Error sending %s to %s", request, self) - error = Errors.ConnectionError(e) + error = Errors.ConnectionError("%s: %s" % (str(self), e)) self.close(error=error) return future.failure(error) log.debug('%s Request %d: %s', self, correlation_id, request) @@ -324,11 +328,9 @@ class BrokerConnection(object): ' initialized on the broker') elif ifr.correlation_id != recv_correlation_id: - - error = Errors.CorrelationIdError( - 'Correlation ids do not match: sent %d, recv %d' - % (ifr.correlation_id, recv_correlation_id)) + '%s: Correlation ids do not match: sent %d, recv %d' + % (str(self), ifr.correlation_id, recv_correlation_id)) ifr.future.failure(error) self.close() self._processing = False diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index f406a30..7112c7e 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -479,9 +479,6 @@ class Fetcher(six.Iterator): # so create a separate future and attach a callback to update it # based on response error codes future = Future() - if not self._client.ready(node_id): - return future.failure(Errors.NodeNotReadyError(node_id)) - _f = self._client.send(node_id, request) _f.add_callback(self._handle_offset_response, partition, future) _f.add_errback(lambda e: future.failure(e)) diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py index dca809e..b0a0981 100644 --- a/kafka/coordinator/base.py +++ b/kafka/coordinator/base.py @@ -186,7 +186,7 @@ class BaseCoordinator(object): self.coordinator_dead() return True - return not self._client.ready(self.coordinator_id) + return False def ensure_coordinator_known(self): """Block until the coordinator for this group is known @@ -288,9 +288,13 @@ class BaseCoordinator(object): return future def _failed_request(self, node_id, request, future, error): - log.error('Error sending %s to node %s [%s] -- marking coordinator dead', + log.error('Error sending %s to node %s [%s]', request.__class__.__name__, node_id, error) - self.coordinator_dead() + # Marking coordinator dead + # unless the error is caused by internal client pipelining + if not isinstance(error, (Errors.NodeNotReadyError, + Errors.TooManyInFlightRequests)): + self.coordinator_dead() future.failure(error) def _handle_join_group_response(self, future, response): @@ -388,7 +392,8 @@ class BaseCoordinator(object): def _send_sync_group_request(self, request): if self.coordinator_unknown(): - return Future().failure(Errors.GroupCoordinatorNotAvailableError()) + e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id) + return Future().failure(e) future = Future() _f = self._client.send(self.coordinator_id, request) _f.add_callback(self._handle_sync_group_response, future) @@ -439,7 +444,7 @@ class BaseCoordinator(object): Future: resolves to the node id of the coordinator """ node_id = self._client.least_loaded_node() - if node_id is None or not self._client.ready(node_id): + if node_id is None: return Future().failure(Errors.NoBrokersAvailable()) log.debug("Issuing group metadata request to broker %s", node_id) @@ -490,8 +495,8 @@ class BaseCoordinator(object): def coordinator_dead(self, error=None): """Mark the current coordinator as dead.""" if self.coordinator_id is not None: - log.info("Marking the coordinator dead (node %s): %s.", - self.coordinator_id, error) + log.warning("Marking the coordinator dead (node %s): %s.", + self.coordinator_id, error) self.coordinator_id = None def close(self): @@ -501,6 +506,7 @@ class BaseCoordinator(object): self._client.unschedule(self.heartbeat_task) except KeyError: pass + if not self.coordinator_unknown() and self.generation > 0: # this is a minimal effort attempt to leave the group. we do not # attempt any resending if the request fails or times out. @@ -634,7 +640,7 @@ class HeartbeatTask(object): self._client.schedule(self, time.time() + ttl) def _handle_heartbeat_failure(self, e): - log.warning("Heartbeat failed; retrying") + log.warning("Heartbeat failed (%s); retrying", e) self._request_in_flight = False etd = time.time() + self._coordinator.config['retry_backoff_ms'] / 1000.0 self._client.schedule(self, etd) diff --git a/test/test_client_async.py b/test/test_client_async.py index e0b98c4..884686d 100644 --- a/test/test_client_async.py +++ b/test/test_client_async.py @@ -41,7 +41,8 @@ def conn(mocker): [(0, 'foo', 12), (1, 'bar', 34)], # brokers [])) # topics conn.blacked_out.return_value = False - conn.connect.return_value = conn.state + conn.connect.side_effect = lambda: conn.state + conn.connected = lambda: conn.connect() is ConnectionStates.CONNECTED return conn @@ -76,7 +77,7 @@ def test_can_connect(conn): assert cli._can_connect(0) # Node is connected, can't reconnect - cli._initiate_connect(0) + assert cli._maybe_connect(0) is True assert not cli._can_connect(0) # Node is disconnected, can connect @@ -87,60 +88,47 @@ def test_can_connect(conn): conn.blacked_out.return_value = True assert not cli._can_connect(0) -def test_initiate_connect(conn): +def test_maybe_connect(conn): cli = KafkaClient() try: # Node not in metadata, raises AssertionError - cli._initiate_connect(2) + cli._maybe_connect(2) except AssertionError: pass else: assert False, 'Exception not raised' assert 0 not in cli._conns - state = cli._initiate_connect(0) + conn.state = ConnectionStates.DISCONNECTED + conn.connect.side_effect = lambda: ConnectionStates.CONNECTING + assert cli._maybe_connect(0) is False assert cli._conns[0] is conn - assert state is conn.state - - -def test_finish_connect(conn): - cli = KafkaClient() - try: - # Node not in metadata, raises AssertionError - cli._initiate_connect(2) - except AssertionError: - pass - else: - assert False, 'Exception not raised' - - assert 0 not in cli._conns - cli._initiate_connect(0) - - conn.connect.return_value = ConnectionStates.CONNECTING - state = cli._finish_connect(0) assert 0 in cli._connecting - assert state is ConnectionStates.CONNECTING - conn.connect.return_value = ConnectionStates.CONNECTED - state = cli._finish_connect(0) + conn.state = ConnectionStates.CONNECTING + conn.connect.side_effect = lambda: ConnectionStates.CONNECTED + assert cli._maybe_connect(0) is True assert 0 not in cli._connecting - assert state is ConnectionStates.CONNECTED # Failure to connect should trigger metadata update - assert not cli.cluster._need_update + assert cli.cluster._need_update is False cli._connecting.add(0) - conn.connect.return_value = ConnectionStates.DISCONNECTED - state = cli._finish_connect(0) + conn.state = ConnectionStates.CONNECTING + conn.connect.side_effect = lambda: ConnectionStates.DISCONNECTED + assert cli._maybe_connect(0) is False assert 0 not in cli._connecting - assert state is ConnectionStates.DISCONNECTED - assert cli.cluster._need_update + assert cli.cluster._need_update is True def test_ready(conn): cli = KafkaClient() - # Node not in metadata - assert not cli.ready(2) + # Node not in metadata raises Exception + try: + cli.ready(2) + assert False, 'Exception not raised' + except AssertionError: + pass # Node in metadata will connect assert 0 not in cli._conns @@ -176,13 +164,13 @@ def test_ready(conn): # disconnected nodes, not ready assert cli.ready(0) assert cli.ready(1) - conn.connected.return_value = False + conn.state = ConnectionStates.DISCONNECTED assert not cli.ready(0) - conn.connected.return_value = True # connecting node connects cli._connecting.add(0) - conn.connected.return_value = False + conn.state = ConnectionStates.CONNECTING + conn.connect.side_effect = lambda: ConnectionStates.CONNECTED cli.ready(0) assert 0 not in cli._connecting assert cli._conns[0].connect.called_with() @@ -195,13 +183,13 @@ def test_close(conn): cli.close(2) # Single node close - cli._initiate_connect(0) + cli._maybe_connect(0) assert not conn.close.call_count cli.close(0) assert conn.close.call_count == 1 # All node close - cli._initiate_connect(1) + cli._maybe_connect(1) cli.close() assert conn.close.call_count == 3 @@ -213,7 +201,7 @@ def test_is_disconnected(conn): conn.state = ConnectionStates.DISCONNECTED assert not cli.is_disconnected(0) - cli._initiate_connect(0) + cli._maybe_connect(0) assert cli.is_disconnected(0) conn.state = ConnectionStates.CONNECTING @@ -225,14 +213,22 @@ def test_is_disconnected(conn): def test_send(conn): cli = KafkaClient() + + # Send to unknown node => raises AssertionError try: cli.send(2, None) - except Errors.NodeNotReadyError: + assert False, 'Exception not raised' + except AssertionError: pass - else: - assert False, 'NodeNotReadyError not raised' - cli._initiate_connect(0) + # Send to disconnected node => NodeNotReady + conn.state = ConnectionStates.DISCONNECTED + f = cli.send(0, None) + assert f.failed() + assert isinstance(f.exception, Errors.NodeNotReadyError) + + conn.state = ConnectionStates.CONNECTED + cli._maybe_connect(0) # ProduceRequest w/ 0 required_acks -> no response request = ProduceRequest(0, 0, []) ret = cli.send(0, request) diff --git a/test/test_conn.py b/test/test_conn.py index f0ef8fb..d394f74 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -1,242 +1,82 @@ +# pylint: skip-file +from __future__ import absolute_import + +from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET import socket -import struct -from threading import Thread - -import mock -from . import unittest - -from kafka.common import ConnectionError -from kafka.conn import KafkaConnection, collect_hosts, DEFAULT_SOCKET_TIMEOUT_SECONDS - -class ConnTest(unittest.TestCase): - def setUp(self): - - self.config = { - 'host': 'localhost', - 'port': 9090, - 'request_id': 0, - 'payload': b'test data', - 'payload2': b'another packet' - } - - # Mocking socket.create_connection will cause _sock to always be a - # MagicMock() - patcher = mock.patch('socket.create_connection', spec=True) - self.MockCreateConn = patcher.start() - self.addCleanup(patcher.stop) - - # Also mock socket.sendall() to appear successful - self.MockCreateConn().sendall.return_value = None - - # And mock socket.recv() to return two payloads, then '', then raise - # Note that this currently ignores the num_bytes parameter to sock.recv() - payload_size = len(self.config['payload']) - payload2_size = len(self.config['payload2']) - self.MockCreateConn().recv.side_effect = [ - struct.pack('>i', payload_size), - struct.pack('>%ds' % payload_size, self.config['payload']), - struct.pack('>i', payload2_size), - struct.pack('>%ds' % payload2_size, self.config['payload2']), - b'' - ] - - # Create a connection object - self.conn = KafkaConnection(self.config['host'], self.config['port']) - - # Reset any mock counts caused by __init__ - self.MockCreateConn.reset_mock() - - def test_collect_hosts__happy_path(self): - hosts = "localhost:1234,localhost" - results = collect_hosts(hosts) - - self.assertEqual(set(results), set([ - ('localhost', 1234, socket.AF_INET), - ('localhost', 9092, socket.AF_INET), - ])) - - def test_collect_hosts__ipv6(self): - hosts = "[localhost]:1234,[2001:1000:2000::1],[2001:1000:2000::1]:1234" - results = collect_hosts(hosts) - - self.assertEqual(set(results), set([ - ('localhost', 1234, socket.AF_INET6), - ('2001:1000:2000::1', 9092, socket.AF_INET6), - ('2001:1000:2000::1', 1234, socket.AF_INET6), - ])) - - def test_collect_hosts__string_list(self): - hosts = [ - 'localhost:1234', - 'localhost', - '[localhost]', - '2001::1', - '[2001::1]:1234', - ] - - results = collect_hosts(hosts) - - self.assertEqual(set(results), set([ - ('localhost', 1234, socket.AF_INET), - ('localhost', 9092, socket.AF_INET), - ('localhost', 9092, socket.AF_INET6), - ('2001::1', 9092, socket.AF_INET6), - ('2001::1', 1234, socket.AF_INET6), - ])) - - def test_collect_hosts__with_spaces(self): - hosts = "localhost:1234, localhost" - results = collect_hosts(hosts) - - self.assertEqual(set(results), set([ - ('localhost', 1234, socket.AF_INET), - ('localhost', 9092, socket.AF_INET), - ])) - - - def test_send(self): - self.conn.send(self.config['request_id'], self.config['payload']) - self.conn._sock.sendall.assert_called_with(self.config['payload']) - - def test_init_creates_socket_connection(self): - KafkaConnection(self.config['host'], self.config['port']) - self.MockCreateConn.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS) - - def test_init_failure_raises_connection_error(self): - - def raise_error(*args): - raise socket.error - - assert socket.create_connection is self.MockCreateConn - socket.create_connection.side_effect=raise_error - with self.assertRaises(ConnectionError): - KafkaConnection(self.config['host'], self.config['port']) - - def test_send__reconnects_on_dirty_conn(self): - - # Dirty the connection - try: - self.conn._raise_connection_error() - except ConnectionError: - pass - - # Now test that sending attempts to reconnect - self.assertEqual(self.MockCreateConn.call_count, 0) - self.conn.send(self.config['request_id'], self.config['payload']) - self.assertEqual(self.MockCreateConn.call_count, 1) - - def test_send__failure_sets_dirty_connection(self): - - def raise_error(*args): - raise socket.error - - assert isinstance(self.conn._sock, mock.Mock) - self.conn._sock.sendall.side_effect=raise_error - try: - self.conn.send(self.config['request_id'], self.config['payload']) - except ConnectionError: - self.assertIsNone(self.conn._sock) - - def test_recv(self): - - self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload']) - - def test_recv__reconnects_on_dirty_conn(self): - - # Dirty the connection - try: - self.conn._raise_connection_error() - except ConnectionError: - pass - - # Now test that recv'ing attempts to reconnect - self.assertEqual(self.MockCreateConn.call_count, 0) - self.conn.recv(self.config['request_id']) - self.assertEqual(self.MockCreateConn.call_count, 1) - - def test_recv__failure_sets_dirty_connection(self): - - def raise_error(*args): - raise socket.error - - # test that recv'ing attempts to reconnect - assert isinstance(self.conn._sock, mock.Mock) - self.conn._sock.recv.side_effect=raise_error - try: - self.conn.recv(self.config['request_id']) - except ConnectionError: - self.assertIsNone(self.conn._sock) - - def test_recv__doesnt_consume_extra_data_in_stream(self): - - # Here just test that each call to recv will return a single payload - self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload']) - self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload2']) - - def test_get_connected_socket(self): - s = self.conn.get_connected_socket() - - self.assertEqual(s, self.MockCreateConn()) - - def test_get_connected_socket_on_dirty_conn(self): - # Dirty the connection - try: - self.conn._raise_connection_error() - except ConnectionError: - pass - - # Test that get_connected_socket tries to connect - self.assertEqual(self.MockCreateConn.call_count, 0) - self.conn.get_connected_socket() - self.assertEqual(self.MockCreateConn.call_count, 1) +import time + +import pytest + +from kafka.conn import BrokerConnection, ConnectionStates + + +@pytest.fixture +def socket(mocker): + socket = mocker.MagicMock() + socket.connect_ex.return_value = 0 + mocker.patch('socket.socket', return_value=socket) + return socket + + +@pytest.fixture +def conn(socket): + conn = BrokerConnection('localhost', 9092, socket.AF_INET) + return conn + + +@pytest.mark.parametrize("states", [ + (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING),), + (([EALREADY, EALREADY], ConnectionStates.CONNECTING),), + (([0], ConnectionStates.CONNECTED),), + (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING), + ([ECONNRESET], ConnectionStates.DISCONNECTED)), + (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING), + ([EALREADY], ConnectionStates.CONNECTING), + ([EISCONN], ConnectionStates.CONNECTED)), +]) +def test_connect(socket, conn, states): + assert conn.state is ConnectionStates.DISCONNECTED + + for errno, state in states: + socket.connect_ex.side_effect = errno + conn.connect() + assert conn.state is state + + +def test_connect_timeout(socket, conn): + assert conn.state is ConnectionStates.DISCONNECTED + + # Initial connect returns EINPROGRESS + # immediate inline connect returns EALREADY + # second explicit connect returns EALREADY + # third explicit connect returns EALREADY and times out via last_attempt + socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY] + conn.connect() + assert conn.state is ConnectionStates.CONNECTING + conn.connect() + assert conn.state is ConnectionStates.CONNECTING + conn.last_attempt = 0 + conn.connect() + assert conn.state is ConnectionStates.DISCONNECTED + + +def test_blacked_out(conn): + assert not conn.blacked_out() + conn.last_attempt = time.time() + assert conn.blacked_out() + + +def test_connected(conn): + assert not conn.connected() + conn.state = ConnectionStates.CONNECTED + assert conn.connected() + - def test_close__object_is_reusable(self): +def test_connecting(conn): + assert not conn.connecting() + conn.state = ConnectionStates.CONNECTING + assert conn.connecting() + conn.state = ConnectionStates.CONNECTED + assert not conn.connecting() - # test that sending to a closed connection - # will re-connect and send data to the socket - self.conn.close() - self.conn.send(self.config['request_id'], self.config['payload']) - self.assertEqual(self.MockCreateConn.call_count, 1) - self.conn._sock.sendall.assert_called_with(self.config['payload']) - - -class TestKafkaConnection(unittest.TestCase): - @mock.patch('socket.create_connection') - def test_copy(self, socket): - """KafkaConnection copies work as expected""" - - conn = KafkaConnection('kafka', 9092) - self.assertEqual(socket.call_count, 1) - - copy = conn.copy() - self.assertEqual(socket.call_count, 1) - self.assertEqual(copy.host, 'kafka') - self.assertEqual(copy.port, 9092) - self.assertEqual(copy._sock, None) - - copy.reinit() - self.assertEqual(socket.call_count, 2) - self.assertNotEqual(copy._sock, None) - - @mock.patch('socket.create_connection') - def test_copy_thread(self, socket): - """KafkaConnection copies work in other threads""" - - err = [] - copy = KafkaConnection('kafka', 9092).copy() - - def thread_func(err, copy): - try: - self.assertEqual(copy.host, 'kafka') - self.assertEqual(copy.port, 9092) - self.assertNotEqual(copy._sock, None) - except Exception as e: - err.append(e) - else: - err.append(None) - thread = Thread(target=thread_func, args=(err, copy)) - thread.start() - thread.join() - - self.assertEqual(err, [None]) - self.assertEqual(socket.call_count, 2) +# TODO: test_send, test_recv, test_can_send_more, test_close diff --git a/test/test_conn_legacy.py b/test/test_conn_legacy.py new file mode 100644 index 0000000..f0ef8fb --- /dev/null +++ b/test/test_conn_legacy.py @@ -0,0 +1,242 @@ +import socket +import struct +from threading import Thread + +import mock +from . import unittest + +from kafka.common import ConnectionError +from kafka.conn import KafkaConnection, collect_hosts, DEFAULT_SOCKET_TIMEOUT_SECONDS + +class ConnTest(unittest.TestCase): + def setUp(self): + + self.config = { + 'host': 'localhost', + 'port': 9090, + 'request_id': 0, + 'payload': b'test data', + 'payload2': b'another packet' + } + + # Mocking socket.create_connection will cause _sock to always be a + # MagicMock() + patcher = mock.patch('socket.create_connection', spec=True) + self.MockCreateConn = patcher.start() + self.addCleanup(patcher.stop) + + # Also mock socket.sendall() to appear successful + self.MockCreateConn().sendall.return_value = None + + # And mock socket.recv() to return two payloads, then '', then raise + # Note that this currently ignores the num_bytes parameter to sock.recv() + payload_size = len(self.config['payload']) + payload2_size = len(self.config['payload2']) + self.MockCreateConn().recv.side_effect = [ + struct.pack('>i', payload_size), + struct.pack('>%ds' % payload_size, self.config['payload']), + struct.pack('>i', payload2_size), + struct.pack('>%ds' % payload2_size, self.config['payload2']), + b'' + ] + + # Create a connection object + self.conn = KafkaConnection(self.config['host'], self.config['port']) + + # Reset any mock counts caused by __init__ + self.MockCreateConn.reset_mock() + + def test_collect_hosts__happy_path(self): + hosts = "localhost:1234,localhost" + results = collect_hosts(hosts) + + self.assertEqual(set(results), set([ + ('localhost', 1234, socket.AF_INET), + ('localhost', 9092, socket.AF_INET), + ])) + + def test_collect_hosts__ipv6(self): + hosts = "[localhost]:1234,[2001:1000:2000::1],[2001:1000:2000::1]:1234" + results = collect_hosts(hosts) + + self.assertEqual(set(results), set([ + ('localhost', 1234, socket.AF_INET6), + ('2001:1000:2000::1', 9092, socket.AF_INET6), + ('2001:1000:2000::1', 1234, socket.AF_INET6), + ])) + + def test_collect_hosts__string_list(self): + hosts = [ + 'localhost:1234', + 'localhost', + '[localhost]', + '2001::1', + '[2001::1]:1234', + ] + + results = collect_hosts(hosts) + + self.assertEqual(set(results), set([ + ('localhost', 1234, socket.AF_INET), + ('localhost', 9092, socket.AF_INET), + ('localhost', 9092, socket.AF_INET6), + ('2001::1', 9092, socket.AF_INET6), + ('2001::1', 1234, socket.AF_INET6), + ])) + + def test_collect_hosts__with_spaces(self): + hosts = "localhost:1234, localhost" + results = collect_hosts(hosts) + + self.assertEqual(set(results), set([ + ('localhost', 1234, socket.AF_INET), + ('localhost', 9092, socket.AF_INET), + ])) + + + def test_send(self): + self.conn.send(self.config['request_id'], self.config['payload']) + self.conn._sock.sendall.assert_called_with(self.config['payload']) + + def test_init_creates_socket_connection(self): + KafkaConnection(self.config['host'], self.config['port']) + self.MockCreateConn.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS) + + def test_init_failure_raises_connection_error(self): + + def raise_error(*args): + raise socket.error + + assert socket.create_connection is self.MockCreateConn + socket.create_connection.side_effect=raise_error + with self.assertRaises(ConnectionError): + KafkaConnection(self.config['host'], self.config['port']) + + def test_send__reconnects_on_dirty_conn(self): + + # Dirty the connection + try: + self.conn._raise_connection_error() + except ConnectionError: + pass + + # Now test that sending attempts to reconnect + self.assertEqual(self.MockCreateConn.call_count, 0) + self.conn.send(self.config['request_id'], self.config['payload']) + self.assertEqual(self.MockCreateConn.call_count, 1) + + def test_send__failure_sets_dirty_connection(self): + + def raise_error(*args): + raise socket.error + + assert isinstance(self.conn._sock, mock.Mock) + self.conn._sock.sendall.side_effect=raise_error + try: + self.conn.send(self.config['request_id'], self.config['payload']) + except ConnectionError: + self.assertIsNone(self.conn._sock) + + def test_recv(self): + + self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload']) + + def test_recv__reconnects_on_dirty_conn(self): + + # Dirty the connection + try: + self.conn._raise_connection_error() + except ConnectionError: + pass + + # Now test that recv'ing attempts to reconnect + self.assertEqual(self.MockCreateConn.call_count, 0) + self.conn.recv(self.config['request_id']) + self.assertEqual(self.MockCreateConn.call_count, 1) + + def test_recv__failure_sets_dirty_connection(self): + + def raise_error(*args): + raise socket.error + + # test that recv'ing attempts to reconnect + assert isinstance(self.conn._sock, mock.Mock) + self.conn._sock.recv.side_effect=raise_error + try: + self.conn.recv(self.config['request_id']) + except ConnectionError: + self.assertIsNone(self.conn._sock) + + def test_recv__doesnt_consume_extra_data_in_stream(self): + + # Here just test that each call to recv will return a single payload + self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload']) + self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload2']) + + def test_get_connected_socket(self): + s = self.conn.get_connected_socket() + + self.assertEqual(s, self.MockCreateConn()) + + def test_get_connected_socket_on_dirty_conn(self): + # Dirty the connection + try: + self.conn._raise_connection_error() + except ConnectionError: + pass + + # Test that get_connected_socket tries to connect + self.assertEqual(self.MockCreateConn.call_count, 0) + self.conn.get_connected_socket() + self.assertEqual(self.MockCreateConn.call_count, 1) + + def test_close__object_is_reusable(self): + + # test that sending to a closed connection + # will re-connect and send data to the socket + self.conn.close() + self.conn.send(self.config['request_id'], self.config['payload']) + self.assertEqual(self.MockCreateConn.call_count, 1) + self.conn._sock.sendall.assert_called_with(self.config['payload']) + + +class TestKafkaConnection(unittest.TestCase): + @mock.patch('socket.create_connection') + def test_copy(self, socket): + """KafkaConnection copies work as expected""" + + conn = KafkaConnection('kafka', 9092) + self.assertEqual(socket.call_count, 1) + + copy = conn.copy() + self.assertEqual(socket.call_count, 1) + self.assertEqual(copy.host, 'kafka') + self.assertEqual(copy.port, 9092) + self.assertEqual(copy._sock, None) + + copy.reinit() + self.assertEqual(socket.call_count, 2) + self.assertNotEqual(copy._sock, None) + + @mock.patch('socket.create_connection') + def test_copy_thread(self, socket): + """KafkaConnection copies work in other threads""" + + err = [] + copy = KafkaConnection('kafka', 9092).copy() + + def thread_func(err, copy): + try: + self.assertEqual(copy.host, 'kafka') + self.assertEqual(copy.port, 9092) + self.assertNotEqual(copy._sock, None) + except Exception as e: + err.append(e) + else: + err.append(None) + thread = Thread(target=thread_func, args=(err, copy)) + thread.start() + thread.join() + + self.assertEqual(err, [None]) + self.assertEqual(socket.call_count, 2) |