summaryrefslogtreecommitdiff
path: root/test/testutil.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/testutil.py')
-rw-r--r--test/testutil.py42
1 files changed, 38 insertions, 4 deletions
diff --git a/test/testutil.py b/test/testutil.py
index 3a1d2ba..fc3ebfa 100644
--- a/test/testutil.py
+++ b/test/testutil.py
@@ -1,5 +1,6 @@
import functools
import logging
+import operator
import os
import random
import socket
@@ -26,15 +27,48 @@ def random_string(l):
return "".join(random.choice(string.ascii_letters) for i in xrange(l))
def kafka_versions(*versions):
+
+ def version_str_to_list(s):
+ return list(map(int, s.split('.'))) # e.g., [0, 8, 1, 1]
+
+ def construct_lambda(s):
+ if s[0].isdigit():
+ op_str = '='
+ v_str = s
+ elif s[1].isdigit():
+ op_str = s[0] # ! < > =
+ v_str = s[1:]
+ elif s[2].isdigit():
+ op_str = s[0:2] # >= <=
+ v_str = s[2:]
+ else:
+ raise ValueError('Unrecognized kafka version / operator: %s' % s)
+
+ op_map = {
+ '=': operator.eq,
+ '!': operator.ne,
+ '>': operator.gt,
+ '<': operator.lt,
+ '>=': operator.ge,
+ '<=': operator.le
+ }
+ op = op_map[op_str]
+ version = version_str_to_list(v_str)
+ return lambda a: op(version_str_to_list(a), version)
+
+ validators = map(construct_lambda, versions)
+
def kafka_versions(func):
@functools.wraps(func)
def wrapper(self):
kafka_version = os.environ.get('KAFKA_VERSION')
if not kafka_version:
- self.skipTest("no kafka version specified")
- elif 'all' not in versions and kafka_version not in versions:
- self.skipTest("unsupported kafka version")
+ self.skipTest("no kafka version set in KAFKA_VERSION env var")
+
+ for f in validators:
+ if not f(kafka_version):
+ self.skipTest("unsupported kafka version")
return func(self)
return wrapper
@@ -57,7 +91,7 @@ class KafkaIntegrationTestCase(unittest.TestCase):
def setUp(self):
super(KafkaIntegrationTestCase, self).setUp()
if not os.environ.get('KAFKA_VERSION'):
- return
+ self.skipTest('Integration test requires KAFKA_VERSION')
if not self.topic:
topic = "%s-%s" % (self.id()[self.id().rindex(".") + 1:], random_string(10))