summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kafka/client_async.py6
-rw-r--r--kafka/cluster.py10
-rw-r--r--kafka/consumer/group.py16
-rw-r--r--test/test_client_async.py30
4 files changed, 47 insertions, 15 deletions
diff --git a/kafka/client_async.py b/kafka/client_async.py
index 30d4d6f..1838aed 100644
--- a/kafka/client_async.py
+++ b/kafka/client_async.py
@@ -302,7 +302,7 @@ class KafkaClient(object):
self._finish_connect(node_id)
# Send a metadata request if needed
- metadata_timeout = self._maybe_refresh_metadata()
+ metadata_timeout_ms = self._maybe_refresh_metadata()
# Send scheduled tasks
for task, task_future in self._delayed_tasks.pop_ready():
@@ -314,7 +314,9 @@ class KafkaClient(object):
else:
task_future.success(result)
- timeout = min(timeout_ms, metadata_timeout,
+ task_timeout_ms = max(0, 1000 * (
+ self._delayed_tasks.next_at() - time.time()))
+ timeout = min(timeout_ms, metadata_timeout_ms, task_timeout_ms,
self.config['request_timeout_ms'])
timeout /= 1000.0
diff --git a/kafka/cluster.py b/kafka/cluster.py
index 84ad1d3..1cdc8dd 100644
--- a/kafka/cluster.py
+++ b/kafka/cluster.py
@@ -59,9 +59,13 @@ class ClusterMetadata(object):
if self._need_update:
ttl = 0
else:
- ttl = self._last_successful_refresh_ms + self.config['metadata_max_age_ms'] - now
- retry = self._last_refresh_ms + self.config['retry_backoff_ms'] - now
- return max(ttl, retry, 0)
+ metadata_age = now - self._last_successful_refresh_ms
+ ttl = self.config['metadata_max_age_ms'] - metadata_age
+
+ retry_age = now - self._last_refresh_ms
+ next_retry = self.config['retry_backoff_ms'] - retry_age
+
+ return max(ttl, next_retry, 0)
def request_update(self):
"""
diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py
index 4930ba1..75fe3ee 100644
--- a/kafka/consumer/group.py
+++ b/kafka/consumer/group.py
@@ -623,19 +623,19 @@ class KafkaConsumer(six.Iterator):
# fetch positions if we have partitions we're subscribed to that we
# don't know the offset for
if not self._subscription.has_all_fetch_positions():
- self._update_fetch_positions(self._subscription.missing_fetch_positions())
+ partitions = self._subscription.missing_fetch_positions()
+ self._update_fetch_positions(partitions)
# init any new fetches (won't resend pending fetches)
self._fetcher.init_fetches()
- self._client.poll(self.config['request_timeout_ms'] / 1000.0)
- timeout = self._consumer_timeout
- if self.config['api_version'] >= (0, 9):
- heartbeat_timeout = time.time() + (
- self.config['heartbeat_interval_ms'] / 1000.0)
- timeout = min(heartbeat_timeout, timeout)
+ self._client.poll()
+
+ timeout_at = min(self._consumer_timeout,
+ self._client._delayed_tasks.next_at(),
+ self._client.cluster.ttl() / 1000.0 + time.time())
for msg in self._fetcher:
yield msg
- if time.time() > timeout:
+ if time.time() > timeout_at:
break
def __iter__(self): # pylint: disable=non-iterator-returned
diff --git a/test/test_client_async.py b/test/test_client_async.py
index 447ea49..9191c5e 100644
--- a/test/test_client_async.py
+++ b/test/test_client_async.py
@@ -1,3 +1,4 @@
+import time
import pytest
@@ -242,8 +243,33 @@ def test_send(conn):
assert conn.send.called_with(request, expect_response=True)
-def test_poll():
- pass
+def test_poll(mocker):
+ mocker.patch.object(KafkaClient, '_bootstrap')
+ metadata = mocker.patch.object(KafkaClient, '_maybe_refresh_metadata')
+ _poll = mocker.patch.object(KafkaClient, '_poll')
+ cli = KafkaClient()
+ tasks = mocker.patch.object(cli._delayed_tasks, 'next_at')
+
+ # metadata timeout wins
+ metadata.return_value = 1000
+ tasks.return_value = time.time() + 2 # 2 seconds from now
+ cli.poll()
+ _poll.assert_called_with(1.0)
+
+ # user timeout wins
+ cli.poll(250)
+ _poll.assert_called_with(0.25)
+
+ # tasks timeout wins
+ tasks.return_value = time.time() # next task is now
+ cli.poll(250)
+ _poll.assert_called_with(0)
+
+ # default is request_timeout_ms
+ metadata.return_value = 1000000
+ tasks.return_value = time.time() + 10000
+ cli.poll()
+ _poll.assert_called_with(cli.config['request_timeout_ms'] / 1000.0)
def test__poll():