summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kafka/client.py14
-rw-r--r--test/test_client.py8
2 files changed, 15 insertions, 7 deletions
diff --git a/kafka/client.py b/kafka/client.py
index 48a534e..c36cd08 100644
--- a/kafka/client.py
+++ b/kafka/client.py
@@ -2,7 +2,6 @@ import binascii
import collections
import copy
import functools
-import itertools
import logging
import time
import kafka.common
@@ -23,17 +22,18 @@ log = logging.getLogger("kafka")
class KafkaClient(object):
CLIENT_ID = b"kafka-python"
- ID_GEN = itertools.count()
# NOTE: The timeout given to the client should always be greater than the
# one passed to SimpleConsumer.get_message(), otherwise you can get a
# socket timeout.
def __init__(self, hosts, client_id=CLIENT_ID,
- timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS):
+ timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS,
+ correlation_id=0):
# We need one connection to bootstrap
self.client_id = kafka_bytestring(client_id)
self.timeout = timeout
self.hosts = collect_hosts(hosts)
+ self.correlation_id = correlation_id
# create connections only when we need them
self.conns = {}
@@ -98,10 +98,10 @@ class KafkaClient(object):
return self.brokers[meta.leader]
def _next_id(self):
- """
- Generate a new correlation id
- """
- return next(KafkaClient.ID_GEN)
+ """Generate a new correlation id"""
+ # modulo to keep w/i int32
+ self.correlation_id = (self.correlation_id + 1) % 2**31
+ return self.correlation_id
def _send_broker_unaware_request(self, payloads, encoder_fn, decoder_fn):
"""
diff --git a/test/test_client.py b/test/test_client.py
index c522d9a..abda421 100644
--- a/test/test_client.py
+++ b/test/test_client.py
@@ -401,3 +401,11 @@ class TestKafkaClient(unittest.TestCase):
with self.assertRaises(ConnectionError):
KafkaConnection("nowhere", 1234, 1.0)
self.assertGreaterEqual(t.interval, 1.0)
+
+ def test_correlation_rollover(self):
+ with patch.object(KafkaClient, 'load_metadata_for_topics'):
+ big_num = 2**31 - 3
+ client = KafkaClient(hosts=[], correlation_id=big_num)
+ self.assertEqual(big_num + 1, client._next_id())
+ self.assertEqual(big_num + 2, client._next_id())
+ self.assertEqual(0, client._next_id())