diff options
author | Dana Powers <dana.powers@gmail.com> | 2016-08-02 23:07:34 -0700 |
---|---|---|
committer | Dana Powers <dana.powers@gmail.com> | 2016-08-02 23:07:34 -0700 |
commit | 3baca5f306eadf276bb78970bfe2c134b6031449 (patch) | |
tree | d791a034a5b98f88318b3590bcfc5446620ef333 /kafka/conn.py | |
parent | 40afc98dcaf4206040d2366886d9ba52d7ca0026 (diff) | |
download | kafka-python-larsjsol-sasl_plain.tar.gz |
Use callbacks for sasl handshake request / responselarsjsol-sasl_plain
Diffstat (limited to 'kafka/conn.py')
-rw-r--r-- | kafka/conn.py | 140 |
1 files changed, 71 insertions, 69 deletions
diff --git a/kafka/conn.py b/kafka/conn.py index 2e70165..6ccaca4 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -74,10 +74,11 @@ class BrokerConnection(object): 'ssl_password': None, 'api_version': (0, 8, 2), # default to most restrictive 'state_change_callback': lambda conn: True, - 'sasl_mechanism': None, + 'sasl_mechanism': 'PLAIN', 'sasl_plain_username': None, 'sasl_plain_password': None } + SASL_MECHANISMS = ('PLAIN',) def __init__(self, host, port, afi, **configs): self.host = host @@ -100,11 +101,19 @@ class BrokerConnection(object): (socket.SOL_SOCKET, socket.SO_SNDBUF, self.config['send_buffer_bytes'])) + if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): + assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, ( + 'sasl_mechanism must be in ' + self.SASL_MECHANISMS) + if self.config['sasl_mechanism'] == 'PLAIN': + assert self.config['sasl_plain_username'] is not None, 'sasl_plain_username required for PLAIN sasl' + assert self.config['sasl_plain_password'] is not None, 'sasl_plain_password required for PLAIN sasl' + self.state = ConnectionStates.DISCONNECTED self._sock = None self._ssl_context = None if self.config['ssl_context'] is not None: self._ssl_context = self.config['ssl_context'] + self._sasl_auth_future = None self._rbuffer = io.BytesIO() self._receiving = False self._next_payload_bytes = 0 @@ -224,8 +233,9 @@ class BrokerConnection(object): self.config['state_change_callback'](self) if self.state is ConnectionStates.AUTHENTICATING: + assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL') if self._try_authenticate(): - log.debug('%s: Authenticated as %s', str(self), self.config['sasl_plain_username']) + log.info('%s: Authenticated as %s', str(self), self.config['sasl_plain_username']) self.state = ConnectionStates.CONNECTED self.config['state_change_callback'](self) @@ -289,58 +299,44 @@ class BrokerConnection(object): return False def _try_authenticate(self): - assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL') - - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Sending username and password in the clear', str(self)) - - # Build a SaslHandShakeRequest message - correlation_id = self._next_correlation_id() - request = SaslHandShakeRequest[0](self.config['sasl_mechanism']) - header = RequestHeader(request, - correlation_id=correlation_id, - client_id=self.config['client_id']) - - message = b''.join([header.encode(), request.encode()]) - size = Int32.encode(len(message)) - - # Attempt to send it over our socket - try: - self._sock.setblocking(True) - self._sock.sendall(size + message) - self._sock.setblocking(False) - except (AssertionError, ConnectionError) as e: - log.exception("Error sending %s to %s", request, self) - error = Errors.ConnectionError("%s: %s" % (str(self), e)) + assert self.config['api_version'] >= (0, 10) or self.config['api_version'] is None + + if self._sasl_auth_future is None: + # Build a SaslHandShakeRequest message + request = SaslHandShakeRequest[0](self.config['sasl_mechanism']) + future = Future() + sasl_response = self._send(request) + sasl_response.add_callback(self._handle_sasl_handshake_response, future) + sasl_response.add_errback(lambda f, e: f.failure(e), future) + self._sasl_auth_future = future + self._recv() + if self._sasl_auth_future.failed(): + raise self._sasl_auth_future.exception + return self._sasl_auth_future.succeeded() + + def _handle_sasl_handshake_response(self, future, response): + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + error = error_type(self) self.close(error=error) - return False - - future = Future() - ifr = InFlightRequest(request=request, - correlation_id=correlation_id, - response_type=request.RESPONSE_TYPE, - future=future, - timestamp=time.time()) - self.in_flight_requests.append(ifr) - - # Listen for a reply and check that the server supports the PLAIN mechanism - response = None - while not response: - response = self.recv() + return future.failure(error_type(self)) - if not response.error_code is 0: - raise Errors.for_code(response.error_code) - - if not self.config['sasl_mechanism'] in response.enabled_mechanisms: - raise Errors.AuthenticationMethodNotSupported(self.config['sasl_mechanism'] + " is not supported by broker") + if self.config['sasl_mechanism'] == 'PLAIN': + return self._try_authenticate_plain(future) + else: + return future.failure( + Errors.UnsupportedSaslMechanismError( + 'kafka-python does not support SASL mechanism %s' % + self.config['sasl_mechanism'])) - return self._try_authenticate_plain() + def _try_authenticate_plain(self, future): + if self.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Sending username and password in the clear', str(self)) - def _try_authenticate_plain(self): data = b'' try: self._sock.setblocking(True) - # Send our credentials + # Send PLAIN credentials per RFC-4616 msg = bytes('\0'.join([self.config['sasl_plain_username'], self.config['sasl_plain_username'], self.config['sasl_plain_password']]).encode('utf-8')) @@ -351,26 +347,26 @@ class BrokerConnection(object): # The connection is closed on failure received_bytes = 0 while received_bytes < 4: - data = data + self._sock.recv(4 - received_bytes) - received_bytes = received_bytes + len(data) + data += self._sock.recv(4 - received_bytes) + received_bytes += len(data) if not data: log.error('%s: Authentication failed for user %s', self, self.config['sasl_plain_username']) - self.close(error=Errors.ConnectionError('Authentication failed')) - raise Errors.AuthenticationFailedError('Authentication failed for user {}'.format(self.config['sasl_plain_username'])) + error = Errors.AuthenticationFailedError( + 'Authentication failed for user {0}'.format( + self.config['sasl_plain_username'])) + future.failure(error) + raise error self._sock.setblocking(False) except (AssertionError, ConnectionError) as e: log.exception("%s: Error receiving reply from server", self) error = Errors.ConnectionError("%s: %s" % (str(self), e)) + future.failure(error) self.close(error=error) - return False - with io.BytesIO() as buffer: - buffer.write(data) - buffer.seek(0) - if not Int32.decode(buffer) == 0: - raise Errors.KafkaError('Expected a zero sized reply after sending credentials') + if data != '\x00\x00\x00\x00': + return future.failure(Errors.AuthenticationFailedError()) - return True + return future.success(True) def blacked_out(self): """ @@ -430,30 +426,33 @@ class BrokerConnection(object): Can block on network if request is larger than send_buffer_bytes """ - future = Future() if self.connecting(): - return future.failure(Errors.NodeNotReadyError(str(self))) + return Future().failure(Errors.NodeNotReadyError(str(self))) elif not self.connected(): - return future.failure(Errors.ConnectionError(str(self))) + return Future().failure(Errors.ConnectionError(str(self))) elif not self.can_send_more(): - return future.failure(Errors.TooManyInFlightRequests(str(self))) + return Future().failure(Errors.TooManyInFlightRequests(str(self))) + return self._send(request, expect_response=expect_response) + + def _send(self, request, expect_response=True): + future = Future() correlation_id = self._next_correlation_id() header = RequestHeader(request, correlation_id=correlation_id, client_id=self.config['client_id']) message = b''.join([header.encode(), request.encode()]) size = Int32.encode(len(message)) + data = size + message try: # In the future we might manage an internal write buffer # and send bytes asynchronously. For now, just block # sending each request payload self._sock.setblocking(True) - for data in (size, message): - total_sent = 0 - while total_sent < len(data): - sent_bytes = self._sock.send(data[total_sent:]) - total_sent += sent_bytes - assert total_sent == len(data) + total_sent = 0 + while total_sent < len(data): + sent_bytes = self._sock.send(data[total_sent:]) + total_sent += sent_bytes + assert total_sent == len(data) self._sock.setblocking(False) except (AssertionError, ConnectionError) as e: log.exception("Error sending %s to %s", request, self) @@ -505,6 +504,9 @@ class BrokerConnection(object): self.config['request_timeout_ms'])) return None + return self._recv() + + def _recv(self): # Not receiving is the state of reading the payload header if not self._receiving: try: @@ -552,7 +554,7 @@ class BrokerConnection(object): # enough data to read the full bytes_to_read # but if the socket is disconnected, we will get empty data # without an exception raised - if not data: + if bytes_to_read and not data: log.error('%s: socket disconnected', self) self.close(error=Errors.ConnectionError('socket disconnected')) return None |