summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@rd.io>2015-03-08 16:03:06 -0700
committerDana Powers <dana.powers@rd.io>2015-03-08 16:03:06 -0700
commit92aa7e94288cbfc4aed0dfbd52021d21694bced4 (patch)
treec33e7341b9624eb22f89051887e83449e8146a98
parent6ef982c5e8bde6ea50f721ddb4bb11b7fd51559b (diff)
parent5137163fa44b4a6a8a315c30f959e816f657e921 (diff)
downloadkafka-python-92aa7e94288cbfc4aed0dfbd52021d21694bced4.tar.gz
Merge branch 'vshlapakov-feature-async-threading'
PR 330: Threading for async batching Conflicts: kafka/producer/base.py
-rw-r--r--kafka/conn.py3
-rw-r--r--kafka/producer/base.py52
-rw-r--r--test/test_conn.py44
3 files changed, 73 insertions, 26 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index 30debec..ea55481 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -62,6 +62,9 @@ class KafkaConnection(local):
self.reinit()
+ def __getnewargs__(self):
+ return (self.host, self.port, self.timeout)
+
def __repr__(self):
return "<KafkaConnection host=%s port=%d>" % (self.host, self.port)
diff --git a/kafka/producer/base.py b/kafka/producer/base.py
index 695f195..e32b168 100644
--- a/kafka/producer/base.py
+++ b/kafka/producer/base.py
@@ -4,11 +4,12 @@ import logging
import time
try:
- from queue import Empty
+ from queue import Empty, Queue
except ImportError:
- from Queue import Empty
+ from Queue import Empty, Queue
from collections import defaultdict
-from multiprocessing import Queue, Process
+
+from threading import Thread, Event
import six
@@ -26,20 +27,15 @@ STOP_ASYNC_PRODUCER = -1
def _send_upstream(queue, client, codec, batch_time, batch_size,
- req_acks, ack_timeout):
+ req_acks, ack_timeout, stop_event):
"""
Listen on the queue for a specified number of messages or till
a specified timeout and send them upstream to the brokers in one
request
-
- NOTE: Ideally, this should have been a method inside the Producer
- class. However, multiprocessing module has issues in windows. The
- functionality breaks unless this function is kept outside of a class
"""
stop = False
- client.reinit()
- while not stop:
+ while not stop_event.is_set():
timeout = batch_time
count = batch_size
send_at = time.time() + timeout
@@ -56,7 +52,7 @@ def _send_upstream(queue, client, codec, batch_time, batch_size,
# Check if the controller has requested us to stop
if topic_partition == STOP_ASYNC_PRODUCER:
- stop = True
+ stop_event.set()
break
# Adjust the timeout to match the remaining period
@@ -141,18 +137,22 @@ class Producer(object):
log.warning("Current implementation does not retry Failed messages")
log.warning("Use at your own risk! (or help improve with a PR!)")
self.queue = Queue() # Messages are sent through this queue
- self.proc = Process(target=_send_upstream,
- args=(self.queue,
- self.client.copy(),
- self.codec,
- batch_send_every_t,
- batch_send_every_n,
- self.req_acks,
- self.ack_timeout))
-
- # Process will die if main thread exits
- self.proc.daemon = True
- self.proc.start()
+ self.thread_stop_event = Event()
+ self.thread = Thread(target=_send_upstream,
+ args=(self.queue,
+ self.client.copy(),
+ self.codec,
+ batch_send_every_t,
+ batch_send_every_n,
+ self.req_acks,
+ self.ack_timeout,
+ self.thread_stop_event))
+
+ # Thread will die if main thread exits
+ self.thread.daemon = True
+ self.thread.start()
+
+
def send_messages(self, topic, partition, *msg):
"""
@@ -209,10 +209,10 @@ class Producer(object):
"""
if self.async:
self.queue.put((STOP_ASYNC_PRODUCER, None, None))
- self.proc.join(timeout)
+ self.thread.join(timeout)
- if self.proc.is_alive():
- self.proc.terminate()
+ if self.thread.is_alive():
+ self.thread_stop_event.set()
self.stopped = True
def __del__(self):
diff --git a/test/test_conn.py b/test/test_conn.py
index 2c8f3b2..c4f219b 100644
--- a/test/test_conn.py
+++ b/test/test_conn.py
@@ -1,5 +1,6 @@
import socket
import struct
+from threading import Thread
import mock
from . import unittest
@@ -162,3 +163,46 @@ class ConnTest(unittest.TestCase):
self.conn.send(self.config['request_id'], self.config['payload'])
self.assertEqual(self.MockCreateConn.call_count, 1)
self.conn._sock.sendall.assert_called_with(self.config['payload'])
+
+
+class TestKafkaConnection(unittest.TestCase):
+
+ @mock.patch('socket.create_connection')
+ def test_copy(self, socket):
+ """KafkaConnection copies work as expected"""
+
+ conn = KafkaConnection('kafka', 9092)
+ self.assertEqual(socket.call_count, 1)
+
+ copy = conn.copy()
+ self.assertEqual(socket.call_count, 1)
+ self.assertEqual(copy.host, 'kafka')
+ self.assertEqual(copy.port, 9092)
+ self.assertEqual(copy._sock, None)
+
+ copy.reinit()
+ self.assertEqual(socket.call_count, 2)
+ self.assertNotEqual(copy._sock, None)
+
+ @mock.patch('socket.create_connection')
+ def test_copy_thread(self, socket):
+ """KafkaConnection copies work in other threads"""
+
+ err = []
+ copy = KafkaConnection('kafka', 9092).copy()
+
+ def thread_func(err, copy):
+ try:
+ self.assertEqual(copy.host, 'kafka')
+ self.assertEqual(copy.port, 9092)
+ self.assertNotEqual(copy._sock, None)
+ except Exception as e:
+ err.append(e)
+ else:
+ err.append(None)
+ thread = Thread(target=thread_func, args=(err, copy))
+ thread.start()
+ thread.join()
+
+ self.assertEqual(err, [None])
+ self.assertEqual(socket.call_count, 2)