summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/fixtures.py309
-rw-r--r--test/resources/kafka.properties59
-rw-r--r--test/resources/log4j.properties24
-rw-r--r--test/resources/zookeeper.properties19
-rw-r--r--test/service.py106
-rw-r--r--test/test_client.py249
-rw-r--r--test/test_client_integration.py66
-rw-r--r--test/test_codec.py70
-rw-r--r--test/test_conn.py69
-rw-r--r--test/test_consumer.py22
-rw-r--r--test/test_consumer_integration.py257
-rw-r--r--test/test_failover_integration.py123
-rw-r--r--test/test_integration.py971
-rw-r--r--test/test_package.py29
-rw-r--r--test/test_producer_integration.py404
-rw-r--r--test/test_protocol.py694
-rw-r--r--test/test_unit.py674
-rw-r--r--test/test_util.py102
-rw-r--r--test/testutil.py108
19 files changed, 2389 insertions, 1966 deletions
diff --git a/test/fixtures.py b/test/fixtures.py
index 9e283d3..df8cd42 100644
--- a/test/fixtures.py
+++ b/test/fixtures.py
@@ -1,204 +1,74 @@
+import logging
import glob
import os
-import re
-import select
import shutil
-import socket
import subprocess
-import sys
import tempfile
-import threading
-import time
import uuid
from urlparse import urlparse
-
-
-PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
-KAFKA_ROOT = os.path.join(PROJECT_ROOT, "kafka-src")
-IVY_ROOT = os.path.expanduser("~/.ivy2/cache")
-SCALA_VERSION = '2.8.0'
-
-if "PROJECT_ROOT" in os.environ:
- PROJECT_ROOT = os.environ["PROJECT_ROOT"]
-if "KAFKA_ROOT" in os.environ:
- KAFKA_ROOT = os.environ["KAFKA_ROOT"]
-if "IVY_ROOT" in os.environ:
- IVY_ROOT = os.environ["IVY_ROOT"]
-if "SCALA_VERSION" in os.environ:
- SCALA_VERSION = os.environ["SCALA_VERSION"]
-
-
-def test_resource(file):
- return os.path.join(PROJECT_ROOT, "test", "resources", file)
-
-
-def test_classpath():
- # ./kafka-src/bin/kafka-run-class.sh is the authority.
- jars = ["."]
- # assume all dependencies have been packaged into one jar with sbt-assembly's task "assembly-package-dependency"
- jars.extend(glob.glob(KAFKA_ROOT + "/core/target/scala-%s/*.jar" % SCALA_VERSION))
-
- jars = filter(os.path.exists, map(os.path.abspath, jars))
- return ":".join(jars)
-
-
-def kafka_run_class_args(*args):
- # ./kafka-src/bin/kafka-run-class.sh is the authority.
- result = ["java", "-Xmx512M", "-server"]
- result.append("-Dlog4j.configuration=file:%s" % test_resource("log4j.properties"))
- result.append("-Dcom.sun.management.jmxremote")
- result.append("-Dcom.sun.management.jmxremote.authenticate=false")
- result.append("-Dcom.sun.management.jmxremote.ssl=false")
- result.append("-cp")
- result.append(test_classpath())
- result.extend(args)
- return result
-
-
-def get_open_port():
- sock = socket.socket()
- sock.bind(("", 0))
- port = sock.getsockname()[1]
- sock.close()
- return port
-
-
-def render_template(source_file, target_file, binding):
- with open(source_file, "r") as handle:
- template = handle.read()
- with open(target_file, "w") as handle:
- handle.write(template.format(**binding))
-
-
-class ExternalService(object):
- def __init__(self, host, port):
- print("Using already running service at %s:%d" % (host, port))
- self.host = host
- self.port = port
-
- def open(self):
- pass
-
- def close(self):
- pass
-
-
-class SpawnedService(threading.Thread):
- def __init__(self, args=[]):
- threading.Thread.__init__(self)
-
- self.args = args
- self.captured_stdout = ""
- self.captured_stderr = ""
- self.stdout_file = None
- self.stderr_file = None
- self.capture_stdout = True
- self.capture_stderr = True
- self.show_stdout = True
- self.show_stderr = True
-
- self.should_die = threading.Event()
-
- def configure_stdout(self, file=None, capture=True, show=False):
- self.stdout_file = file
- self.capture_stdout = capture
- self.show_stdout = show
-
- def configure_stderr(self, file=None, capture=False, show=True):
- self.stderr_file = file
- self.capture_stderr = capture
- self.show_stderr = show
-
- def run(self):
- stdout_handle = None
- stderr_handle = None
- try:
- if self.stdout_file:
- stdout_handle = open(self.stdout_file, "w")
- if self.stderr_file:
- stderr_handle = open(self.stderr_file, "w")
- self.run_with_handles(stdout_handle, stderr_handle)
- finally:
- if stdout_handle:
- stdout_handle.close()
- if stderr_handle:
- stderr_handle.close()
-
- def run_with_handles(self, stdout_handle, stderr_handle):
- child = subprocess.Popen(
- self.args,
- bufsize=1,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- alive = True
-
- while True:
- (rds, wds, xds) = select.select([child.stdout, child.stderr], [], [], 1)
-
- if child.stdout in rds:
- line = child.stdout.readline()
- if stdout_handle:
- stdout_handle.write(line)
- stdout_handle.flush()
- if self.capture_stdout:
- self.captured_stdout += line
- if self.show_stdout:
- sys.stdout.write(line)
- sys.stdout.flush()
-
- if child.stderr in rds:
- line = child.stderr.readline()
- if stderr_handle:
- stderr_handle.write(line)
- stderr_handle.flush()
- if self.capture_stderr:
- self.captured_stderr += line
- if self.show_stderr:
- sys.stderr.write(line)
- sys.stderr.flush()
-
- if self.should_die.is_set():
- child.terminate()
- alive = False
-
- if child.poll() is not None:
- if not alive:
- break
- else:
- raise RuntimeError("Subprocess has died. Aborting.")
-
- def wait_for(self, pattern, timeout=10):
- t1 = time.time()
- while True:
- t2 = time.time()
- if t2 - t1 >= timeout:
- raise RuntimeError("Waiting for %r timed out" % pattern)
- if re.search(pattern, self.captured_stdout) is not None:
- return
- if re.search(pattern, self.captured_stderr) is not None:
- return
- time.sleep(0.1)
-
- def start(self):
- threading.Thread.start(self)
-
- def stop(self):
- self.should_die.set()
- self.join()
-
-
-class ZookeeperFixture(object):
- @staticmethod
- def instance():
+from service import ExternalService, SpawnedService
+from testutil import get_open_port
+
+class Fixture(object):
+ kafka_version = os.environ.get('KAFKA_VERSION', '0.8.0')
+ scala_version = os.environ.get("SCALA_VERSION", '2.8.0')
+ project_root = os.environ.get('PROJECT_ROOT', os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+ kafka_root = os.environ.get("KAFKA_ROOT", os.path.join(project_root, 'servers', kafka_version, "kafka-src"))
+ ivy_root = os.environ.get('IVY_ROOT', os.path.expanduser("~/.ivy2/cache"))
+
+ @classmethod
+ def test_resource(cls, filename):
+ return os.path.join(cls.project_root, "servers", cls.kafka_version, "resources", filename)
+
+ @classmethod
+ def test_classpath(cls):
+ # ./kafka-src/bin/kafka-run-class.sh is the authority.
+ jars = ["."]
+
+ # 0.8.0 build path, should contain the core jar and a deps jar
+ jars.extend(glob.glob(cls.kafka_root + "/core/target/scala-%s/*.jar" % cls.scala_version))
+
+ # 0.8.1 build path, should contain the core jar and several dep jars
+ jars.extend(glob.glob(cls.kafka_root + "/core/build/libs/*.jar"))
+ jars.extend(glob.glob(cls.kafka_root + "/core/build/dependant-libs-%s/*.jar" % cls.scala_version))
+
+ jars = filter(os.path.exists, map(os.path.abspath, jars))
+ return ":".join(jars)
+
+ @classmethod
+ def kafka_run_class_args(cls, *args):
+ # ./kafka-src/bin/kafka-run-class.sh is the authority.
+ result = ["java", "-Xmx512M", "-server"]
+ result.append("-Dlog4j.configuration=file:%s" % cls.test_resource("log4j.properties"))
+ result.append("-Dcom.sun.management.jmxremote")
+ result.append("-Dcom.sun.management.jmxremote.authenticate=false")
+ result.append("-Dcom.sun.management.jmxremote.ssl=false")
+ result.append("-cp")
+ result.append(cls.test_classpath())
+ result.extend(args)
+ return result
+
+ @classmethod
+ def render_template(cls, source_file, target_file, binding):
+ with open(source_file, "r") as handle:
+ template = handle.read()
+ with open(target_file, "w") as handle:
+ handle.write(template.format(**binding))
+
+
+class ZookeeperFixture(Fixture):
+ @classmethod
+ def instance(cls):
if "ZOOKEEPER_URI" in os.environ:
parse = urlparse(os.environ["ZOOKEEPER_URI"])
(host, port) = (parse.hostname, parse.port)
fixture = ExternalService(host, port)
else:
(host, port) = ("127.0.0.1", get_open_port())
- fixture = ZookeeperFixture(host, port)
- fixture.open()
+ fixture = cls(host, port)
+
+ fixture.open()
return fixture
def __init__(self, host, port):
@@ -209,27 +79,25 @@ class ZookeeperFixture(object):
self.child = None
def out(self, message):
- print("*** Zookeeper [%s:%d]: %s" % (self.host, self.port, message))
+ logging.info("*** Zookeeper [%s:%d]: %s", self.host, self.port, message)
def open(self):
self.tmp_dir = tempfile.mkdtemp()
self.out("Running local instance...")
- print(" host = %s" % self.host)
- print(" port = %s" % self.port)
- print(" tmp_dir = %s" % self.tmp_dir)
+ logging.info(" host = %s", self.host)
+ logging.info(" port = %s", self.port)
+ logging.info(" tmp_dir = %s", self.tmp_dir)
# Generate configs
- template = test_resource("zookeeper.properties")
+ template = self.test_resource("zookeeper.properties")
properties = os.path.join(self.tmp_dir, "zookeeper.properties")
- render_template(template, properties, vars(self))
+ self.render_template(template, properties, vars(self))
# Configure Zookeeper child process
- self.child = SpawnedService(kafka_run_class_args(
+ self.child = SpawnedService(self.kafka_run_class_args(
"org.apache.zookeeper.server.quorum.QuorumPeerMain",
properties
))
- self.child.configure_stdout(os.path.join(self.tmp_dir, "stdout.txt"))
- self.child.configure_stderr(os.path.join(self.tmp_dir, "stderr.txt"))
# Party!
self.out("Starting...")
@@ -245,9 +113,9 @@ class ZookeeperFixture(object):
shutil.rmtree(self.tmp_dir)
-class KafkaFixture(object):
- @staticmethod
- def instance(broker_id, zk_host, zk_port, zk_chroot=None, replicas=1, partitions=2):
+class KafkaFixture(Fixture):
+ @classmethod
+ def instance(cls, broker_id, zk_host, zk_port, zk_chroot=None, replicas=1, partitions=2):
if zk_chroot is None:
zk_chroot = "kafka-python_" + str(uuid.uuid4()).replace("-", "_")
if "KAFKA_URI" in os.environ:
@@ -278,7 +146,7 @@ class KafkaFixture(object):
self.running = False
def out(self, message):
- print("*** Kafka [%s:%d]: %s" % (self.host, self.port, message))
+ logging.info("*** Kafka [%s:%d]: %s", self.host, self.port, message)
def open(self):
if self.running:
@@ -287,41 +155,44 @@ class KafkaFixture(object):
self.tmp_dir = tempfile.mkdtemp()
self.out("Running local instance...")
- print(" host = %s" % self.host)
- print(" port = %s" % self.port)
- print(" broker_id = %s" % self.broker_id)
- print(" zk_host = %s" % self.zk_host)
- print(" zk_port = %s" % self.zk_port)
- print(" zk_chroot = %s" % self.zk_chroot)
- print(" replicas = %s" % self.replicas)
- print(" partitions = %s" % self.partitions)
- print(" tmp_dir = %s" % self.tmp_dir)
+ logging.info(" host = %s", self.host)
+ logging.info(" port = %s", self.port)
+ logging.info(" broker_id = %s", self.broker_id)
+ logging.info(" zk_host = %s", self.zk_host)
+ logging.info(" zk_port = %s", self.zk_port)
+ logging.info(" zk_chroot = %s", self.zk_chroot)
+ logging.info(" replicas = %s", self.replicas)
+ logging.info(" partitions = %s", self.partitions)
+ logging.info(" tmp_dir = %s", self.tmp_dir)
# Create directories
os.mkdir(os.path.join(self.tmp_dir, "logs"))
os.mkdir(os.path.join(self.tmp_dir, "data"))
# Generate configs
- template = test_resource("kafka.properties")
+ template = self.test_resource("kafka.properties")
properties = os.path.join(self.tmp_dir, "kafka.properties")
- render_template(template, properties, vars(self))
+ self.render_template(template, properties, vars(self))
# Configure Kafka child process
- self.child = SpawnedService(kafka_run_class_args(
+ self.child = SpawnedService(self.kafka_run_class_args(
"kafka.Kafka", properties
))
- self.child.configure_stdout(os.path.join(self.tmp_dir, "stdout.txt"))
- self.child.configure_stderr(os.path.join(self.tmp_dir, "stderr.txt"))
# Party!
self.out("Creating Zookeeper chroot node...")
- proc = subprocess.Popen(kafka_run_class_args(
- "org.apache.zookeeper.ZooKeeperMain",
- "-server", "%s:%d" % (self.zk_host, self.zk_port),
- "create", "/%s" % self.zk_chroot, "kafka-python"
- ))
+ proc = subprocess.Popen(self.kafka_run_class_args(
+ "org.apache.zookeeper.ZooKeeperMain",
+ "-server", "%s:%d" % (self.zk_host, self.zk_port),
+ "create", "/%s" % self.zk_chroot, "kafka-python"
+ ),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+
if proc.wait() != 0:
self.out("Failed to create Zookeeper chroot node")
+ self.out(proc.stdout)
+ self.out(proc.stderr)
raise RuntimeError("Failed to create Zookeeper chroot node")
self.out("Done!")
diff --git a/test/resources/kafka.properties b/test/resources/kafka.properties
deleted file mode 100644
index f8732fb..0000000
--- a/test/resources/kafka.properties
+++ /dev/null
@@ -1,59 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-############################# Server Basics #############################
-
-broker.id={broker_id}
-
-############################# Socket Server Settings #############################
-
-port={port}
-host.name={host}
-
-num.network.threads=2
-num.io.threads=2
-
-socket.send.buffer.bytes=1048576
-socket.receive.buffer.bytes=1048576
-socket.request.max.bytes=104857600
-
-############################# Log Basics #############################
-
-log.dir={tmp_dir}/data
-num.partitions={partitions}
-default.replication.factor={replicas}
-
-############################# Log Flush Policy #############################
-
-log.flush.interval.messages=10000
-log.flush.interval.ms=1000
-
-############################# Log Retention Policy #############################
-
-log.retention.hours=168
-log.segment.bytes=536870912
-log.cleanup.interval.mins=1
-
-############################# Zookeeper #############################
-
-zookeeper.connect={zk_host}:{zk_port}/{zk_chroot}
-zookeeper.connection.timeout.ms=1000000
-
-kafka.metrics.polling.interval.secs=5
-kafka.metrics.reporters=kafka.metrics.KafkaCSVMetricsReporter
-kafka.csv.metrics.dir={tmp_dir}
-kafka.csv.metrics.reporter.enabled=false
-
-log.cleanup.policy=delete
diff --git a/test/resources/log4j.properties b/test/resources/log4j.properties
deleted file mode 100644
index f863b3b..0000000
--- a/test/resources/log4j.properties
+++ /dev/null
@@ -1,24 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-log4j.rootLogger=INFO, stdout
-
-log4j.appender.stdout=org.apache.log4j.ConsoleAppender
-log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
-log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c)%n
-
-log4j.logger.kafka=DEBUG, stdout
-log4j.logger.org.I0Itec.zkclient.ZkClient=INFO, stdout
-log4j.logger.org.apache.zookeeper=INFO, stdout
diff --git a/test/resources/zookeeper.properties b/test/resources/zookeeper.properties
deleted file mode 100644
index 68e1ef9..0000000
--- a/test/resources/zookeeper.properties
+++ /dev/null
@@ -1,19 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-dataDir={tmp_dir}
-clientPortAddress={host}
-clientPort={port}
-maxClientCnxns=0
diff --git a/test/service.py b/test/service.py
new file mode 100644
index 0000000..8872c82
--- /dev/null
+++ b/test/service.py
@@ -0,0 +1,106 @@
+import logging
+import re
+import select
+import subprocess
+import sys
+import threading
+import time
+
+__all__ = [
+ 'ExternalService',
+ 'SpawnedService',
+
+]
+
+class ExternalService(object):
+ def __init__(self, host, port):
+ print("Using already running service at %s:%d" % (host, port))
+ self.host = host
+ self.port = port
+
+ def open(self):
+ pass
+
+ def close(self):
+ pass
+
+
+class SpawnedService(threading.Thread):
+ def __init__(self, args=[]):
+ threading.Thread.__init__(self)
+
+ self.args = args
+ self.captured_stdout = []
+ self.captured_stderr = []
+
+ self.should_die = threading.Event()
+
+ def run(self):
+ self.run_with_handles()
+
+ def run_with_handles(self):
+ self.child = subprocess.Popen(
+ self.args,
+ bufsize=1,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ alive = True
+
+ while True:
+ (rds, wds, xds) = select.select([self.child.stdout, self.child.stderr], [], [], 1)
+
+ if self.child.stdout in rds:
+ line = self.child.stdout.readline()
+ self.captured_stdout.append(line)
+
+ if self.child.stderr in rds:
+ line = self.child.stderr.readline()
+ self.captured_stderr.append(line)
+
+ if self.should_die.is_set():
+ self.child.terminate()
+ alive = False
+
+ poll_results = self.child.poll()
+ if poll_results is not None:
+ if not alive:
+ break
+ else:
+ self.dump_logs()
+ raise RuntimeError("Subprocess has died. Aborting. (args=%s)" % ' '.join(str(x) for x in self.args))
+
+ def dump_logs(self):
+ logging.critical('stderr')
+ for line in self.captured_stderr:
+ logging.critical(line.rstrip())
+
+ logging.critical('stdout')
+ for line in self.captured_stdout:
+ logging.critical(line.rstrip())
+
+ def wait_for(self, pattern, timeout=10):
+ t1 = time.time()
+ while True:
+ t2 = time.time()
+ if t2 - t1 >= timeout:
+ try:
+ self.child.kill()
+ except:
+ logging.exception("Received exception when killing child process")
+ self.dump_logs()
+
+ raise RuntimeError("Waiting for %r timed out" % pattern)
+
+ if re.search(pattern, '\n'.join(self.captured_stdout), re.IGNORECASE) is not None:
+ return
+ if re.search(pattern, '\n'.join(self.captured_stderr), re.IGNORECASE) is not None:
+ return
+ time.sleep(0.1)
+
+ def start(self):
+ threading.Thread.start(self)
+
+ def stop(self):
+ self.should_die.set()
+ self.join()
+
diff --git a/test/test_client.py b/test/test_client.py
new file mode 100644
index 0000000..fe9beff
--- /dev/null
+++ b/test/test_client.py
@@ -0,0 +1,249 @@
+import os
+import random
+import struct
+import unittest2
+
+from mock import MagicMock, patch
+
+from kafka import KafkaClient
+from kafka.common import (
+ ProduceRequest, BrokerMetadata, PartitionMetadata,
+ TopicAndPartition, KafkaUnavailableError,
+ LeaderUnavailableError, PartitionUnavailableError
+)
+from kafka.protocol import (
+ create_message, KafkaProtocol
+)
+
+class TestKafkaClient(unittest2.TestCase):
+ def test_init_with_list(self):
+ with patch.object(KafkaClient, 'load_metadata_for_topics'):
+ client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092', 'kafka03:9092'])
+
+ self.assertItemsEqual(
+ [('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
+ client.hosts)
+
+ def test_init_with_csv(self):
+ with patch.object(KafkaClient, 'load_metadata_for_topics'):
+ client = KafkaClient(hosts='kafka01:9092,kafka02:9092,kafka03:9092')
+
+ self.assertItemsEqual(
+ [('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
+ client.hosts)
+
+ def test_init_with_unicode_csv(self):
+ with patch.object(KafkaClient, 'load_metadata_for_topics'):
+ client = KafkaClient(hosts=u'kafka01:9092,kafka02:9092,kafka03:9092')
+
+ self.assertItemsEqual(
+ [('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
+ client.hosts)
+
+ def test_send_broker_unaware_request_fail(self):
+ 'Tests that call fails when all hosts are unavailable'
+
+ mocked_conns = {
+ ('kafka01', 9092): MagicMock(),
+ ('kafka02', 9092): MagicMock()
+ }
+
+ # inject KafkaConnection side effects
+ mocked_conns[('kafka01', 9092)].send.side_effect = RuntimeError("kafka01 went away (unittest)")
+ mocked_conns[('kafka02', 9092)].send.side_effect = RuntimeError("Kafka02 went away (unittest)")
+
+ def mock_get_conn(host, port):
+ return mocked_conns[(host, port)]
+
+ # patch to avoid making requests before we want it
+ with patch.object(KafkaClient, 'load_metadata_for_topics'):
+ with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn):
+ client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092'])
+
+ with self.assertRaises(KafkaUnavailableError):
+ client._send_broker_unaware_request(1, 'fake request')
+
+ for key, conn in mocked_conns.iteritems():
+ conn.send.assert_called_with(1, 'fake request')
+
+ def test_send_broker_unaware_request(self):
+ 'Tests that call works when at least one of the host is available'
+
+ mocked_conns = {
+ ('kafka01', 9092): MagicMock(),
+ ('kafka02', 9092): MagicMock(),
+ ('kafka03', 9092): MagicMock()
+ }
+ # inject KafkaConnection side effects
+ mocked_conns[('kafka01', 9092)].send.side_effect = RuntimeError("kafka01 went away (unittest)")
+ mocked_conns[('kafka02', 9092)].recv.return_value = 'valid response'
+ mocked_conns[('kafka03', 9092)].send.side_effect = RuntimeError("kafka03 went away (unittest)")
+
+ def mock_get_conn(host, port):
+ return mocked_conns[(host, port)]
+
+ # patch to avoid making requests before we want it
+ with patch.object(KafkaClient, 'load_metadata_for_topics'):
+ with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn):
+ client = KafkaClient(hosts='kafka01:9092,kafka02:9092')
+
+ resp = client._send_broker_unaware_request(1, 'fake request')
+
+ self.assertEqual('valid response', resp)
+ mocked_conns[('kafka02', 9092)].recv.assert_called_with(1)
+
+ @patch('kafka.client.KafkaConnection')
+ @patch('kafka.client.KafkaProtocol')
+ def test_load_metadata(self, protocol, conn):
+ "Load metadata for all topics"
+
+ conn.recv.return_value = 'response' # anything but None
+
+ brokers = {}
+ brokers[0] = BrokerMetadata(1, 'broker_1', 4567)
+ brokers[1] = BrokerMetadata(2, 'broker_2', 5678)
+
+ topics = {}
+ topics['topic_1'] = {
+ 0: PartitionMetadata('topic_1', 0, 1, [1, 2], [1, 2])
+ }
+ topics['topic_noleader'] = {
+ 0: PartitionMetadata('topic_noleader', 0, -1, [], []),
+ 1: PartitionMetadata('topic_noleader', 1, -1, [], [])
+ }
+ topics['topic_no_partitions'] = {}
+ topics['topic_3'] = {
+ 0: PartitionMetadata('topic_3', 0, 0, [0, 1], [0, 1]),
+ 1: PartitionMetadata('topic_3', 1, 1, [1, 0], [1, 0]),
+ 2: PartitionMetadata('topic_3', 2, 0, [0, 1], [0, 1])
+ }
+ protocol.decode_metadata_response.return_value = (brokers, topics)
+
+ # client loads metadata at init
+ client = KafkaClient(hosts=['broker_1:4567'])
+ self.assertDictEqual({
+ TopicAndPartition('topic_1', 0): brokers[1],
+ TopicAndPartition('topic_noleader', 0): None,
+ TopicAndPartition('topic_noleader', 1): None,
+ TopicAndPartition('topic_3', 0): brokers[0],
+ TopicAndPartition('topic_3', 1): brokers[1],
+ TopicAndPartition('topic_3', 2): brokers[0]},
+ client.topics_to_brokers)
+
+ @patch('kafka.client.KafkaConnection')
+ @patch('kafka.client.KafkaProtocol')
+ def test_get_leader_for_partitions_reloads_metadata(self, protocol, conn):
+ "Get leader for partitions reload metadata if it is not available"
+
+ conn.recv.return_value = 'response' # anything but None
+
+ brokers = {}
+ brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
+ brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
+
+ topics = {'topic_no_partitions': {}}
+ protocol.decode_metadata_response.return_value = (brokers, topics)
+
+ client = KafkaClient(hosts=['broker_1:4567'])
+
+ # topic metadata is loaded but empty
+ self.assertDictEqual({}, client.topics_to_brokers)
+
+ topics['topic_no_partitions'] = {
+ 0: PartitionMetadata('topic_no_partitions', 0, 0, [0, 1], [0, 1])
+ }
+ protocol.decode_metadata_response.return_value = (brokers, topics)
+
+ # calling _get_leader_for_partition (from any broker aware request)
+ # will try loading metadata again for the same topic
+ leader = client._get_leader_for_partition('topic_no_partitions', 0)
+
+ self.assertEqual(brokers[0], leader)
+ self.assertDictEqual({
+ TopicAndPartition('topic_no_partitions', 0): brokers[0]},
+ client.topics_to_brokers)
+
+ @patch('kafka.client.KafkaConnection')
+ @patch('kafka.client.KafkaProtocol')
+ def test_get_leader_for_unassigned_partitions(self, protocol, conn):
+ "Get leader raises if no partitions is defined for a topic"
+
+ conn.recv.return_value = 'response' # anything but None
+
+ brokers = {}
+ brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
+ brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
+
+ topics = {'topic_no_partitions': {}}
+ protocol.decode_metadata_response.return_value = (brokers, topics)
+
+ client = KafkaClient(hosts=['broker_1:4567'])
+
+ self.assertDictEqual({}, client.topics_to_brokers)
+
+ with self.assertRaises(PartitionUnavailableError):
+ client._get_leader_for_partition('topic_no_partitions', 0)
+
+ @patch('kafka.client.KafkaConnection')
+ @patch('kafka.client.KafkaProtocol')
+ def test_get_leader_returns_none_when_noleader(self, protocol, conn):
+ "Getting leader for partitions returns None when the partiion has no leader"
+
+ conn.recv.return_value = 'response' # anything but None
+
+ brokers = {}
+ brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
+ brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
+
+ topics = {}
+ topics['topic_noleader'] = {
+ 0: PartitionMetadata('topic_noleader', 0, -1, [], []),
+ 1: PartitionMetadata('topic_noleader', 1, -1, [], [])
+ }
+ protocol.decode_metadata_response.return_value = (brokers, topics)
+
+ client = KafkaClient(hosts=['broker_1:4567'])
+ self.assertDictEqual(
+ {
+ TopicAndPartition('topic_noleader', 0): None,
+ TopicAndPartition('topic_noleader', 1): None
+ },
+ client.topics_to_brokers)
+ self.assertIsNone(client._get_leader_for_partition('topic_noleader', 0))
+ self.assertIsNone(client._get_leader_for_partition('topic_noleader', 1))
+
+ topics['topic_noleader'] = {
+ 0: PartitionMetadata('topic_noleader', 0, 0, [0, 1], [0, 1]),
+ 1: PartitionMetadata('topic_noleader', 1, 1, [1, 0], [1, 0])
+ }
+ protocol.decode_metadata_response.return_value = (brokers, topics)
+ self.assertEqual(brokers[0], client._get_leader_for_partition('topic_noleader', 0))
+ self.assertEqual(brokers[1], client._get_leader_for_partition('topic_noleader', 1))
+
+ @patch('kafka.client.KafkaConnection')
+ @patch('kafka.client.KafkaProtocol')
+ def test_send_produce_request_raises_when_noleader(self, protocol, conn):
+ "Send producer request raises LeaderUnavailableError if leader is not available"
+
+ conn.recv.return_value = 'response' # anything but None
+
+ brokers = {}
+ brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
+ brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
+
+ topics = {}
+ topics['topic_noleader'] = {
+ 0: PartitionMetadata('topic_noleader', 0, -1, [], []),
+ 1: PartitionMetadata('topic_noleader', 1, -1, [], [])
+ }
+ protocol.decode_metadata_response.return_value = (brokers, topics)
+
+ client = KafkaClient(hosts=['broker_1:4567'])
+
+ requests = [ProduceRequest(
+ "topic_noleader", 0,
+ [create_message("a"), create_message("b")])]
+
+ with self.assertRaises(LeaderUnavailableError):
+ client.send_produce_request(requests)
+
diff --git a/test/test_client_integration.py b/test/test_client_integration.py
new file mode 100644
index 0000000..261d168
--- /dev/null
+++ b/test/test_client_integration.py
@@ -0,0 +1,66 @@
+import os
+import random
+import socket
+import time
+import unittest2
+
+import kafka
+from kafka.common import *
+from fixtures import ZookeeperFixture, KafkaFixture
+from testutil import *
+
+class TestKafkaClientIntegration(KafkaIntegrationTestCase):
+ @classmethod
+ def setUpClass(cls): # noqa
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ cls.zk = ZookeeperFixture.instance()
+ cls.server = KafkaFixture.instance(0, cls.zk.host, cls.zk.port)
+
+ @classmethod
+ def tearDownClass(cls): # noqa
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ cls.server.close()
+ cls.zk.close()
+
+ @unittest2.skip("This doesn't appear to work on Linux?")
+ def test_timeout(self):
+ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server_port = get_open_port()
+ server_socket.bind(('localhost', server_port))
+
+ with Timer() as t:
+ with self.assertRaises((socket.timeout, socket.error)):
+ conn = kafka.conn.KafkaConnection("localhost", server_port, 1.0)
+ self.assertGreaterEqual(t.interval, 1.0)
+
+ @kafka_versions("all")
+ def test_consume_none(self):
+ fetch = FetchRequest(self.topic, 0, 0, 1024)
+
+ fetch_resp, = self.client.send_fetch_request([fetch])
+ self.assertEquals(fetch_resp.error, 0)
+ self.assertEquals(fetch_resp.topic, self.topic)
+ self.assertEquals(fetch_resp.partition, 0)
+
+ messages = list(fetch_resp.messages)
+ self.assertEquals(len(messages), 0)
+
+ ####################
+ # Offset Tests #
+ ####################
+
+ @kafka_versions("0.8.1")
+ def test_commit_fetch_offsets(self):
+ req = OffsetCommitRequest(self.topic, 0, 42, "metadata")
+ (resp,) = self.client.send_offset_commit_request("group", [req])
+ self.assertEquals(resp.error, 0)
+
+ req = OffsetFetchRequest(self.topic, 0)
+ (resp,) = self.client.send_offset_fetch_request("group", [req])
+ self.assertEquals(resp.error, 0)
+ self.assertEquals(resp.offset, 42)
+ self.assertEquals(resp.metadata, "") # Metadata isn't stored for now
diff --git a/test/test_codec.py b/test/test_codec.py
new file mode 100644
index 0000000..2e6f67e
--- /dev/null
+++ b/test/test_codec.py
@@ -0,0 +1,70 @@
+import struct
+import unittest2
+
+from kafka.codec import (
+ has_snappy, gzip_encode, gzip_decode,
+ snappy_encode, snappy_decode
+)
+from kafka.protocol import (
+ create_gzip_message, create_message, create_snappy_message, KafkaProtocol
+)
+from testutil import *
+
+class TestCodec(unittest2.TestCase):
+ def test_gzip(self):
+ for i in xrange(1000):
+ s1 = random_string(100)
+ s2 = gzip_decode(gzip_encode(s1))
+ self.assertEquals(s1, s2)
+
+ @unittest2.skipUnless(has_snappy(), "Snappy not available")
+ def test_snappy(self):
+ for i in xrange(1000):
+ s1 = random_string(100)
+ s2 = snappy_decode(snappy_encode(s1))
+ self.assertEquals(s1, s2)
+
+ @unittest2.skipUnless(has_snappy(), "Snappy not available")
+ def test_snappy_detect_xerial(self):
+ import kafka as kafka1
+ _detect_xerial_stream = kafka1.codec._detect_xerial_stream
+
+ header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes'
+ false_header = b'\x01SNAPPY\x00\x00\x00\x01\x00\x00\x00\x01'
+ random_snappy = snappy_encode('SNAPPY' * 50)
+ short_data = b'\x01\x02\x03\x04'
+
+ self.assertTrue(_detect_xerial_stream(header))
+ self.assertFalse(_detect_xerial_stream(b''))
+ self.assertFalse(_detect_xerial_stream(b'\x00'))
+ self.assertFalse(_detect_xerial_stream(false_header))
+ self.assertFalse(_detect_xerial_stream(random_snappy))
+ self.assertFalse(_detect_xerial_stream(short_data))
+
+ @unittest2.skipUnless(has_snappy(), "Snappy not available")
+ def test_snappy_decode_xerial(self):
+ header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01'
+ random_snappy = snappy_encode('SNAPPY' * 50)
+ block_len = len(random_snappy)
+ random_snappy2 = snappy_encode('XERIAL' * 50)
+ block_len2 = len(random_snappy2)
+
+ to_test = header \
+ + struct.pack('!i', block_len) + random_snappy \
+ + struct.pack('!i', block_len2) + random_snappy2 \
+
+ self.assertEquals(snappy_decode(to_test), ('SNAPPY' * 50) + ('XERIAL' * 50))
+
+ @unittest2.skipUnless(has_snappy(), "Snappy not available")
+ def test_snappy_encode_xerial(self):
+ to_ensure = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' + \
+ '\x00\x00\x00\x18' + \
+ '\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' + \
+ '\x00\x00\x00\x18' + \
+ '\xac\x02\x14XERIAL\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00'
+
+ to_test = ('SNAPPY' * 50) + ('XERIAL' * 50)
+
+ compressed = snappy_encode(to_test, xerial_compatible=True, xerial_blocksize=300)
+ self.assertEquals(compressed, to_ensure)
+
diff --git a/test/test_conn.py b/test/test_conn.py
new file mode 100644
index 0000000..4ab6d4f
--- /dev/null
+++ b/test/test_conn.py
@@ -0,0 +1,69 @@
+import os
+import random
+import struct
+import unittest2
+import kafka.conn
+
+class ConnTest(unittest2.TestCase):
+ def test_collect_hosts__happy_path(self):
+ hosts = "localhost:1234,localhost"
+ results = kafka.conn.collect_hosts(hosts)
+
+ self.assertEqual(set(results), set([
+ ('localhost', 1234),
+ ('localhost', 9092),
+ ]))
+
+ def test_collect_hosts__string_list(self):
+ hosts = [
+ 'localhost:1234',
+ 'localhost',
+ ]
+
+ results = kafka.conn.collect_hosts(hosts)
+
+ self.assertEqual(set(results), set([
+ ('localhost', 1234),
+ ('localhost', 9092),
+ ]))
+
+ def test_collect_hosts__with_spaces(self):
+ hosts = "localhost:1234, localhost"
+ results = kafka.conn.collect_hosts(hosts)
+
+ self.assertEqual(set(results), set([
+ ('localhost', 1234),
+ ('localhost', 9092),
+ ]))
+
+ @unittest2.skip("Not Implemented")
+ def test_send(self):
+ pass
+
+ @unittest2.skip("Not Implemented")
+ def test_send__reconnects_on_dirty_conn(self):
+ pass
+
+ @unittest2.skip("Not Implemented")
+ def test_send__failure_sets_dirty_connection(self):
+ pass
+
+ @unittest2.skip("Not Implemented")
+ def test_recv(self):
+ pass
+
+ @unittest2.skip("Not Implemented")
+ def test_recv__reconnects_on_dirty_conn(self):
+ pass
+
+ @unittest2.skip("Not Implemented")
+ def test_recv__failure_sets_dirty_connection(self):
+ pass
+
+ @unittest2.skip("Not Implemented")
+ def test_recv__doesnt_consume_extra_data_in_stream(self):
+ pass
+
+ @unittest2.skip("Not Implemented")
+ def test_close__object_is_reusable(self):
+ pass
diff --git a/test/test_consumer.py b/test/test_consumer.py
new file mode 100644
index 0000000..778d76a
--- /dev/null
+++ b/test/test_consumer.py
@@ -0,0 +1,22 @@
+import os
+import random
+import struct
+import unittest2
+
+from mock import MagicMock, patch
+
+from kafka import KafkaClient
+from kafka.consumer import SimpleConsumer
+from kafka.common import (
+ ProduceRequest, BrokerMetadata, PartitionMetadata,
+ TopicAndPartition, KafkaUnavailableError,
+ LeaderUnavailableError, PartitionUnavailableError
+)
+from kafka.protocol import (
+ create_message, KafkaProtocol
+)
+
+class TestKafkaConsumer(unittest2.TestCase):
+ def test_non_integer_partitions(self):
+ with self.assertRaises(AssertionError):
+ consumer = SimpleConsumer(MagicMock(), 'group', 'topic', partitions = [ '0' ])
diff --git a/test/test_consumer_integration.py b/test/test_consumer_integration.py
new file mode 100644
index 0000000..a6589b3
--- /dev/null
+++ b/test/test_consumer_integration.py
@@ -0,0 +1,257 @@
+import os
+from datetime import datetime
+
+from kafka import * # noqa
+from kafka.common import * # noqa
+from kafka.consumer import MAX_FETCH_BUFFER_SIZE_BYTES
+from fixtures import ZookeeperFixture, KafkaFixture
+from testutil import *
+
+class TestConsumerIntegration(KafkaIntegrationTestCase):
+ @classmethod
+ def setUpClass(cls):
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ cls.zk = ZookeeperFixture.instance()
+ cls.server1 = KafkaFixture.instance(0, cls.zk.host, cls.zk.port)
+ cls.server2 = KafkaFixture.instance(1, cls.zk.host, cls.zk.port)
+
+ cls.server = cls.server1 # Bootstrapping server
+
+ @classmethod
+ def tearDownClass(cls):
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ cls.server1.close()
+ cls.server2.close()
+ cls.zk.close()
+
+ def send_messages(self, partition, messages):
+ messages = [ create_message(self.msg(str(msg))) for msg in messages ]
+ produce = ProduceRequest(self.topic, partition, messages = messages)
+ resp, = self.client.send_produce_request([produce])
+ self.assertEquals(resp.error, 0)
+
+ return [ x.value for x in messages ]
+
+ def assert_message_count(self, messages, num_messages):
+ # Make sure we got them all
+ self.assertEquals(len(messages), num_messages)
+
+ # Make sure there are no duplicates
+ self.assertEquals(len(set(messages)), num_messages)
+
+ @kafka_versions("all")
+ def test_simple_consumer(self):
+ self.send_messages(0, range(0, 100))
+ self.send_messages(1, range(100, 200))
+
+ # Start a consumer
+ consumer = self.consumer()
+
+ self.assert_message_count([ message for message in consumer ], 200)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_simple_consumer__seek(self):
+ self.send_messages(0, range(0, 100))
+ self.send_messages(1, range(100, 200))
+
+ consumer = self.consumer()
+
+ # Rewind 10 messages from the end
+ consumer.seek(-10, 2)
+ self.assert_message_count([ message for message in consumer ], 10)
+
+ # Rewind 13 messages from the end
+ consumer.seek(-13, 2)
+ self.assert_message_count([ message for message in consumer ], 13)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_simple_consumer_blocking(self):
+ consumer = self.consumer()
+
+ # Ask for 5 messages, nothing in queue, block 5 seconds
+ with Timer() as t:
+ messages = consumer.get_messages(block=True, timeout=5)
+ self.assert_message_count(messages, 0)
+ self.assertGreaterEqual(t.interval, 5)
+
+ self.send_messages(0, range(0, 10))
+
+ # Ask for 5 messages, 10 in queue. Get 5 back, no blocking
+ with Timer() as t:
+ messages = consumer.get_messages(count=5, block=True, timeout=5)
+ self.assert_message_count(messages, 5)
+ self.assertLessEqual(t.interval, 1)
+
+ # Ask for 10 messages, get 5 back, block 5 seconds
+ with Timer() as t:
+ messages = consumer.get_messages(count=10, block=True, timeout=5)
+ self.assert_message_count(messages, 5)
+ self.assertGreaterEqual(t.interval, 5)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_simple_consumer_pending(self):
+ # Produce 10 messages to partitions 0 and 1
+ self.send_messages(0, range(0, 10))
+ self.send_messages(1, range(10, 20))
+
+ consumer = self.consumer()
+
+ self.assertEquals(consumer.pending(), 20)
+ self.assertEquals(consumer.pending(partitions=[0]), 10)
+ self.assertEquals(consumer.pending(partitions=[1]), 10)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_multi_process_consumer(self):
+ # Produce 100 messages to partitions 0 and 1
+ self.send_messages(0, range(0, 100))
+ self.send_messages(1, range(100, 200))
+
+ consumer = self.consumer(consumer = MultiProcessConsumer)
+
+ self.assert_message_count([ message for message in consumer ], 200)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_multi_process_consumer_blocking(self):
+ consumer = self.consumer(consumer = MultiProcessConsumer)
+
+ # Ask for 5 messages, No messages in queue, block 5 seconds
+ with Timer() as t:
+ messages = consumer.get_messages(block=True, timeout=5)
+ self.assert_message_count(messages, 0)
+
+ self.assertGreaterEqual(t.interval, 5)
+
+ # Send 10 messages
+ self.send_messages(0, range(0, 10))
+
+ # Ask for 5 messages, 10 messages in queue, block 0 seconds
+ with Timer() as t:
+ messages = consumer.get_messages(count=5, block=True, timeout=5)
+ self.assert_message_count(messages, 5)
+ self.assertLessEqual(t.interval, 1)
+
+ # Ask for 10 messages, 5 in queue, block 5 seconds
+ with Timer() as t:
+ messages = consumer.get_messages(count=10, block=True, timeout=5)
+ self.assert_message_count(messages, 5)
+ self.assertGreaterEqual(t.interval, 5)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_multi_proc_pending(self):
+ self.send_messages(0, range(0, 10))
+ self.send_messages(1, range(10, 20))
+
+ consumer = MultiProcessConsumer(self.client, "group1", self.topic, auto_commit=False)
+
+ self.assertEquals(consumer.pending(), 20)
+ self.assertEquals(consumer.pending(partitions=[0]), 10)
+ self.assertEquals(consumer.pending(partitions=[1]), 10)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_large_messages(self):
+ # Produce 10 "normal" size messages
+ small_messages = self.send_messages(0, [ str(x) for x in range(10) ])
+
+ # Produce 10 messages that are large (bigger than default fetch size)
+ large_messages = self.send_messages(0, [ random_string(5000) for x in range(10) ])
+
+ # Consumer should still get all of them
+ consumer = self.consumer()
+
+ expected_messages = set(small_messages + large_messages)
+ actual_messages = set([ x.message.value for x in consumer ])
+ self.assertEqual(expected_messages, actual_messages)
+
+ consumer.stop()
+
+ @kafka_versions("all")
+ def test_huge_messages(self):
+ huge_message, = self.send_messages(0, [
+ create_message(random_string(MAX_FETCH_BUFFER_SIZE_BYTES + 10)),
+ ])
+
+ # Create a consumer with the default buffer size
+ consumer = self.consumer()
+
+ # This consumer failes to get the message
+ with self.assertRaises(ConsumerFetchSizeTooSmall):
+ consumer.get_message(False, 0.1)
+
+ consumer.stop()
+
+ # Create a consumer with no fetch size limit
+ big_consumer = self.consumer(
+ max_buffer_size = None,
+ partitions = [0],
+ )
+
+ # Seek to the last message
+ big_consumer.seek(-1, 2)
+
+ # Consume giant message successfully
+ message = big_consumer.get_message(block=False, timeout=10)
+ self.assertIsNotNone(message)
+ self.assertEquals(message.message.value, huge_message)
+
+ big_consumer.stop()
+
+ @kafka_versions("0.8.1")
+ def test_offset_behavior__resuming_behavior(self):
+ msgs1 = self.send_messages(0, range(0, 100))
+ msgs2 = self.send_messages(1, range(100, 200))
+
+ # Start a consumer
+ consumer1 = self.consumer(
+ auto_commit_every_t = None,
+ auto_commit_every_n = 20,
+ )
+
+ # Grab the first 195 messages
+ output_msgs1 = [ consumer1.get_message().message.value for _ in xrange(195) ]
+ self.assert_message_count(output_msgs1, 195)
+
+ # The total offset across both partitions should be at 180
+ consumer2 = self.consumer(
+ auto_commit_every_t = None,
+ auto_commit_every_n = 20,
+ )
+
+ # 181-200
+ self.assert_message_count([ message for message in consumer2 ], 20)
+
+ consumer1.stop()
+ consumer2.stop()
+
+ def consumer(self, **kwargs):
+ if os.environ['KAFKA_VERSION'] == "0.8.0":
+ # Kafka 0.8.0 simply doesn't support offset requests, so hard code it being off
+ kwargs['auto_commit'] = False
+ else:
+ kwargs.setdefault('auto_commit', True)
+
+ consumer_class = kwargs.pop('consumer', SimpleConsumer)
+ group = kwargs.pop('group', self.id())
+ topic = kwargs.pop('topic', self.topic)
+
+ if consumer_class == SimpleConsumer:
+ kwargs.setdefault('iter_timeout', 0)
+
+ return consumer_class(self.client, group, topic, **kwargs)
diff --git a/test/test_failover_integration.py b/test/test_failover_integration.py
new file mode 100644
index 0000000..6298f62
--- /dev/null
+++ b/test/test_failover_integration.py
@@ -0,0 +1,123 @@
+import os
+import time
+
+from kafka import * # noqa
+from kafka.common import * # noqa
+from fixtures import ZookeeperFixture, KafkaFixture
+from testutil import *
+
+class TestFailover(KafkaIntegrationTestCase):
+ create_client = False
+
+ @classmethod
+ def setUpClass(cls): # noqa
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ zk_chroot = random_string(10)
+ replicas = 2
+ partitions = 2
+
+ # mini zookeeper, 2 kafka brokers
+ cls.zk = ZookeeperFixture.instance()
+ kk_args = [cls.zk.host, cls.zk.port, zk_chroot, replicas, partitions]
+ cls.brokers = [KafkaFixture.instance(i, *kk_args) for i in range(replicas)]
+
+ hosts = ['%s:%d' % (b.host, b.port) for b in cls.brokers]
+ cls.client = KafkaClient(hosts)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ cls.client.close()
+ for broker in cls.brokers:
+ broker.close()
+ cls.zk.close()
+
+ @kafka_versions("all")
+ def test_switch_leader(self):
+ key, topic, partition = random_string(5), self.topic, 0
+ producer = SimpleProducer(self.client)
+
+ for i in range(1, 4):
+
+ # XXX unfortunately, the conns dict needs to be warmed for this to work
+ # XXX unfortunately, for warming to work, we need at least as many partitions as brokers
+ self._send_random_messages(producer, self.topic, 10)
+
+ # kil leader for partition 0
+ broker = self._kill_leader(topic, partition)
+
+ # expect failure, reload meta data
+ with self.assertRaises(FailedPayloadsError):
+ producer.send_messages(self.topic, 'part 1')
+ producer.send_messages(self.topic, 'part 2')
+ time.sleep(1)
+
+ # send to new leader
+ self._send_random_messages(producer, self.topic, 10)
+
+ broker.open()
+ time.sleep(3)
+
+ # count number of messages
+ count = self._count_messages('test_switch_leader group %s' % i, topic)
+ self.assertIn(count, range(20 * i, 22 * i + 1))
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_switch_leader_async(self):
+ key, topic, partition = random_string(5), self.topic, 0
+ producer = SimpleProducer(self.client, async=True)
+
+ for i in range(1, 4):
+
+ self._send_random_messages(producer, self.topic, 10)
+
+ # kil leader for partition 0
+ broker = self._kill_leader(topic, partition)
+
+ # expect failure, reload meta data
+ producer.send_messages(self.topic, 'part 1')
+ producer.send_messages(self.topic, 'part 2')
+ time.sleep(1)
+
+ # send to new leader
+ self._send_random_messages(producer, self.topic, 10)
+
+ broker.open()
+ time.sleep(3)
+
+ # count number of messages
+ count = self._count_messages('test_switch_leader_async group %s' % i, topic)
+ self.assertIn(count, range(20 * i, 22 * i + 1))
+
+ producer.stop()
+
+ def _send_random_messages(self, producer, topic, n):
+ for j in range(n):
+ resp = producer.send_messages(topic, random_string(10))
+ if len(resp) > 0:
+ self.assertEquals(resp[0].error, 0)
+ time.sleep(1) # give it some time
+
+ def _kill_leader(self, topic, partition):
+ leader = self.client.topics_to_brokers[TopicAndPartition(topic, partition)]
+ broker = self.brokers[leader.nodeId]
+ broker.close()
+ time.sleep(1) # give it some time
+ return broker
+
+ def _count_messages(self, group, topic):
+ hosts = '%s:%d' % (self.brokers[0].host, self.brokers[0].port)
+ client = KafkaClient(hosts)
+ consumer = SimpleConsumer(client, group, topic, auto_commit=False, iter_timeout=0)
+ all_messages = []
+ for message in consumer:
+ all_messages.append(message)
+ consumer.stop()
+ client.close()
+ return len(all_messages)
diff --git a/test/test_integration.py b/test/test_integration.py
deleted file mode 100644
index 4087df7..0000000
--- a/test/test_integration.py
+++ /dev/null
@@ -1,971 +0,0 @@
-import logging
-import unittest
-import time
-from datetime import datetime
-import string
-import random
-
-from kafka import * # noqa
-from kafka.common import * # noqa
-from kafka.codec import has_gzip, has_snappy
-from kafka.consumer import MAX_FETCH_BUFFER_SIZE_BYTES
-from .fixtures import ZookeeperFixture, KafkaFixture
-
-
-def random_string(l):
- s = "".join(random.choice(string.letters) for i in xrange(l))
- return s
-
-
-def ensure_topic_creation(client, topic_name):
- times = 0
- while True:
- times += 1
- client.load_metadata_for_topics(topic_name)
- if client.has_metadata_for_topic(topic_name):
- break
- print "Waiting for %s topic to be created" % topic_name
- time.sleep(1)
-
- if times > 30:
- raise Exception("Unable to create topic %s" % topic_name)
-
-
-class KafkaTestCase(unittest.TestCase):
- def setUp(self):
- self.topic = "%s-%s" % (self.id()[self.id().rindex(".") + 1:], random_string(10))
- ensure_topic_creation(self.client, self.topic)
-
-
-class TestKafkaClient(KafkaTestCase):
- @classmethod
- def setUpClass(cls): # noqa
- cls.zk = ZookeeperFixture.instance()
- cls.server = KafkaFixture.instance(0, cls.zk.host, cls.zk.port)
- cls.client = KafkaClient('%s:%d' % (cls.server.host, cls.server.port))
-
- @classmethod
- def tearDownClass(cls): # noqa
- cls.client.close()
- cls.server.close()
- cls.zk.close()
-
- #####################
- # Produce Tests #
- #####################
-
- def test_produce_many_simple(self):
-
- produce = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message %d" % i) for i in range(100)
- ])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 100)
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 100)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 200)
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 200)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 300)
-
- def test_produce_10k_simple(self):
- produce = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message %d" % i) for i in range(10000)
- ])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 10000)
-
- def test_produce_many_gzip(self):
- if not has_gzip():
- return
- message1 = create_gzip_message(["Gzipped 1 %d" % i for i in range(100)])
- message2 = create_gzip_message(["Gzipped 2 %d" % i for i in range(100)])
-
- produce = ProduceRequest(self.topic, 0, messages=[message1, message2])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 200)
-
- def test_produce_many_snappy(self):
- if not has_snappy():
- return
- message1 = create_snappy_message(["Snappy 1 %d" % i for i in range(100)])
- message2 = create_snappy_message(["Snappy 2 %d" % i for i in range(100)])
-
- produce = ProduceRequest(self.topic, 0, messages=[message1, message2])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 200)
-
- def test_produce_mixed(self):
- if not has_gzip() or not has_snappy():
- return
- message1 = create_message("Just a plain message")
- message2 = create_gzip_message(["Gzipped %d" % i for i in range(100)])
- message3 = create_snappy_message(["Snappy %d" % i for i in range(100)])
-
- produce = ProduceRequest(self.topic, 0, messages=[message1, message2, message3])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 201)
-
- def test_produce_100k_gzipped(self):
- req1 = ProduceRequest(self.topic, 0, messages=[
- create_gzip_message(["Gzipped batch 1, message %d" % i for i in range(50000)])
- ])
-
- for resp in self.client.send_produce_request([req1]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 50000)
-
- req2 = ProduceRequest(self.topic, 0, messages=[
- create_gzip_message(["Gzipped batch 2, message %d" % i for i in range(50000)])
- ])
-
- for resp in self.client.send_produce_request([req2]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 50000)
-
- (offset, ) = self.client.send_offset_request([OffsetRequest(self.topic, 0, -1, 1)])
- self.assertEquals(offset.offsets[0], 100000)
-
- #####################
- # Consume Tests #
- #####################
-
- def test_consume_none(self):
- fetch = FetchRequest(self.topic, 0, 0, 1024)
-
- fetch_resp = self.client.send_fetch_request([fetch])[0]
- self.assertEquals(fetch_resp.error, 0)
- self.assertEquals(fetch_resp.topic, self.topic)
- self.assertEquals(fetch_resp.partition, 0)
-
- messages = list(fetch_resp.messages)
- self.assertEquals(len(messages), 0)
-
- def test_produce_consume(self):
- produce = ProduceRequest(self.topic, 0, messages=[
- create_message("Just a test message"),
- create_message("Message with a key", "foo"),
- ])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- fetch = FetchRequest(self.topic, 0, 0, 1024)
-
- fetch_resp = self.client.send_fetch_request([fetch])[0]
- self.assertEquals(fetch_resp.error, 0)
-
- messages = list(fetch_resp.messages)
- self.assertEquals(len(messages), 2)
- self.assertEquals(messages[0].offset, 0)
- self.assertEquals(messages[0].message.value, "Just a test message")
- self.assertEquals(messages[0].message.key, None)
- self.assertEquals(messages[1].offset, 1)
- self.assertEquals(messages[1].message.value, "Message with a key")
- self.assertEquals(messages[1].message.key, "foo")
-
- def test_produce_consume_many(self):
- produce = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message %d" % i) for i in range(100)
- ])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- # 1024 is not enough for 100 messages...
- fetch1 = FetchRequest(self.topic, 0, 0, 1024)
-
- (fetch_resp1,) = self.client.send_fetch_request([fetch1])
-
- self.assertEquals(fetch_resp1.error, 0)
- self.assertEquals(fetch_resp1.highwaterMark, 100)
- messages = list(fetch_resp1.messages)
- self.assertTrue(len(messages) < 100)
-
- # 10240 should be enough
- fetch2 = FetchRequest(self.topic, 0, 0, 10240)
- (fetch_resp2,) = self.client.send_fetch_request([fetch2])
-
- self.assertEquals(fetch_resp2.error, 0)
- self.assertEquals(fetch_resp2.highwaterMark, 100)
- messages = list(fetch_resp2.messages)
- self.assertEquals(len(messages), 100)
- for i, message in enumerate(messages):
- self.assertEquals(message.offset, i)
- self.assertEquals(message.message.value, "Test message %d" % i)
- self.assertEquals(message.message.key, None)
-
- def test_produce_consume_two_partitions(self):
- produce1 = ProduceRequest(self.topic, 0, messages=[
- create_message("Partition 0 %d" % i) for i in range(10)
- ])
- produce2 = ProduceRequest(self.topic, 1, messages=[
- create_message("Partition 1 %d" % i) for i in range(10)
- ])
-
- for resp in self.client.send_produce_request([produce1, produce2]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- fetch1 = FetchRequest(self.topic, 0, 0, 1024)
- fetch2 = FetchRequest(self.topic, 1, 0, 1024)
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request([fetch1, fetch2])
- self.assertEquals(fetch_resp1.error, 0)
- self.assertEquals(fetch_resp1.highwaterMark, 10)
- messages = list(fetch_resp1.messages)
- self.assertEquals(len(messages), 10)
- for i, message in enumerate(messages):
- self.assertEquals(message.offset, i)
- self.assertEquals(message.message.value, "Partition 0 %d" % i)
- self.assertEquals(message.message.key, None)
- self.assertEquals(fetch_resp2.error, 0)
- self.assertEquals(fetch_resp2.highwaterMark, 10)
- messages = list(fetch_resp2.messages)
- self.assertEquals(len(messages), 10)
- for i, message in enumerate(messages):
- self.assertEquals(message.offset, i)
- self.assertEquals(message.message.value, "Partition 1 %d" % i)
- self.assertEquals(message.message.key, None)
-
- ####################
- # Offset Tests #
- ####################
-
- @unittest.skip('commmit offset not supported in this version')
- def test_commit_fetch_offsets(self):
- req = OffsetCommitRequest(self.topic, 0, 42, "metadata")
- (resp,) = self.client.send_offset_commit_request("group", [req])
- self.assertEquals(resp.error, 0)
-
- req = OffsetFetchRequest(self.topic, 0)
- (resp,) = self.client.send_offset_fetch_request("group", [req])
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 42)
- self.assertEquals(resp.metadata, "") # Metadata isn't stored for now
-
- # Producer Tests
-
- def test_simple_producer(self):
- producer = SimpleProducer(self.client)
- resp = producer.send_messages(self.topic, "one", "two")
-
- partition_for_first_batch = resp[0].partition
-
- self.assertEquals(len(resp), 1)
- self.assertEquals(resp[0].error, 0)
- self.assertEquals(resp[0].offset, 0) # offset of first msg
-
- # ensure this partition is different from the first partition
- resp = producer.send_messages(self.topic, "three")
- partition_for_second_batch = resp[0].partition
- self.assertNotEquals(partition_for_first_batch, partition_for_second_batch)
-
- self.assertEquals(len(resp), 1)
- self.assertEquals(resp[0].error, 0)
- self.assertEquals(resp[0].offset, 0) # offset of first msg
-
- fetch_requests = (
- FetchRequest(self.topic, partition_for_first_batch, 0, 1024),
- FetchRequest(self.topic, partition_for_second_batch, 0, 1024),
- )
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request(fetch_requests)
- self.assertEquals(fetch_resp1.error, 0)
- self.assertEquals(fetch_resp1.highwaterMark, 2)
- messages = list(fetch_resp1.messages)
- self.assertEquals(len(messages), 2)
- self.assertEquals(messages[0].message.value, "one")
- self.assertEquals(messages[1].message.value, "two")
- self.assertEquals(fetch_resp2.error, 0)
- self.assertEquals(fetch_resp2.highwaterMark, 1)
- messages = list(fetch_resp2.messages)
- self.assertEquals(len(messages), 1)
- self.assertEquals(messages[0].message.value, "three")
-
- # Will go to same partition as first batch
- resp = producer.send_messages(self.topic, "four", "five")
- self.assertEquals(len(resp), 1)
- self.assertEquals(resp[0].error, 0)
- self.assertEquals(resp[0].offset, 2) # offset of first msg
- self.assertEquals(resp[0].partition, partition_for_first_batch)
-
- producer.stop()
-
- def test_round_robin_partitioner(self):
- producer = KeyedProducer(self.client,
- partitioner=RoundRobinPartitioner)
- producer.send(self.topic, "key1", "one")
- producer.send(self.topic, "key2", "two")
- producer.send(self.topic, "key3", "three")
- producer.send(self.topic, "key4", "four")
-
- fetch1 = FetchRequest(self.topic, 0, 0, 1024)
- fetch2 = FetchRequest(self.topic, 1, 0, 1024)
-
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request([fetch1,
- fetch2])
-
- self.assertEquals(fetch_resp1.error, 0)
- self.assertEquals(fetch_resp1.highwaterMark, 2)
- self.assertEquals(fetch_resp1.partition, 0)
-
- messages = list(fetch_resp1.messages)
- self.assertEquals(len(messages), 2)
- self.assertEquals(messages[0].message.value, "one")
- self.assertEquals(messages[1].message.value, "three")
-
- self.assertEquals(fetch_resp2.error, 0)
- self.assertEquals(fetch_resp2.highwaterMark, 2)
- self.assertEquals(fetch_resp2.partition, 1)
-
- messages = list(fetch_resp2.messages)
- self.assertEquals(len(messages), 2)
- self.assertEquals(messages[0].message.value, "two")
- self.assertEquals(messages[1].message.value, "four")
-
- producer.stop()
-
- def test_hashed_partitioner(self):
- producer = KeyedProducer(self.client,
- partitioner=HashedPartitioner)
- producer.send(self.topic, 1, "one")
- producer.send(self.topic, 2, "two")
- producer.send(self.topic, 3, "three")
- producer.send(self.topic, 4, "four")
-
- fetch1 = FetchRequest(self.topic, 0, 0, 1024)
- fetch2 = FetchRequest(self.topic, 1, 0, 1024)
-
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request([fetch1,
- fetch2])
-
- self.assertEquals(fetch_resp1.error, 0)
- self.assertEquals(fetch_resp1.highwaterMark, 2)
- self.assertEquals(fetch_resp1.partition, 0)
-
- messages = list(fetch_resp1.messages)
- self.assertEquals(len(messages), 2)
- self.assertEquals(messages[0].message.value, "two")
- self.assertEquals(messages[1].message.value, "four")
-
- self.assertEquals(fetch_resp2.error, 0)
- self.assertEquals(fetch_resp2.highwaterMark, 2)
- self.assertEquals(fetch_resp2.partition, 1)
-
- messages = list(fetch_resp2.messages)
- self.assertEquals(len(messages), 2)
- self.assertEquals(messages[0].message.value, "one")
- self.assertEquals(messages[1].message.value, "three")
-
- producer.stop()
-
- def test_acks_none(self):
- producer = SimpleProducer(self.client,
- req_acks=SimpleProducer.ACK_NOT_REQUIRED)
- resp = producer.send_messages(self.topic, "one")
- self.assertEquals(len(resp), 0)
-
- # fetch from both partitions
- fetch_requests = (
- FetchRequest(self.topic, 0, 0, 1024),
- FetchRequest(self.topic, 1, 0, 1024),
- )
- fetch_resps = self.client.send_fetch_request(fetch_requests)
-
- # determine which partition was selected (due to random round-robin)
- published_to_resp = max(fetch_resps, key=lambda x: x.highwaterMark)
- not_published_to_resp = min(fetch_resps, key=lambda x: x.highwaterMark)
- self.assertNotEquals(published_to_resp.partition, not_published_to_resp.partition)
-
- self.assertEquals(published_to_resp.error, 0)
- self.assertEquals(published_to_resp.highwaterMark, 1)
-
- self.assertEquals(not_published_to_resp.error, 0)
- self.assertEquals(not_published_to_resp.highwaterMark, 0)
-
- messages = list(published_to_resp.messages)
- self.assertEquals(len(messages), 1)
- self.assertEquals(messages[0].message.value, "one")
-
- producer.stop()
-
- def test_acks_local_write(self):
- producer = SimpleProducer(self.client,
- req_acks=SimpleProducer.ACK_AFTER_LOCAL_WRITE)
- resp = producer.send_messages(self.topic, "one")
- self.assertEquals(len(resp), 1)
-
- partition = resp[0].partition
-
- fetch = FetchRequest(self.topic, partition, 0, 1024)
- fetch_resp = self.client.send_fetch_request([fetch])
-
- self.assertEquals(fetch_resp[0].error, 0)
- self.assertEquals(fetch_resp[0].highwaterMark, 1)
- self.assertEquals(fetch_resp[0].partition, partition)
-
- messages = list(fetch_resp[0].messages)
- self.assertEquals(len(messages), 1)
- self.assertEquals(messages[0].message.value, "one")
-
- producer.stop()
-
- def test_acks_cluster_commit(self):
- producer = SimpleProducer(
- self.client,
- req_acks=SimpleProducer.ACK_AFTER_CLUSTER_COMMIT)
- resp = producer.send_messages(self.topic, "one")
- self.assertEquals(len(resp), 1)
-
- partition = resp[0].partition
-
- fetch = FetchRequest(self.topic, partition, 0, 1024)
- fetch_resp = self.client.send_fetch_request([fetch])
-
- self.assertEquals(fetch_resp[0].error, 0)
- self.assertEquals(fetch_resp[0].highwaterMark, 1)
- self.assertEquals(fetch_resp[0].partition, partition)
-
- messages = list(fetch_resp[0].messages)
- self.assertEquals(len(messages), 1)
- self.assertEquals(messages[0].message.value, "one")
-
- producer.stop()
-
- def test_async_simple_producer(self):
- producer = SimpleProducer(self.client, async=True)
- resp = producer.send_messages(self.topic, "one")
- self.assertEquals(len(resp), 0)
-
- # Give it some time
- time.sleep(2)
-
- # fetch from both partitions
- fetch_requests = (
- FetchRequest(self.topic, 0, 0, 1024),
- FetchRequest(self.topic, 1, 0, 1024),
- )
- fetch_resps = self.client.send_fetch_request(fetch_requests)
-
- # determine which partition was selected (due to random round-robin)
- published_to_resp = max(fetch_resps, key=lambda x: x.highwaterMark)
- not_published_to_resp = min(fetch_resps, key=lambda x: x.highwaterMark)
- self.assertNotEquals(published_to_resp.partition, not_published_to_resp.partition)
-
- self.assertEquals(published_to_resp.error, 0)
- self.assertEquals(published_to_resp.highwaterMark, 1)
-
- self.assertEquals(not_published_to_resp.error, 0)
- self.assertEquals(not_published_to_resp.highwaterMark, 0)
-
- messages = list(published_to_resp.messages)
- self.assertEquals(len(messages), 1)
- self.assertEquals(messages[0].message.value, "one")
-
- messages = list(not_published_to_resp.messages)
- self.assertEquals(len(messages), 0)
-
- producer.stop()
-
- def test_async_keyed_producer(self):
- producer = KeyedProducer(self.client, async=True)
-
- resp = producer.send(self.topic, "key1", "one")
- self.assertEquals(len(resp), 0)
-
- # Give it some time
- time.sleep(2)
-
- fetch = FetchRequest(self.topic, 0, 0, 1024)
- fetch_resp = self.client.send_fetch_request([fetch])
-
- self.assertEquals(fetch_resp[0].error, 0)
- self.assertEquals(fetch_resp[0].highwaterMark, 1)
- self.assertEquals(fetch_resp[0].partition, 0)
-
- messages = list(fetch_resp[0].messages)
- self.assertEquals(len(messages), 1)
- self.assertEquals(messages[0].message.value, "one")
-
- producer.stop()
-
- def test_batched_simple_producer(self):
- producer = SimpleProducer(self.client,
- batch_send=True,
- batch_send_every_n=10,
- batch_send_every_t=20)
-
- # Send 5 messages and do a fetch
- msgs = ["message-%d" % i for i in range(0, 5)]
- resp = producer.send_messages(self.topic, *msgs)
-
- # Batch mode is async. No ack
- self.assertEquals(len(resp), 0)
-
- # Give it some time
- time.sleep(2)
-
- fetch1 = FetchRequest(self.topic, 0, 0, 1024)
- fetch2 = FetchRequest(self.topic, 1, 0, 1024)
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request([fetch1,
- fetch2])
-
- self.assertEquals(fetch_resp1.error, 0)
- messages = list(fetch_resp1.messages)
- self.assertEquals(len(messages), 0)
-
- self.assertEquals(fetch_resp2.error, 0)
- messages = list(fetch_resp2.messages)
- self.assertEquals(len(messages), 0)
-
- # Send 5 more messages, wait for 2 seconds and do a fetch
- msgs = ["message-%d" % i for i in range(5, 10)]
- resp = producer.send_messages(self.topic, *msgs)
-
- # Give it some time
- time.sleep(2)
-
- fetch1 = FetchRequest(self.topic, 0, 0, 1024)
- fetch2 = FetchRequest(self.topic, 1, 0, 1024)
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request([fetch1,
- fetch2])
-
- self.assertEquals(fetch_resp1.error, 0)
- messages = list(fetch_resp1.messages)
- self.assertEquals(len(messages), 5)
-
- self.assertEquals(fetch_resp2.error, 0)
- messages = list(fetch_resp2.messages)
- self.assertEquals(len(messages), 5)
-
- # Send 7 messages and wait for 20 seconds
- msgs = ["message-%d" % i for i in range(10, 15)]
- resp = producer.send_messages(self.topic, *msgs)
- msgs = ["message-%d" % i for i in range(15, 17)]
- resp = producer.send_messages(self.topic, *msgs)
-
- fetch1 = FetchRequest(self.topic, 0, 5, 1024)
- fetch2 = FetchRequest(self.topic, 1, 5, 1024)
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request([fetch1,
- fetch2])
-
- self.assertEquals(fetch_resp1.error, 0)
- self.assertEquals(fetch_resp2.error, 0)
- messages = list(fetch_resp1.messages) + list(fetch_resp2.messages)
- self.assertEquals(len(messages), 0)
-
- # Give it some time
- time.sleep(22)
-
- fetch1 = FetchRequest(self.topic, 0, 5, 1024)
- fetch2 = FetchRequest(self.topic, 1, 5, 1024)
- fetch_resp1, fetch_resp2 = self.client.send_fetch_request([fetch1,
- fetch2])
-
- self.assertEquals(fetch_resp1.error, 0)
- self.assertEquals(fetch_resp2.error, 0)
- messages = list(fetch_resp1.messages) + list(fetch_resp2.messages)
- self.assertEquals(len(messages), 7)
-
- producer.stop()
-
-
-class TestConsumer(KafkaTestCase):
- @classmethod
- def setUpClass(cls):
- cls.zk = ZookeeperFixture.instance()
- cls.server1 = KafkaFixture.instance(0, cls.zk.host, cls.zk.port)
- cls.server2 = KafkaFixture.instance(1, cls.zk.host, cls.zk.port)
- cls.client = KafkaClient('%s:%d' % (cls.server2.host, cls.server2.port))
-
- @classmethod
- def tearDownClass(cls): # noqa
- cls.client.close()
- cls.server1.close()
- cls.server2.close()
- cls.zk.close()
-
- def test_simple_consumer(self):
- # Produce 100 messages to partition 0
- produce1 = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message 0 %d" % i) for i in range(100)
- ])
-
- for resp in self.client.send_produce_request([produce1]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- # Produce 100 messages to partition 1
- produce2 = ProduceRequest(self.topic, 1, messages=[
- create_message("Test message 1 %d" % i) for i in range(100)
- ])
-
- for resp in self.client.send_produce_request([produce2]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- # Start a consumer
- consumer = SimpleConsumer(self.client, "group1",
- self.topic, auto_commit=False,
- iter_timeout=0)
- all_messages = []
- for message in consumer:
- all_messages.append(message)
-
- self.assertEquals(len(all_messages), 200)
- # Make sure there are no duplicates
- self.assertEquals(len(all_messages), len(set(all_messages)))
-
- consumer.seek(-10, 2)
- all_messages = []
- for message in consumer:
- all_messages.append(message)
-
- self.assertEquals(len(all_messages), 10)
-
- consumer.seek(-13, 2)
- all_messages = []
- for message in consumer:
- all_messages.append(message)
-
- self.assertEquals(len(all_messages), 13)
-
- consumer.stop()
-
- def test_simple_consumer_blocking(self):
- consumer = SimpleConsumer(self.client, "group1",
- self.topic,
- auto_commit=False, iter_timeout=0)
-
- # Blocking API
- start = datetime.now()
- messages = consumer.get_messages(block=True, timeout=5)
- diff = (datetime.now() - start).total_seconds()
- self.assertGreaterEqual(diff, 5)
- self.assertEqual(len(messages), 0)
-
- # Send 10 messages
- produce = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message 0 %d" % i) for i in range(10)
- ])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- # Fetch 5 messages
- messages = consumer.get_messages(count=5, block=True, timeout=5)
- self.assertEqual(len(messages), 5)
-
- # Fetch 10 messages
- start = datetime.now()
- messages = consumer.get_messages(count=10, block=True, timeout=5)
- self.assertEqual(len(messages), 5)
- diff = (datetime.now() - start).total_seconds()
- self.assertGreaterEqual(diff, 5)
-
- consumer.stop()
-
- def test_simple_consumer_pending(self):
- # Produce 10 messages to partition 0 and 1
-
- produce1 = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message 0 %d" % i) for i in range(10)
- ])
- for resp in self.client.send_produce_request([produce1]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- produce2 = ProduceRequest(self.topic, 1, messages=[
- create_message("Test message 1 %d" % i) for i in range(10)
- ])
- for resp in self.client.send_produce_request([produce2]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- consumer = SimpleConsumer(self.client, "group1", self.topic,
- auto_commit=False, iter_timeout=0)
- self.assertEquals(consumer.pending(), 20)
- self.assertEquals(consumer.pending(partitions=[0]), 10)
- self.assertEquals(consumer.pending(partitions=[1]), 10)
- consumer.stop()
-
- def test_multi_process_consumer(self):
- # Produce 100 messages to partition 0
- produce1 = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message 0 %d" % i) for i in range(100)
- ])
-
- for resp in self.client.send_produce_request([produce1]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- # Produce 100 messages to partition 1
- produce2 = ProduceRequest(self.topic, 1, messages=[
- create_message("Test message 1 %d" % i) for i in range(100)
- ])
-
- for resp in self.client.send_produce_request([produce2]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- # Start a consumer
- consumer = MultiProcessConsumer(self.client, "grp1", self.topic, auto_commit=False)
- all_messages = []
- for message in consumer:
- all_messages.append(message)
-
- self.assertEquals(len(all_messages), 200)
- # Make sure there are no duplicates
- self.assertEquals(len(all_messages), len(set(all_messages)))
-
- # Blocking API
- start = datetime.now()
- messages = consumer.get_messages(block=True, timeout=5)
- diff = (datetime.now() - start).total_seconds()
- self.assertGreaterEqual(diff, 4.999)
- self.assertEqual(len(messages), 0)
-
- # Send 10 messages
- produce = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message 0 %d" % i) for i in range(10)
- ])
-
- for resp in self.client.send_produce_request([produce]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 100)
-
- # Fetch 5 messages
- messages = consumer.get_messages(count=5, block=True, timeout=5)
- self.assertEqual(len(messages), 5)
-
- # Fetch 10 messages
- start = datetime.now()
- messages = consumer.get_messages(count=10, block=True, timeout=5)
- self.assertEqual(len(messages), 5)
- diff = (datetime.now() - start).total_seconds()
- self.assertGreaterEqual(diff, 5)
-
- consumer.stop()
-
- def test_multi_proc_pending(self):
- # Produce 10 messages to partition 0 and 1
- produce1 = ProduceRequest(self.topic, 0, messages=[
- create_message("Test message 0 %d" % i) for i in range(10)
- ])
-
- for resp in self.client.send_produce_request([produce1]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- produce2 = ProduceRequest(self.topic, 1, messages=[
- create_message("Test message 1 %d" % i) for i in range(10)
- ])
-
- for resp in self.client.send_produce_request([produce2]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- consumer = MultiProcessConsumer(self.client, "group1", self.topic, auto_commit=False)
- self.assertEquals(consumer.pending(), 20)
- self.assertEquals(consumer.pending(partitions=[0]), 10)
- self.assertEquals(consumer.pending(partitions=[1]), 10)
-
- consumer.stop()
-
- def test_large_messages(self):
- # Produce 10 "normal" size messages
- messages1 = [create_message(random_string(1024)) for i in range(10)]
- produce1 = ProduceRequest(self.topic, 0, messages1)
-
- for resp in self.client.send_produce_request([produce1]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 0)
-
- # Produce 10 messages that are large (bigger than default fetch size)
- messages2 = [create_message(random_string(5000)) for i in range(10)]
- produce2 = ProduceRequest(self.topic, 0, messages2)
-
- for resp in self.client.send_produce_request([produce2]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 10)
-
- # Consumer should still get all of them
- consumer = SimpleConsumer(self.client, "group1", self.topic,
- auto_commit=False, iter_timeout=0)
- all_messages = messages1 + messages2
- for i, message in enumerate(consumer):
- self.assertEquals(all_messages[i], message.message)
- self.assertEquals(i, 19)
-
- # Produce 1 message that is too large (bigger than max fetch size)
- big_message_size = MAX_FETCH_BUFFER_SIZE_BYTES + 10
- big_message = create_message(random_string(big_message_size))
- produce3 = ProduceRequest(self.topic, 0, [big_message])
- for resp in self.client.send_produce_request([produce3]):
- self.assertEquals(resp.error, 0)
- self.assertEquals(resp.offset, 20)
-
- self.assertRaises(ConsumerFetchSizeTooSmall, consumer.get_message, False, 0.1)
-
- # Create a consumer with no fetch size limit
- big_consumer = SimpleConsumer(self.client, "group1", self.topic,
- max_buffer_size=None, partitions=[0],
- auto_commit=False, iter_timeout=0)
-
- # Seek to the last message
- big_consumer.seek(-1, 2)
-
- # Consume giant message successfully
- message = big_consumer.get_message(block=False, timeout=10)
- self.assertIsNotNone(message)
- self.assertEquals(message.message.value, big_message.value)
-
-
-class TestFailover(KafkaTestCase):
-
- @classmethod
- def setUpClass(cls): # noqa
- zk_chroot = random_string(10)
- replicas = 2
- partitions = 2
-
- # mini zookeeper, 2 kafka brokers
- cls.zk = ZookeeperFixture.instance()
- kk_args = [cls.zk.host, cls.zk.port, zk_chroot, replicas, partitions]
- cls.brokers = [KafkaFixture.instance(i, *kk_args) for i in range(replicas)]
-
- hosts = ['%s:%d' % (b.host, b.port) for b in cls.brokers]
- cls.client = KafkaClient(hosts)
-
- @classmethod
- def tearDownClass(cls):
- cls.client.close()
- for broker in cls.brokers:
- broker.close()
- cls.zk.close()
-
- def test_switch_leader(self):
- key, topic, partition = random_string(5), self.topic, 0
- producer = SimpleProducer(self.client)
-
- for i in range(1, 4):
-
- # XXX unfortunately, the conns dict needs to be warmed for this to work
- # XXX unfortunately, for warming to work, we need at least as many partitions as brokers
- self._send_random_messages(producer, self.topic, 10)
-
- # kil leader for partition 0
- broker = self._kill_leader(topic, partition)
-
- # expect failure, reload meta data
- with self.assertRaises(FailedPayloadsError):
- producer.send_messages(self.topic, 'part 1')
- producer.send_messages(self.topic, 'part 2')
- time.sleep(1)
-
- # send to new leader
- self._send_random_messages(producer, self.topic, 10)
-
- broker.open()
- time.sleep(3)
-
- # count number of messages
- count = self._count_messages('test_switch_leader group %s' % i, topic)
- self.assertIn(count, range(20 * i, 22 * i + 1))
-
- producer.stop()
-
- def test_switch_leader_async(self):
- key, topic, partition = random_string(5), self.topic, 0
- producer = SimpleProducer(self.client, async=True)
-
- for i in range(1, 4):
-
- self._send_random_messages(producer, self.topic, 10)
-
- # kil leader for partition 0
- broker = self._kill_leader(topic, partition)
-
- # expect failure, reload meta data
- producer.send_messages(self.topic, 'part 1')
- producer.send_messages(self.topic, 'part 2')
- time.sleep(1)
-
- # send to new leader
- self._send_random_messages(producer, self.topic, 10)
-
- broker.open()
- time.sleep(3)
-
- # count number of messages
- count = self._count_messages('test_switch_leader_async group %s' % i, topic)
- self.assertIn(count, range(20 * i, 22 * i + 1))
-
- producer.stop()
-
- def _send_random_messages(self, producer, topic, n):
- for j in range(n):
- resp = producer.send_messages(topic, random_string(10))
- if len(resp) > 0:
- self.assertEquals(resp[0].error, 0)
- time.sleep(1) # give it some time
-
- def _kill_leader(self, topic, partition):
- leader = self.client.topics_to_brokers[TopicAndPartition(topic, partition)]
- broker = self.brokers[leader.nodeId]
- broker.close()
- time.sleep(1) # give it some time
- return broker
-
- def _count_messages(self, group, topic):
- hosts = '%s:%d' % (self.brokers[0].host, self.brokers[0].port)
- client = KafkaClient(hosts)
- consumer = SimpleConsumer(client, group, topic, auto_commit=False, iter_timeout=0)
- all_messages = []
- for message in consumer:
- all_messages.append(message)
- consumer.stop()
- client.close()
- return len(all_messages)
-
-if __name__ == "__main__":
- logging.basicConfig(level=logging.DEBUG)
- unittest.main()
diff --git a/test/test_package.py b/test/test_package.py
new file mode 100644
index 0000000..a6a3a14
--- /dev/null
+++ b/test/test_package.py
@@ -0,0 +1,29 @@
+import unittest2
+
+class TestPackage(unittest2.TestCase):
+ def test_top_level_namespace(self):
+ import kafka as kafka1
+ self.assertEquals(kafka1.KafkaClient.__name__, "KafkaClient")
+ self.assertEquals(kafka1.client.__name__, "kafka.client")
+ self.assertEquals(kafka1.codec.__name__, "kafka.codec")
+
+ def test_submodule_namespace(self):
+ import kafka.client as client1
+ self.assertEquals(client1.__name__, "kafka.client")
+ self.assertEquals(client1.KafkaClient.__name__, "KafkaClient")
+
+ from kafka import client as client2
+ self.assertEquals(client2.__name__, "kafka.client")
+ self.assertEquals(client2.KafkaClient.__name__, "KafkaClient")
+
+ from kafka.client import KafkaClient as KafkaClient1
+ self.assertEquals(KafkaClient1.__name__, "KafkaClient")
+
+ from kafka.codec import gzip_encode as gzip_encode1
+ self.assertEquals(gzip_encode1.__name__, "gzip_encode")
+
+ from kafka import KafkaClient as KafkaClient2
+ self.assertEquals(KafkaClient2.__name__, "KafkaClient")
+
+ from kafka.codec import snappy_encode
+ self.assertEquals(snappy_encode.__name__, "snappy_encode")
diff --git a/test/test_producer_integration.py b/test/test_producer_integration.py
new file mode 100644
index 0000000..c69e117
--- /dev/null
+++ b/test/test_producer_integration.py
@@ -0,0 +1,404 @@
+import os
+import time
+import uuid
+
+from kafka import * # noqa
+from kafka.common import * # noqa
+from kafka.codec import has_gzip, has_snappy
+from fixtures import ZookeeperFixture, KafkaFixture
+from testutil import *
+
+class TestKafkaProducerIntegration(KafkaIntegrationTestCase):
+ topic = 'produce_topic'
+
+ @classmethod
+ def setUpClass(cls): # noqa
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ cls.zk = ZookeeperFixture.instance()
+ cls.server = KafkaFixture.instance(0, cls.zk.host, cls.zk.port)
+
+ @classmethod
+ def tearDownClass(cls): # noqa
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ cls.server.close()
+ cls.zk.close()
+
+ @kafka_versions("all")
+ def test_produce_many_simple(self):
+ start_offset = self.current_offset(self.topic, 0)
+
+ self.assert_produce_request(
+ [ create_message("Test message %d" % i) for i in range(100) ],
+ start_offset,
+ 100,
+ )
+
+ self.assert_produce_request(
+ [ create_message("Test message %d" % i) for i in range(100) ],
+ start_offset+100,
+ 100,
+ )
+
+ @kafka_versions("all")
+ def test_produce_10k_simple(self):
+ start_offset = self.current_offset(self.topic, 0)
+
+ self.assert_produce_request(
+ [ create_message("Test message %d" % i) for i in range(10000) ],
+ start_offset,
+ 10000,
+ )
+
+ @kafka_versions("all")
+ def test_produce_many_gzip(self):
+ start_offset = self.current_offset(self.topic, 0)
+
+ message1 = create_gzip_message(["Gzipped 1 %d" % i for i in range(100)])
+ message2 = create_gzip_message(["Gzipped 2 %d" % i for i in range(100)])
+
+ self.assert_produce_request(
+ [ message1, message2 ],
+ start_offset,
+ 200,
+ )
+
+ @kafka_versions("all")
+ def test_produce_many_snappy(self):
+ self.skipTest("All snappy integration tests fail with nosnappyjava")
+ start_offset = self.current_offset(self.topic, 0)
+
+ self.assert_produce_request([
+ create_snappy_message(["Snappy 1 %d" % i for i in range(100)]),
+ create_snappy_message(["Snappy 2 %d" % i for i in range(100)]),
+ ],
+ start_offset,
+ 200,
+ )
+
+ @kafka_versions("all")
+ def test_produce_mixed(self):
+ start_offset = self.current_offset(self.topic, 0)
+
+ msg_count = 1+100
+ messages = [
+ create_message("Just a plain message"),
+ create_gzip_message(["Gzipped %d" % i for i in range(100)]),
+ ]
+
+ # All snappy integration tests fail with nosnappyjava
+ if False and has_snappy():
+ msg_count += 100
+ messages.append(create_snappy_message(["Snappy %d" % i for i in range(100)]))
+
+ self.assert_produce_request(messages, start_offset, msg_count)
+
+ @kafka_versions("all")
+ def test_produce_100k_gzipped(self):
+ start_offset = self.current_offset(self.topic, 0)
+
+ self.assert_produce_request([
+ create_gzip_message(["Gzipped batch 1, message %d" % i for i in range(50000)])
+ ],
+ start_offset,
+ 50000,
+ )
+
+ self.assert_produce_request([
+ create_gzip_message(["Gzipped batch 1, message %d" % i for i in range(50000)])
+ ],
+ start_offset+50000,
+ 50000,
+ )
+
+ ############################
+ # SimpleProducer Tests #
+ ############################
+
+ @kafka_versions("all")
+ def test_simple_producer(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+ producer = SimpleProducer(self.client)
+
+ # Goes to first partition, randomly.
+ resp = producer.send_messages(self.topic, self.msg("one"), self.msg("two"))
+ self.assert_produce_response(resp, start_offset0)
+
+ # Goes to the next partition, randomly.
+ resp = producer.send_messages(self.topic, self.msg("three"))
+ self.assert_produce_response(resp, start_offset1)
+
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one"), self.msg("two") ])
+ self.assert_fetch_offset(1, start_offset1, [ self.msg("three") ])
+
+ # Goes back to the first partition because there's only two partitions
+ resp = producer.send_messages(self.topic, self.msg("four"), self.msg("five"))
+ self.assert_produce_response(resp, start_offset0+2)
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one"), self.msg("two"), self.msg("four"), self.msg("five") ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_producer_random_order(self):
+ producer = SimpleProducer(self.client, random_start = True)
+ resp1 = producer.send_messages(self.topic, self.msg("one"), self.msg("two"))
+ resp2 = producer.send_messages(self.topic, self.msg("three"))
+ resp3 = producer.send_messages(self.topic, self.msg("four"), self.msg("five"))
+
+ self.assertEqual(resp1[0].partition, resp3[0].partition)
+ self.assertNotEqual(resp1[0].partition, resp2[0].partition)
+
+ @kafka_versions("all")
+ def test_producer_ordered_start(self):
+ producer = SimpleProducer(self.client, random_start = False)
+ resp1 = producer.send_messages(self.topic, self.msg("one"), self.msg("two"))
+ resp2 = producer.send_messages(self.topic, self.msg("three"))
+ resp3 = producer.send_messages(self.topic, self.msg("four"), self.msg("five"))
+
+ self.assertEqual(resp1[0].partition, 0)
+ self.assertEqual(resp2[0].partition, 1)
+ self.assertEqual(resp3[0].partition, 0)
+
+ @kafka_versions("all")
+ def test_round_robin_partitioner(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = KeyedProducer(self.client, partitioner=RoundRobinPartitioner)
+ resp1 = producer.send(self.topic, "key1", self.msg("one"))
+ resp2 = producer.send(self.topic, "key2", self.msg("two"))
+ resp3 = producer.send(self.topic, "key3", self.msg("three"))
+ resp4 = producer.send(self.topic, "key4", self.msg("four"))
+
+ self.assert_produce_response(resp1, start_offset0+0)
+ self.assert_produce_response(resp2, start_offset1+0)
+ self.assert_produce_response(resp3, start_offset0+1)
+ self.assert_produce_response(resp4, start_offset1+1)
+
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one"), self.msg("three") ])
+ self.assert_fetch_offset(1, start_offset1, [ self.msg("two"), self.msg("four") ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_hashed_partitioner(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = KeyedProducer(self.client, partitioner=HashedPartitioner)
+ resp1 = producer.send(self.topic, 1, self.msg("one"))
+ resp2 = producer.send(self.topic, 2, self.msg("two"))
+ resp3 = producer.send(self.topic, 3, self.msg("three"))
+ resp4 = producer.send(self.topic, 3, self.msg("four"))
+ resp5 = producer.send(self.topic, 4, self.msg("five"))
+
+ self.assert_produce_response(resp1, start_offset1+0)
+ self.assert_produce_response(resp2, start_offset0+0)
+ self.assert_produce_response(resp3, start_offset1+1)
+ self.assert_produce_response(resp4, start_offset1+2)
+ self.assert_produce_response(resp5, start_offset0+1)
+
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("two"), self.msg("five") ])
+ self.assert_fetch_offset(1, start_offset1, [ self.msg("one"), self.msg("three"), self.msg("four") ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_acks_none(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = SimpleProducer(self.client, req_acks=SimpleProducer.ACK_NOT_REQUIRED)
+ resp = producer.send_messages(self.topic, self.msg("one"))
+ self.assertEquals(len(resp), 0)
+
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one") ])
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_acks_local_write(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = SimpleProducer(self.client, req_acks=SimpleProducer.ACK_AFTER_LOCAL_WRITE)
+ resp = producer.send_messages(self.topic, self.msg("one"))
+
+ self.assert_produce_response(resp, start_offset0)
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one") ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_acks_cluster_commit(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = SimpleProducer(
+ self.client,
+ req_acks=SimpleProducer.ACK_AFTER_CLUSTER_COMMIT)
+
+ resp = producer.send_messages(self.topic, self.msg("one"))
+ self.assert_produce_response(resp, start_offset0)
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one") ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_batched_simple_producer__triggers_by_message(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = SimpleProducer(self.client,
+ batch_send=True,
+ batch_send_every_n=5,
+ batch_send_every_t=20)
+
+ # Send 5 messages and do a fetch
+ resp = producer.send_messages(self.topic,
+ self.msg("one"),
+ self.msg("two"),
+ self.msg("three"),
+ self.msg("four"),
+ )
+
+ # Batch mode is async. No ack
+ self.assertEquals(len(resp), 0)
+
+ # It hasn't sent yet
+ self.assert_fetch_offset(0, start_offset0, [])
+ self.assert_fetch_offset(1, start_offset1, [])
+
+ resp = producer.send_messages(self.topic,
+ self.msg("five"),
+ self.msg("six"),
+ self.msg("seven"),
+ )
+
+ # Batch mode is async. No ack
+ self.assertEquals(len(resp), 0)
+
+ self.assert_fetch_offset(0, start_offset0, [
+ self.msg("one"),
+ self.msg("two"),
+ self.msg("three"),
+ self.msg("four"),
+ ])
+
+ self.assert_fetch_offset(1, start_offset1, [
+ self.msg("five"),
+ # self.msg("six"),
+ # self.msg("seven"),
+ ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_batched_simple_producer__triggers_by_time(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = SimpleProducer(self.client,
+ batch_send=True,
+ batch_send_every_n=100,
+ batch_send_every_t=5)
+
+ # Send 5 messages and do a fetch
+ resp = producer.send_messages(self.topic,
+ self.msg("one"),
+ self.msg("two"),
+ self.msg("three"),
+ self.msg("four"),
+ )
+
+ # Batch mode is async. No ack
+ self.assertEquals(len(resp), 0)
+
+ # It hasn't sent yet
+ self.assert_fetch_offset(0, start_offset0, [])
+ self.assert_fetch_offset(1, start_offset1, [])
+
+ resp = producer.send_messages(self.topic,
+ self.msg("five"),
+ self.msg("six"),
+ self.msg("seven"),
+ )
+
+ # Batch mode is async. No ack
+ self.assertEquals(len(resp), 0)
+
+ # Wait the timeout out
+ time.sleep(5)
+
+ self.assert_fetch_offset(0, start_offset0, [
+ self.msg("one"),
+ self.msg("two"),
+ self.msg("three"),
+ self.msg("four"),
+ ])
+
+ self.assert_fetch_offset(1, start_offset1, [
+ self.msg("five"),
+ self.msg("six"),
+ self.msg("seven"),
+ ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_async_simple_producer(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = SimpleProducer(self.client, async=True)
+ resp = producer.send_messages(self.topic, self.msg("one"))
+ self.assertEquals(len(resp), 0)
+
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one") ])
+
+ producer.stop()
+
+ @kafka_versions("all")
+ def test_async_keyed_producer(self):
+ start_offset0 = self.current_offset(self.topic, 0)
+ start_offset1 = self.current_offset(self.topic, 1)
+
+ producer = KeyedProducer(self.client, partitioner = RoundRobinPartitioner, async=True)
+
+ resp = producer.send(self.topic, "key1", self.msg("one"))
+ self.assertEquals(len(resp), 0)
+
+ self.assert_fetch_offset(0, start_offset0, [ self.msg("one") ])
+
+ producer.stop()
+
+ def assert_produce_request(self, messages, initial_offset, message_ct):
+ produce = ProduceRequest(self.topic, 0, messages=messages)
+
+ # There should only be one response message from the server.
+ # This will throw an exception if there's more than one.
+ resp = self.client.send_produce_request([ produce ])
+ self.assert_produce_response(resp, initial_offset)
+
+ self.assertEqual(self.current_offset(self.topic, 0), initial_offset + message_ct)
+
+ def assert_produce_response(self, resp, initial_offset):
+ self.assertEqual(len(resp), 1)
+ self.assertEqual(resp[0].error, 0)
+ self.assertEqual(resp[0].offset, initial_offset)
+
+ def assert_fetch_offset(self, partition, start_offset, expected_messages):
+ # There should only be one response message from the server.
+ # This will throw an exception if there's more than one.
+
+ resp, = self.client.send_fetch_request([ FetchRequest(self.topic, partition, start_offset, 1024) ])
+
+ self.assertEquals(resp.error, 0)
+ self.assertEquals(resp.partition, partition)
+ messages = [ x.message.value for x in resp.messages ]
+
+ self.assertEqual(messages, expected_messages)
+ self.assertEquals(resp.highwaterMark, start_offset+len(expected_messages))
diff --git a/test/test_protocol.py b/test/test_protocol.py
new file mode 100644
index 0000000..8bd2f5e
--- /dev/null
+++ b/test/test_protocol.py
@@ -0,0 +1,694 @@
+import struct
+import unittest2
+
+from kafka import KafkaClient
+from kafka.common import (
+ OffsetRequest, OffsetCommitRequest, OffsetFetchRequest,
+ OffsetResponse, OffsetCommitResponse, OffsetFetchResponse,
+ ProduceRequest, FetchRequest, Message, ChecksumError,
+ ConsumerFetchSizeTooSmall, ProduceResponse, FetchResponse,
+ OffsetAndMessage, BrokerMetadata, PartitionMetadata,
+ TopicAndPartition, KafkaUnavailableError, ProtocolError,
+ LeaderUnavailableError, PartitionUnavailableError
+)
+from kafka.codec import (
+ has_snappy, gzip_encode, gzip_decode,
+ snappy_encode, snappy_decode
+)
+from kafka.protocol import (
+ create_gzip_message, create_message, create_snappy_message, KafkaProtocol
+)
+
+class TestProtocol(unittest2.TestCase):
+ def test_create_message(self):
+ payload = "test"
+ key = "key"
+ msg = create_message(payload, key)
+ self.assertEqual(msg.magic, 0)
+ self.assertEqual(msg.attributes, 0)
+ self.assertEqual(msg.key, key)
+ self.assertEqual(msg.value, payload)
+
+ def test_create_gzip(self):
+ payloads = ["v1", "v2"]
+ msg = create_gzip_message(payloads)
+ self.assertEqual(msg.magic, 0)
+ self.assertEqual(msg.attributes, KafkaProtocol.ATTRIBUTE_CODEC_MASK &
+ KafkaProtocol.CODEC_GZIP)
+ self.assertEqual(msg.key, None)
+ # Need to decode to check since gzipped payload is non-deterministic
+ decoded = gzip_decode(msg.value)
+ expect = "".join([
+ struct.pack(">q", 0), # MsgSet offset
+ struct.pack(">i", 16), # MsgSet size
+ struct.pack(">i", 1285512130), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", -1), # -1 indicates a null key
+ struct.pack(">i", 2), # Msg length (bytes)
+ "v1", # Message contents
+
+ struct.pack(">q", 0), # MsgSet offset
+ struct.pack(">i", 16), # MsgSet size
+ struct.pack(">i", -711587208), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", -1), # -1 indicates a null key
+ struct.pack(">i", 2), # Msg length (bytes)
+ "v2", # Message contents
+ ])
+
+ self.assertEqual(decoded, expect)
+
+ @unittest2.skipUnless(has_snappy(), "Snappy not available")
+ def test_create_snappy(self):
+ payloads = ["v1", "v2"]
+ msg = create_snappy_message(payloads)
+ self.assertEqual(msg.magic, 0)
+ self.assertEqual(msg.attributes, KafkaProtocol.ATTRIBUTE_CODEC_MASK &
+ KafkaProtocol.CODEC_SNAPPY)
+ self.assertEqual(msg.key, None)
+ decoded = snappy_decode(msg.value)
+ expect = "".join([
+ struct.pack(">q", 0), # MsgSet offset
+ struct.pack(">i", 16), # MsgSet size
+ struct.pack(">i", 1285512130), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", -1), # -1 indicates a null key
+ struct.pack(">i", 2), # Msg length (bytes)
+ "v1", # Message contents
+
+ struct.pack(">q", 0), # MsgSet offset
+ struct.pack(">i", 16), # MsgSet size
+ struct.pack(">i", -711587208), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", -1), # -1 indicates a null key
+ struct.pack(">i", 2), # Msg length (bytes)
+ "v2", # Message contents
+ ])
+
+ self.assertEqual(decoded, expect)
+
+ def test_encode_message_header(self):
+ expect = "".join([
+ struct.pack(">h", 10), # API Key
+ struct.pack(">h", 0), # API Version
+ struct.pack(">i", 4), # Correlation Id
+ struct.pack(">h", len("client3")), # Length of clientId
+ "client3", # ClientId
+ ])
+
+ encoded = KafkaProtocol._encode_message_header("client3", 4, 10)
+ self.assertEqual(encoded, expect)
+
+ def test_encode_message(self):
+ message = create_message("test", "key")
+ encoded = KafkaProtocol._encode_message(message)
+ expect = "".join([
+ struct.pack(">i", -1427009701), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 3), # Length of key
+ "key", # key
+ struct.pack(">i", 4), # Length of value
+ "test", # value
+ ])
+
+ self.assertEqual(encoded, expect)
+
+ def test_decode_message(self):
+ encoded = "".join([
+ struct.pack(">i", -1427009701), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 3), # Length of key
+ "key", # key
+ struct.pack(">i", 4), # Length of value
+ "test", # value
+ ])
+
+ offset = 10
+ (returned_offset, decoded_message) = list(KafkaProtocol._decode_message(encoded, offset))[0]
+
+ self.assertEqual(returned_offset, offset)
+ self.assertEqual(decoded_message, create_message("test", "key"))
+
+ def test_encode_message_failure(self):
+ with self.assertRaises(ProtocolError):
+ KafkaProtocol._encode_message(Message(1, 0, "key", "test"))
+
+ def test_encode_message_set(self):
+ message_set = [
+ create_message("v1", "k1"),
+ create_message("v2", "k2")
+ ]
+
+ encoded = KafkaProtocol._encode_message_set(message_set)
+ expect = "".join([
+ struct.pack(">q", 0), # MsgSet Offset
+ struct.pack(">i", 18), # Msg Size
+ struct.pack(">i", 1474775406), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 2), # Length of key
+ "k1", # Key
+ struct.pack(">i", 2), # Length of value
+ "v1", # Value
+
+ struct.pack(">q", 0), # MsgSet Offset
+ struct.pack(">i", 18), # Msg Size
+ struct.pack(">i", -16383415), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 2), # Length of key
+ "k2", # Key
+ struct.pack(">i", 2), # Length of value
+ "v2", # Value
+ ])
+
+ self.assertEqual(encoded, expect)
+
+ def test_decode_message_set(self):
+ encoded = "".join([
+ struct.pack(">q", 0), # MsgSet Offset
+ struct.pack(">i", 18), # Msg Size
+ struct.pack(">i", 1474775406), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 2), # Length of key
+ "k1", # Key
+ struct.pack(">i", 2), # Length of value
+ "v1", # Value
+
+ struct.pack(">q", 1), # MsgSet Offset
+ struct.pack(">i", 18), # Msg Size
+ struct.pack(">i", -16383415), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 2), # Length of key
+ "k2", # Key
+ struct.pack(">i", 2), # Length of value
+ "v2", # Value
+ ])
+
+ msgs = list(KafkaProtocol._decode_message_set_iter(encoded))
+ self.assertEqual(len(msgs), 2)
+ msg1, msg2 = msgs
+
+ returned_offset1, decoded_message1 = msg1
+ returned_offset2, decoded_message2 = msg2
+
+ self.assertEqual(returned_offset1, 0)
+ self.assertEqual(decoded_message1, create_message("v1", "k1"))
+
+ self.assertEqual(returned_offset2, 1)
+ self.assertEqual(decoded_message2, create_message("v2", "k2"))
+
+ def test_decode_message_gzip(self):
+ gzip_encoded = ('\xc0\x11\xb2\xf0\x00\x01\xff\xff\xff\xff\x00\x00\x000'
+ '\x1f\x8b\x08\x00\xa1\xc1\xc5R\x02\xffc`\x80\x03\x01'
+ '\x9f\xf9\xd1\x87\x18\x18\xfe\x03\x01\x90\xc7Tf\xc8'
+ '\x80$wu\x1aW\x05\x92\x9c\x11\x00z\xc0h\x888\x00\x00'
+ '\x00')
+ offset = 11
+ messages = list(KafkaProtocol._decode_message(gzip_encoded, offset))
+
+ self.assertEqual(len(messages), 2)
+ msg1, msg2 = messages
+
+ returned_offset1, decoded_message1 = msg1
+ self.assertEqual(returned_offset1, 0)
+ self.assertEqual(decoded_message1, create_message("v1"))
+
+ returned_offset2, decoded_message2 = msg2
+ self.assertEqual(returned_offset2, 0)
+ self.assertEqual(decoded_message2, create_message("v2"))
+
+ @unittest2.skipUnless(has_snappy(), "Snappy not available")
+ def test_decode_message_snappy(self):
+ snappy_encoded = ('\xec\x80\xa1\x95\x00\x02\xff\xff\xff\xff\x00\x00'
+ '\x00,8\x00\x00\x19\x01@\x10L\x9f[\xc2\x00\x00\xff'
+ '\xff\xff\xff\x00\x00\x00\x02v1\x19\x1bD\x00\x10\xd5'
+ '\x96\nx\x00\x00\xff\xff\xff\xff\x00\x00\x00\x02v2')
+ offset = 11
+ messages = list(KafkaProtocol._decode_message(snappy_encoded, offset))
+ self.assertEqual(len(messages), 2)
+
+ msg1, msg2 = messages
+
+ returned_offset1, decoded_message1 = msg1
+ self.assertEqual(returned_offset1, 0)
+ self.assertEqual(decoded_message1, create_message("v1"))
+
+ returned_offset2, decoded_message2 = msg2
+ self.assertEqual(returned_offset2, 0)
+ self.assertEqual(decoded_message2, create_message("v2"))
+
+ def test_decode_message_checksum_error(self):
+ invalid_encoded_message = "This is not a valid encoded message"
+ iter = KafkaProtocol._decode_message(invalid_encoded_message, 0)
+ self.assertRaises(ChecksumError, list, iter)
+
+ # NOTE: The error handling in _decode_message_set_iter() is questionable.
+ # If it's modified, the next two tests might need to be fixed.
+ def test_decode_message_set_fetch_size_too_small(self):
+ with self.assertRaises(ConsumerFetchSizeTooSmall):
+ list(KafkaProtocol._decode_message_set_iter('a'))
+
+ def test_decode_message_set_stop_iteration(self):
+ encoded = "".join([
+ struct.pack(">q", 0), # MsgSet Offset
+ struct.pack(">i", 18), # Msg Size
+ struct.pack(">i", 1474775406), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 2), # Length of key
+ "k1", # Key
+ struct.pack(">i", 2), # Length of value
+ "v1", # Value
+
+ struct.pack(">q", 1), # MsgSet Offset
+ struct.pack(">i", 18), # Msg Size
+ struct.pack(">i", -16383415), # CRC
+ struct.pack(">bb", 0, 0), # Magic, flags
+ struct.pack(">i", 2), # Length of key
+ "k2", # Key
+ struct.pack(">i", 2), # Length of value
+ "v2", # Value
+ "@1$%(Y!", # Random padding
+ ])
+
+ msgs = list(KafkaProtocol._decode_message_set_iter(encoded))
+ self.assertEqual(len(msgs), 2)
+ msg1, msg2 = msgs
+
+ returned_offset1, decoded_message1 = msg1
+ returned_offset2, decoded_message2 = msg2
+
+ self.assertEqual(returned_offset1, 0)
+ self.assertEqual(decoded_message1, create_message("v1", "k1"))
+
+ self.assertEqual(returned_offset2, 1)
+ self.assertEqual(decoded_message2, create_message("v2", "k2"))
+
+ def test_encode_produce_request(self):
+ requests = [
+ ProduceRequest("topic1", 0, [
+ create_message("a"),
+ create_message("b")
+ ]),
+ ProduceRequest("topic2", 1, [
+ create_message("c")
+ ])
+ ]
+
+ msg_a_binary = KafkaProtocol._encode_message(create_message("a"))
+ msg_b_binary = KafkaProtocol._encode_message(create_message("b"))
+ msg_c_binary = KafkaProtocol._encode_message(create_message("c"))
+
+ header = "".join([
+ struct.pack('>i', 0x94), # The length of the message overall
+ struct.pack('>h', 0), # Msg Header, Message type = Produce
+ struct.pack('>h', 0), # Msg Header, API version
+ struct.pack('>i', 2), # Msg Header, Correlation ID
+ struct.pack('>h7s', 7, "client1"), # Msg Header, The client ID
+ struct.pack('>h', 2), # Num acks required
+ struct.pack('>i', 100), # Request Timeout
+ struct.pack('>i', 2), # The number of requests
+ ])
+
+ total_len = len(msg_a_binary) + len(msg_b_binary)
+ topic1 = "".join([
+ struct.pack('>h6s', 6, 'topic1'), # The topic1
+ struct.pack('>i', 1), # One message set
+ struct.pack('>i', 0), # Partition 0
+ struct.pack('>i', total_len + 24), # Size of the incoming message set
+ struct.pack('>q', 0), # No offset specified
+ struct.pack('>i', len(msg_a_binary)), # Length of message
+ msg_a_binary, # Actual message
+ struct.pack('>q', 0), # No offset specified
+ struct.pack('>i', len(msg_b_binary)), # Length of message
+ msg_b_binary, # Actual message
+ ])
+
+ topic2 = "".join([
+ struct.pack('>h6s', 6, 'topic2'), # The topic1
+ struct.pack('>i', 1), # One message set
+ struct.pack('>i', 1), # Partition 1
+ struct.pack('>i', len(msg_c_binary) + 12), # Size of the incoming message set
+ struct.pack('>q', 0), # No offset specified
+ struct.pack('>i', len(msg_c_binary)), # Length of message
+ msg_c_binary, # Actual message
+ ])
+
+ expected1 = "".join([ header, topic1, topic2 ])
+ expected2 = "".join([ header, topic2, topic1 ])
+
+ encoded = KafkaProtocol.encode_produce_request("client1", 2, requests, 2, 100)
+ self.assertIn(encoded, [ expected1, expected2 ])
+
+ def test_decode_produce_response(self):
+ t1 = "topic1"
+ t2 = "topic2"
+ encoded = struct.pack('>iih%dsiihqihqh%dsiihq' % (len(t1), len(t2)),
+ 2, 2, len(t1), t1, 2, 0, 0, 10L, 1, 1, 20L,
+ len(t2), t2, 1, 0, 0, 30L)
+ responses = list(KafkaProtocol.decode_produce_response(encoded))
+ self.assertEqual(responses,
+ [ProduceResponse(t1, 0, 0, 10L),
+ ProduceResponse(t1, 1, 1, 20L),
+ ProduceResponse(t2, 0, 0, 30L)])
+
+ def test_encode_fetch_request(self):
+ requests = [
+ FetchRequest("topic1", 0, 10, 1024),
+ FetchRequest("topic2", 1, 20, 100),
+ ]
+
+ header = "".join([
+ struct.pack('>i', 89), # The length of the message overall
+ struct.pack('>h', 1), # Msg Header, Message type = Fetch
+ struct.pack('>h', 0), # Msg Header, API version
+ struct.pack('>i', 3), # Msg Header, Correlation ID
+ struct.pack('>h7s', 7, "client1"), # Msg Header, The client ID
+ struct.pack('>i', -1), # Replica Id
+ struct.pack('>i', 2), # Max wait time
+ struct.pack('>i', 100), # Min bytes
+ struct.pack('>i', 2), # Num requests
+ ])
+
+ topic1 = "".join([
+ struct.pack('>h6s', 6, 'topic1'), # Topic
+ struct.pack('>i', 1), # Num Payloads
+ struct.pack('>i', 0), # Partition 0
+ struct.pack('>q', 10), # Offset
+ struct.pack('>i', 1024), # Max Bytes
+ ])
+
+ topic2 = "".join([
+ struct.pack('>h6s', 6, 'topic2'), # Topic
+ struct.pack('>i', 1), # Num Payloads
+ struct.pack('>i', 1), # Partition 0
+ struct.pack('>q', 20), # Offset
+ struct.pack('>i', 100), # Max Bytes
+ ])
+
+ expected1 = "".join([ header, topic1, topic2 ])
+ expected2 = "".join([ header, topic2, topic1 ])
+
+ encoded = KafkaProtocol.encode_fetch_request("client1", 3, requests, 2, 100)
+ self.assertIn(encoded, [ expected1, expected2 ])
+
+ def test_decode_fetch_response(self):
+ t1 = "topic1"
+ t2 = "topic2"
+ msgs = map(create_message, ["message1", "hi", "boo", "foo", "so fun!"])
+ ms1 = KafkaProtocol._encode_message_set([msgs[0], msgs[1]])
+ ms2 = KafkaProtocol._encode_message_set([msgs[2]])
+ ms3 = KafkaProtocol._encode_message_set([msgs[3], msgs[4]])
+
+ encoded = struct.pack('>iih%dsiihqi%dsihqi%dsh%dsiihqi%ds' %
+ (len(t1), len(ms1), len(ms2), len(t2), len(ms3)),
+ 4, 2, len(t1), t1, 2, 0, 0, 10, len(ms1), ms1, 1,
+ 1, 20, len(ms2), ms2, len(t2), t2, 1, 0, 0, 30,
+ len(ms3), ms3)
+
+ responses = list(KafkaProtocol.decode_fetch_response(encoded))
+ def expand_messages(response):
+ return FetchResponse(response.topic, response.partition,
+ response.error, response.highwaterMark,
+ list(response.messages))
+
+ expanded_responses = map(expand_messages, responses)
+ expect = [FetchResponse(t1, 0, 0, 10, [OffsetAndMessage(0, msgs[0]),
+ OffsetAndMessage(0, msgs[1])]),
+ FetchResponse(t1, 1, 1, 20, [OffsetAndMessage(0, msgs[2])]),
+ FetchResponse(t2, 0, 0, 30, [OffsetAndMessage(0, msgs[3]),
+ OffsetAndMessage(0, msgs[4])])]
+ self.assertEqual(expanded_responses, expect)
+
+ def test_encode_metadata_request_no_topics(self):
+ expected = "".join([
+ struct.pack(">i", 17), # Total length of the request
+ struct.pack('>h', 3), # API key metadata fetch
+ struct.pack('>h', 0), # API version
+ struct.pack('>i', 4), # Correlation ID
+ struct.pack('>h3s', 3, "cid"), # The client ID
+ struct.pack('>i', 0), # No topics, give all the data!
+ ])
+
+ encoded = KafkaProtocol.encode_metadata_request("cid", 4)
+
+ self.assertEqual(encoded, expected)
+
+ def test_encode_metadata_request_with_topics(self):
+ expected = "".join([
+ struct.pack(">i", 25), # Total length of the request
+ struct.pack('>h', 3), # API key metadata fetch
+ struct.pack('>h', 0), # API version
+ struct.pack('>i', 4), # Correlation ID
+ struct.pack('>h3s', 3, "cid"), # The client ID
+ struct.pack('>i', 2), # Number of topics in the request
+ struct.pack('>h2s', 2, "t1"), # Topic "t1"
+ struct.pack('>h2s', 2, "t2"), # Topic "t2"
+ ])
+
+ encoded = KafkaProtocol.encode_metadata_request("cid", 4, ["t1", "t2"])
+
+ self.assertEqual(encoded, expected)
+
+ def _create_encoded_metadata_response(self, broker_data, topic_data,
+ topic_errors, partition_errors):
+ encoded = struct.pack('>ii', 3, len(broker_data))
+ for node_id, broker in broker_data.iteritems():
+ encoded += struct.pack('>ih%dsi' % len(broker.host), node_id,
+ len(broker.host), broker.host, broker.port)
+
+ encoded += struct.pack('>i', len(topic_data))
+ for topic, partitions in topic_data.iteritems():
+ encoded += struct.pack('>hh%dsi' % len(topic), topic_errors[topic],
+ len(topic), topic, len(partitions))
+ for partition, metadata in partitions.iteritems():
+ encoded += struct.pack('>hiii',
+ partition_errors[(topic, partition)],
+ partition, metadata.leader,
+ len(metadata.replicas))
+ if len(metadata.replicas) > 0:
+ encoded += struct.pack('>%di' % len(metadata.replicas),
+ *metadata.replicas)
+
+ encoded += struct.pack('>i', len(metadata.isr))
+ if len(metadata.isr) > 0:
+ encoded += struct.pack('>%di' % len(metadata.isr),
+ *metadata.isr)
+
+ return encoded
+
+ def test_decode_metadata_response(self):
+ node_brokers = {
+ 0: BrokerMetadata(0, "brokers1.kafka.rdio.com", 1000),
+ 1: BrokerMetadata(1, "brokers1.kafka.rdio.com", 1001),
+ 3: BrokerMetadata(3, "brokers2.kafka.rdio.com", 1000)
+ }
+
+ topic_partitions = {
+ "topic1": {
+ 0: PartitionMetadata("topic1", 0, 1, (0, 2), (2,)),
+ 1: PartitionMetadata("topic1", 1, 3, (0, 1), (0, 1))
+ },
+ "topic2": {
+ 0: PartitionMetadata("topic2", 0, 0, (), ())
+ }
+ }
+ topic_errors = {"topic1": 0, "topic2": 1}
+ partition_errors = {
+ ("topic1", 0): 0,
+ ("topic1", 1): 1,
+ ("topic2", 0): 0
+ }
+ encoded = self._create_encoded_metadata_response(node_brokers,
+ topic_partitions,
+ topic_errors,
+ partition_errors)
+ decoded = KafkaProtocol.decode_metadata_response(encoded)
+ self.assertEqual(decoded, (node_brokers, topic_partitions))
+
+ def test_encode_offset_request(self):
+ expected = "".join([
+ struct.pack(">i", 21), # Total length of the request
+ struct.pack('>h', 2), # Message type = offset fetch
+ struct.pack('>h', 0), # API version
+ struct.pack('>i', 4), # Correlation ID
+ struct.pack('>h3s', 3, "cid"), # The client ID
+ struct.pack('>i', -1), # Replica Id
+ struct.pack('>i', 0), # No topic/partitions
+ ])
+
+ encoded = KafkaProtocol.encode_offset_request("cid", 4)
+
+ self.assertEqual(encoded, expected)
+
+ def test_encode_offset_request__no_payload(self):
+ expected = "".join([
+ struct.pack(">i", 65), # Total length of the request
+
+ struct.pack('>h', 2), # Message type = offset fetch
+ struct.pack('>h', 0), # API version
+ struct.pack('>i', 4), # Correlation ID
+ struct.pack('>h3s', 3, "cid"), # The client ID
+ struct.pack('>i', -1), # Replica Id
+ struct.pack('>i', 1), # Num topics
+ struct.pack(">h6s", 6, "topic1"), # Topic for the request
+ struct.pack(">i", 2), # Two partitions
+
+ struct.pack(">i", 3), # Partition 3
+ struct.pack(">q", -1), # No time offset
+ struct.pack(">i", 1), # One offset requested
+
+ struct.pack(">i", 4), # Partition 3
+ struct.pack(">q", -1), # No time offset
+ struct.pack(">i", 1), # One offset requested
+ ])
+
+ encoded = KafkaProtocol.encode_offset_request("cid", 4, [
+ OffsetRequest('topic1', 3, -1, 1),
+ OffsetRequest('topic1', 4, -1, 1),
+ ])
+
+ self.assertEqual(encoded, expected)
+
+ def test_decode_offset_response(self):
+ encoded = "".join([
+ struct.pack(">i", 42), # Correlation ID
+ struct.pack(">i", 1), # One topics
+ struct.pack(">h6s", 6, "topic1"), # First topic
+ struct.pack(">i", 2), # Two partitions
+
+ struct.pack(">i", 2), # Partition 2
+ struct.pack(">h", 0), # No error
+ struct.pack(">i", 1), # One offset
+ struct.pack(">q", 4), # Offset 4
+
+ struct.pack(">i", 4), # Partition 4
+ struct.pack(">h", 0), # No error
+ struct.pack(">i", 1), # One offset
+ struct.pack(">q", 8), # Offset 8
+ ])
+
+ results = KafkaProtocol.decode_offset_response(encoded)
+ self.assertEqual(set(results), set([
+ OffsetResponse(topic = 'topic1', partition = 2, error = 0, offsets=(4,)),
+ OffsetResponse(topic = 'topic1', partition = 4, error = 0, offsets=(8,)),
+ ]))
+
+ def test_encode_offset_commit_request(self):
+ header = "".join([
+ struct.pack('>i', 99), # Total message length
+
+ struct.pack('>h', 8), # Message type = offset commit
+ struct.pack('>h', 0), # API version
+ struct.pack('>i', 42), # Correlation ID
+ struct.pack('>h9s', 9, "client_id"), # The client ID
+ struct.pack('>h8s', 8, "group_id"), # The group to commit for
+ struct.pack('>i', 2), # Num topics
+ ])
+
+ topic1 = "".join([
+ struct.pack(">h6s", 6, "topic1"), # Topic for the request
+ struct.pack(">i", 2), # Two partitions
+ struct.pack(">i", 0), # Partition 0
+ struct.pack(">q", 123), # Offset 123
+ struct.pack(">h", -1), # Null metadata
+ struct.pack(">i", 1), # Partition 1
+ struct.pack(">q", 234), # Offset 234
+ struct.pack(">h", -1), # Null metadata
+ ])
+
+ topic2 = "".join([
+ struct.pack(">h6s", 6, "topic2"), # Topic for the request
+ struct.pack(">i", 1), # One partition
+ struct.pack(">i", 2), # Partition 2
+ struct.pack(">q", 345), # Offset 345
+ struct.pack(">h", -1), # Null metadata
+ ])
+
+ expected1 = "".join([ header, topic1, topic2 ])
+ expected2 = "".join([ header, topic2, topic1 ])
+
+ encoded = KafkaProtocol.encode_offset_commit_request("client_id", 42, "group_id", [
+ OffsetCommitRequest("topic1", 0, 123, None),
+ OffsetCommitRequest("topic1", 1, 234, None),
+ OffsetCommitRequest("topic2", 2, 345, None),
+ ])
+
+ self.assertIn(encoded, [ expected1, expected2 ])
+
+ def test_decode_offset_commit_response(self):
+ encoded = "".join([
+ struct.pack(">i", 42), # Correlation ID
+ struct.pack(">i", 1), # One topic
+ struct.pack(">h6s", 6, "topic1"), # First topic
+ struct.pack(">i", 2), # Two partitions
+
+ struct.pack(">i", 2), # Partition 2
+ struct.pack(">h", 0), # No error
+
+ struct.pack(">i", 4), # Partition 4
+ struct.pack(">h", 0), # No error
+ ])
+
+ results = KafkaProtocol.decode_offset_commit_response(encoded)
+ self.assertEqual(set(results), set([
+ OffsetCommitResponse(topic = 'topic1', partition = 2, error = 0),
+ OffsetCommitResponse(topic = 'topic1', partition = 4, error = 0),
+ ]))
+
+ def test_encode_offset_fetch_request(self):
+ header = "".join([
+ struct.pack('>i', 69), # Total message length
+ struct.pack('>h', 9), # Message type = offset fetch
+ struct.pack('>h', 0), # API version
+ struct.pack('>i', 42), # Correlation ID
+ struct.pack('>h9s', 9, "client_id"), # The client ID
+ struct.pack('>h8s', 8, "group_id"), # The group to commit for
+ struct.pack('>i', 2), # Num topics
+ ])
+
+ topic1 = "".join([
+ struct.pack(">h6s", 6, "topic1"), # Topic for the request
+ struct.pack(">i", 2), # Two partitions
+ struct.pack(">i", 0), # Partition 0
+ struct.pack(">i", 1), # Partition 1
+ ])
+
+ topic2 = "".join([
+ struct.pack(">h6s", 6, "topic2"), # Topic for the request
+ struct.pack(">i", 1), # One partitions
+ struct.pack(">i", 2), # Partition 2
+ ])
+
+ expected1 = "".join([ header, topic1, topic2 ])
+ expected2 = "".join([ header, topic2, topic1 ])
+
+ encoded = KafkaProtocol.encode_offset_fetch_request("client_id", 42, "group_id", [
+ OffsetFetchRequest("topic1", 0),
+ OffsetFetchRequest("topic1", 1),
+ OffsetFetchRequest("topic2", 2),
+ ])
+
+ self.assertIn(encoded, [ expected1, expected2 ])
+
+ def test_decode_offset_fetch_response(self):
+ encoded = "".join([
+ struct.pack(">i", 42), # Correlation ID
+ struct.pack(">i", 1), # One topics
+ struct.pack(">h6s", 6, "topic1"), # First topic
+ struct.pack(">i", 2), # Two partitions
+
+ struct.pack(">i", 2), # Partition 2
+ struct.pack(">q", 4), # Offset 4
+ struct.pack(">h4s", 4, "meta"), # Metadata
+ struct.pack(">h", 0), # No error
+
+ struct.pack(">i", 4), # Partition 4
+ struct.pack(">q", 8), # Offset 8
+ struct.pack(">h4s", 4, "meta"), # Metadata
+ struct.pack(">h", 0), # No error
+ ])
+
+ results = KafkaProtocol.decode_offset_fetch_response(encoded)
+ self.assertEqual(set(results), set([
+ OffsetFetchResponse(topic = 'topic1', partition = 2, offset = 4, error = 0, metadata = "meta"),
+ OffsetFetchResponse(topic = 'topic1', partition = 4, offset = 8, error = 0, metadata = "meta"),
+ ]))
diff --git a/test/test_unit.py b/test/test_unit.py
deleted file mode 100644
index 8c0dd00..0000000
--- a/test/test_unit.py
+++ /dev/null
@@ -1,674 +0,0 @@
-import os
-import random
-import struct
-import unittest
-
-from mock import MagicMock, patch
-
-from kafka import KafkaClient
-from kafka.common import (
- ProduceRequest, FetchRequest, Message, ChecksumError,
- ConsumerFetchSizeTooSmall, ProduceResponse, FetchResponse,
- OffsetAndMessage, BrokerMetadata, PartitionMetadata,
- TopicAndPartition, KafkaUnavailableError,
- LeaderUnavailableError, PartitionUnavailableError
-)
-from kafka.codec import (
- has_gzip, has_snappy, gzip_encode, gzip_decode,
- snappy_encode, snappy_decode
-)
-from kafka.protocol import (
- create_gzip_message, create_message, create_snappy_message, KafkaProtocol
-)
-
-ITERATIONS = 1000
-STRLEN = 100
-
-
-def random_string():
- return os.urandom(random.randint(1, STRLEN))
-
-
-class TestPackage(unittest.TestCase):
-
- def test_top_level_namespace(self):
- import kafka as kafka1
- self.assertEquals(kafka1.KafkaClient.__name__, "KafkaClient")
- self.assertEquals(kafka1.client.__name__, "kafka.client")
- self.assertEquals(kafka1.codec.__name__, "kafka.codec")
-
- def test_submodule_namespace(self):
- import kafka.client as client1
- self.assertEquals(client1.__name__, "kafka.client")
- self.assertEquals(client1.KafkaClient.__name__, "KafkaClient")
-
- from kafka import client as client2
- self.assertEquals(client2.__name__, "kafka.client")
- self.assertEquals(client2.KafkaClient.__name__, "KafkaClient")
-
- from kafka.client import KafkaClient as KafkaClient1
- self.assertEquals(KafkaClient1.__name__, "KafkaClient")
-
- from kafka.codec import gzip_encode as gzip_encode1
- self.assertEquals(gzip_encode1.__name__, "gzip_encode")
-
- from kafka import KafkaClient as KafkaClient2
- self.assertEquals(KafkaClient2.__name__, "KafkaClient")
-
- from kafka.codec import snappy_encode
- self.assertEquals(snappy_encode.__name__, "snappy_encode")
-
-
-class TestCodec(unittest.TestCase):
-
- @unittest.skipUnless(has_gzip(), "Gzip not available")
- def test_gzip(self):
- for i in xrange(ITERATIONS):
- s1 = random_string()
- s2 = gzip_decode(gzip_encode(s1))
- self.assertEquals(s1, s2)
-
- @unittest.skipUnless(has_snappy(), "Snappy not available")
- def test_snappy(self):
- for i in xrange(ITERATIONS):
- s1 = random_string()
- s2 = snappy_decode(snappy_encode(s1))
- self.assertEquals(s1, s2)
-
- @unittest.skipUnless(has_snappy(), "Snappy not available")
- def test_snappy_detect_xerial(self):
- import kafka as kafka1
- _detect_xerial_stream = kafka1.codec._detect_xerial_stream
-
- header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes'
- false_header = b'\x01SNAPPY\x00\x00\x00\x01\x00\x00\x00\x01'
- random_snappy = snappy_encode('SNAPPY' * 50)
- short_data = b'\x01\x02\x03\x04'
-
- self.assertTrue(_detect_xerial_stream(header))
- self.assertFalse(_detect_xerial_stream(b''))
- self.assertFalse(_detect_xerial_stream(b'\x00'))
- self.assertFalse(_detect_xerial_stream(false_header))
- self.assertFalse(_detect_xerial_stream(random_snappy))
- self.assertFalse(_detect_xerial_stream(short_data))
-
- @unittest.skipUnless(has_snappy(), "Snappy not available")
- def test_snappy_decode_xerial(self):
- header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01'
- random_snappy = snappy_encode('SNAPPY' * 50)
- block_len = len(random_snappy)
- random_snappy2 = snappy_encode('XERIAL' * 50)
- block_len2 = len(random_snappy2)
-
- to_test = header \
- + struct.pack('!i', block_len) + random_snappy \
- + struct.pack('!i', block_len2) + random_snappy2 \
-
- self.assertEquals(snappy_decode(to_test), ('SNAPPY' * 50) + ('XERIAL' * 50))
-
- @unittest.skipUnless(has_snappy(), "Snappy not available")
- def test_snappy_encode_xerial(self):
- to_ensure = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' + \
- '\x00\x00\x00\x18' + \
- '\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' + \
- '\x00\x00\x00\x18' + \
- '\xac\x02\x14XERIAL\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00'
-
- to_test = ('SNAPPY' * 50) + ('XERIAL' * 50)
-
- compressed = snappy_encode(to_test, xerial_compatible=True, xerial_blocksize=300)
- self.assertEquals(compressed, to_ensure)
-
-class TestProtocol(unittest.TestCase):
-
- def test_create_message(self):
- payload = "test"
- key = "key"
- msg = create_message(payload, key)
- self.assertEqual(msg.magic, 0)
- self.assertEqual(msg.attributes, 0)
- self.assertEqual(msg.key, key)
- self.assertEqual(msg.value, payload)
-
- @unittest.skipUnless(has_gzip(), "Snappy not available")
- def test_create_gzip(self):
- payloads = ["v1", "v2"]
- msg = create_gzip_message(payloads)
- self.assertEqual(msg.magic, 0)
- self.assertEqual(msg.attributes, KafkaProtocol.ATTRIBUTE_CODEC_MASK &
- KafkaProtocol.CODEC_GZIP)
- self.assertEqual(msg.key, None)
- # Need to decode to check since gzipped payload is non-deterministic
- decoded = gzip_decode(msg.value)
- expect = ("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10L\x9f[\xc2"
- "\x00\x00\xff\xff\xff\xff\x00\x00\x00\x02v1\x00\x00\x00\x00"
- "\x00\x00\x00\x00\x00\x00\x00\x10\xd5\x96\nx\x00\x00\xff\xff"
- "\xff\xff\x00\x00\x00\x02v2")
- self.assertEqual(decoded, expect)
-
- @unittest.skipUnless(has_snappy(), "Snappy not available")
- def test_create_snappy(self):
- payloads = ["v1", "v2"]
- msg = create_snappy_message(payloads)
- self.assertEqual(msg.magic, 0)
- self.assertEqual(msg.attributes, KafkaProtocol.ATTRIBUTE_CODEC_MASK &
- KafkaProtocol.CODEC_SNAPPY)
- self.assertEqual(msg.key, None)
- expect = ("8\x00\x00\x19\x01@\x10L\x9f[\xc2\x00\x00\xff\xff\xff\xff"
- "\x00\x00\x00\x02v1\x19\x1bD\x00\x10\xd5\x96\nx\x00\x00\xff"
- "\xff\xff\xff\x00\x00\x00\x02v2")
- self.assertEqual(msg.value, expect)
-
- def test_encode_message_header(self):
- expect = '\x00\n\x00\x00\x00\x00\x00\x04\x00\x07client3'
- encoded = KafkaProtocol._encode_message_header("client3", 4, 10)
- self.assertEqual(encoded, expect)
-
- def test_encode_message(self):
- message = create_message("test", "key")
- encoded = KafkaProtocol._encode_message(message)
- expect = "\xaa\xf1\x8f[\x00\x00\x00\x00\x00\x03key\x00\x00\x00\x04test"
- self.assertEqual(encoded, expect)
-
- def test_encode_message_failure(self):
- self.assertRaises(Exception, KafkaProtocol._encode_message,
- Message(1, 0, "key", "test"))
-
- def test_encode_message_set(self):
- message_set = [create_message("v1", "k1"), create_message("v2", "k2")]
- encoded = KafkaProtocol._encode_message_set(message_set)
- expect = ("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x12W\xe7In\x00"
- "\x00\x00\x00\x00\x02k1\x00\x00\x00\x02v1\x00\x00\x00\x00"
- "\x00\x00\x00\x00\x00\x00\x00\x12\xff\x06\x02I\x00\x00\x00"
- "\x00\x00\x02k2\x00\x00\x00\x02v2")
- self.assertEqual(encoded, expect)
-
- def test_decode_message(self):
- encoded = "\xaa\xf1\x8f[\x00\x00\x00\x00\x00\x03key\x00\x00\x00\x04test"
- offset = 10
- (returned_offset, decoded_message) = \
- list(KafkaProtocol._decode_message(encoded, offset))[0]
- self.assertEqual(returned_offset, offset)
- self.assertEqual(decoded_message, create_message("test", "key"))
-
- def test_decode_message_set(self):
- encoded = ('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10L\x9f[\xc2'
- '\x00\x00\xff\xff\xff\xff\x00\x00\x00\x02v1\x00\x00\x00\x00'
- '\x00\x00\x00\x00\x00\x00\x00\x10\xd5\x96\nx\x00\x00\xff'
- '\xff\xff\xff\x00\x00\x00\x02v2')
- iter = KafkaProtocol._decode_message_set_iter(encoded)
- decoded = list(iter)
- self.assertEqual(len(decoded), 2)
- (returned_offset1, decoded_message1) = decoded[0]
- self.assertEqual(returned_offset1, 0)
- self.assertEqual(decoded_message1, create_message("v1"))
- (returned_offset2, decoded_message2) = decoded[1]
- self.assertEqual(returned_offset2, 0)
- self.assertEqual(decoded_message2, create_message("v2"))
-
- @unittest.skipUnless(has_gzip(), "Gzip not available")
- def test_decode_message_gzip(self):
- gzip_encoded = ('\xc0\x11\xb2\xf0\x00\x01\xff\xff\xff\xff\x00\x00\x000'
- '\x1f\x8b\x08\x00\xa1\xc1\xc5R\x02\xffc`\x80\x03\x01'
- '\x9f\xf9\xd1\x87\x18\x18\xfe\x03\x01\x90\xc7Tf\xc8'
- '\x80$wu\x1aW\x05\x92\x9c\x11\x00z\xc0h\x888\x00\x00'
- '\x00')
- offset = 11
- decoded = list(KafkaProtocol._decode_message(gzip_encoded, offset))
- self.assertEqual(len(decoded), 2)
- (returned_offset1, decoded_message1) = decoded[0]
- self.assertEqual(returned_offset1, 0)
- self.assertEqual(decoded_message1, create_message("v1"))
- (returned_offset2, decoded_message2) = decoded[1]
- self.assertEqual(returned_offset2, 0)
- self.assertEqual(decoded_message2, create_message("v2"))
-
- @unittest.skipUnless(has_snappy(), "Snappy not available")
- def test_decode_message_snappy(self):
- snappy_encoded = ('\xec\x80\xa1\x95\x00\x02\xff\xff\xff\xff\x00\x00'
- '\x00,8\x00\x00\x19\x01@\x10L\x9f[\xc2\x00\x00\xff'
- '\xff\xff\xff\x00\x00\x00\x02v1\x19\x1bD\x00\x10\xd5'
- '\x96\nx\x00\x00\xff\xff\xff\xff\x00\x00\x00\x02v2')
- offset = 11
- decoded = list(KafkaProtocol._decode_message(snappy_encoded, offset))
- self.assertEqual(len(decoded), 2)
- (returned_offset1, decoded_message1) = decoded[0]
- self.assertEqual(returned_offset1, 0)
- self.assertEqual(decoded_message1, create_message("v1"))
- (returned_offset2, decoded_message2) = decoded[1]
- self.assertEqual(returned_offset2, 0)
- self.assertEqual(decoded_message2, create_message("v2"))
-
- def test_decode_message_checksum_error(self):
- invalid_encoded_message = "This is not a valid encoded message"
- iter = KafkaProtocol._decode_message(invalid_encoded_message, 0)
- self.assertRaises(ChecksumError, list, iter)
-
- # NOTE: The error handling in _decode_message_set_iter() is questionable.
- # If it's modified, the next two tests might need to be fixed.
- def test_decode_message_set_fetch_size_too_small(self):
- iter = KafkaProtocol._decode_message_set_iter('a')
- self.assertRaises(ConsumerFetchSizeTooSmall, list, iter)
-
- def test_decode_message_set_stop_iteration(self):
- encoded = ('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10L\x9f[\xc2'
- '\x00\x00\xff\xff\xff\xff\x00\x00\x00\x02v1\x00\x00\x00\x00'
- '\x00\x00\x00\x00\x00\x00\x00\x10\xd5\x96\nx\x00\x00\xff'
- '\xff\xff\xff\x00\x00\x00\x02v2')
- iter = KafkaProtocol._decode_message_set_iter(encoded + "@#$%(Y!")
- decoded = list(iter)
- self.assertEqual(len(decoded), 2)
- (returned_offset1, decoded_message1) = decoded[0]
- self.assertEqual(returned_offset1, 0)
- self.assertEqual(decoded_message1, create_message("v1"))
- (returned_offset2, decoded_message2) = decoded[1]
- self.assertEqual(returned_offset2, 0)
- self.assertEqual(decoded_message2, create_message("v2"))
-
- def test_encode_produce_request(self):
- requests = [ProduceRequest("topic1", 0, [create_message("a"),
- create_message("b")]),
- ProduceRequest("topic2", 1, [create_message("c")])]
- expect = ('\x00\x00\x00\x94\x00\x00\x00\x00\x00\x00\x00\x02\x00\x07'
- 'client1\x00\x02\x00\x00\x00d\x00\x00\x00\x02\x00\x06topic1'
- '\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x006\x00\x00\x00'
- '\x00\x00\x00\x00\x00\x00\x00\x00\x0fQ\xdf:2\x00\x00\xff\xff'
- '\xff\xff\x00\x00\x00\x01a\x00\x00\x00\x00\x00\x00\x00\x00'
- '\x00\x00\x00\x0f\xc8\xd6k\x88\x00\x00\xff\xff\xff\xff\x00'
- '\x00\x00\x01b\x00\x06topic2\x00\x00\x00\x01\x00\x00\x00\x01'
- '\x00\x00\x00\x1b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
- '\x00\x0f\xbf\xd1[\x1e\x00\x00\xff\xff\xff\xff\x00\x00\x00'
- '\x01c')
- encoded = KafkaProtocol.encode_produce_request("client1", 2, requests,
- 2, 100)
- self.assertEqual(encoded, expect)
-
- def test_decode_produce_response(self):
- t1 = "topic1"
- t2 = "topic2"
- encoded = struct.pack('>iih%dsiihqihqh%dsiihq' % (len(t1), len(t2)),
- 2, 2, len(t1), t1, 2, 0, 0, 10L, 1, 1, 20L,
- len(t2), t2, 1, 0, 0, 30L)
- responses = list(KafkaProtocol.decode_produce_response(encoded))
- self.assertEqual(responses,
- [ProduceResponse(t1, 0, 0, 10L),
- ProduceResponse(t1, 1, 1, 20L),
- ProduceResponse(t2, 0, 0, 30L)])
-
- def test_encode_fetch_request(self):
- requests = [FetchRequest("topic1", 0, 10, 1024),
- FetchRequest("topic2", 1, 20, 100)]
- expect = ('\x00\x00\x00Y\x00\x01\x00\x00\x00\x00\x00\x03\x00\x07'
- 'client1\xff\xff\xff\xff\x00\x00\x00\x02\x00\x00\x00d\x00'
- '\x00\x00\x02\x00\x06topic1\x00\x00\x00\x01\x00\x00\x00\x00'
- '\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x04\x00\x00\x06'
- 'topic2\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x00'
- '\x00\x00\x14\x00\x00\x00d')
- encoded = KafkaProtocol.encode_fetch_request("client1", 3, requests, 2,
- 100)
- self.assertEqual(encoded, expect)
-
- def test_decode_fetch_response(self):
- t1 = "topic1"
- t2 = "topic2"
- msgs = map(create_message, ["message1", "hi", "boo", "foo", "so fun!"])
- ms1 = KafkaProtocol._encode_message_set([msgs[0], msgs[1]])
- ms2 = KafkaProtocol._encode_message_set([msgs[2]])
- ms3 = KafkaProtocol._encode_message_set([msgs[3], msgs[4]])
-
- encoded = struct.pack('>iih%dsiihqi%dsihqi%dsh%dsiihqi%ds' %
- (len(t1), len(ms1), len(ms2), len(t2), len(ms3)),
- 4, 2, len(t1), t1, 2, 0, 0, 10, len(ms1), ms1, 1,
- 1, 20, len(ms2), ms2, len(t2), t2, 1, 0, 0, 30,
- len(ms3), ms3)
-
- responses = list(KafkaProtocol.decode_fetch_response(encoded))
- def expand_messages(response):
- return FetchResponse(response.topic, response.partition,
- response.error, response.highwaterMark,
- list(response.messages))
-
- expanded_responses = map(expand_messages, responses)
- expect = [FetchResponse(t1, 0, 0, 10, [OffsetAndMessage(0, msgs[0]),
- OffsetAndMessage(0, msgs[1])]),
- FetchResponse(t1, 1, 1, 20, [OffsetAndMessage(0, msgs[2])]),
- FetchResponse(t2, 0, 0, 30, [OffsetAndMessage(0, msgs[3]),
- OffsetAndMessage(0, msgs[4])])]
- self.assertEqual(expanded_responses, expect)
-
- def test_encode_metadata_request_no_topics(self):
- encoded = KafkaProtocol.encode_metadata_request("cid", 4)
- self.assertEqual(encoded, '\x00\x00\x00\x11\x00\x03\x00\x00\x00\x00'
- '\x00\x04\x00\x03cid\x00\x00\x00\x00')
-
- def test_encode_metadata_request_with_topics(self):
- encoded = KafkaProtocol.encode_metadata_request("cid", 4, ["t1", "t2"])
- self.assertEqual(encoded, '\x00\x00\x00\x19\x00\x03\x00\x00\x00\x00'
- '\x00\x04\x00\x03cid\x00\x00\x00\x02\x00\x02'
- 't1\x00\x02t2')
-
- def _create_encoded_metadata_response(self, broker_data, topic_data,
- topic_errors, partition_errors):
- encoded = struct.pack('>ii', 3, len(broker_data))
- for node_id, broker in broker_data.iteritems():
- encoded += struct.pack('>ih%dsi' % len(broker.host), node_id,
- len(broker.host), broker.host, broker.port)
-
- encoded += struct.pack('>i', len(topic_data))
- for topic, partitions in topic_data.iteritems():
- encoded += struct.pack('>hh%dsi' % len(topic), topic_errors[topic],
- len(topic), topic, len(partitions))
- for partition, metadata in partitions.iteritems():
- encoded += struct.pack('>hiii',
- partition_errors[(topic, partition)],
- partition, metadata.leader,
- len(metadata.replicas))
- if len(metadata.replicas) > 0:
- encoded += struct.pack('>%di' % len(metadata.replicas),
- *metadata.replicas)
-
- encoded += struct.pack('>i', len(metadata.isr))
- if len(metadata.isr) > 0:
- encoded += struct.pack('>%di' % len(metadata.isr),
- *metadata.isr)
-
- return encoded
-
- def test_decode_metadata_response(self):
- node_brokers = {
- 0: BrokerMetadata(0, "brokers1.kafka.rdio.com", 1000),
- 1: BrokerMetadata(1, "brokers1.kafka.rdio.com", 1001),
- 3: BrokerMetadata(3, "brokers2.kafka.rdio.com", 1000)
- }
- topic_partitions = {
- "topic1": {
- 0: PartitionMetadata("topic1", 0, 1, (0, 2), (2,)),
- 1: PartitionMetadata("topic1", 1, 3, (0, 1), (0, 1))
- },
- "topic2": {
- 0: PartitionMetadata("topic2", 0, 0, (), ())
- }
- }
- topic_errors = {"topic1": 0, "topic2": 1}
- partition_errors = {
- ("topic1", 0): 0,
- ("topic1", 1): 1,
- ("topic2", 0): 0
- }
- encoded = self._create_encoded_metadata_response(node_brokers,
- topic_partitions,
- topic_errors,
- partition_errors)
- decoded = KafkaProtocol.decode_metadata_response(encoded)
- self.assertEqual(decoded, (node_brokers, topic_partitions))
-
- @unittest.skip("Not Implemented")
- def test_encode_offset_request(self):
- pass
-
- @unittest.skip("Not Implemented")
- def test_decode_offset_response(self):
- pass
-
-
- @unittest.skip("Not Implemented")
- def test_encode_offset_commit_request(self):
- pass
-
- @unittest.skip("Not Implemented")
- def test_decode_offset_commit_response(self):
- pass
-
- @unittest.skip("Not Implemented")
- def test_encode_offset_fetch_request(self):
- pass
-
- @unittest.skip("Not Implemented")
- def test_decode_offset_fetch_response(self):
- pass
-
-
-class TestKafkaClient(unittest.TestCase):
-
- def test_init_with_list(self):
-
- with patch.object(KafkaClient, 'load_metadata_for_topics'):
- client = KafkaClient(
- hosts=['kafka01:9092', 'kafka02:9092', 'kafka03:9092'])
-
- self.assertItemsEqual(
- [('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
- client.hosts)
-
- def test_init_with_csv(self):
-
- with patch.object(KafkaClient, 'load_metadata_for_topics'):
- client = KafkaClient(
- hosts='kafka01:9092,kafka02:9092,kafka03:9092')
-
- self.assertItemsEqual(
- [('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
- client.hosts)
-
- def test_init_with_unicode_csv(self):
-
- with patch.object(KafkaClient, 'load_metadata_for_topics'):
- client = KafkaClient(
- hosts=u'kafka01:9092,kafka02:9092,kafka03:9092')
-
- self.assertItemsEqual(
- [('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
- client.hosts)
-
- def test_send_broker_unaware_request_fail(self):
- 'Tests that call fails when all hosts are unavailable'
-
- mocked_conns = {
- ('kafka01', 9092): MagicMock(),
- ('kafka02', 9092): MagicMock()
- }
- # inject KafkaConnection side effects
- mocked_conns[('kafka01', 9092)].send.side_effect = RuntimeError("kafka01 went away (unittest)")
- mocked_conns[('kafka02', 9092)].send.side_effect = RuntimeError("Kafka02 went away (unittest)")
-
- def mock_get_conn(host, port):
- return mocked_conns[(host, port)]
-
- # patch to avoid making requests before we want it
- with patch.object(KafkaClient, 'load_metadata_for_topics'):
- with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn):
- client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092'])
-
- self.assertRaises(
- KafkaUnavailableError,
- client._send_broker_unaware_request,
- 1, 'fake request')
-
- for key, conn in mocked_conns.iteritems():
- conn.send.assert_called_with(1, 'fake request')
-
- def test_send_broker_unaware_request(self):
- 'Tests that call works when at least one of the host is available'
-
- mocked_conns = {
- ('kafka01', 9092): MagicMock(),
- ('kafka02', 9092): MagicMock(),
- ('kafka03', 9092): MagicMock()
- }
- # inject KafkaConnection side effects
- mocked_conns[('kafka01', 9092)].send.side_effect = RuntimeError("kafka01 went away (unittest)")
- mocked_conns[('kafka02', 9092)].recv.return_value = 'valid response'
- mocked_conns[('kafka03', 9092)].send.side_effect = RuntimeError("kafka03 went away (unittest)")
-
- def mock_get_conn(host, port):
- return mocked_conns[(host, port)]
-
- # patch to avoid making requests before we want it
- with patch.object(KafkaClient, 'load_metadata_for_topics'):
- with patch.object(KafkaClient, '_get_conn', side_effect=mock_get_conn):
- client = KafkaClient(hosts='kafka01:9092,kafka02:9092')
-
- resp = client._send_broker_unaware_request(1, 'fake request')
-
- self.assertEqual('valid response', resp)
- mocked_conns[('kafka02', 9092)].recv.assert_called_with(1)
-
- @patch('kafka.client.KafkaConnection')
- @patch('kafka.client.KafkaProtocol')
- def test_load_metadata(self, protocol, conn):
- "Load metadata for all topics"
-
- conn.recv.return_value = 'response' # anything but None
-
- brokers = {}
- brokers[0] = BrokerMetadata(1, 'broker_1', 4567)
- brokers[1] = BrokerMetadata(2, 'broker_2', 5678)
-
- topics = {}
- topics['topic_1'] = {
- 0: PartitionMetadata('topic_1', 0, 1, [1, 2], [1, 2])
- }
- topics['topic_noleader'] = {
- 0: PartitionMetadata('topic_noleader', 0, -1, [], []),
- 1: PartitionMetadata('topic_noleader', 1, -1, [], [])
- }
- topics['topic_no_partitions'] = {}
- topics['topic_3'] = {
- 0: PartitionMetadata('topic_3', 0, 0, [0, 1], [0, 1]),
- 1: PartitionMetadata('topic_3', 1, 1, [1, 0], [1, 0]),
- 2: PartitionMetadata('topic_3', 2, 0, [0, 1], [0, 1])
- }
- protocol.decode_metadata_response.return_value = (brokers, topics)
-
- # client loads metadata at init
- client = KafkaClient(hosts=['broker_1:4567'])
- self.assertDictEqual({
- TopicAndPartition('topic_1', 0): brokers[1],
- TopicAndPartition('topic_noleader', 0): None,
- TopicAndPartition('topic_noleader', 1): None,
- TopicAndPartition('topic_3', 0): brokers[0],
- TopicAndPartition('topic_3', 1): brokers[1],
- TopicAndPartition('topic_3', 2): brokers[0]},
- client.topics_to_brokers)
-
- @patch('kafka.client.KafkaConnection')
- @patch('kafka.client.KafkaProtocol')
- def test_get_leader_for_partitions_reloads_metadata(self, protocol, conn):
- "Get leader for partitions reload metadata if it is not available"
-
- conn.recv.return_value = 'response' # anything but None
-
- brokers = {}
- brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
- brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
-
- topics = {'topic_no_partitions': {}}
- protocol.decode_metadata_response.return_value = (brokers, topics)
-
- client = KafkaClient(hosts=['broker_1:4567'])
-
- # topic metadata is loaded but empty
- self.assertDictEqual({}, client.topics_to_brokers)
-
- topics['topic_no_partitions'] = {
- 0: PartitionMetadata('topic_no_partitions', 0, 0, [0, 1], [0, 1])
- }
- protocol.decode_metadata_response.return_value = (brokers, topics)
-
- # calling _get_leader_for_partition (from any broker aware request)
- # will try loading metadata again for the same topic
- leader = client._get_leader_for_partition('topic_no_partitions', 0)
-
- self.assertEqual(brokers[0], leader)
- self.assertDictEqual({
- TopicAndPartition('topic_no_partitions', 0): brokers[0]},
- client.topics_to_brokers)
-
- @patch('kafka.client.KafkaConnection')
- @patch('kafka.client.KafkaProtocol')
- def test_get_leader_for_unassigned_partitions(self, protocol, conn):
- "Get leader raises if no partitions is defined for a topic"
-
- conn.recv.return_value = 'response' # anything but None
-
- brokers = {}
- brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
- brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
-
- topics = {'topic_no_partitions': {}}
- protocol.decode_metadata_response.return_value = (brokers, topics)
-
- client = KafkaClient(hosts=['broker_1:4567'])
-
- self.assertDictEqual({}, client.topics_to_brokers)
- self.assertRaises(
- PartitionUnavailableError,
- client._get_leader_for_partition,
- 'topic_no_partitions', 0)
-
- @patch('kafka.client.KafkaConnection')
- @patch('kafka.client.KafkaProtocol')
- def test_get_leader_returns_none_when_noleader(self, protocol, conn):
- "Getting leader for partitions returns None when the partiion has no leader"
-
- conn.recv.return_value = 'response' # anything but None
-
- brokers = {}
- brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
- brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
-
- topics = {}
- topics['topic_noleader'] = {
- 0: PartitionMetadata('topic_noleader', 0, -1, [], []),
- 1: PartitionMetadata('topic_noleader', 1, -1, [], [])
- }
- protocol.decode_metadata_response.return_value = (brokers, topics)
-
- client = KafkaClient(hosts=['broker_1:4567'])
- self.assertDictEqual(
- {
- TopicAndPartition('topic_noleader', 0): None,
- TopicAndPartition('topic_noleader', 1): None
- },
- client.topics_to_brokers)
- self.assertIsNone(client._get_leader_for_partition('topic_noleader', 0))
- self.assertIsNone(client._get_leader_for_partition('topic_noleader', 1))
-
- topics['topic_noleader'] = {
- 0: PartitionMetadata('topic_noleader', 0, 0, [0, 1], [0, 1]),
- 1: PartitionMetadata('topic_noleader', 1, 1, [1, 0], [1, 0])
- }
- protocol.decode_metadata_response.return_value = (brokers, topics)
- self.assertEqual(brokers[0], client._get_leader_for_partition('topic_noleader', 0))
- self.assertEqual(brokers[1], client._get_leader_for_partition('topic_noleader', 1))
-
- @patch('kafka.client.KafkaConnection')
- @patch('kafka.client.KafkaProtocol')
- def test_send_produce_request_raises_when_noleader(self, protocol, conn):
- "Send producer request raises LeaderUnavailableError if leader is not available"
-
- conn.recv.return_value = 'response' # anything but None
-
- brokers = {}
- brokers[0] = BrokerMetadata(0, 'broker_1', 4567)
- brokers[1] = BrokerMetadata(1, 'broker_2', 5678)
-
- topics = {}
- topics['topic_noleader'] = {
- 0: PartitionMetadata('topic_noleader', 0, -1, [], []),
- 1: PartitionMetadata('topic_noleader', 1, -1, [], [])
- }
- protocol.decode_metadata_response.return_value = (brokers, topics)
-
- client = KafkaClient(hosts=['broker_1:4567'])
-
- requests = [ProduceRequest(
- "topic_noleader", 0,
- [create_message("a"), create_message("b")])]
-
- self.assertRaises(
- LeaderUnavailableError,
- client.send_produce_request, requests)
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/test/test_util.py b/test/test_util.py
new file mode 100644
index 0000000..8179b01
--- /dev/null
+++ b/test/test_util.py
@@ -0,0 +1,102 @@
+import os
+import random
+import struct
+import unittest2
+import kafka.util
+import kafka.common
+
+class UtilTest(unittest2.TestCase):
+ @unittest2.skip("Unwritten")
+ def test_relative_unpack(self):
+ pass
+
+ def test_write_int_string(self):
+ self.assertEqual(
+ kafka.util.write_int_string('some string'),
+ '\x00\x00\x00\x0bsome string'
+ )
+
+ def test_write_int_string__empty(self):
+ self.assertEqual(
+ kafka.util.write_int_string(''),
+ '\x00\x00\x00\x00'
+ )
+
+ def test_write_int_string__null(self):
+ self.assertEqual(
+ kafka.util.write_int_string(None),
+ '\xff\xff\xff\xff'
+ )
+
+ def test_read_int_string(self):
+ self.assertEqual(kafka.util.read_int_string('\xff\xff\xff\xff', 0), (None, 4))
+ self.assertEqual(kafka.util.read_int_string('\x00\x00\x00\x00', 0), ('', 4))
+ self.assertEqual(kafka.util.read_int_string('\x00\x00\x00\x0bsome string', 0), ('some string', 15))
+
+ def test_read_int_string__insufficient_data(self):
+ with self.assertRaises(kafka.common.BufferUnderflowError):
+ kafka.util.read_int_string('\x00\x00\x00\x021', 0)
+
+ def test_write_short_string(self):
+ self.assertEqual(
+ kafka.util.write_short_string('some string'),
+ '\x00\x0bsome string'
+ )
+
+ def test_write_short_string__empty(self):
+ self.assertEqual(
+ kafka.util.write_short_string(''),
+ '\x00\x00'
+ )
+
+ def test_write_short_string__null(self):
+ self.assertEqual(
+ kafka.util.write_short_string(None),
+ '\xff\xff'
+ )
+
+ def test_write_short_string__too_long(self):
+ with self.assertRaises(struct.error):
+ kafka.util.write_short_string(' ' * 33000)
+
+ def test_read_short_string(self):
+ self.assertEqual(kafka.util.read_short_string('\xff\xff', 0), (None, 2))
+ self.assertEqual(kafka.util.read_short_string('\x00\x00', 0), ('', 2))
+ self.assertEqual(kafka.util.read_short_string('\x00\x0bsome string', 0), ('some string', 13))
+
+ def test_read_int_string__insufficient_data(self):
+ with self.assertRaises(kafka.common.BufferUnderflowError):
+ kafka.util.read_int_string('\x00\x021', 0)
+
+ def test_relative_unpack(self):
+ self.assertEqual(
+ kafka.util.relative_unpack('>hh', '\x00\x01\x00\x00\x02', 0),
+ ((1, 0), 4)
+ )
+
+ def test_relative_unpack(self):
+ with self.assertRaises(kafka.common.BufferUnderflowError):
+ kafka.util.relative_unpack('>hh', '\x00', 0)
+
+
+ def test_group_by_topic_and_partition(self):
+ t = kafka.common.TopicAndPartition
+
+ l = [
+ t("a", 1),
+ t("a", 1),
+ t("a", 2),
+ t("a", 3),
+ t("b", 3),
+ ]
+
+ self.assertEqual(kafka.util.group_by_topic_and_partition(l), {
+ "a" : {
+ 1 : t("a", 1),
+ 2 : t("a", 2),
+ 3 : t("a", 3),
+ },
+ "b" : {
+ 3 : t("b", 3),
+ }
+ })
diff --git a/test/testutil.py b/test/testutil.py
new file mode 100644
index 0000000..78e6f7d
--- /dev/null
+++ b/test/testutil.py
@@ -0,0 +1,108 @@
+import functools
+import logging
+import os
+import random
+import socket
+import string
+import time
+import unittest2
+import uuid
+
+from kafka.common import OffsetRequest
+from kafka import KafkaClient
+
+__all__ = [
+ 'random_string',
+ 'ensure_topic_creation',
+ 'get_open_port',
+ 'kafka_versions',
+ 'KafkaIntegrationTestCase',
+ 'Timer',
+]
+
+def random_string(l):
+ s = "".join(random.choice(string.letters) for i in xrange(l))
+ return s
+
+def kafka_versions(*versions):
+ def kafka_versions(func):
+ @functools.wraps(func)
+ def wrapper(self):
+ kafka_version = os.environ.get('KAFKA_VERSION')
+
+ if not kafka_version:
+ self.skipTest("no kafka version specified")
+ elif 'all' not in versions and kafka_version not in versions:
+ self.skipTest("unsupported kafka version")
+
+ return func(self)
+ return wrapper
+ return kafka_versions
+
+def ensure_topic_creation(client, topic_name, timeout = 30):
+ start_time = time.time()
+
+ client.load_metadata_for_topics(topic_name)
+ while not client.has_metadata_for_topic(topic_name):
+ if time.time() > start_time + timeout:
+ raise Exception("Unable to create topic %s" % topic_name)
+ client.load_metadata_for_topics(topic_name)
+ time.sleep(1)
+
+def get_open_port():
+ sock = socket.socket()
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ return port
+
+class KafkaIntegrationTestCase(unittest2.TestCase):
+ create_client = True
+ topic = None
+
+ def setUp(self):
+ super(KafkaIntegrationTestCase, self).setUp()
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ if not self.topic:
+ self.topic = "%s-%s" % (self.id()[self.id().rindex(".") + 1:], random_string(10))
+
+ if self.create_client:
+ self.client = KafkaClient('%s:%d' % (self.server.host, self.server.port))
+
+ ensure_topic_creation(self.client, self.topic)
+
+ self._messages = {}
+
+ def tearDown(self):
+ super(KafkaIntegrationTestCase, self).tearDown()
+ if not os.environ.get('KAFKA_VERSION'):
+ return
+
+ if self.create_client:
+ self.client.close()
+
+ def current_offset(self, topic, partition):
+ offsets, = self.client.send_offset_request([ OffsetRequest(topic, partition, -1, 1) ])
+ return offsets.offsets[0]
+
+ def msgs(self, iterable):
+ return [ self.msg(x) for x in iterable ]
+
+ def msg(self, s):
+ if s not in self._messages:
+ self._messages[s] = '%s-%s-%s' % (s, self.id(), str(uuid.uuid4()))
+
+ return self._messages[s]
+
+class Timer(object):
+ def __enter__(self):
+ self.start = time.time()
+ return self
+
+ def __exit__(self, *args):
+ self.end = time.time()
+ self.interval = self.end - self.start
+
+logging.basicConfig(level=logging.DEBUG)