diff options
-rw-r--r-- | kafka/client.py | 23 | ||||
-rw-r--r-- | kafka/client_async.py | 10 | ||||
-rw-r--r-- | kafka/conn.py | 49 | ||||
-rw-r--r-- | test/test_client.py | 11 | ||||
-rw-r--r-- | test/test_client_async.py | 15 | ||||
-rw-r--r-- | test/test_conn.py | 29 |
6 files changed, 101 insertions, 36 deletions
diff --git a/kafka/client.py b/kafka/client.py index 11f54eb..99d6fec 100644 --- a/kafka/client.py +++ b/kafka/client.py @@ -16,7 +16,7 @@ from kafka.common import (TopicPartition, BrokerMetadata, UnknownError, from kafka.conn import ( collect_hosts, BrokerConnection, DEFAULT_SOCKET_TIMEOUT_SECONDS, - ConnectionStates) + ConnectionStates, get_ip_port_afi) from kafka.protocol import KafkaProtocol # New KafkaClient @@ -56,12 +56,12 @@ class SimpleClient(object): # Private API # ################## - def _get_conn(self, host, port): + def _get_conn(self, host, port, afi): """Get or create a connection to a broker using host and port""" host_key = (host, port) if host_key not in self._conns: self._conns[host_key] = BrokerConnection( - host, port, + host, port, afi, request_timeout_ms=self.timeout * 1000, client_id=self.client_id ) @@ -139,13 +139,17 @@ class SimpleClient(object): Attempt to send a broker-agnostic request to one of the available brokers. Keep trying until you succeed. """ - hosts = set([(broker.host, broker.port) for broker in self.brokers.values()]) + hosts = set() + for broker in self.brokers.values(): + host, port, afi = get_ip_port_afi(broker.host) + hosts.add((host, broker.port, afi)) + hosts.update(self.hosts) hosts = list(hosts) random.shuffle(hosts) - for (host, port) in hosts: - conn = self._get_conn(host, port) + for (host, port, afi) in hosts: + conn = self._get_conn(host, port, afi) if not conn.connected(): log.warning("Skipping unconnected connection: %s", conn) continue @@ -227,7 +231,9 @@ class SimpleClient(object): failed_payloads(broker_payloads) continue - conn = self._get_conn(broker.host, broker.port) + + host, port, afi = get_ip_port_afi(broker.host) + conn = self._get_conn(host, broker.port, afi) conn.connect() if not conn.connected(): refresh_metadata = True @@ -323,7 +329,8 @@ class SimpleClient(object): # Send the request, recv the response try: - conn = self._get_conn(broker.host, broker.port) + host, port, afi = get_ip_port_afi(broker.host) + conn = self._get_conn(host, broker.port, afi) conn.send(requestId, request) except ConnectionError as e: diff --git a/kafka/client_async.py b/kafka/client_async.py index ae9dbb4..5a1d624 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -14,7 +14,7 @@ import six import kafka.common as Errors # TODO: make Errors a separate class from .cluster import ClusterMetadata -from .conn import BrokerConnection, ConnectionStates, collect_hosts +from .conn import BrokerConnection, ConnectionStates, collect_hosts, get_ip_port_afi from .future import Future from .protocol.metadata import MetadataRequest from .protocol.produce import ProduceRequest @@ -115,9 +115,9 @@ class KafkaClient(object): self._last_bootstrap = time.time() metadata_request = MetadataRequest([]) - for host, port in hosts: + for host, port, afi in hosts: log.debug("Attempting to bootstrap via node at %s:%s", host, port) - bootstrap = BrokerConnection(host, port, **self.config) + bootstrap = BrokerConnection(host, port, afi, **self.config) bootstrap.connect() while bootstrap.state is ConnectionStates.CONNECTING: bootstrap.connect() @@ -160,7 +160,9 @@ class KafkaClient(object): log.debug("Initiating connection to node %s at %s:%s", node_id, broker.host, broker.port) - self._conns[node_id] = BrokerConnection(broker.host, broker.port, + + host, port, afi = get_ip_port_afi(broker.host) + self._conns[node_id] = BrokerConnection(host, broker.port, afi, **self.config) return self._finish_connect(node_id) diff --git a/kafka/conn.py b/kafka/conn.py index 65451f9..0ce469d 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -52,9 +52,10 @@ class BrokerConnection(object): 'api_version': (0, 8, 2), # default to most restrictive } - def __init__(self, host, port, **configs): + def __init__(self, host, port, afi, **configs): self.host = host self.port = port + self.afi = afi self.in_flight_requests = collections.deque() self.config = copy.copy(self.DEFAULT_CONFIG) @@ -76,7 +77,7 @@ class BrokerConnection(object): """Attempt to connect and return ConnectionState""" if self.state is ConnectionStates.DISCONNECTED: self.close() - self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock = socket.socket(self.afi, socket.SOCK_STREAM) if self.config['receive_buffer_bytes'] is not None: self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.config['receive_buffer_bytes']) @@ -356,6 +357,39 @@ class BrokerConnection(object): return "<BrokerConnection host=%s port=%d>" % (self.host, self.port) +def get_ip_port_afi(host_and_port_str): + """ + Parse the IP and port from a string in the format of: + + * host_or_ip <- Can be either IPv4 or IPv6 address or hostname/fqdn + * host_or_ip:port <- This is only for IPv4 + * [host_or_ip]:port. <- This is only for IPv6 + + .. note:: If the port is not specified, default will be returned. + + :return: tuple (host, port, afi), afi will be socket.AF_INET or socket.AF_INET6 + """ + afi = socket.AF_INET + + if host_and_port_str.strip()[0] == '[': + afi = socket.AF_INET6 + res = host_and_port_str.split("]:") + res[0] = res[0].replace("[", "") + res[0] = res[0].replace("]", "") + + elif host_and_port_str.count(":") > 1: + afi = socket.AF_INET6 + res = [host_and_port_str] + + else: + res = host_and_port_str.split(':') + + host = res[0] + port = int(res[1]) if len(res) > 1 else DEFAULT_KAFKA_PORT + + return host.strip(), port, afi + + def collect_hosts(hosts, randomize=True): """ Collects a comma-separated set of hosts (host:port) and optionally @@ -366,12 +400,15 @@ def collect_hosts(hosts, randomize=True): hosts = hosts.strip().split(',') result = [] + afi = socket.AF_INET for host_port in hosts: - res = host_port.split(':') - host = res[0] - port = int(res[1]) if len(res) > 1 else DEFAULT_KAFKA_PORT - result.append((host.strip(), port)) + host, port, afi = get_ip_port_afi(host_port) + + if port < 0: + port = DEFAULT_KAFKA_PORT + + result.append((host, port, afi)) if randomize: shuffle(result) diff --git a/test/test_client.py b/test/test_client.py index a53fce1..6980434 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -41,7 +41,7 @@ class TestSimpleClient(unittest.TestCase): client = SimpleClient(hosts=['kafka01:9092', 'kafka02:9092', 'kafka03:9092']) self.assertEqual( - sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]), + sorted([('kafka01', 9092, socket.AF_INET), ('kafka02', 9092, socket.AF_INET), ('kafka03', 9092, socket.AF_INET)]), sorted(client.hosts)) def test_init_with_csv(self): @@ -49,7 +49,7 @@ class TestSimpleClient(unittest.TestCase): client = SimpleClient(hosts='kafka01:9092,kafka02:9092,kafka03:9092') self.assertEqual( - sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]), + sorted([('kafka01', 9092, socket.AF_INET), ('kafka02', 9092, socket.AF_INET), ('kafka03', 9092, socket.AF_INET)]), sorted(client.hosts)) def test_init_with_unicode_csv(self): @@ -57,7 +57,7 @@ class TestSimpleClient(unittest.TestCase): client = SimpleClient(hosts=u'kafka01:9092,kafka02:9092,kafka03:9092') self.assertEqual( - sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]), + sorted([('kafka01', 9092, socket.AF_INET), ('kafka02', 9092, socket.AF_INET), ('kafka03', 9092, socket.AF_INET)]), sorted(client.hosts)) @patch.object(SimpleClient, '_get_conn') @@ -70,7 +70,7 @@ class TestSimpleClient(unittest.TestCase): for val in mocked_conns.values(): mock_conn(val, success=False) - def mock_get_conn(host, port): + def mock_get_conn(host, port, afi): return mocked_conns[(host, port)] conn.side_effect = mock_get_conn @@ -98,7 +98,7 @@ class TestSimpleClient(unittest.TestCase): mocked_conns[('kafka02', 9092)].send.return_value = future mocked_conns[('kafka02', 9092)].recv.side_effect = lambda: future.success('valid response') - def mock_get_conn(host, port): + def mock_get_conn(host, port, afi): return mocked_conns[(host, port)] # patch to avoid making requests before we want it @@ -409,3 +409,4 @@ class TestSimpleClient(unittest.TestCase): self.assertEqual(big_num + 1, client._next_id()) self.assertEqual(big_num + 2, client._next_id()) self.assertEqual(0, client._next_id()) + diff --git a/test/test_client_async.py b/test/test_client_async.py index 2e0d9b4..e0b98c4 100644 --- a/test/test_client_async.py +++ b/test/test_client_async.py @@ -1,4 +1,5 @@ import time +import socket import pytest @@ -12,11 +13,11 @@ from kafka.protocol.produce import ProduceRequest @pytest.mark.parametrize("bootstrap,expected_hosts", [ - (None, [('localhost', 9092)]), - ('foobar:1234', [('foobar', 1234)]), - ('fizzbuzz', [('fizzbuzz', 9092)]), - ('foo:12,bar:34', [('foo', 12), ('bar', 34)]), - (['fizz:56', 'buzz'], [('fizz', 56), ('buzz', 9092)]), + (None, [('localhost', 9092, socket.AF_INET)]), + ('foobar:1234', [('foobar', 1234, socket.AF_INET)]), + ('fizzbuzz', [('fizzbuzz', 9092, socket.AF_INET)]), + ('foo:12,bar:34', [('foo', 12, socket.AF_INET), ('bar', 34, socket.AF_INET)]), + (['fizz:56', 'buzz'], [('fizz', 56, socket.AF_INET), ('buzz', 9092, socket.AF_INET)]), ]) def test_bootstrap_servers(mocker, bootstrap, expected_hosts): mocker.patch.object(KafkaClient, '_bootstrap') @@ -47,7 +48,7 @@ def conn(mocker): def test_bootstrap_success(conn): conn.state = ConnectionStates.CONNECTED cli = KafkaClient() - conn.assert_called_once_with('localhost', 9092, **cli.config) + conn.assert_called_once_with('localhost', 9092, socket.AF_INET, **cli.config) conn.connect.assert_called_with() conn.send.assert_called_once_with(MetadataRequest([])) assert cli._bootstrap_fails == 0 @@ -57,7 +58,7 @@ def test_bootstrap_success(conn): def test_bootstrap_failure(conn): conn.state = ConnectionStates.DISCONNECTED cli = KafkaClient() - conn.assert_called_once_with('localhost', 9092, **cli.config) + conn.assert_called_once_with('localhost', 9092, socket.AF_INET, **cli.config) conn.connect.assert_called_with() conn.close.assert_called_with() assert cli._bootstrap_fails == 1 diff --git a/test/test_conn.py b/test/test_conn.py index 684ffe5..f0ef8fb 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -51,21 +51,37 @@ class ConnTest(unittest.TestCase): results = collect_hosts(hosts) self.assertEqual(set(results), set([ - ('localhost', 1234), - ('localhost', 9092), + ('localhost', 1234, socket.AF_INET), + ('localhost', 9092, socket.AF_INET), + ])) + + def test_collect_hosts__ipv6(self): + hosts = "[localhost]:1234,[2001:1000:2000::1],[2001:1000:2000::1]:1234" + results = collect_hosts(hosts) + + self.assertEqual(set(results), set([ + ('localhost', 1234, socket.AF_INET6), + ('2001:1000:2000::1', 9092, socket.AF_INET6), + ('2001:1000:2000::1', 1234, socket.AF_INET6), ])) def test_collect_hosts__string_list(self): hosts = [ 'localhost:1234', 'localhost', + '[localhost]', + '2001::1', + '[2001::1]:1234', ] results = collect_hosts(hosts) self.assertEqual(set(results), set([ - ('localhost', 1234), - ('localhost', 9092), + ('localhost', 1234, socket.AF_INET), + ('localhost', 9092, socket.AF_INET), + ('localhost', 9092, socket.AF_INET6), + ('2001::1', 9092, socket.AF_INET6), + ('2001::1', 1234, socket.AF_INET6), ])) def test_collect_hosts__with_spaces(self): @@ -73,10 +89,11 @@ class ConnTest(unittest.TestCase): results = collect_hosts(hosts) self.assertEqual(set(results), set([ - ('localhost', 1234), - ('localhost', 9092), + ('localhost', 1234, socket.AF_INET), + ('localhost', 9092, socket.AF_INET), ])) + def test_send(self): self.conn.send(self.config['request_id'], self.config['payload']) self.conn._sock.sendall.assert_called_with(self.config['payload']) |