diff options
-rw-r--r-- | kafka/conn.py | 78 | ||||
-rw-r--r-- | test/test_conn.py | 139 |
2 files changed, 177 insertions, 40 deletions
diff --git a/kafka/conn.py b/kafka/conn.py index b5dafd8..0d17cb8 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -67,7 +67,11 @@ class KafkaConnection(local): ################### def _raise_connection_error(self): - self._dirty = True + # Cleanup socket if we have one + if self._sock: + self.close() + + # And then raise raise ConnectionError("Kafka @ {0}:{1} went away".format(self.host, self.port)) def _read_bytes(self, num_bytes): @@ -75,20 +79,26 @@ class KafkaConnection(local): responses = [] log.debug("About to read %d bytes from Kafka", num_bytes) - if self._dirty: + + # Make sure we have a connection + if not self._sock: self.reinit() while bytes_left: + try: data = self._sock.recv(min(bytes_left, 4096)) + + # Receiving empty string from recv signals + # that the socket is in error. we will never get + # more data from this socket + if data == '': + raise socket.error('Not enough data to read message -- did server kill socket?') + except socket.error: log.exception('Unable to receive data from Kafka') self._raise_connection_error() - if data == '': - log.error("Not enough data to read this response") - self._raise_connection_error() - bytes_left -= len(data) log.debug("Read %d/%d bytes from Kafka", num_bytes - bytes_left, num_bytes) responses.append(data) @@ -102,26 +112,34 @@ class KafkaConnection(local): # TODO multiplex socket communication to allow for multi-threaded clients def send(self, request_id, payload): - "Send a request to Kafka" + """ + Send a request to Kafka + param: request_id -- can be any int (used only for debug logging...) + param: payload -- an encoded kafka packet (see KafkaProtocol) + """ + log.debug("About to send %d bytes to Kafka, request %d" % (len(payload), request_id)) + + # Make sure we have a connection + if not self._sock: + self.reinit() + try: - if self._dirty: - self.reinit() - sent = self._sock.sendall(payload) - if sent is not None: - self._raise_connection_error() + self._sock.sendall(payload) except socket.error: log.exception('Unable to send payload to Kafka') self._raise_connection_error() def recv(self, request_id): """ - Get a response from Kafka + Get a response packet from Kafka + param: request_id -- can be any int (only used for debug logging...) + returns encoded kafka packet response from server as type str """ log.debug("Reading response %d from Kafka" % request_id) + # Read the size off of the header resp = self._read_bytes(4) - (size,) = struct.unpack('>i', resp) # Read the remainder of the response @@ -132,6 +150,7 @@ class KafkaConnection(local): """ Create an inactive copy of the connection object A reinit() has to be done on the copy before it can be used again + return a new KafkaConnection object """ c = copy.deepcopy(self) c._sock = None @@ -139,15 +158,38 @@ class KafkaConnection(local): def close(self): """ - Close this connection + Shutdown and close the connection socket """ + log.debug("Closing socket connection for %s:%d" % (self.host, self.port)) if self._sock: + # Call shutdown to be a good TCP client + # But expect an error if the socket has already been + # closed by the server + try: + self._sock.shutdown(socket.SHUT_RDWR) + except socket.error: + pass + + # Closing the socket should always succeed self._sock.close() + self._sock = None + else: + log.debug("No socket found to close!") def reinit(self): """ Re-initialize the socket connection + close current socket (if open) + and start a fresh connection + raise ConnectionError on error """ - self.close() - self._sock = socket.create_connection((self.host, self.port), self.timeout) - self._dirty = False + log.debug("Reinitializing socket connection for %s:%d" % (self.host, self.port)) + + if self._sock: + self.close() + + try: + self._sock = socket.create_connection((self.host, self.port), self.timeout) + except socket.error: + log.exception('Unable to connect to kafka broker at %s:%d' % (self.host, self.port)) + self._raise_connection_error() diff --git a/test/test_conn.py b/test/test_conn.py index 4ab6d4f..184a99e 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -1,13 +1,52 @@ -import os -import random +import socket import struct + +import mock import unittest2 -import kafka.conn + +from kafka.common import * +from kafka.conn import * class ConnTest(unittest2.TestCase): + def setUp(self): + self.config = { + 'host': 'localhost', + 'port': 9090, + 'request_id': 0, + 'payload': 'test data', + 'payload2': 'another packet' + } + + # Mocking socket.create_connection will cause _sock to always be a + # MagicMock() + patcher = mock.patch('socket.create_connection', spec=True) + self.MockCreateConn = patcher.start() + self.addCleanup(patcher.stop) + + # Also mock socket.sendall() to appear successful + socket.create_connection().sendall.return_value = None + + # And mock socket.recv() to return two payloads, then '', then raise + # Note that this currently ignores the num_bytes parameter to sock.recv() + payload_size = len(self.config['payload']) + payload2_size = len(self.config['payload2']) + socket.create_connection().recv.side_effect = [ + struct.pack('>i', payload_size), + struct.pack('>%ds' % payload_size, self.config['payload']), + struct.pack('>i', payload2_size), + struct.pack('>%ds' % payload2_size, self.config['payload2']), + '' + ] + + # Create a connection object + self.conn = KafkaConnection(self.config['host'], self.config['port']) + + # Reset any mock counts caused by __init__ + socket.create_connection.reset_mock() + def test_collect_hosts__happy_path(self): hosts = "localhost:1234,localhost" - results = kafka.conn.collect_hosts(hosts) + results = collect_hosts(hosts) self.assertEqual(set(results), set([ ('localhost', 1234), @@ -20,7 +59,7 @@ class ConnTest(unittest2.TestCase): 'localhost', ] - results = kafka.conn.collect_hosts(hosts) + results = collect_hosts(hosts) self.assertEqual(set(results), set([ ('localhost', 1234), @@ -29,41 +68,97 @@ class ConnTest(unittest2.TestCase): def test_collect_hosts__with_spaces(self): hosts = "localhost:1234, localhost" - results = kafka.conn.collect_hosts(hosts) + results = collect_hosts(hosts) self.assertEqual(set(results), set([ ('localhost', 1234), ('localhost', 9092), ])) - @unittest2.skip("Not Implemented") def test_send(self): - pass + self.conn.send(self.config['request_id'], self.config['payload']) + self.conn._sock.sendall.assert_called_with(self.config['payload']) + + def test_init_creates_socket_connection(self): + KafkaConnection(self.config['host'], self.config['port']) + socket.create_connection.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS) + + def test_init_failure_raises_connection_error(self): + + def raise_error(*args): + raise socket.error + + assert socket.create_connection is self.MockCreateConn + socket.create_connection.side_effect=raise_error + with self.assertRaises(ConnectionError): + KafkaConnection(self.config['host'], self.config['port']) - @unittest2.skip("Not Implemented") def test_send__reconnects_on_dirty_conn(self): - pass - @unittest2.skip("Not Implemented") + # Dirty the connection + try: + self.conn._raise_connection_error() + except ConnectionError: + pass + + # Now test that sending attempts to reconnect + self.assertEqual(socket.create_connection.call_count, 0) + self.conn.send(self.config['request_id'], self.config['payload']) + self.assertEqual(socket.create_connection.call_count, 1) + def test_send__failure_sets_dirty_connection(self): - pass - @unittest2.skip("Not Implemented") + def raise_error(*args): + raise socket.error + + assert isinstance(self.conn._sock, mock.Mock) + self.conn._sock.sendall.side_effect=raise_error + try: + self.conn.send(self.config['request_id'], self.config['payload']) + except ConnectionError: + self.assertIsNone(self.conn._sock) + def test_recv(self): - pass - @unittest2.skip("Not Implemented") + self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload']) + def test_recv__reconnects_on_dirty_conn(self): - pass - @unittest2.skip("Not Implemented") + # Dirty the connection + try: + self.conn._raise_connection_error() + except ConnectionError: + pass + + # Now test that recv'ing attempts to reconnect + self.assertEqual(socket.create_connection.call_count, 0) + self.conn.recv(self.config['request_id']) + self.assertEqual(socket.create_connection.call_count, 1) + def test_recv__failure_sets_dirty_connection(self): - pass - @unittest2.skip("Not Implemented") + def raise_error(*args): + raise socket.error + + # test that recv'ing attempts to reconnect + assert isinstance(self.conn._sock, mock.Mock) + self.conn._sock.recv.side_effect=raise_error + try: + self.conn.recv(self.config['request_id']) + except ConnectionError: + self.assertIsNone(self.conn._sock) + def test_recv__doesnt_consume_extra_data_in_stream(self): - pass - @unittest2.skip("Not Implemented") + # Here just test that each call to recv will return a single payload + self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload']) + self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload2']) + def test_close__object_is_reusable(self): - pass + + # test that sending to a closed connection + # will re-connect and send data to the socket + self.conn.close() + self.conn.send(self.config['request_id'], self.config['payload']) + self.assertEqual(socket.create_connection.call_count, 1) + self.conn._sock.sendall.assert_called_with(self.config['payload']) |