summaryrefslogtreecommitdiff
path: root/test/testutil.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/testutil.py')
-rw-r--r--test/testutil.py89
1 files changed, 54 insertions, 35 deletions
diff --git a/test/testutil.py b/test/testutil.py
index 0ec1cff..850e925 100644
--- a/test/testutil.py
+++ b/test/testutil.py
@@ -1,36 +1,20 @@
-import functools
-import logging
import operator
import os
-import random
import socket
-import string
import time
import uuid
-from six.moves import xrange
+import decorator
+import pytest
from . import unittest
-from kafka import SimpleClient
+from kafka import SimpleClient, create_message
from kafka.errors import LeaderNotAvailableError, KafkaTimeoutError
-from kafka.structs import OffsetRequestPayload
-
-__all__ = [
- 'random_string',
- 'get_open_port',
- 'kafka_versions',
- 'KafkaIntegrationTestCase',
- 'Timer',
-]
-
-def random_string(l):
- return "".join(random.choice(string.ascii_letters) for i in xrange(l))
+from kafka.structs import OffsetRequestPayload, ProduceRequestPayload
+from test.fixtures import random_string, version_str_to_list, version as kafka_version #pylint: disable=wrong-import-order
def kafka_versions(*versions):
- def version_str_to_list(s):
- return list(map(int, s.split('.'))) # e.g., [0, 8, 1, 1]
-
def construct_lambda(s):
if s[0].isdigit():
op_str = '='
@@ -54,25 +38,25 @@ def kafka_versions(*versions):
}
op = op_map[op_str]
version = version_str_to_list(v_str)
- return lambda a: op(version_str_to_list(a), version)
+ return lambda a: op(a, version)
validators = map(construct_lambda, versions)
- def kafka_versions(func):
- @functools.wraps(func)
- def wrapper(self):
- kafka_version = os.environ.get('KAFKA_VERSION')
+ def real_kafka_versions(func):
+ def wrapper(func, *args, **kwargs):
+ version = kafka_version()
- if not kafka_version:
- self.skipTest("no kafka version set in KAFKA_VERSION env var")
+ if not version:
+ pytest.skip("no kafka version set in KAFKA_VERSION env var")
for f in validators:
- if not f(kafka_version):
- self.skipTest("unsupported kafka version")
+ if not f(version):
+ pytest.skip("unsupported kafka version")
- return func(self)
- return wrapper
- return kafka_versions
+ return func(*args, **kwargs)
+ return decorator.decorator(wrapper, func)
+
+ return real_kafka_versions
def get_open_port():
sock = socket.socket()
@@ -81,6 +65,40 @@ def get_open_port():
sock.close()
return port
+_MESSAGES = {}
+def msg(message):
+ """Format, encode and deduplicate a message
+ """
+ global _MESSAGES #pylint: disable=global-statement
+ if message not in _MESSAGES:
+ _MESSAGES[message] = '%s-%s' % (message, str(uuid.uuid4()))
+
+ return _MESSAGES[message].encode('utf-8')
+
+def send_messages(client, topic, partition, messages):
+ """Send messages to a topic's partition
+ """
+ messages = [create_message(msg(str(m))) for m in messages]
+ produce = ProduceRequestPayload(topic, partition, messages=messages)
+ resp, = client.send_produce_request([produce])
+ assert resp.error == 0
+
+ return [x.value for x in messages]
+
+def current_offset(client, topic, partition, kafka_broker=None):
+ """Get the current offset of a topic's partition
+ """
+ try:
+ offsets, = client.send_offset_request([OffsetRequestPayload(topic,
+ partition, -1, 1)])
+ except Exception:
+ # XXX: We've seen some UnknownErrors here and can't debug w/o server logs
+ if kafka_broker:
+ kafka_broker.dump_logs()
+ raise
+ else:
+ return offsets.offsets[0]
+
class KafkaIntegrationTestCase(unittest.TestCase):
create_client = True
topic = None
@@ -122,7 +140,8 @@ class KafkaIntegrationTestCase(unittest.TestCase):
def current_offset(self, topic, partition):
try:
- offsets, = self.client.send_offset_request([OffsetRequestPayload(topic, partition, -1, 1)])
+ offsets, = self.client.send_offset_request([OffsetRequestPayload(topic,
+ partition, -1, 1)])
except Exception:
# XXX: We've seen some UnknownErrors here and can't debug w/o server logs
self.zk.child.dump_logs()
@@ -132,7 +151,7 @@ class KafkaIntegrationTestCase(unittest.TestCase):
return offsets.offsets[0]
def msgs(self, iterable):
- return [ self.msg(x) for x in iterable ]
+ return [self.msg(x) for x in iterable]
def msg(self, s):
if s not in self._messages: