diff options
Diffstat (limited to 'kafka/conn.py')
-rw-r--r-- | kafka/conn.py | 60 |
1 files changed, 57 insertions, 3 deletions
diff --git a/kafka/conn.py b/kafka/conn.py index 4aa94f7..52ed9d6 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -25,6 +25,7 @@ from kafka.vendor import six import kafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate +from kafka.oauth.abstract import AbstractTokenProvider from kafka.protocol.admin import SaslHandShakeRequest from kafka.protocol.commit import OffsetFetchRequest from kafka.protocol.metadata import MetadataRequest @@ -184,6 +185,8 @@ class BrokerConnection(object): sasl mechanism handshake. Default: 'kafka' sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI sasl mechanism handshake. Default: one of bootstrap servers + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None """ DEFAULT_CONFIG = { @@ -216,10 +219,11 @@ class BrokerConnection(object): 'sasl_plain_username': None, 'sasl_plain_password': None, 'sasl_kerberos_service_name': 'kafka', - 'sasl_kerberos_domain_name': None + 'sasl_kerberos_domain_name': None, + 'sasl_oauth_token_provider': None } SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') - SASL_MECHANISMS = ('PLAIN', 'GSSAPI') + SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER') def __init__(self, host, port, afi, **configs): self.host = host @@ -263,7 +267,10 @@ class BrokerConnection(object): if self.config['sasl_mechanism'] == 'GSSAPI': assert gssapi is not None, 'GSSAPI lib not available' assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' - + if self.config['sasl_mechanism'] == 'OAUTHBEARER': + token_provider = self.config['sasl_oauth_token_provider'] + assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' + assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' # This is not a general lock / this class is not generally thread-safe yet # However, to avoid pushing responsibility for maintaining # per-connection locks to the upstream client, we will use this lock to @@ -537,6 +544,8 @@ class BrokerConnection(object): return self._try_authenticate_plain(future) elif self.config['sasl_mechanism'] == 'GSSAPI': return self._try_authenticate_gssapi(future) + elif self.config['sasl_mechanism'] == 'OAUTHBEARER': + return self._try_authenticate_oauth(future) else: return future.failure( Errors.UnsupportedSaslMechanismError( @@ -660,6 +669,51 @@ class BrokerConnection(object): log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) return future.success(True) + def _try_authenticate_oauth(self, future): + data = b'' + + msg = bytes(self._build_oauth_client_request().encode("utf-8")) + size = Int32.encode(len(msg)) + try: + # Send SASL OAuthBearer request with OAuth token + self._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) + + except ConnectionError as e: + log.exception("%s: Error receiving reply from server", self) + error = Errors.KafkaConnectionError("%s: %s" % (self, e)) + self.close(error=error) + return future.failure(error) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated via OAuth', self) + return future.success(True) + + def _build_oauth_client_request(self): + token_provider = self.config['sasl_oauth_token_provider'] + return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) + + def _token_extensions(self): + """ + Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER + initial request. + """ + token_provider = self.config['sasl_oauth_token_provider'] + + # Only run if the #extensions() method is implemented by the clients Token Provider class + # Builds up a string separated by \x01 via a dict of key value pairs + if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: + msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) + return "\x01" + msg + else: + return "" + def blacked_out(self): """ Return true if we are disconnected from the given node and can't |