summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2016-03-17 11:22:36 -0700
committerDana Powers <dana.powers@gmail.com>2016-03-17 11:22:36 -0700
commitab03296b65b2031930a0f04d06502b156dd01657 (patch)
tree4a2199ee851997c827ab93598d8150ccf98a9e48
parentf83e27168fe5f0d5c71b962b8788e05640ea8e2b (diff)
parent2b6d063085dc6bab9e84cc5c714be5cf2716fe38 (diff)
downloadkafka-python-ab03296b65b2031930a0f04d06502b156dd01657.tar.gz
Merge pull request #602 from zackdever/KAFKA-2698
KAFKA-2698: add paused API
-rw-r--r--kafka/consumer/group.py8
-rw-r--r--kafka/consumer/subscription_state.py5
-rw-r--r--test/test_consumer_group.py30
3 files changed, 38 insertions, 5 deletions
diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py
index f646e05..9172040 100644
--- a/kafka/consumer/group.py
+++ b/kafka/consumer/group.py
@@ -528,6 +528,14 @@ class KafkaConsumer(six.Iterator):
log.debug("Pausing partition %s", partition)
self._subscription.pause(partition)
+ def paused(self):
+ """Get the partitions that were previously paused by a call to pause().
+
+ Returns:
+ set: {partition (TopicPartition), ...}
+ """
+ return self._subscription.paused_partitions()
+
def resume(self, *partitions):
"""Resume fetching from the specified (paused) partitions.
diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py
index a4043a1..3d170ae 100644
--- a/kafka/consumer/subscription_state.py
+++ b/kafka/consumer/subscription_state.py
@@ -265,6 +265,11 @@ class SubscriptionState(object):
"""Return set of TopicPartitions in current assignment."""
return set(self.assignment.keys())
+ def paused_partitions(self):
+ """Return current set of paused TopicPartitions."""
+ return set(partition for partition in self.assignment
+ if self.is_paused(partition))
+
def fetchable_partitions(self):
"""Return set of TopicPartitions that should be Fetched."""
fetchable = set()
diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py
index 34b1be4..5fcfbe2 100644
--- a/test/test_consumer_group.py
+++ b/test/test_consumer_group.py
@@ -17,10 +17,13 @@ from test.conftest import version
from test.testutil import random_string
+def get_connect_str(kafka_broker):
+ return 'localhost:' + str(kafka_broker.port)
+
+
@pytest.fixture
def simple_client(kafka_broker):
- connect_str = 'localhost:' + str(kafka_broker.port)
- return SimpleClient(connect_str)
+ return SimpleClient(get_connect_str(kafka_broker))
@pytest.fixture
@@ -37,8 +40,7 @@ def test_consumer(kafka_broker, version):
if version >= (0, 8, 2) and version < (0, 9):
topic(simple_client(kafka_broker))
- connect_str = 'localhost:' + str(kafka_broker.port)
- consumer = KafkaConsumer(bootstrap_servers=connect_str)
+ consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
consumer.poll(500)
assert len(consumer._client._conns) > 0
node_id = list(consumer._client._conns.keys())[0]
@@ -49,7 +51,7 @@ def test_consumer(kafka_broker, version):
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_group(kafka_broker, topic):
num_partitions = 4
- connect_str = 'localhost:' + str(kafka_broker.port)
+ connect_str = get_connect_str(kafka_broker)
consumers = {}
stop = {}
threads = {}
@@ -120,6 +122,24 @@ def test_group(kafka_broker, topic):
threads[c].join()
+@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
+def test_paused(kafka_broker, topic):
+ consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
+ topics = [TopicPartition(topic, 1)]
+ consumer.assign(topics)
+ assert set(topics) == consumer.assignment()
+ assert set() == consumer.paused()
+
+ consumer.pause(topics[0])
+ assert set([topics[0]]) == consumer.paused()
+
+ consumer.resume(topics[0])
+ assert set() == consumer.paused()
+
+ consumer.unsubscribe()
+ assert set() == consumer.paused()
+
+
@pytest.fixture
def conn(mocker):
conn = mocker.patch('kafka.client_async.BrokerConnection')