summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@rd.io>2015-12-17 17:29:54 -0800
committerDana Powers <dana.powers@rd.io>2015-12-17 23:22:35 -0800
commitf1ad0247df5bf6e0315ffbb1633d5979da828de0 (patch)
treeca96d1d960a13ae481b76fd32761ea535234f02b
parent799824535ceeb698152a3078f64ecbf6baca9b39 (diff)
downloadkafka-python-f1ad0247df5bf6e0315ffbb1633d5979da828de0.tar.gz
Switch BrokerConnection to (mostly) non-blocking IO.
- return kafka.Future on send() - recv is now non-blocking call that completes futures when possible - update KafkaClient to block on future completion
-rw-r--r--kafka/client.py90
-rw-r--r--kafka/cluster.py2
-rw-r--r--kafka/common.py16
-rw-r--r--kafka/conn.py299
-rw-r--r--kafka/future.py51
-rw-r--r--test/test_client.py82
6 files changed, 386 insertions, 154 deletions
diff --git a/kafka/client.py b/kafka/client.py
index b09927d..7f9969e 100644
--- a/kafka/client.py
+++ b/kafka/client.py
@@ -3,7 +3,6 @@ import copy
import functools
import logging
import random
-import select
import time
import six
@@ -15,7 +14,9 @@ from kafka.common import (TopicAndPartition, BrokerMetadata, UnknownError,
LeaderNotAvailableError, UnknownTopicOrPartitionError,
NotLeaderForPartitionError, ReplicaNotAvailableError)
-from kafka.conn import collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS
+from kafka.conn import (
+ collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS,
+ ConnectionStates)
from kafka.protocol import KafkaProtocol
@@ -45,7 +46,6 @@ class KafkaClient(object):
self.load_metadata_for_topics() # bootstrap with all metadata
-
##################
# Private API #
##################
@@ -56,11 +56,14 @@ class KafkaClient(object):
if host_key not in self._conns:
self._conns[host_key] = BrokerConnection(
host, port,
- timeout=self.timeout,
+ request_timeout_ms=self.timeout * 1000,
client_id=self.client_id
)
- return self._conns[host_key]
+ conn = self._conns[host_key]
+ while conn.connect() == ConnectionStates.CONNECTING:
+ pass
+ return conn
def _get_leader_for_partition(self, topic, partition):
"""
@@ -137,16 +140,23 @@ class KafkaClient(object):
for (host, port) in hosts:
conn = self._get_conn(host, port)
+ if not conn.connected():
+ log.warning("Skipping unconnected connection: %s", conn)
+ continue
request = encoder_fn(payloads=payloads)
- correlation_id = conn.send(request)
- if correlation_id is None:
+ future = conn.send(request)
+
+ # Block
+ while not future.is_done:
+ conn.recv()
+
+ if future.failed():
+ log.error("Request failed: %s", future.exception)
continue
- response = conn.recv()
- if response is not None:
- decoded = decoder_fn(response)
- return decoded
- raise KafkaUnavailableError('All servers failed to process request')
+ return decoder_fn(future.value)
+
+ raise KafkaUnavailableError('All servers failed to process request: %s' % hosts)
def _payloads_by_broker(self, payloads):
payloads_by_broker = collections.defaultdict(list)
@@ -204,55 +214,59 @@ class KafkaClient(object):
# For each BrokerConnection keep the real socket so that we can use
# a select to perform unblocking I/O
- connections_by_socket = {}
+ connections_by_future = {}
for broker, broker_payloads in six.iteritems(payloads_by_broker):
if broker is None:
failed_payloads(broker_payloads)
continue
conn = self._get_conn(broker.host, broker.port)
+ conn.connect()
+ if not conn.connected():
+ refresh_metadata = True
+ failed_payloads(broker_payloads)
+ continue
+
request = encoder_fn(payloads=broker_payloads)
# decoder_fn=None signal that the server is expected to not
# send a response. This probably only applies to
# ProduceRequest w/ acks = 0
expect_response = (decoder_fn is not None)
- correlation_id = conn.send(request, expect_response=expect_response)
+ future = conn.send(request, expect_response=expect_response)
- if correlation_id is None:
+ if future.failed():
refresh_metadata = True
failed_payloads(broker_payloads)
- log.warning('Error attempting to send request %s '
- 'to server %s', correlation_id, broker)
continue
if not expect_response:
- log.debug('Request %s does not expect a response '
- '(skipping conn.recv)', correlation_id)
for payload in broker_payloads:
topic_partition = (str(payload.topic), payload.partition)
responses[topic_partition] = None
continue
- connections_by_socket[conn._read_fd] = (conn, broker)
+ connections_by_future[future] = (conn, broker)
conn = None
- while connections_by_socket:
- sockets = connections_by_socket.keys()
- rlist, _, _ = select.select(sockets, [], [], None)
- conn, broker = connections_by_socket.pop(rlist[0])
- correlation_id = conn.next_correlation_id_recv()
- response = conn.recv()
- if response is None:
- refresh_metadata = True
- failed_payloads(payloads_by_broker[broker])
- log.warning('Error receiving response to request %s '
- 'from server %s', correlation_id, broker)
- continue
+ while connections_by_future:
+ futures = list(connections_by_future.keys())
+ for future in futures:
+
+ if not future.is_done:
+ conn, _ = connections_by_future[future]
+ conn.recv()
+ continue
- for payload_response in decoder_fn(response):
- topic_partition = (str(payload_response.topic),
- payload_response.partition)
- responses[topic_partition] = payload_response
+ _, broker = connections_by_future.pop(future)
+ if future.failed():
+ refresh_metadata = True
+ failed_payloads(payloads_by_broker[broker])
+
+ else:
+ for payload_response in decoder_fn(future.value):
+ topic_partition = (str(payload_response.topic),
+ payload_response.partition)
+ responses[topic_partition] = payload_response
if refresh_metadata:
self.reset_all_metadata()
@@ -392,7 +406,9 @@ class KafkaClient(object):
def reinit(self):
for conn in self._conns.values():
- conn.reinit()
+ conn.close()
+ while conn.connect() == ConnectionStates.CONNECTING:
+ pass
def reset_topic_metadata(self, *topics):
for topic in topics:
diff --git a/kafka/cluster.py b/kafka/cluster.py
index 55765dc..15921dc 100644
--- a/kafka/cluster.py
+++ b/kafka/cluster.py
@@ -73,7 +73,7 @@ class Cluster(object):
def _bootstrap(self, hosts, timeout=2):
for host, port in hosts:
- conn = BrokerConnection(host, port, timeout)
+ conn = BrokerConnection(host, port)
if not conn.connect():
continue
self._brokers['bootstrap'] = conn
diff --git a/kafka/common.py b/kafka/common.py
index 253137c..173fc82 100644
--- a/kafka/common.py
+++ b/kafka/common.py
@@ -93,6 +93,22 @@ class KafkaError(RuntimeError):
pass
+class IllegalStateError(KafkaError):
+ pass
+
+
+class RetriableError(KafkaError):
+ pass
+
+
+class DisconnectError(KafkaError):
+ pass
+
+
+class CorrelationIdError(KafkaError):
+ pass
+
+
class BrokerResponseError(KafkaError):
errno = None
message = None
diff --git a/kafka/conn.py b/kafka/conn.py
index d45b824..c2b8fb0 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -1,15 +1,20 @@
-from collections import deque
+import collections
import copy
+import errno
import logging
+import io
from random import shuffle
from select import select
import socket
import struct
from threading import local
+import time
import six
+import kafka.common as Errors
from kafka.common import ConnectionError
+from kafka.future import Future
from kafka.protocol.api import RequestHeader
from kafka.protocol.types import Int32
@@ -20,106 +25,244 @@ DEFAULT_SOCKET_TIMEOUT_SECONDS = 120
DEFAULT_KAFKA_PORT = 9092
-class BrokerConnection(local):
- def __init__(self, host, port, timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS,
- client_id='kafka-python-0.10.0', correlation_id=0):
- super(BrokerConnection, self).__init__()
+class ConnectionStates(object):
+ DISCONNECTED = 1
+ CONNECTING = 2
+ CONNECTED = 3
+
+
+InFlightRequest = collections.namedtuple('InFlightRequest',
+ ['request', 'response_type', 'correlation_id', 'future', 'timestamp'])
+
+
+class BrokerConnection(object):
+ _receive_buffer_bytes = 32768
+ _send_buffer_bytes = 32768
+ _client_id = 'kafka-python-0.10.0'
+ _correlation_id = 0
+ _request_timeout_ms = 40000
+
+ def __init__(self, host, port, **kwargs):
self.host = host
self.port = port
- self.timeout = timeout
- self._write_fd = None
- self._read_fd = None
- self.correlation_id = correlation_id
- self.client_id = client_id
- self.in_flight_requests = deque()
+ self.in_flight_requests = collections.deque()
+
+ for config in ('receive_buffer_bytes', 'send_buffer_bytes',
+ 'client_id', 'correlation_id', 'request_timeout_ms'):
+ if config in kwargs:
+ setattr(self, '_' + config, kwargs.pop(config))
+
+ self.state = ConnectionStates.DISCONNECTED
+ self._sock = None
+ self._rbuffer = io.BytesIO()
+ self._receiving = False
+ self._next_payload_bytes = 0
+ self._last_connection_attempt = None
+ self._last_connection_failure = None
def connect(self):
- if self.connected():
+ """Attempt to connect and return ConnectionState"""
+ if self.state is ConnectionStates.DISCONNECTED:
self.close()
- try:
- sock = socket.create_connection((self.host, self.port), self.timeout)
- self._write_fd = sock.makefile('wb')
- self._read_fd = sock.makefile('rb')
- except socket.error:
- log.exception("Error in BrokerConnection.connect()")
- return None
- self.in_flight_requests.clear()
- return True
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self._receive_buffer_bytes)
+ self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self._send_buffer_bytes)
+ self._sock.setblocking(False)
+ ret = self._sock.connect_ex((self.host, self.port))
+ self._last_connection_attempt = time.time()
+
+ if not ret or ret is errno.EISCONN:
+ self.state = ConnectionStates.CONNECTED
+ elif ret in (errno.EINPROGRESS, errno.EALREADY):
+ self.state = ConnectionStates.CONNECTING
+ else:
+ log.error('Connect attempt returned error %s. Disconnecting.', ret)
+ self.close()
+ self._last_connection_failure = time.time()
+
+ if self.state is ConnectionStates.CONNECTING:
+ # in non-blocking mode, use repeated calls to socket.connect_ex
+ # to check connection status
+ if time.time() > (self._request_timeout_ms / 1000.0) + self._last_connection_attempt:
+ log.error('Connection attempt timed out')
+ self.close() # error=TimeoutError ?
+ self._last_connection_failure = time.time()
+
+ ret = self._sock.connect_ex((self.host, self.port))
+ if not ret or ret is errno.EISCONN:
+ self.state = ConnectionStates.CONNECTED
+ elif ret is not errno.EALREADY:
+ log.error('Connect attempt returned error %s. Disconnecting.', ret)
+ self.close()
+ self._last_connection_failure = time.time()
+ return self.state
def connected(self):
- return (self._read_fd is not None and self._write_fd is not None)
+ return self.state is ConnectionStates.CONNECTED
- def close(self):
- if self.connected():
- try:
- self._read_fd.close()
- self._write_fd.close()
- except socket.error:
- log.exception("Error in BrokerConnection.close()")
- pass
- self._read_fd = None
- self._write_fd = None
+ def close(self, error=None):
+ if self._sock:
+ self._sock.close()
+ self._sock = None
+ self.state = ConnectionStates.DISCONNECTED
+
+ if error is None:
+ error = Errors.DisconnectError()
+ while self.in_flight_requests:
+ ifr = self.in_flight_requests.popleft()
+ ifr.future.failure(error)
self.in_flight_requests.clear()
+ self._receiving = False
+ self._next_payload_bytes = 0
+ self._rbuffer.seek(0)
+ self._rbuffer.truncate()
def send(self, request, expect_response=True):
- if not self.connected() and not self.connect():
- return None
- self.correlation_id += 1
+ """send request, return Future()
+
+ Can block on network if request is larger than send_buffer_bytes
+ """
+ future = Future()
+ if not self.connected():
+ return future.failure(Errors.DisconnectError())
+ self._correlation_id += 1
header = RequestHeader(request,
- correlation_id=self.correlation_id,
- client_id=self.client_id)
+ correlation_id=self._correlation_id,
+ client_id=self._client_id)
message = b''.join([header.encode(), request.encode()])
size = Int32.encode(len(message))
try:
- self._write_fd.write(size)
- self._write_fd.write(message)
- self._write_fd.flush()
- except socket.error:
- log.exception("Error in BrokerConnection.send(): %s", request)
- self.close()
- return None
+ # In the future we might manage an internal write buffer
+ # and send bytes asynchronously. For now, just block
+ # sending each request payload
+ self._sock.setblocking(True)
+ sent_bytes = self._sock.send(size)
+ assert sent_bytes == len(size)
+ sent_bytes = self._sock.send(message)
+ assert sent_bytes == len(message)
+ self._sock.setblocking(False)
+ except (AssertionError, socket.error) as e:
+ log.debug("Error in BrokerConnection.send(): %s", request)
+ self.close(error=e)
+ return future.failure(e)
+ log.debug('Request %d: %s', self._correlation_id, request)
+
if expect_response:
- self.in_flight_requests.append((self.correlation_id, request.RESPONSE_TYPE))
- log.debug('Request %d: %s', self.correlation_id, request)
- return self.correlation_id
+ ifr = InFlightRequest(request=request,
+ correlation_id=self._correlation_id,
+ response_type=request.RESPONSE_TYPE,
+ future=future,
+ timestamp=time.time())
+ self.in_flight_requests.append(ifr)
+ else:
+ future.success(None)
+
+ return future
+
+ def recv(self, timeout=0):
+ """Non-blocking network receive
- def recv(self, timeout=None):
+ Return response if available
+ """
if not self.connected():
+ log.warning('Cannot recv: socket not connected')
+ # If requests are pending, we should close the socket and
+ # fail all the pending request futures
+ if self.in_flight_requests:
+ self.close()
return None
- readable, _, _ = select([self._read_fd], [], [], timeout)
- if not readable:
- return None
+
if not self.in_flight_requests:
log.warning('No in-flight-requests to recv')
return None
- correlation_id, response_type = self.in_flight_requests.popleft()
- # Current implementation does not use size
- # instead we read directly from the socket fd buffer
- # alternatively, we could read size bytes into a separate buffer
- # and decode from that buffer (and verify buffer is empty afterwards)
- try:
- size = Int32.decode(self._read_fd)
- recv_correlation_id = Int32.decode(self._read_fd)
- if correlation_id != recv_correlation_id:
- raise RuntimeError('Correlation ids do not match!')
- response = response_type.decode(self._read_fd)
- except (RuntimeError, socket.error, struct.error):
- log.exception("Error in BrokerConnection.recv() for request %d", correlation_id)
- self.close()
+
+ self._fail_timed_out_requests()
+
+ readable, _, _ = select([self._sock], [], [], timeout)
+ if not readable:
return None
- log.debug('Response %d: %s', correlation_id, response)
- return response
- def next_correlation_id_recv(self):
- if len(self.in_flight_requests) == 0:
+ # Not receiving is the state of reading the payload header
+ if not self._receiving:
+ try:
+ # An extremely small, but non-zero, probability that there are
+ # more than 0 but not yet 4 bytes available to read
+ self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell()))
+ except socket.error as e:
+ if e.errno == errno.EWOULDBLOCK:
+ # This shouldn't happen after selecting above
+ # but just in case
+ return None
+ log.exception("Error receiving 4-byte payload header - closing socket")
+ self.close(error=e)
+ return None
+
+ if self._rbuffer.tell() == 4:
+ self._rbuffer.seek(0)
+ self._next_payload_bytes = Int32.decode(self._rbuffer)
+ # reset buffer and switch state to receiving payload bytes
+ self._rbuffer.seek(0)
+ self._rbuffer.truncate()
+ self._receiving = True
+ elif self._rbuffer.tell() > 4:
+ raise Errors.KafkaError('this should not happen - are you threading?')
+
+ if self._receiving:
+ staged_bytes = self._rbuffer.tell()
+ try:
+ self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes))
+ except socket.error as e:
+ # Extremely small chance that we have exactly 4 bytes for a
+ # header, but nothing to read in the body yet
+ if e.errno == errno.EWOULDBLOCK:
+ return None
+ log.exception()
+ self.close(error=e)
+ return None
+
+ staged_bytes = self._rbuffer.tell()
+ if staged_bytes > self._next_payload_bytes:
+ self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
+
+ if staged_bytes != self._next_payload_bytes:
+ return None
+
+ self._receiving = False
+ self._next_payload_bytes = 0
+ self._rbuffer.seek(0)
+ response = self._process_response(self._rbuffer)
+ self._rbuffer.seek(0)
+ self._rbuffer.truncate()
+ return response
+
+ def _process_response(self, read_buffer):
+ ifr = self.in_flight_requests.popleft()
+
+ # verify send/recv correlation ids match
+ recv_correlation_id = Int32.decode(read_buffer)
+ if ifr.correlation_id != recv_correlation_id:
+ error = Errors.CorrelationIdError(
+ 'Correlation ids do not match: sent %d, recv %d'
+ % (ifr.correlation_id, recv_correlation_id))
+ ifr.future.fail(error)
+ self.close()
return None
- return self.in_flight_requests[0][0]
- def next_correlation_id_send(self):
- return self.correlation_id + 1
+ # decode response
+ response = ifr.response_type.decode(read_buffer)
+ ifr.future.success(response)
+ log.debug('Response %d: %s', ifr.correlation_id, response)
+ return response
- def __getnewargs__(self):
- return (self.host, self.port, self.timeout)
+ def _fail_timed_out_requests(self):
+ now = time.time()
+ while self.in_flight_requests:
+ next_timeout = self.in_flight_requests[0].timestamp + (self._request_timeout_ms / 1000.0)
+ if now < next_timeout:
+ break
+ timed_out = self.in_flight_requests.popleft()
+ error = Errors.RequestTimedOutError('Request timed out after %s ms' % self._request_timeout_ms)
+ timed_out.future.failure(error)
def __repr__(self):
return "<BrokerConnection host=%s port=%d>" % (self.host, self.port)
@@ -149,13 +292,7 @@ def collect_hosts(hosts, randomize=True):
class KafkaConnection(local):
- """
- A socket connection to a single Kafka broker
-
- This class is _not_ thread safe. Each call to `send` must be followed
- by a call to `recv` in order to get the correct response. Eventually,
- we can do something in here to facilitate multiplexed requests/responses
- since the Kafka API includes a correlation id.
+ """A socket connection to a single Kafka broker
Arguments:
host: the host name or IP address of a kafka broker
diff --git a/kafka/future.py b/kafka/future.py
new file mode 100644
index 0000000..24173bb
--- /dev/null
+++ b/kafka/future.py
@@ -0,0 +1,51 @@
+from kafka.common import RetriableError, IllegalStateError
+
+
+class Future(object):
+ def __init__(self):
+ self.is_done = False
+ self.value = None
+ self.exception = None
+ self._callbacks = []
+ self._errbacks = []
+
+ def succeeded(self):
+ return self.is_done and not self.exception
+
+ def failed(self):
+ return self.is_done and self.exception
+
+ def retriable(self):
+ return isinstance(self.exception, RetriableError)
+
+ def success(self, value):
+ if self.is_done:
+ raise IllegalStateError('Invalid attempt to complete a request future which is already complete')
+ self.value = value
+ self.is_done = True
+ for f in self._callbacks:
+ f(value)
+ return self
+
+ def failure(self, e):
+ if self.is_done:
+ raise IllegalStateError('Invalid attempt to complete a request future which is already complete')
+ self.exception = e
+ self.is_done = True
+ for f in self._errbacks:
+ f(e)
+ return self
+
+ def add_callback(self, f):
+ if self.is_done and not self.exception:
+ f(self.value)
+ else:
+ self._callbacks.append(f)
+ return self
+
+ def add_errback(self, f):
+ if self.is_done and self.exception:
+ f(self.exception)
+ else:
+ self._errbacks.append(f)
+ return self
diff --git a/test/test_client.py b/test/test_client.py
index dd8948f..00e888c 100644
--- a/test/test_client.py
+++ b/test/test_client.py
@@ -14,6 +14,7 @@ from kafka.common import (
KafkaTimeoutError, ConnectionError
)
from kafka.conn import KafkaConnection
+from kafka.future import Future
from kafka.protocol import KafkaProtocol, create_message
from kafka.protocol.metadata import MetadataResponse
@@ -23,6 +24,17 @@ NO_ERROR = 0
UNKNOWN_TOPIC_OR_PARTITION = 3
NO_LEADER = 5
+
+def mock_conn(conn, success=True):
+ mocked = MagicMock()
+ mocked.connected.return_value = True
+ if success:
+ mocked.send.return_value = Future().success(True)
+ else:
+ mocked.send.return_value = Future().failure(Exception())
+ conn.return_value = mocked
+
+
class TestKafkaClient(unittest.TestCase):
def test_init_with_list(self):
with patch.object(KafkaClient, 'load_metadata_for_topics'):
@@ -48,32 +60,30 @@ class TestKafkaClient(unittest.TestCase):
sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]),
sorted(client.hosts))
- def test_send_broker_unaware_request_fail(self):
+ @patch.object(KafkaClient, '_get_conn')
+ @patch.object(KafkaClient, 'load_metadata_for_topics')
+ def test_send_broker_unaware_request_fail(self, load_metadata, conn):
mocked_conns = {
('kafka01', 9092): MagicMock(),
('kafka02', 9092): MagicMock()
}
-
- # inject KafkaConnection side effects
- mocked_conns[('kafka01', 9092)].send.return_value = None
- mocked_conns[('kafka02', 9092)].send.return_value = None
+ for val in mocked_conns.values():
+ mock_conn(val, success=False)
def mock_get_conn(host, port):
return mocked_conns[(host, port)]
+ conn.side_effect = mock_get_conn
- # patch to avoid making requests before we want it
- with patch.object(KafkaClient, 'load_metadata_for_topics'):
- with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn):
- client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092'])
+ client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092'])
- req = KafkaProtocol.encode_metadata_request()
- with self.assertRaises(KafkaUnavailableError):
- client._send_broker_unaware_request(payloads=['fake request'],
- encoder_fn=MagicMock(return_value='fake encoded message'),
- decoder_fn=lambda x: x)
+ req = KafkaProtocol.encode_metadata_request()
+ with self.assertRaises(KafkaUnavailableError):
+ client._send_broker_unaware_request(payloads=['fake request'],
+ encoder_fn=MagicMock(return_value='fake encoded message'),
+ decoder_fn=lambda x: x)
- for key, conn in six.iteritems(mocked_conns):
- conn.send.assert_called_with('fake encoded message')
+ for key, conn in six.iteritems(mocked_conns):
+ conn.send.assert_called_with('fake encoded message')
def test_send_broker_unaware_request(self):
mocked_conns = {
@@ -82,9 +92,11 @@ class TestKafkaClient(unittest.TestCase):
('kafka03', 9092): MagicMock()
}
# inject KafkaConnection side effects
- mocked_conns[('kafka01', 9092)].send.return_value = None
- mocked_conns[('kafka02', 9092)].recv.return_value = 'valid response'
- mocked_conns[('kafka03', 9092)].send.return_value = None
+ mock_conn(mocked_conns[('kafka01', 9092)], success=False)
+ mock_conn(mocked_conns[('kafka03', 9092)], success=False)
+ future = Future()
+ mocked_conns[('kafka02', 9092)].send.return_value = future
+ mocked_conns[('kafka02', 9092)].recv.side_effect = lambda: future.success('valid response')
def mock_get_conn(host, port):
return mocked_conns[(host, port)]
@@ -101,11 +113,11 @@ class TestKafkaClient(unittest.TestCase):
self.assertEqual('valid response', resp)
mocked_conns[('kafka02', 9092)].recv.assert_called_once_with()
- @patch('kafka.client.BrokerConnection')
+ @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol')
def test_load_metadata(self, protocol, conn):
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),
@@ -151,11 +163,11 @@ class TestKafkaClient(unittest.TestCase):
# This should not raise
client.load_metadata_for_topics('topic_no_leader')
- @patch('kafka.client.BrokerConnection')
+ @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol')
def test_has_metadata_for_topic(self, protocol, conn):
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),
@@ -181,11 +193,11 @@ class TestKafkaClient(unittest.TestCase):
# Topic with partition metadata, but no leaders return True
self.assertTrue(client.has_metadata_for_topic('topic_noleaders'))
- @patch('kafka.client.BrokerConnection')
+ @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol.decode_metadata_response')
def test_ensure_topic_exists(self, decode_metadata_response, conn):
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),
@@ -213,12 +225,12 @@ class TestKafkaClient(unittest.TestCase):
# This should not raise
client.ensure_topic_exists('topic_noleaders', timeout=1)
- @patch('kafka.client.BrokerConnection')
+ @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol')
def test_get_leader_for_partitions_reloads_metadata(self, protocol, conn):
"Get leader for partitions reload metadata if it is not available"
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),
@@ -251,11 +263,11 @@ class TestKafkaClient(unittest.TestCase):
TopicAndPartition('topic_one_partition', 0): brokers[0]},
client.topics_to_brokers)
- @patch('kafka.client.BrokerConnection')
+ @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol')
def test_get_leader_for_unassigned_partitions(self, protocol, conn):
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),
@@ -278,11 +290,11 @@ class TestKafkaClient(unittest.TestCase):
with self.assertRaises(UnknownTopicOrPartitionError):
client._get_leader_for_partition('topic_unknown', 0)
- @patch('kafka.client.BrokerConnection')
+ @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol')
def test_get_leader_exceptions_when_noleader(self, protocol, conn):
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),
@@ -325,10 +337,10 @@ class TestKafkaClient(unittest.TestCase):
self.assertEqual(brokers[0], client._get_leader_for_partition('topic_noleader', 0))
self.assertEqual(brokers[1], client._get_leader_for_partition('topic_noleader', 1))
- @patch('kafka.client.BrokerConnection')
+ @patch.object(KafkaClient, '_get_conn')
@patch('kafka.client.KafkaProtocol')
def test_send_produce_request_raises_when_noleader(self, protocol, conn):
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),
@@ -352,11 +364,11 @@ class TestKafkaClient(unittest.TestCase):
with self.assertRaises(LeaderNotAvailableError):
client.send_produce_request(requests)
- @patch('kafka.client.BrokerConnection')
+ @patch('kafka.client.KafkaClient._get_conn')
@patch('kafka.client.KafkaProtocol')
def test_send_produce_request_raises_when_topic_unknown(self, protocol, conn):
- conn.recv.return_value = 'response' # anything but None
+ mock_conn(conn)
brokers = [
BrokerMetadata(0, 'broker_1', 4567),