summaryrefslogtreecommitdiff
path: root/kafka/conn.py
diff options
context:
space:
mode:
Diffstat (limited to 'kafka/conn.py')
-rw-r--r--kafka/conn.py60
1 files changed, 57 insertions, 3 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index 4aa94f7..52ed9d6 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -25,6 +25,7 @@ from kafka.vendor import six
import kafka.errors as Errors
from kafka.future import Future
from kafka.metrics.stats import Avg, Count, Max, Rate
+from kafka.oauth.abstract import AbstractTokenProvider
from kafka.protocol.admin import SaslHandShakeRequest
from kafka.protocol.commit import OffsetFetchRequest
from kafka.protocol.metadata import MetadataRequest
@@ -184,6 +185,8 @@ class BrokerConnection(object):
sasl mechanism handshake. Default: 'kafka'
sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI
sasl mechanism handshake. Default: one of bootstrap servers
+ sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
+ instance. (See kafka.oauth.abstract). Default: None
"""
DEFAULT_CONFIG = {
@@ -216,10 +219,11 @@ class BrokerConnection(object):
'sasl_plain_username': None,
'sasl_plain_password': None,
'sasl_kerberos_service_name': 'kafka',
- 'sasl_kerberos_domain_name': None
+ 'sasl_kerberos_domain_name': None,
+ 'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
- SASL_MECHANISMS = ('PLAIN', 'GSSAPI')
+ SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER')
def __init__(self, host, port, afi, **configs):
self.host = host
@@ -263,7 +267,10 @@ class BrokerConnection(object):
if self.config['sasl_mechanism'] == 'GSSAPI':
assert gssapi is not None, 'GSSAPI lib not available'
assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
-
+ if self.config['sasl_mechanism'] == 'OAUTHBEARER':
+ token_provider = self.config['sasl_oauth_token_provider']
+ assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
+ assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'
# This is not a general lock / this class is not generally thread-safe yet
# However, to avoid pushing responsibility for maintaining
# per-connection locks to the upstream client, we will use this lock to
@@ -537,6 +544,8 @@ class BrokerConnection(object):
return self._try_authenticate_plain(future)
elif self.config['sasl_mechanism'] == 'GSSAPI':
return self._try_authenticate_gssapi(future)
+ elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
+ return self._try_authenticate_oauth(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
@@ -660,6 +669,51 @@ class BrokerConnection(object):
log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
return future.success(True)
+ def _try_authenticate_oauth(self, future):
+ data = b''
+
+ msg = bytes(self._build_oauth_client_request().encode("utf-8"))
+ size = Int32.encode(len(msg))
+ try:
+ # Send SASL OAuthBearer request with OAuth token
+ self._send_bytes_blocking(size + msg)
+
+ # The server will send a zero sized message (that is Int32(0)) on success.
+ # The connection is closed on failure
+ data = self._recv_bytes_blocking(4)
+
+ except ConnectionError as e:
+ log.exception("%s: Error receiving reply from server", self)
+ error = Errors.KafkaConnectionError("%s: %s" % (self, e))
+ self.close(error=error)
+ return future.failure(error)
+
+ if data != b'\x00\x00\x00\x00':
+ error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
+ return future.failure(error)
+
+ log.info('%s: Authenticated via OAuth', self)
+ return future.success(True)
+
+ def _build_oauth_client_request(self):
+ token_provider = self.config['sasl_oauth_token_provider']
+ return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions())
+
+ def _token_extensions(self):
+ """
+ Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER
+ initial request.
+ """
+ token_provider = self.config['sasl_oauth_token_provider']
+
+ # Only run if the #extensions() method is implemented by the clients Token Provider class
+ # Builds up a string separated by \x01 via a dict of key value pairs
+ if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0:
+ msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()])
+ return "\x01" + msg
+ else:
+ return ""
+
def blacked_out(self):
"""
Return true if we are disconnected from the given node and can't