diff options
Diffstat (limited to 'kafka')
-rw-r--r-- | kafka/conn.py | 57 |
1 files changed, 36 insertions, 21 deletions
diff --git a/kafka/conn.py b/kafka/conn.py index 6b5aff9..c273765 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -17,6 +17,7 @@ except ImportError: import socket import struct import sys +import threading import time from kafka.vendor import six @@ -220,7 +221,6 @@ class BrokerConnection(object): self.afi = afi self._sock_afi = afi self._sock_addr = None - self.in_flight_requests = collections.deque() self._api_versions = None self.config = copy.copy(self.DEFAULT_CONFIG) @@ -255,6 +255,20 @@ class BrokerConnection(object): assert gssapi is not None, 'GSSAPI lib not available' assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' + # This is not a general lock / this class is not generally thread-safe yet + # However, to avoid pushing responsibility for maintaining + # per-connection locks to the upstream client, we will use this lock to + # make sure that access to the protocol buffer is synchronized + # when sends happen on multiple threads + self._lock = threading.Lock() + + # the protocol parser instance manages actual tracking of the + # sequence of in-flight requests to responses, which should + # function like a FIFO queue. For additional request data, + # including tracking request futures and timestamps, we + # can use a simple dictionary of correlation_id => request data + self.in_flight_requests = dict() + self._protocol = KafkaProtocol( client_id=self.config['client_id'], api_version=self.config['api_version']) @@ -729,7 +743,7 @@ class BrokerConnection(object): if error is None: error = Errors.Cancelled(str(self)) while self.in_flight_requests: - (_, future, _) = self.in_flight_requests.popleft() + (_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem() future.failure(error) self.config['state_change_callback'](self) @@ -747,23 +761,22 @@ class BrokerConnection(object): def _send(self, request, blocking=True): assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED) future = Future() - correlation_id = self._protocol.send_request(request) - - # Attempt to replicate behavior from prior to introduction of - # send_pending_requests() / async sends - if blocking: - error = self.send_pending_requests() - if isinstance(error, Exception): - future.failure(error) - return future + with self._lock: + correlation_id = self._protocol.send_request(request) log.debug('%s Request %d: %s', self, correlation_id, request) if request.expect_response(): sent_time = time.time() - ifr = (correlation_id, future, sent_time) - self.in_flight_requests.append(ifr) + assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' + self.in_flight_requests[correlation_id] = (future, sent_time) else: future.success(None) + + # Attempt to replicate behavior from prior to introduction of + # send_pending_requests() / async sends + if blocking: + self.send_pending_requests() + return future def send_pending_requests(self): @@ -818,8 +831,12 @@ class BrokerConnection(object): return () # augment respones w/ correlation_id, future, and timestamp - for i, response in enumerate(responses): - (correlation_id, future, timestamp) = self.in_flight_requests.popleft() + for i, (correlation_id, response) in enumerate(responses): + try: + (future, timestamp) = self.in_flight_requests.pop(correlation_id) + except KeyError: + self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) + return () latency_ms = (time.time() - timestamp) * 1000 if self._sensors: self._sensors.request_time.record(latency_ms) @@ -870,20 +887,18 @@ class BrokerConnection(object): self.close(e) return [] else: - return [resp for (_, resp) in responses] # drop correlation id + return responses def requests_timed_out(self): if self.in_flight_requests: - (_, _, oldest_at) = self.in_flight_requests[0] + get_timestamp = lambda v: v[1] + oldest_at = min(map(get_timestamp, + self.in_flight_requests.values())) timeout = self.config['request_timeout_ms'] / 1000.0 if time.time() >= oldest_at + timeout: return True return False - def _next_correlation_id(self): - self._correlation_id = (self._correlation_id + 1) % 2**31 - return self._correlation_id - def _handle_api_version_response(self, response): error_type = Errors.for_code(response.error_code) assert error_type is Errors.NoError, "API version check failed" |