diff options
author | Dana Powers <dana.powers@rd.io> | 2015-12-17 17:29:54 -0800 |
---|---|---|
committer | Dana Powers <dana.powers@rd.io> | 2015-12-17 23:22:35 -0800 |
commit | f1ad0247df5bf6e0315ffbb1633d5979da828de0 (patch) | |
tree | ca96d1d960a13ae481b76fd32761ea535234f02b | |
parent | 799824535ceeb698152a3078f64ecbf6baca9b39 (diff) | |
download | kafka-python-f1ad0247df5bf6e0315ffbb1633d5979da828de0.tar.gz |
Switch BrokerConnection to (mostly) non-blocking IO.
- return kafka.Future on send()
- recv is now non-blocking call that completes futures when possible
- update KafkaClient to block on future completion
-rw-r--r-- | kafka/client.py | 90 | ||||
-rw-r--r-- | kafka/cluster.py | 2 | ||||
-rw-r--r-- | kafka/common.py | 16 | ||||
-rw-r--r-- | kafka/conn.py | 299 | ||||
-rw-r--r-- | kafka/future.py | 51 | ||||
-rw-r--r-- | test/test_client.py | 82 |
6 files changed, 386 insertions, 154 deletions
diff --git a/kafka/client.py b/kafka/client.py index b09927d..7f9969e 100644 --- a/kafka/client.py +++ b/kafka/client.py @@ -3,7 +3,6 @@ import copy import functools import logging import random -import select import time import six @@ -15,7 +14,9 @@ from kafka.common import (TopicAndPartition, BrokerMetadata, UnknownError, LeaderNotAvailableError, UnknownTopicOrPartitionError, NotLeaderForPartitionError, ReplicaNotAvailableError) -from kafka.conn import collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS +from kafka.conn import ( + collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS, + ConnectionStates) from kafka.protocol import KafkaProtocol @@ -45,7 +46,6 @@ class KafkaClient(object): self.load_metadata_for_topics() # bootstrap with all metadata - ################## # Private API # ################## @@ -56,11 +56,14 @@ class KafkaClient(object): if host_key not in self._conns: self._conns[host_key] = BrokerConnection( host, port, - timeout=self.timeout, + request_timeout_ms=self.timeout * 1000, client_id=self.client_id ) - return self._conns[host_key] + conn = self._conns[host_key] + while conn.connect() == ConnectionStates.CONNECTING: + pass + return conn def _get_leader_for_partition(self, topic, partition): """ @@ -137,16 +140,23 @@ class KafkaClient(object): for (host, port) in hosts: conn = self._get_conn(host, port) + if not conn.connected(): + log.warning("Skipping unconnected connection: %s", conn) + continue request = encoder_fn(payloads=payloads) - correlation_id = conn.send(request) - if correlation_id is None: + future = conn.send(request) + + # Block + while not future.is_done: + conn.recv() + + if future.failed(): + log.error("Request failed: %s", future.exception) continue - response = conn.recv() - if response is not None: - decoded = decoder_fn(response) - return decoded - raise KafkaUnavailableError('All servers failed to process request') + return decoder_fn(future.value) + + raise KafkaUnavailableError('All servers failed to process request: %s' % hosts) def _payloads_by_broker(self, payloads): payloads_by_broker = collections.defaultdict(list) @@ -204,55 +214,59 @@ class KafkaClient(object): # For each BrokerConnection keep the real socket so that we can use # a select to perform unblocking I/O - connections_by_socket = {} + connections_by_future = {} for broker, broker_payloads in six.iteritems(payloads_by_broker): if broker is None: failed_payloads(broker_payloads) continue conn = self._get_conn(broker.host, broker.port) + conn.connect() + if not conn.connected(): + refresh_metadata = True + failed_payloads(broker_payloads) + continue + request = encoder_fn(payloads=broker_payloads) # decoder_fn=None signal that the server is expected to not # send a response. This probably only applies to # ProduceRequest w/ acks = 0 expect_response = (decoder_fn is not None) - correlation_id = conn.send(request, expect_response=expect_response) + future = conn.send(request, expect_response=expect_response) - if correlation_id is None: + if future.failed(): refresh_metadata = True failed_payloads(broker_payloads) - log.warning('Error attempting to send request %s ' - 'to server %s', correlation_id, broker) continue if not expect_response: - log.debug('Request %s does not expect a response ' - '(skipping conn.recv)', correlation_id) for payload in broker_payloads: topic_partition = (str(payload.topic), payload.partition) responses[topic_partition] = None continue - connections_by_socket[conn._read_fd] = (conn, broker) + connections_by_future[future] = (conn, broker) conn = None - while connections_by_socket: - sockets = connections_by_socket.keys() - rlist, _, _ = select.select(sockets, [], [], None) - conn, broker = connections_by_socket.pop(rlist[0]) - correlation_id = conn.next_correlation_id_recv() - response = conn.recv() - if response is None: - refresh_metadata = True - failed_payloads(payloads_by_broker[broker]) - log.warning('Error receiving response to request %s ' - 'from server %s', correlation_id, broker) - continue + while connections_by_future: + futures = list(connections_by_future.keys()) + for future in futures: + + if not future.is_done: + conn, _ = connections_by_future[future] + conn.recv() + continue - for payload_response in decoder_fn(response): - topic_partition = (str(payload_response.topic), - payload_response.partition) - responses[topic_partition] = payload_response + _, broker = connections_by_future.pop(future) + if future.failed(): + refresh_metadata = True + failed_payloads(payloads_by_broker[broker]) + + else: + for payload_response in decoder_fn(future.value): + topic_partition = (str(payload_response.topic), + payload_response.partition) + responses[topic_partition] = payload_response if refresh_metadata: self.reset_all_metadata() @@ -392,7 +406,9 @@ class KafkaClient(object): def reinit(self): for conn in self._conns.values(): - conn.reinit() + conn.close() + while conn.connect() == ConnectionStates.CONNECTING: + pass def reset_topic_metadata(self, *topics): for topic in topics: diff --git a/kafka/cluster.py b/kafka/cluster.py index 55765dc..15921dc 100644 --- a/kafka/cluster.py +++ b/kafka/cluster.py @@ -73,7 +73,7 @@ class Cluster(object): def _bootstrap(self, hosts, timeout=2): for host, port in hosts: - conn = BrokerConnection(host, port, timeout) + conn = BrokerConnection(host, port) if not conn.connect(): continue self._brokers['bootstrap'] = conn diff --git a/kafka/common.py b/kafka/common.py index 253137c..173fc82 100644 --- a/kafka/common.py +++ b/kafka/common.py @@ -93,6 +93,22 @@ class KafkaError(RuntimeError): pass +class IllegalStateError(KafkaError): + pass + + +class RetriableError(KafkaError): + pass + + +class DisconnectError(KafkaError): + pass + + +class CorrelationIdError(KafkaError): + pass + + class BrokerResponseError(KafkaError): errno = None message = None diff --git a/kafka/conn.py b/kafka/conn.py index d45b824..c2b8fb0 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -1,15 +1,20 @@ -from collections import deque +import collections import copy +import errno import logging +import io from random import shuffle from select import select import socket import struct from threading import local +import time import six +import kafka.common as Errors from kafka.common import ConnectionError +from kafka.future import Future from kafka.protocol.api import RequestHeader from kafka.protocol.types import Int32 @@ -20,106 +25,244 @@ DEFAULT_SOCKET_TIMEOUT_SECONDS = 120 DEFAULT_KAFKA_PORT = 9092 -class BrokerConnection(local): - def __init__(self, host, port, timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS, - client_id='kafka-python-0.10.0', correlation_id=0): - super(BrokerConnection, self).__init__() +class ConnectionStates(object): + DISCONNECTED = 1 + CONNECTING = 2 + CONNECTED = 3 + + +InFlightRequest = collections.namedtuple('InFlightRequest', + ['request', 'response_type', 'correlation_id', 'future', 'timestamp']) + + +class BrokerConnection(object): + _receive_buffer_bytes = 32768 + _send_buffer_bytes = 32768 + _client_id = 'kafka-python-0.10.0' + _correlation_id = 0 + _request_timeout_ms = 40000 + + def __init__(self, host, port, **kwargs): self.host = host self.port = port - self.timeout = timeout - self._write_fd = None - self._read_fd = None - self.correlation_id = correlation_id - self.client_id = client_id - self.in_flight_requests = deque() + self.in_flight_requests = collections.deque() + + for config in ('receive_buffer_bytes', 'send_buffer_bytes', + 'client_id', 'correlation_id', 'request_timeout_ms'): + if config in kwargs: + setattr(self, '_' + config, kwargs.pop(config)) + + self.state = ConnectionStates.DISCONNECTED + self._sock = None + self._rbuffer = io.BytesIO() + self._receiving = False + self._next_payload_bytes = 0 + self._last_connection_attempt = None + self._last_connection_failure = None def connect(self): - if self.connected(): + """Attempt to connect and return ConnectionState""" + if self.state is ConnectionStates.DISCONNECTED: self.close() - try: - sock = socket.create_connection((self.host, self.port), self.timeout) - self._write_fd = sock.makefile('wb') - self._read_fd = sock.makefile('rb') - except socket.error: - log.exception("Error in BrokerConnection.connect()") - return None - self.in_flight_requests.clear() - return True + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self._receive_buffer_bytes) + self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self._send_buffer_bytes) + self._sock.setblocking(False) + ret = self._sock.connect_ex((self.host, self.port)) + self._last_connection_attempt = time.time() + + if not ret or ret is errno.EISCONN: + self.state = ConnectionStates.CONNECTED + elif ret in (errno.EINPROGRESS, errno.EALREADY): + self.state = ConnectionStates.CONNECTING + else: + log.error('Connect attempt returned error %s. Disconnecting.', ret) + self.close() + self._last_connection_failure = time.time() + + if self.state is ConnectionStates.CONNECTING: + # in non-blocking mode, use repeated calls to socket.connect_ex + # to check connection status + if time.time() > (self._request_timeout_ms / 1000.0) + self._last_connection_attempt: + log.error('Connection attempt timed out') + self.close() # error=TimeoutError ? + self._last_connection_failure = time.time() + + ret = self._sock.connect_ex((self.host, self.port)) + if not ret or ret is errno.EISCONN: + self.state = ConnectionStates.CONNECTED + elif ret is not errno.EALREADY: + log.error('Connect attempt returned error %s. Disconnecting.', ret) + self.close() + self._last_connection_failure = time.time() + return self.state def connected(self): - return (self._read_fd is not None and self._write_fd is not None) + return self.state is ConnectionStates.CONNECTED - def close(self): - if self.connected(): - try: - self._read_fd.close() - self._write_fd.close() - except socket.error: - log.exception("Error in BrokerConnection.close()") - pass - self._read_fd = None - self._write_fd = None + def close(self, error=None): + if self._sock: + self._sock.close() + self._sock = None + self.state = ConnectionStates.DISCONNECTED + + if error is None: + error = Errors.DisconnectError() + while self.in_flight_requests: + ifr = self.in_flight_requests.popleft() + ifr.future.failure(error) self.in_flight_requests.clear() + self._receiving = False + self._next_payload_bytes = 0 + self._rbuffer.seek(0) + self._rbuffer.truncate() def send(self, request, expect_response=True): - if not self.connected() and not self.connect(): - return None - self.correlation_id += 1 + """send request, return Future() + + Can block on network if request is larger than send_buffer_bytes + """ + future = Future() + if not self.connected(): + return future.failure(Errors.DisconnectError()) + self._correlation_id += 1 header = RequestHeader(request, - correlation_id=self.correlation_id, - client_id=self.client_id) + correlation_id=self._correlation_id, + client_id=self._client_id) message = b''.join([header.encode(), request.encode()]) size = Int32.encode(len(message)) try: - self._write_fd.write(size) - self._write_fd.write(message) - self._write_fd.flush() - except socket.error: - log.exception("Error in BrokerConnection.send(): %s", request) - self.close() - return None + # In the future we might manage an internal write buffer + # and send bytes asynchronously. For now, just block + # sending each request payload + self._sock.setblocking(True) + sent_bytes = self._sock.send(size) + assert sent_bytes == len(size) + sent_bytes = self._sock.send(message) + assert sent_bytes == len(message) + self._sock.setblocking(False) + except (AssertionError, socket.error) as e: + log.debug("Error in BrokerConnection.send(): %s", request) + self.close(error=e) + return future.failure(e) + log.debug('Request %d: %s', self._correlation_id, request) + if expect_response: - self.in_flight_requests.append((self.correlation_id, request.RESPONSE_TYPE)) - log.debug('Request %d: %s', self.correlation_id, request) - return self.correlation_id + ifr = InFlightRequest(request=request, + correlation_id=self._correlation_id, + response_type=request.RESPONSE_TYPE, + future=future, + timestamp=time.time()) + self.in_flight_requests.append(ifr) + else: + future.success(None) + + return future + + def recv(self, timeout=0): + """Non-blocking network receive - def recv(self, timeout=None): + Return response if available + """ if not self.connected(): + log.warning('Cannot recv: socket not connected') + # If requests are pending, we should close the socket and + # fail all the pending request futures + if self.in_flight_requests: + self.close() return None - readable, _, _ = select([self._read_fd], [], [], timeout) - if not readable: - return None + if not self.in_flight_requests: log.warning('No in-flight-requests to recv') return None - correlation_id, response_type = self.in_flight_requests.popleft() - # Current implementation does not use size - # instead we read directly from the socket fd buffer - # alternatively, we could read size bytes into a separate buffer - # and decode from that buffer (and verify buffer is empty afterwards) - try: - size = Int32.decode(self._read_fd) - recv_correlation_id = Int32.decode(self._read_fd) - if correlation_id != recv_correlation_id: - raise RuntimeError('Correlation ids do not match!') - response = response_type.decode(self._read_fd) - except (RuntimeError, socket.error, struct.error): - log.exception("Error in BrokerConnection.recv() for request %d", correlation_id) - self.close() + + self._fail_timed_out_requests() + + readable, _, _ = select([self._sock], [], [], timeout) + if not readable: return None - log.debug('Response %d: %s', correlation_id, response) - return response - def next_correlation_id_recv(self): - if len(self.in_flight_requests) == 0: + # Not receiving is the state of reading the payload header + if not self._receiving: + try: + # An extremely small, but non-zero, probability that there are + # more than 0 but not yet 4 bytes available to read + self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell())) + except socket.error as e: + if e.errno == errno.EWOULDBLOCK: + # This shouldn't happen after selecting above + # but just in case + return None + log.exception("Error receiving 4-byte payload header - closing socket") + self.close(error=e) + return None + + if self._rbuffer.tell() == 4: + self._rbuffer.seek(0) + self._next_payload_bytes = Int32.decode(self._rbuffer) + # reset buffer and switch state to receiving payload bytes + self._rbuffer.seek(0) + self._rbuffer.truncate() + self._receiving = True + elif self._rbuffer.tell() > 4: + raise Errors.KafkaError('this should not happen - are you threading?') + + if self._receiving: + staged_bytes = self._rbuffer.tell() + try: + self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes)) + except socket.error as e: + # Extremely small chance that we have exactly 4 bytes for a + # header, but nothing to read in the body yet + if e.errno == errno.EWOULDBLOCK: + return None + log.exception() + self.close(error=e) + return None + + staged_bytes = self._rbuffer.tell() + if staged_bytes > self._next_payload_bytes: + self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?')) + + if staged_bytes != self._next_payload_bytes: + return None + + self._receiving = False + self._next_payload_bytes = 0 + self._rbuffer.seek(0) + response = self._process_response(self._rbuffer) + self._rbuffer.seek(0) + self._rbuffer.truncate() + return response + + def _process_response(self, read_buffer): + ifr = self.in_flight_requests.popleft() + + # verify send/recv correlation ids match + recv_correlation_id = Int32.decode(read_buffer) + if ifr.correlation_id != recv_correlation_id: + error = Errors.CorrelationIdError( + 'Correlation ids do not match: sent %d, recv %d' + % (ifr.correlation_id, recv_correlation_id)) + ifr.future.fail(error) + self.close() return None - return self.in_flight_requests[0][0] - def next_correlation_id_send(self): - return self.correlation_id + 1 + # decode response + response = ifr.response_type.decode(read_buffer) + ifr.future.success(response) + log.debug('Response %d: %s', ifr.correlation_id, response) + return response - def __getnewargs__(self): - return (self.host, self.port, self.timeout) + def _fail_timed_out_requests(self): + now = time.time() + while self.in_flight_requests: + next_timeout = self.in_flight_requests[0].timestamp + (self._request_timeout_ms / 1000.0) + if now < next_timeout: + break + timed_out = self.in_flight_requests.popleft() + error = Errors.RequestTimedOutError('Request timed out after %s ms' % self._request_timeout_ms) + timed_out.future.failure(error) def __repr__(self): return "<BrokerConnection host=%s port=%d>" % (self.host, self.port) @@ -149,13 +292,7 @@ def collect_hosts(hosts, randomize=True): class KafkaConnection(local): - """ - A socket connection to a single Kafka broker - - This class is _not_ thread safe. Each call to `send` must be followed - by a call to `recv` in order to get the correct response. Eventually, - we can do something in here to facilitate multiplexed requests/responses - since the Kafka API includes a correlation id. + """A socket connection to a single Kafka broker Arguments: host: the host name or IP address of a kafka broker diff --git a/kafka/future.py b/kafka/future.py new file mode 100644 index 0000000..24173bb --- /dev/null +++ b/kafka/future.py @@ -0,0 +1,51 @@ +from kafka.common import RetriableError, IllegalStateError + + +class Future(object): + def __init__(self): + self.is_done = False + self.value = None + self.exception = None + self._callbacks = [] + self._errbacks = [] + + def succeeded(self): + return self.is_done and not self.exception + + def failed(self): + return self.is_done and self.exception + + def retriable(self): + return isinstance(self.exception, RetriableError) + + def success(self, value): + if self.is_done: + raise IllegalStateError('Invalid attempt to complete a request future which is already complete') + self.value = value + self.is_done = True + for f in self._callbacks: + f(value) + return self + + def failure(self, e): + if self.is_done: + raise IllegalStateError('Invalid attempt to complete a request future which is already complete') + self.exception = e + self.is_done = True + for f in self._errbacks: + f(e) + return self + + def add_callback(self, f): + if self.is_done and not self.exception: + f(self.value) + else: + self._callbacks.append(f) + return self + + def add_errback(self, f): + if self.is_done and self.exception: + f(self.exception) + else: + self._errbacks.append(f) + return self diff --git a/test/test_client.py b/test/test_client.py index dd8948f..00e888c 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -14,6 +14,7 @@ from kafka.common import ( KafkaTimeoutError, ConnectionError ) from kafka.conn import KafkaConnection +from kafka.future import Future from kafka.protocol import KafkaProtocol, create_message from kafka.protocol.metadata import MetadataResponse @@ -23,6 +24,17 @@ NO_ERROR = 0 UNKNOWN_TOPIC_OR_PARTITION = 3 NO_LEADER = 5 + +def mock_conn(conn, success=True): + mocked = MagicMock() + mocked.connected.return_value = True + if success: + mocked.send.return_value = Future().success(True) + else: + mocked.send.return_value = Future().failure(Exception()) + conn.return_value = mocked + + class TestKafkaClient(unittest.TestCase): def test_init_with_list(self): with patch.object(KafkaClient, 'load_metadata_for_topics'): @@ -48,32 +60,30 @@ class TestKafkaClient(unittest.TestCase): sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]), sorted(client.hosts)) - def test_send_broker_unaware_request_fail(self): + @patch.object(KafkaClient, '_get_conn') + @patch.object(KafkaClient, 'load_metadata_for_topics') + def test_send_broker_unaware_request_fail(self, load_metadata, conn): mocked_conns = { ('kafka01', 9092): MagicMock(), ('kafka02', 9092): MagicMock() } - - # inject KafkaConnection side effects - mocked_conns[('kafka01', 9092)].send.return_value = None - mocked_conns[('kafka02', 9092)].send.return_value = None + for val in mocked_conns.values(): + mock_conn(val, success=False) def mock_get_conn(host, port): return mocked_conns[(host, port)] + conn.side_effect = mock_get_conn - # patch to avoid making requests before we want it - with patch.object(KafkaClient, 'load_metadata_for_topics'): - with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn): - client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092']) + client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092']) - req = KafkaProtocol.encode_metadata_request() - with self.assertRaises(KafkaUnavailableError): - client._send_broker_unaware_request(payloads=['fake request'], - encoder_fn=MagicMock(return_value='fake encoded message'), - decoder_fn=lambda x: x) + req = KafkaProtocol.encode_metadata_request() + with self.assertRaises(KafkaUnavailableError): + client._send_broker_unaware_request(payloads=['fake request'], + encoder_fn=MagicMock(return_value='fake encoded message'), + decoder_fn=lambda x: x) - for key, conn in six.iteritems(mocked_conns): - conn.send.assert_called_with('fake encoded message') + for key, conn in six.iteritems(mocked_conns): + conn.send.assert_called_with('fake encoded message') def test_send_broker_unaware_request(self): mocked_conns = { @@ -82,9 +92,11 @@ class TestKafkaClient(unittest.TestCase): ('kafka03', 9092): MagicMock() } # inject KafkaConnection side effects - mocked_conns[('kafka01', 9092)].send.return_value = None - mocked_conns[('kafka02', 9092)].recv.return_value = 'valid response' - mocked_conns[('kafka03', 9092)].send.return_value = None + mock_conn(mocked_conns[('kafka01', 9092)], success=False) + mock_conn(mocked_conns[('kafka03', 9092)], success=False) + future = Future() + mocked_conns[('kafka02', 9092)].send.return_value = future + mocked_conns[('kafka02', 9092)].recv.side_effect = lambda: future.success('valid response') def mock_get_conn(host, port): return mocked_conns[(host, port)] @@ -101,11 +113,11 @@ class TestKafkaClient(unittest.TestCase): self.assertEqual('valid response', resp) mocked_conns[('kafka02', 9092)].recv.assert_called_once_with() - @patch('kafka.client.BrokerConnection') + @patch('kafka.client.KafkaClient._get_conn') @patch('kafka.client.KafkaProtocol') def test_load_metadata(self, protocol, conn): - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), @@ -151,11 +163,11 @@ class TestKafkaClient(unittest.TestCase): # This should not raise client.load_metadata_for_topics('topic_no_leader') - @patch('kafka.client.BrokerConnection') + @patch('kafka.client.KafkaClient._get_conn') @patch('kafka.client.KafkaProtocol') def test_has_metadata_for_topic(self, protocol, conn): - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), @@ -181,11 +193,11 @@ class TestKafkaClient(unittest.TestCase): # Topic with partition metadata, but no leaders return True self.assertTrue(client.has_metadata_for_topic('topic_noleaders')) - @patch('kafka.client.BrokerConnection') + @patch('kafka.client.KafkaClient._get_conn') @patch('kafka.client.KafkaProtocol.decode_metadata_response') def test_ensure_topic_exists(self, decode_metadata_response, conn): - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), @@ -213,12 +225,12 @@ class TestKafkaClient(unittest.TestCase): # This should not raise client.ensure_topic_exists('topic_noleaders', timeout=1) - @patch('kafka.client.BrokerConnection') + @patch('kafka.client.KafkaClient._get_conn') @patch('kafka.client.KafkaProtocol') def test_get_leader_for_partitions_reloads_metadata(self, protocol, conn): "Get leader for partitions reload metadata if it is not available" - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), @@ -251,11 +263,11 @@ class TestKafkaClient(unittest.TestCase): TopicAndPartition('topic_one_partition', 0): brokers[0]}, client.topics_to_brokers) - @patch('kafka.client.BrokerConnection') + @patch('kafka.client.KafkaClient._get_conn') @patch('kafka.client.KafkaProtocol') def test_get_leader_for_unassigned_partitions(self, protocol, conn): - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), @@ -278,11 +290,11 @@ class TestKafkaClient(unittest.TestCase): with self.assertRaises(UnknownTopicOrPartitionError): client._get_leader_for_partition('topic_unknown', 0) - @patch('kafka.client.BrokerConnection') + @patch('kafka.client.KafkaClient._get_conn') @patch('kafka.client.KafkaProtocol') def test_get_leader_exceptions_when_noleader(self, protocol, conn): - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), @@ -325,10 +337,10 @@ class TestKafkaClient(unittest.TestCase): self.assertEqual(brokers[0], client._get_leader_for_partition('topic_noleader', 0)) self.assertEqual(brokers[1], client._get_leader_for_partition('topic_noleader', 1)) - @patch('kafka.client.BrokerConnection') + @patch.object(KafkaClient, '_get_conn') @patch('kafka.client.KafkaProtocol') def test_send_produce_request_raises_when_noleader(self, protocol, conn): - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), @@ -352,11 +364,11 @@ class TestKafkaClient(unittest.TestCase): with self.assertRaises(LeaderNotAvailableError): client.send_produce_request(requests) - @patch('kafka.client.BrokerConnection') + @patch('kafka.client.KafkaClient._get_conn') @patch('kafka.client.KafkaProtocol') def test_send_produce_request_raises_when_topic_unknown(self, protocol, conn): - conn.recv.return_value = 'response' # anything but None + mock_conn(conn) brokers = [ BrokerMetadata(0, 'broker_1', 4567), |