diff options
author | Dana Powers <dana.powers@gmail.com> | 2016-03-17 11:22:36 -0700 |
---|---|---|
committer | Dana Powers <dana.powers@gmail.com> | 2016-03-17 11:22:36 -0700 |
commit | ab03296b65b2031930a0f04d06502b156dd01657 (patch) | |
tree | 4a2199ee851997c827ab93598d8150ccf98a9e48 | |
parent | f83e27168fe5f0d5c71b962b8788e05640ea8e2b (diff) | |
parent | 2b6d063085dc6bab9e84cc5c714be5cf2716fe38 (diff) | |
download | kafka-python-ab03296b65b2031930a0f04d06502b156dd01657.tar.gz |
Merge pull request #602 from zackdever/KAFKA-2698
KAFKA-2698: add paused API
-rw-r--r-- | kafka/consumer/group.py | 8 | ||||
-rw-r--r-- | kafka/consumer/subscription_state.py | 5 | ||||
-rw-r--r-- | test/test_consumer_group.py | 30 |
3 files changed, 38 insertions, 5 deletions
diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py index f646e05..9172040 100644 --- a/kafka/consumer/group.py +++ b/kafka/consumer/group.py @@ -528,6 +528,14 @@ class KafkaConsumer(six.Iterator): log.debug("Pausing partition %s", partition) self._subscription.pause(partition) + def paused(self): + """Get the partitions that were previously paused by a call to pause(). + + Returns: + set: {partition (TopicPartition), ...} + """ + return self._subscription.paused_partitions() + def resume(self, *partitions): """Resume fetching from the specified (paused) partitions. diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py index a4043a1..3d170ae 100644 --- a/kafka/consumer/subscription_state.py +++ b/kafka/consumer/subscription_state.py @@ -265,6 +265,11 @@ class SubscriptionState(object): """Return set of TopicPartitions in current assignment.""" return set(self.assignment.keys()) + def paused_partitions(self): + """Return current set of paused TopicPartitions.""" + return set(partition for partition in self.assignment + if self.is_paused(partition)) + def fetchable_partitions(self): """Return set of TopicPartitions that should be Fetched.""" fetchable = set() diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py index 34b1be4..5fcfbe2 100644 --- a/test/test_consumer_group.py +++ b/test/test_consumer_group.py @@ -17,10 +17,13 @@ from test.conftest import version from test.testutil import random_string +def get_connect_str(kafka_broker): + return 'localhost:' + str(kafka_broker.port) + + @pytest.fixture def simple_client(kafka_broker): - connect_str = 'localhost:' + str(kafka_broker.port) - return SimpleClient(connect_str) + return SimpleClient(get_connect_str(kafka_broker)) @pytest.fixture @@ -37,8 +40,7 @@ def test_consumer(kafka_broker, version): if version >= (0, 8, 2) and version < (0, 9): topic(simple_client(kafka_broker)) - connect_str = 'localhost:' + str(kafka_broker.port) - consumer = KafkaConsumer(bootstrap_servers=connect_str) + consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) consumer.poll(500) assert len(consumer._client._conns) > 0 node_id = list(consumer._client._conns.keys())[0] @@ -49,7 +51,7 @@ def test_consumer(kafka_broker, version): @pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set") def test_group(kafka_broker, topic): num_partitions = 4 - connect_str = 'localhost:' + str(kafka_broker.port) + connect_str = get_connect_str(kafka_broker) consumers = {} stop = {} threads = {} @@ -120,6 +122,24 @@ def test_group(kafka_broker, topic): threads[c].join() +@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set") +def test_paused(kafka_broker, topic): + consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) + topics = [TopicPartition(topic, 1)] + consumer.assign(topics) + assert set(topics) == consumer.assignment() + assert set() == consumer.paused() + + consumer.pause(topics[0]) + assert set([topics[0]]) == consumer.paused() + + consumer.resume(topics[0]) + assert set() == consumer.paused() + + consumer.unsubscribe() + assert set() == consumer.paused() + + @pytest.fixture def conn(mocker): conn = mocker.patch('kafka.client_async.BrokerConnection') |