summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kafka/conn.py78
-rw-r--r--test/test_conn.py139
2 files changed, 177 insertions, 40 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index b5dafd8..0d17cb8 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -67,7 +67,11 @@ class KafkaConnection(local):
###################
def _raise_connection_error(self):
- self._dirty = True
+ # Cleanup socket if we have one
+ if self._sock:
+ self.close()
+
+ # And then raise
raise ConnectionError("Kafka @ {0}:{1} went away".format(self.host, self.port))
def _read_bytes(self, num_bytes):
@@ -75,20 +79,26 @@ class KafkaConnection(local):
responses = []
log.debug("About to read %d bytes from Kafka", num_bytes)
- if self._dirty:
+
+ # Make sure we have a connection
+ if not self._sock:
self.reinit()
while bytes_left:
+
try:
data = self._sock.recv(min(bytes_left, 4096))
+
+ # Receiving empty string from recv signals
+ # that the socket is in error. we will never get
+ # more data from this socket
+ if data == '':
+ raise socket.error('Not enough data to read message -- did server kill socket?')
+
except socket.error:
log.exception('Unable to receive data from Kafka')
self._raise_connection_error()
- if data == '':
- log.error("Not enough data to read this response")
- self._raise_connection_error()
-
bytes_left -= len(data)
log.debug("Read %d/%d bytes from Kafka", num_bytes - bytes_left, num_bytes)
responses.append(data)
@@ -102,26 +112,34 @@ class KafkaConnection(local):
# TODO multiplex socket communication to allow for multi-threaded clients
def send(self, request_id, payload):
- "Send a request to Kafka"
+ """
+ Send a request to Kafka
+ param: request_id -- can be any int (used only for debug logging...)
+ param: payload -- an encoded kafka packet (see KafkaProtocol)
+ """
+
log.debug("About to send %d bytes to Kafka, request %d" % (len(payload), request_id))
+
+ # Make sure we have a connection
+ if not self._sock:
+ self.reinit()
+
try:
- if self._dirty:
- self.reinit()
- sent = self._sock.sendall(payload)
- if sent is not None:
- self._raise_connection_error()
+ self._sock.sendall(payload)
except socket.error:
log.exception('Unable to send payload to Kafka')
self._raise_connection_error()
def recv(self, request_id):
"""
- Get a response from Kafka
+ Get a response packet from Kafka
+ param: request_id -- can be any int (only used for debug logging...)
+ returns encoded kafka packet response from server as type str
"""
log.debug("Reading response %d from Kafka" % request_id)
+
# Read the size off of the header
resp = self._read_bytes(4)
-
(size,) = struct.unpack('>i', resp)
# Read the remainder of the response
@@ -132,6 +150,7 @@ class KafkaConnection(local):
"""
Create an inactive copy of the connection object
A reinit() has to be done on the copy before it can be used again
+ return a new KafkaConnection object
"""
c = copy.deepcopy(self)
c._sock = None
@@ -139,15 +158,38 @@ class KafkaConnection(local):
def close(self):
"""
- Close this connection
+ Shutdown and close the connection socket
"""
+ log.debug("Closing socket connection for %s:%d" % (self.host, self.port))
if self._sock:
+ # Call shutdown to be a good TCP client
+ # But expect an error if the socket has already been
+ # closed by the server
+ try:
+ self._sock.shutdown(socket.SHUT_RDWR)
+ except socket.error:
+ pass
+
+ # Closing the socket should always succeed
self._sock.close()
+ self._sock = None
+ else:
+ log.debug("No socket found to close!")
def reinit(self):
"""
Re-initialize the socket connection
+ close current socket (if open)
+ and start a fresh connection
+ raise ConnectionError on error
"""
- self.close()
- self._sock = socket.create_connection((self.host, self.port), self.timeout)
- self._dirty = False
+ log.debug("Reinitializing socket connection for %s:%d" % (self.host, self.port))
+
+ if self._sock:
+ self.close()
+
+ try:
+ self._sock = socket.create_connection((self.host, self.port), self.timeout)
+ except socket.error:
+ log.exception('Unable to connect to kafka broker at %s:%d' % (self.host, self.port))
+ self._raise_connection_error()
diff --git a/test/test_conn.py b/test/test_conn.py
index 4ab6d4f..184a99e 100644
--- a/test/test_conn.py
+++ b/test/test_conn.py
@@ -1,13 +1,52 @@
-import os
-import random
+import socket
import struct
+
+import mock
import unittest2
-import kafka.conn
+
+from kafka.common import *
+from kafka.conn import *
class ConnTest(unittest2.TestCase):
+ def setUp(self):
+ self.config = {
+ 'host': 'localhost',
+ 'port': 9090,
+ 'request_id': 0,
+ 'payload': 'test data',
+ 'payload2': 'another packet'
+ }
+
+ # Mocking socket.create_connection will cause _sock to always be a
+ # MagicMock()
+ patcher = mock.patch('socket.create_connection', spec=True)
+ self.MockCreateConn = patcher.start()
+ self.addCleanup(patcher.stop)
+
+ # Also mock socket.sendall() to appear successful
+ socket.create_connection().sendall.return_value = None
+
+ # And mock socket.recv() to return two payloads, then '', then raise
+ # Note that this currently ignores the num_bytes parameter to sock.recv()
+ payload_size = len(self.config['payload'])
+ payload2_size = len(self.config['payload2'])
+ socket.create_connection().recv.side_effect = [
+ struct.pack('>i', payload_size),
+ struct.pack('>%ds' % payload_size, self.config['payload']),
+ struct.pack('>i', payload2_size),
+ struct.pack('>%ds' % payload2_size, self.config['payload2']),
+ ''
+ ]
+
+ # Create a connection object
+ self.conn = KafkaConnection(self.config['host'], self.config['port'])
+
+ # Reset any mock counts caused by __init__
+ socket.create_connection.reset_mock()
+
def test_collect_hosts__happy_path(self):
hosts = "localhost:1234,localhost"
- results = kafka.conn.collect_hosts(hosts)
+ results = collect_hosts(hosts)
self.assertEqual(set(results), set([
('localhost', 1234),
@@ -20,7 +59,7 @@ class ConnTest(unittest2.TestCase):
'localhost',
]
- results = kafka.conn.collect_hosts(hosts)
+ results = collect_hosts(hosts)
self.assertEqual(set(results), set([
('localhost', 1234),
@@ -29,41 +68,97 @@ class ConnTest(unittest2.TestCase):
def test_collect_hosts__with_spaces(self):
hosts = "localhost:1234, localhost"
- results = kafka.conn.collect_hosts(hosts)
+ results = collect_hosts(hosts)
self.assertEqual(set(results), set([
('localhost', 1234),
('localhost', 9092),
]))
- @unittest2.skip("Not Implemented")
def test_send(self):
- pass
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ self.conn._sock.sendall.assert_called_with(self.config['payload'])
+
+ def test_init_creates_socket_connection(self):
+ KafkaConnection(self.config['host'], self.config['port'])
+ socket.create_connection.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS)
+
+ def test_init_failure_raises_connection_error(self):
+
+ def raise_error(*args):
+ raise socket.error
+
+ assert socket.create_connection is self.MockCreateConn
+ socket.create_connection.side_effect=raise_error
+ with self.assertRaises(ConnectionError):
+ KafkaConnection(self.config['host'], self.config['port'])
- @unittest2.skip("Not Implemented")
def test_send__reconnects_on_dirty_conn(self):
- pass
- @unittest2.skip("Not Implemented")
+ # Dirty the connection
+ try:
+ self.conn._raise_connection_error()
+ except ConnectionError:
+ pass
+
+ # Now test that sending attempts to reconnect
+ self.assertEqual(socket.create_connection.call_count, 0)
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ self.assertEqual(socket.create_connection.call_count, 1)
+
def test_send__failure_sets_dirty_connection(self):
- pass
- @unittest2.skip("Not Implemented")
+ def raise_error(*args):
+ raise socket.error
+
+ assert isinstance(self.conn._sock, mock.Mock)
+ self.conn._sock.sendall.side_effect=raise_error
+ try:
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ except ConnectionError:
+ self.assertIsNone(self.conn._sock)
+
def test_recv(self):
- pass
- @unittest2.skip("Not Implemented")
+ self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload'])
+
def test_recv__reconnects_on_dirty_conn(self):
- pass
- @unittest2.skip("Not Implemented")
+ # Dirty the connection
+ try:
+ self.conn._raise_connection_error()
+ except ConnectionError:
+ pass
+
+ # Now test that recv'ing attempts to reconnect
+ self.assertEqual(socket.create_connection.call_count, 0)
+ self.conn.recv(self.config['request_id'])
+ self.assertEqual(socket.create_connection.call_count, 1)
+
def test_recv__failure_sets_dirty_connection(self):
- pass
- @unittest2.skip("Not Implemented")
+ def raise_error(*args):
+ raise socket.error
+
+ # test that recv'ing attempts to reconnect
+ assert isinstance(self.conn._sock, mock.Mock)
+ self.conn._sock.recv.side_effect=raise_error
+ try:
+ self.conn.recv(self.config['request_id'])
+ except ConnectionError:
+ self.assertIsNone(self.conn._sock)
+
def test_recv__doesnt_consume_extra_data_in_stream(self):
- pass
- @unittest2.skip("Not Implemented")
+ # Here just test that each call to recv will return a single payload
+ self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload'])
+ self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload2'])
+
def test_close__object_is_reusable(self):
- pass
+
+ # test that sending to a closed connection
+ # will re-connect and send data to the socket
+ self.conn.close()
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ self.assertEqual(socket.create_connection.call_count, 1)
+ self.conn._sock.sendall.assert_called_with(self.config['payload'])