diff options
-rw-r--r-- | kafka/client.py | 14 | ||||
-rw-r--r-- | test/test_client.py | 8 |
2 files changed, 15 insertions, 7 deletions
diff --git a/kafka/client.py b/kafka/client.py index 48a534e..c36cd08 100644 --- a/kafka/client.py +++ b/kafka/client.py @@ -2,7 +2,6 @@ import binascii import collections import copy import functools -import itertools import logging import time import kafka.common @@ -23,17 +22,18 @@ log = logging.getLogger("kafka") class KafkaClient(object): CLIENT_ID = b"kafka-python" - ID_GEN = itertools.count() # NOTE: The timeout given to the client should always be greater than the # one passed to SimpleConsumer.get_message(), otherwise you can get a # socket timeout. def __init__(self, hosts, client_id=CLIENT_ID, - timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS): + timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS, + correlation_id=0): # We need one connection to bootstrap self.client_id = kafka_bytestring(client_id) self.timeout = timeout self.hosts = collect_hosts(hosts) + self.correlation_id = correlation_id # create connections only when we need them self.conns = {} @@ -98,10 +98,10 @@ class KafkaClient(object): return self.brokers[meta.leader] def _next_id(self): - """ - Generate a new correlation id - """ - return next(KafkaClient.ID_GEN) + """Generate a new correlation id""" + # modulo to keep w/i int32 + self.correlation_id = (self.correlation_id + 1) % 2**31 + return self.correlation_id def _send_broker_unaware_request(self, payloads, encoder_fn, decoder_fn): """ diff --git a/test/test_client.py b/test/test_client.py index c522d9a..abda421 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -401,3 +401,11 @@ class TestKafkaClient(unittest.TestCase): with self.assertRaises(ConnectionError): KafkaConnection("nowhere", 1234, 1.0) self.assertGreaterEqual(t.interval, 1.0) + + def test_correlation_rollover(self): + with patch.object(KafkaClient, 'load_metadata_for_topics'): + big_num = 2**31 - 3 + client = KafkaClient(hosts=[], correlation_id=big_num) + self.assertEqual(big_num + 1, client._next_id()) + self.assertEqual(big_num + 2, client._next_id()) + self.assertEqual(0, client._next_id()) |