summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kafka/consumer/multiprocess.py10
-rw-r--r--test/test_consumer.py13
2 files changed, 18 insertions, 5 deletions
diff --git a/kafka/consumer/multiprocess.py b/kafka/consumer/multiprocess.py
index 2ca76b6..8cec92d 100644
--- a/kafka/consumer/multiprocess.py
+++ b/kafka/consumer/multiprocess.py
@@ -100,6 +100,7 @@ class MultiProcessConsumer(Consumer):
topic: the topic to consume
Keyword Arguments:
+ partitions: An optional list of partitions to consume the data from
auto_commit: default True. Whether or not to auto commit the offsets
auto_commit_every_n: default 100. How many messages to consume
before a commit
@@ -116,16 +117,19 @@ class MultiProcessConsumer(Consumer):
commit method on this class. A manual call to commit will also reset
these triggers
"""
- def __init__(self, client, group, topic, auto_commit=True,
+ def __init__(self, client, group, topic,
+ partitions=None,
+ auto_commit=True,
auto_commit_every_n=AUTO_COMMIT_MSG_COUNT,
auto_commit_every_t=AUTO_COMMIT_INTERVAL,
- num_procs=1, partitions_per_proc=0,
+ num_procs=1,
+ partitions_per_proc=0,
**simple_consumer_options):
# Initiate the base consumer class
super(MultiProcessConsumer, self).__init__(
client, group, topic,
- partitions=None,
+ partitions=partitions,
auto_commit=auto_commit,
auto_commit_every_n=auto_commit_every_n,
auto_commit_every_t=auto_commit_every_t)
diff --git a/test/test_consumer.py b/test/test_consumer.py
index 7b8f370..a3d09a8 100644
--- a/test/test_consumer.py
+++ b/test/test_consumer.py
@@ -1,8 +1,8 @@
-from mock import MagicMock
+from mock import MagicMock, patch
from . import unittest
-from kafka import SimpleConsumer, KafkaConsumer
+from kafka import SimpleConsumer, KafkaConsumer, MultiProcessConsumer
from kafka.common import KafkaConfigurationError
class TestKafkaConsumer(unittest.TestCase):
@@ -13,3 +13,12 @@ class TestKafkaConsumer(unittest.TestCase):
def test_broker_list_required(self):
with self.assertRaises(KafkaConfigurationError):
KafkaConsumer()
+
+class TestMultiProcessConsumer(unittest.TestCase):
+ def test_partition_list(self):
+ client = MagicMock()
+ partitions = (0,)
+ with patch.object(MultiProcessConsumer, 'fetch_last_known_offsets') as fetch_last_known_offsets:
+ consumer = MultiProcessConsumer(client, 'testing-group', 'testing-topic', partitions=partitions)
+ self.assertEqual(fetch_last_known_offsets.call_args[0], (partitions,) )
+ self.assertEqual(client.get_partition_ids_for_topic.call_count, 0) # pylint: disable=no-member