summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kafka/partitioner/base.py7
-rw-r--r--kafka/partitioner/hashed.py4
-rw-r--r--kafka/partitioner/roundrobin.py4
3 files changed, 8 insertions, 7 deletions
diff --git a/kafka/partitioner/base.py b/kafka/partitioner/base.py
index 0b1bb59..857f634 100644
--- a/kafka/partitioner/base.py
+++ b/kafka/partitioner/base.py
@@ -12,14 +12,13 @@ class Partitioner(object):
"""
self.partitions = partitions
- def partition(self, key, partitions):
+ def partition(self, key, partitions=None):
"""
Takes a string key and num_partitions as argument and returns
a partition to be used for the message
Arguments:
- partitions: The list of partitions is passed in every call. This
- may look like an overhead, but it will be useful
- (in future) when we handle cases like rebalancing
+ key: the key to use for partitioning
+ partitions: (optional) a list of partitions.
"""
raise NotImplementedError('partition function has to be implemented')
diff --git a/kafka/partitioner/hashed.py b/kafka/partitioner/hashed.py
index 587a3de..fb5e598 100644
--- a/kafka/partitioner/hashed.py
+++ b/kafka/partitioner/hashed.py
@@ -5,7 +5,9 @@ class HashedPartitioner(Partitioner):
Implements a partitioner which selects the target partition based on
the hash of the key
"""
- def partition(self, key, partitions):
+ def partition(self, key, partitions=None):
+ if not partitions:
+ partitions = self.partitions
size = len(partitions)
idx = hash(key) % size
diff --git a/kafka/partitioner/roundrobin.py b/kafka/partitioner/roundrobin.py
index 54d00da..6439e53 100644
--- a/kafka/partitioner/roundrobin.py
+++ b/kafka/partitioner/roundrobin.py
@@ -15,9 +15,9 @@ class RoundRobinPartitioner(Partitioner):
self.partitions = partitions
self.iterpart = cycle(partitions)
- def partition(self, key, partitions):
+ def partition(self, key, partitions=None):
# Refresh the partition list if necessary
- if self.partitions != partitions:
+ if partitions and self.partitions != partitions:
self._set_partitions(partitions)
return next(self.iterpart)