summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kafka/client_async.py45
-rw-r--r--kafka/conn.py50
-rw-r--r--kafka/consumer/fetcher.py3
-rw-r--r--kafka/coordinator/base.py22
-rw-r--r--test/test_client_async.py84
-rw-r--r--test/test_conn.py318
-rw-r--r--test/test_conn_legacy.py242
7 files changed, 420 insertions, 344 deletions
diff --git a/kafka/client_async.py b/kafka/client_async.py
index 5a1d624..d70e4f2 100644
--- a/kafka/client_async.py
+++ b/kafka/client_async.py
@@ -152,8 +152,8 @@ class KafkaClient(object):
conn = self._conns[node_id]
return conn.state is ConnectionStates.DISCONNECTED and not conn.blacked_out()
- def _initiate_connect(self, node_id):
- """Initiate a connection to the given node (must be in metadata)"""
+ def _maybe_connect(self, node_id):
+ """Idempotent non-blocking connection attempt to the given node id."""
if node_id not in self._conns:
broker = self.cluster.broker_metadata(node_id)
assert broker, 'Broker id %s not in current metadata' % node_id
@@ -164,22 +164,21 @@ class KafkaClient(object):
host, port, afi = get_ip_port_afi(broker.host)
self._conns[node_id] = BrokerConnection(host, broker.port, afi,
**self.config)
- return self._finish_connect(node_id)
-
- def _finish_connect(self, node_id):
- assert node_id in self._conns, '%s is not in current conns' % node_id
state = self._conns[node_id].connect()
if state is ConnectionStates.CONNECTING:
self._connecting.add(node_id)
+
+ # Whether CONNECTED or DISCONNECTED, we need to remove from connecting
elif node_id in self._connecting:
log.debug("Node %s connection state is %s", node_id, state)
self._connecting.remove(node_id)
+ # Connection failures imply that our metadata is stale, so let's refresh
if state is ConnectionStates.DISCONNECTED:
log.warning("Node %s connect failed -- refreshing metadata", node_id)
self.cluster.request_update()
- return state
+ return self._conns[node_id].connected()
def ready(self, node_id):
"""Check whether a node is connected and ok to send more requests.
@@ -190,19 +189,15 @@ class KafkaClient(object):
Returns:
bool: True if we are ready to send to the given node
"""
- if self.is_ready(node_id):
- return True
-
- if self._can_connect(node_id):
- # if we are interested in sending to a node
- # and we don't have a connection to it, initiate one
- self._initiate_connect(node_id)
-
- if node_id in self._connecting:
- self._finish_connect(node_id)
-
+ self._maybe_connect(node_id)
return self.is_ready(node_id)
+ def connected(self, node_id):
+ """Return True iff the node_id is connected."""
+ if node_id not in self._conns:
+ return False
+ return self._conns[node_id].connected()
+
def close(self, node_id=None):
"""Closes the connection to a particular node (if there is one).
@@ -295,15 +290,13 @@ class KafkaClient(object):
request (Struct): request object (not-encoded)
Raises:
- NodeNotReadyError: if node_id is not ready
+ AssertionError: if node_id is not in current cluster metadata
Returns:
- Future: resolves to Response struct
+ Future: resolves to Response struct or Error
"""
- if not self._can_send_request(node_id):
- raise Errors.NodeNotReadyError("Attempt to send a request to node"
- " which is not ready (node id %s)."
- % node_id)
+ if not self._maybe_connect(node_id):
+ return Future().failure(Errors.NodeNotReadyError(node_id))
# Every request gets a response, except one special case:
expect_response = True
@@ -341,7 +334,7 @@ class KafkaClient(object):
# Attempt to complete pending connections
for node_id in list(self._connecting):
- self._finish_connect(node_id)
+ self._maybe_connect(node_id)
# Send a metadata request if needed
metadata_timeout_ms = self._maybe_refresh_metadata()
@@ -557,7 +550,7 @@ class KafkaClient(object):
elif self._can_connect(node_id):
log.debug("Initializing connection to node %s for metadata request", node_id)
- self._initiate_connect(node_id)
+ self._maybe_connect(node_id)
return 0
diff --git a/kafka/conn.py b/kafka/conn.py
index 0ce469d..2b82b6d 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -106,24 +106,22 @@ class BrokerConnection(object):
# in non-blocking mode, use repeated calls to socket.connect_ex
# to check connection status
request_timeout = self.config['request_timeout_ms'] / 1000.0
- if time.time() > request_timeout + self.last_attempt:
+ try:
+ ret = self._sock.connect_ex((self.host, self.port))
+ except socket.error as ret:
+ pass
+ if not ret or ret == errno.EISCONN:
+ self.state = ConnectionStates.CONNECTED
+ elif ret not in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022):
+ log.error('Connect attempt to %s returned error %s.'
+ ' Disconnecting.', self, ret)
+ self.close()
+ self.last_failure = time.time()
+ elif time.time() > request_timeout + self.last_attempt:
log.error('Connection attempt to %s timed out', self)
self.close() # error=TimeoutError ?
self.last_failure = time.time()
- else:
- try:
- ret = self._sock.connect_ex((self.host, self.port))
- except socket.error as ret:
- pass
- if not ret or ret == errno.EISCONN:
- self.state = ConnectionStates.CONNECTED
- # WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems
- elif ret not in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022):
- log.error('Connect attempt to %s returned error %s.'
- ' Disconnecting.', self, ret)
- self.close()
- self.last_failure = time.time()
return self.state
def blacked_out(self):
@@ -141,6 +139,10 @@ class BrokerConnection(object):
"""Return True iff socket is connected."""
return self.state is ConnectionStates.CONNECTED
+ def connecting(self):
+ """Return True iff socket is in intermediate connecting state."""
+ return self.state is ConnectionStates.CONNECTING
+
def close(self, error=None):
"""Close socket and fail all in-flight-requests.
@@ -158,7 +160,7 @@ class BrokerConnection(object):
self._rbuffer.seek(0)
self._rbuffer.truncate()
if error is None:
- error = Errors.ConnectionError()
+ error = Errors.ConnectionError(str(self))
while self.in_flight_requests:
ifr = self.in_flight_requests.popleft()
ifr.future.failure(error)
@@ -169,10 +171,12 @@ class BrokerConnection(object):
Can block on network if request is larger than send_buffer_bytes
"""
future = Future()
- if not self.connected():
- return future.failure(Errors.ConnectionError())
- if not self.can_send_more():
- return future.failure(Errors.TooManyInFlightRequests())
+ if self.connecting():
+ return future.failure(Errors.NodeNotReadyError(str(self)))
+ elif not self.connected():
+ return future.failure(Errors.ConnectionError(str(self)))
+ elif not self.can_send_more():
+ return future.failure(Errors.TooManyInFlightRequests(str(self)))
correlation_id = self._next_correlation_id()
header = RequestHeader(request,
correlation_id=correlation_id,
@@ -191,7 +195,7 @@ class BrokerConnection(object):
self._sock.setblocking(False)
except (AssertionError, ConnectionError) as e:
log.exception("Error sending %s to %s", request, self)
- error = Errors.ConnectionError(e)
+ error = Errors.ConnectionError("%s: %s" % (str(self), e))
self.close(error=error)
return future.failure(error)
log.debug('%s Request %d: %s', self, correlation_id, request)
@@ -324,11 +328,9 @@ class BrokerConnection(object):
' initialized on the broker')
elif ifr.correlation_id != recv_correlation_id:
-
-
error = Errors.CorrelationIdError(
- 'Correlation ids do not match: sent %d, recv %d'
- % (ifr.correlation_id, recv_correlation_id))
+ '%s: Correlation ids do not match: sent %d, recv %d'
+ % (str(self), ifr.correlation_id, recv_correlation_id))
ifr.future.failure(error)
self.close()
self._processing = False
diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py
index f406a30..7112c7e 100644
--- a/kafka/consumer/fetcher.py
+++ b/kafka/consumer/fetcher.py
@@ -479,9 +479,6 @@ class Fetcher(six.Iterator):
# so create a separate future and attach a callback to update it
# based on response error codes
future = Future()
- if not self._client.ready(node_id):
- return future.failure(Errors.NodeNotReadyError(node_id))
-
_f = self._client.send(node_id, request)
_f.add_callback(self._handle_offset_response, partition, future)
_f.add_errback(lambda e: future.failure(e))
diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py
index dca809e..b0a0981 100644
--- a/kafka/coordinator/base.py
+++ b/kafka/coordinator/base.py
@@ -186,7 +186,7 @@ class BaseCoordinator(object):
self.coordinator_dead()
return True
- return not self._client.ready(self.coordinator_id)
+ return False
def ensure_coordinator_known(self):
"""Block until the coordinator for this group is known
@@ -288,9 +288,13 @@ class BaseCoordinator(object):
return future
def _failed_request(self, node_id, request, future, error):
- log.error('Error sending %s to node %s [%s] -- marking coordinator dead',
+ log.error('Error sending %s to node %s [%s]',
request.__class__.__name__, node_id, error)
- self.coordinator_dead()
+ # Marking coordinator dead
+ # unless the error is caused by internal client pipelining
+ if not isinstance(error, (Errors.NodeNotReadyError,
+ Errors.TooManyInFlightRequests)):
+ self.coordinator_dead()
future.failure(error)
def _handle_join_group_response(self, future, response):
@@ -388,7 +392,8 @@ class BaseCoordinator(object):
def _send_sync_group_request(self, request):
if self.coordinator_unknown():
- return Future().failure(Errors.GroupCoordinatorNotAvailableError())
+ e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id)
+ return Future().failure(e)
future = Future()
_f = self._client.send(self.coordinator_id, request)
_f.add_callback(self._handle_sync_group_response, future)
@@ -439,7 +444,7 @@ class BaseCoordinator(object):
Future: resolves to the node id of the coordinator
"""
node_id = self._client.least_loaded_node()
- if node_id is None or not self._client.ready(node_id):
+ if node_id is None:
return Future().failure(Errors.NoBrokersAvailable())
log.debug("Issuing group metadata request to broker %s", node_id)
@@ -490,8 +495,8 @@ class BaseCoordinator(object):
def coordinator_dead(self, error=None):
"""Mark the current coordinator as dead."""
if self.coordinator_id is not None:
- log.info("Marking the coordinator dead (node %s): %s.",
- self.coordinator_id, error)
+ log.warning("Marking the coordinator dead (node %s): %s.",
+ self.coordinator_id, error)
self.coordinator_id = None
def close(self):
@@ -501,6 +506,7 @@ class BaseCoordinator(object):
self._client.unschedule(self.heartbeat_task)
except KeyError:
pass
+
if not self.coordinator_unknown() and self.generation > 0:
# this is a minimal effort attempt to leave the group. we do not
# attempt any resending if the request fails or times out.
@@ -634,7 +640,7 @@ class HeartbeatTask(object):
self._client.schedule(self, time.time() + ttl)
def _handle_heartbeat_failure(self, e):
- log.warning("Heartbeat failed; retrying")
+ log.warning("Heartbeat failed (%s); retrying", e)
self._request_in_flight = False
etd = time.time() + self._coordinator.config['retry_backoff_ms'] / 1000.0
self._client.schedule(self, etd)
diff --git a/test/test_client_async.py b/test/test_client_async.py
index e0b98c4..884686d 100644
--- a/test/test_client_async.py
+++ b/test/test_client_async.py
@@ -41,7 +41,8 @@ def conn(mocker):
[(0, 'foo', 12), (1, 'bar', 34)], # brokers
[])) # topics
conn.blacked_out.return_value = False
- conn.connect.return_value = conn.state
+ conn.connect.side_effect = lambda: conn.state
+ conn.connected = lambda: conn.connect() is ConnectionStates.CONNECTED
return conn
@@ -76,7 +77,7 @@ def test_can_connect(conn):
assert cli._can_connect(0)
# Node is connected, can't reconnect
- cli._initiate_connect(0)
+ assert cli._maybe_connect(0) is True
assert not cli._can_connect(0)
# Node is disconnected, can connect
@@ -87,60 +88,47 @@ def test_can_connect(conn):
conn.blacked_out.return_value = True
assert not cli._can_connect(0)
-def test_initiate_connect(conn):
+def test_maybe_connect(conn):
cli = KafkaClient()
try:
# Node not in metadata, raises AssertionError
- cli._initiate_connect(2)
+ cli._maybe_connect(2)
except AssertionError:
pass
else:
assert False, 'Exception not raised'
assert 0 not in cli._conns
- state = cli._initiate_connect(0)
+ conn.state = ConnectionStates.DISCONNECTED
+ conn.connect.side_effect = lambda: ConnectionStates.CONNECTING
+ assert cli._maybe_connect(0) is False
assert cli._conns[0] is conn
- assert state is conn.state
-
-
-def test_finish_connect(conn):
- cli = KafkaClient()
- try:
- # Node not in metadata, raises AssertionError
- cli._initiate_connect(2)
- except AssertionError:
- pass
- else:
- assert False, 'Exception not raised'
-
- assert 0 not in cli._conns
- cli._initiate_connect(0)
-
- conn.connect.return_value = ConnectionStates.CONNECTING
- state = cli._finish_connect(0)
assert 0 in cli._connecting
- assert state is ConnectionStates.CONNECTING
- conn.connect.return_value = ConnectionStates.CONNECTED
- state = cli._finish_connect(0)
+ conn.state = ConnectionStates.CONNECTING
+ conn.connect.side_effect = lambda: ConnectionStates.CONNECTED
+ assert cli._maybe_connect(0) is True
assert 0 not in cli._connecting
- assert state is ConnectionStates.CONNECTED
# Failure to connect should trigger metadata update
- assert not cli.cluster._need_update
+ assert cli.cluster._need_update is False
cli._connecting.add(0)
- conn.connect.return_value = ConnectionStates.DISCONNECTED
- state = cli._finish_connect(0)
+ conn.state = ConnectionStates.CONNECTING
+ conn.connect.side_effect = lambda: ConnectionStates.DISCONNECTED
+ assert cli._maybe_connect(0) is False
assert 0 not in cli._connecting
- assert state is ConnectionStates.DISCONNECTED
- assert cli.cluster._need_update
+ assert cli.cluster._need_update is True
def test_ready(conn):
cli = KafkaClient()
- # Node not in metadata
- assert not cli.ready(2)
+ # Node not in metadata raises Exception
+ try:
+ cli.ready(2)
+ assert False, 'Exception not raised'
+ except AssertionError:
+ pass
# Node in metadata will connect
assert 0 not in cli._conns
@@ -176,13 +164,13 @@ def test_ready(conn):
# disconnected nodes, not ready
assert cli.ready(0)
assert cli.ready(1)
- conn.connected.return_value = False
+ conn.state = ConnectionStates.DISCONNECTED
assert not cli.ready(0)
- conn.connected.return_value = True
# connecting node connects
cli._connecting.add(0)
- conn.connected.return_value = False
+ conn.state = ConnectionStates.CONNECTING
+ conn.connect.side_effect = lambda: ConnectionStates.CONNECTED
cli.ready(0)
assert 0 not in cli._connecting
assert cli._conns[0].connect.called_with()
@@ -195,13 +183,13 @@ def test_close(conn):
cli.close(2)
# Single node close
- cli._initiate_connect(0)
+ cli._maybe_connect(0)
assert not conn.close.call_count
cli.close(0)
assert conn.close.call_count == 1
# All node close
- cli._initiate_connect(1)
+ cli._maybe_connect(1)
cli.close()
assert conn.close.call_count == 3
@@ -213,7 +201,7 @@ def test_is_disconnected(conn):
conn.state = ConnectionStates.DISCONNECTED
assert not cli.is_disconnected(0)
- cli._initiate_connect(0)
+ cli._maybe_connect(0)
assert cli.is_disconnected(0)
conn.state = ConnectionStates.CONNECTING
@@ -225,14 +213,22 @@ def test_is_disconnected(conn):
def test_send(conn):
cli = KafkaClient()
+
+ # Send to unknown node => raises AssertionError
try:
cli.send(2, None)
- except Errors.NodeNotReadyError:
+ assert False, 'Exception not raised'
+ except AssertionError:
pass
- else:
- assert False, 'NodeNotReadyError not raised'
- cli._initiate_connect(0)
+ # Send to disconnected node => NodeNotReady
+ conn.state = ConnectionStates.DISCONNECTED
+ f = cli.send(0, None)
+ assert f.failed()
+ assert isinstance(f.exception, Errors.NodeNotReadyError)
+
+ conn.state = ConnectionStates.CONNECTED
+ cli._maybe_connect(0)
# ProduceRequest w/ 0 required_acks -> no response
request = ProduceRequest(0, 0, [])
ret = cli.send(0, request)
diff --git a/test/test_conn.py b/test/test_conn.py
index f0ef8fb..d394f74 100644
--- a/test/test_conn.py
+++ b/test/test_conn.py
@@ -1,242 +1,82 @@
+# pylint: skip-file
+from __future__ import absolute_import
+
+from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET
import socket
-import struct
-from threading import Thread
-
-import mock
-from . import unittest
-
-from kafka.common import ConnectionError
-from kafka.conn import KafkaConnection, collect_hosts, DEFAULT_SOCKET_TIMEOUT_SECONDS
-
-class ConnTest(unittest.TestCase):
- def setUp(self):
-
- self.config = {
- 'host': 'localhost',
- 'port': 9090,
- 'request_id': 0,
- 'payload': b'test data',
- 'payload2': b'another packet'
- }
-
- # Mocking socket.create_connection will cause _sock to always be a
- # MagicMock()
- patcher = mock.patch('socket.create_connection', spec=True)
- self.MockCreateConn = patcher.start()
- self.addCleanup(patcher.stop)
-
- # Also mock socket.sendall() to appear successful
- self.MockCreateConn().sendall.return_value = None
-
- # And mock socket.recv() to return two payloads, then '', then raise
- # Note that this currently ignores the num_bytes parameter to sock.recv()
- payload_size = len(self.config['payload'])
- payload2_size = len(self.config['payload2'])
- self.MockCreateConn().recv.side_effect = [
- struct.pack('>i', payload_size),
- struct.pack('>%ds' % payload_size, self.config['payload']),
- struct.pack('>i', payload2_size),
- struct.pack('>%ds' % payload2_size, self.config['payload2']),
- b''
- ]
-
- # Create a connection object
- self.conn = KafkaConnection(self.config['host'], self.config['port'])
-
- # Reset any mock counts caused by __init__
- self.MockCreateConn.reset_mock()
-
- def test_collect_hosts__happy_path(self):
- hosts = "localhost:1234,localhost"
- results = collect_hosts(hosts)
-
- self.assertEqual(set(results), set([
- ('localhost', 1234, socket.AF_INET),
- ('localhost', 9092, socket.AF_INET),
- ]))
-
- def test_collect_hosts__ipv6(self):
- hosts = "[localhost]:1234,[2001:1000:2000::1],[2001:1000:2000::1]:1234"
- results = collect_hosts(hosts)
-
- self.assertEqual(set(results), set([
- ('localhost', 1234, socket.AF_INET6),
- ('2001:1000:2000::1', 9092, socket.AF_INET6),
- ('2001:1000:2000::1', 1234, socket.AF_INET6),
- ]))
-
- def test_collect_hosts__string_list(self):
- hosts = [
- 'localhost:1234',
- 'localhost',
- '[localhost]',
- '2001::1',
- '[2001::1]:1234',
- ]
-
- results = collect_hosts(hosts)
-
- self.assertEqual(set(results), set([
- ('localhost', 1234, socket.AF_INET),
- ('localhost', 9092, socket.AF_INET),
- ('localhost', 9092, socket.AF_INET6),
- ('2001::1', 9092, socket.AF_INET6),
- ('2001::1', 1234, socket.AF_INET6),
- ]))
-
- def test_collect_hosts__with_spaces(self):
- hosts = "localhost:1234, localhost"
- results = collect_hosts(hosts)
-
- self.assertEqual(set(results), set([
- ('localhost', 1234, socket.AF_INET),
- ('localhost', 9092, socket.AF_INET),
- ]))
-
-
- def test_send(self):
- self.conn.send(self.config['request_id'], self.config['payload'])
- self.conn._sock.sendall.assert_called_with(self.config['payload'])
-
- def test_init_creates_socket_connection(self):
- KafkaConnection(self.config['host'], self.config['port'])
- self.MockCreateConn.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS)
-
- def test_init_failure_raises_connection_error(self):
-
- def raise_error(*args):
- raise socket.error
-
- assert socket.create_connection is self.MockCreateConn
- socket.create_connection.side_effect=raise_error
- with self.assertRaises(ConnectionError):
- KafkaConnection(self.config['host'], self.config['port'])
-
- def test_send__reconnects_on_dirty_conn(self):
-
- # Dirty the connection
- try:
- self.conn._raise_connection_error()
- except ConnectionError:
- pass
-
- # Now test that sending attempts to reconnect
- self.assertEqual(self.MockCreateConn.call_count, 0)
- self.conn.send(self.config['request_id'], self.config['payload'])
- self.assertEqual(self.MockCreateConn.call_count, 1)
-
- def test_send__failure_sets_dirty_connection(self):
-
- def raise_error(*args):
- raise socket.error
-
- assert isinstance(self.conn._sock, mock.Mock)
- self.conn._sock.sendall.side_effect=raise_error
- try:
- self.conn.send(self.config['request_id'], self.config['payload'])
- except ConnectionError:
- self.assertIsNone(self.conn._sock)
-
- def test_recv(self):
-
- self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload'])
-
- def test_recv__reconnects_on_dirty_conn(self):
-
- # Dirty the connection
- try:
- self.conn._raise_connection_error()
- except ConnectionError:
- pass
-
- # Now test that recv'ing attempts to reconnect
- self.assertEqual(self.MockCreateConn.call_count, 0)
- self.conn.recv(self.config['request_id'])
- self.assertEqual(self.MockCreateConn.call_count, 1)
-
- def test_recv__failure_sets_dirty_connection(self):
-
- def raise_error(*args):
- raise socket.error
-
- # test that recv'ing attempts to reconnect
- assert isinstance(self.conn._sock, mock.Mock)
- self.conn._sock.recv.side_effect=raise_error
- try:
- self.conn.recv(self.config['request_id'])
- except ConnectionError:
- self.assertIsNone(self.conn._sock)
-
- def test_recv__doesnt_consume_extra_data_in_stream(self):
-
- # Here just test that each call to recv will return a single payload
- self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload'])
- self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload2'])
-
- def test_get_connected_socket(self):
- s = self.conn.get_connected_socket()
-
- self.assertEqual(s, self.MockCreateConn())
-
- def test_get_connected_socket_on_dirty_conn(self):
- # Dirty the connection
- try:
- self.conn._raise_connection_error()
- except ConnectionError:
- pass
-
- # Test that get_connected_socket tries to connect
- self.assertEqual(self.MockCreateConn.call_count, 0)
- self.conn.get_connected_socket()
- self.assertEqual(self.MockCreateConn.call_count, 1)
+import time
+
+import pytest
+
+from kafka.conn import BrokerConnection, ConnectionStates
+
+
+@pytest.fixture
+def socket(mocker):
+ socket = mocker.MagicMock()
+ socket.connect_ex.return_value = 0
+ mocker.patch('socket.socket', return_value=socket)
+ return socket
+
+
+@pytest.fixture
+def conn(socket):
+ conn = BrokerConnection('localhost', 9092, socket.AF_INET)
+ return conn
+
+
+@pytest.mark.parametrize("states", [
+ (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING),),
+ (([EALREADY, EALREADY], ConnectionStates.CONNECTING),),
+ (([0], ConnectionStates.CONNECTED),),
+ (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING),
+ ([ECONNRESET], ConnectionStates.DISCONNECTED)),
+ (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING),
+ ([EALREADY], ConnectionStates.CONNECTING),
+ ([EISCONN], ConnectionStates.CONNECTED)),
+])
+def test_connect(socket, conn, states):
+ assert conn.state is ConnectionStates.DISCONNECTED
+
+ for errno, state in states:
+ socket.connect_ex.side_effect = errno
+ conn.connect()
+ assert conn.state is state
+
+
+def test_connect_timeout(socket, conn):
+ assert conn.state is ConnectionStates.DISCONNECTED
+
+ # Initial connect returns EINPROGRESS
+ # immediate inline connect returns EALREADY
+ # second explicit connect returns EALREADY
+ # third explicit connect returns EALREADY and times out via last_attempt
+ socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY]
+ conn.connect()
+ assert conn.state is ConnectionStates.CONNECTING
+ conn.connect()
+ assert conn.state is ConnectionStates.CONNECTING
+ conn.last_attempt = 0
+ conn.connect()
+ assert conn.state is ConnectionStates.DISCONNECTED
+
+
+def test_blacked_out(conn):
+ assert not conn.blacked_out()
+ conn.last_attempt = time.time()
+ assert conn.blacked_out()
+
+
+def test_connected(conn):
+ assert not conn.connected()
+ conn.state = ConnectionStates.CONNECTED
+ assert conn.connected()
+
- def test_close__object_is_reusable(self):
+def test_connecting(conn):
+ assert not conn.connecting()
+ conn.state = ConnectionStates.CONNECTING
+ assert conn.connecting()
+ conn.state = ConnectionStates.CONNECTED
+ assert not conn.connecting()
- # test that sending to a closed connection
- # will re-connect and send data to the socket
- self.conn.close()
- self.conn.send(self.config['request_id'], self.config['payload'])
- self.assertEqual(self.MockCreateConn.call_count, 1)
- self.conn._sock.sendall.assert_called_with(self.config['payload'])
-
-
-class TestKafkaConnection(unittest.TestCase):
- @mock.patch('socket.create_connection')
- def test_copy(self, socket):
- """KafkaConnection copies work as expected"""
-
- conn = KafkaConnection('kafka', 9092)
- self.assertEqual(socket.call_count, 1)
-
- copy = conn.copy()
- self.assertEqual(socket.call_count, 1)
- self.assertEqual(copy.host, 'kafka')
- self.assertEqual(copy.port, 9092)
- self.assertEqual(copy._sock, None)
-
- copy.reinit()
- self.assertEqual(socket.call_count, 2)
- self.assertNotEqual(copy._sock, None)
-
- @mock.patch('socket.create_connection')
- def test_copy_thread(self, socket):
- """KafkaConnection copies work in other threads"""
-
- err = []
- copy = KafkaConnection('kafka', 9092).copy()
-
- def thread_func(err, copy):
- try:
- self.assertEqual(copy.host, 'kafka')
- self.assertEqual(copy.port, 9092)
- self.assertNotEqual(copy._sock, None)
- except Exception as e:
- err.append(e)
- else:
- err.append(None)
- thread = Thread(target=thread_func, args=(err, copy))
- thread.start()
- thread.join()
-
- self.assertEqual(err, [None])
- self.assertEqual(socket.call_count, 2)
+# TODO: test_send, test_recv, test_can_send_more, test_close
diff --git a/test/test_conn_legacy.py b/test/test_conn_legacy.py
new file mode 100644
index 0000000..f0ef8fb
--- /dev/null
+++ b/test/test_conn_legacy.py
@@ -0,0 +1,242 @@
+import socket
+import struct
+from threading import Thread
+
+import mock
+from . import unittest
+
+from kafka.common import ConnectionError
+from kafka.conn import KafkaConnection, collect_hosts, DEFAULT_SOCKET_TIMEOUT_SECONDS
+
+class ConnTest(unittest.TestCase):
+ def setUp(self):
+
+ self.config = {
+ 'host': 'localhost',
+ 'port': 9090,
+ 'request_id': 0,
+ 'payload': b'test data',
+ 'payload2': b'another packet'
+ }
+
+ # Mocking socket.create_connection will cause _sock to always be a
+ # MagicMock()
+ patcher = mock.patch('socket.create_connection', spec=True)
+ self.MockCreateConn = patcher.start()
+ self.addCleanup(patcher.stop)
+
+ # Also mock socket.sendall() to appear successful
+ self.MockCreateConn().sendall.return_value = None
+
+ # And mock socket.recv() to return two payloads, then '', then raise
+ # Note that this currently ignores the num_bytes parameter to sock.recv()
+ payload_size = len(self.config['payload'])
+ payload2_size = len(self.config['payload2'])
+ self.MockCreateConn().recv.side_effect = [
+ struct.pack('>i', payload_size),
+ struct.pack('>%ds' % payload_size, self.config['payload']),
+ struct.pack('>i', payload2_size),
+ struct.pack('>%ds' % payload2_size, self.config['payload2']),
+ b''
+ ]
+
+ # Create a connection object
+ self.conn = KafkaConnection(self.config['host'], self.config['port'])
+
+ # Reset any mock counts caused by __init__
+ self.MockCreateConn.reset_mock()
+
+ def test_collect_hosts__happy_path(self):
+ hosts = "localhost:1234,localhost"
+ results = collect_hosts(hosts)
+
+ self.assertEqual(set(results), set([
+ ('localhost', 1234, socket.AF_INET),
+ ('localhost', 9092, socket.AF_INET),
+ ]))
+
+ def test_collect_hosts__ipv6(self):
+ hosts = "[localhost]:1234,[2001:1000:2000::1],[2001:1000:2000::1]:1234"
+ results = collect_hosts(hosts)
+
+ self.assertEqual(set(results), set([
+ ('localhost', 1234, socket.AF_INET6),
+ ('2001:1000:2000::1', 9092, socket.AF_INET6),
+ ('2001:1000:2000::1', 1234, socket.AF_INET6),
+ ]))
+
+ def test_collect_hosts__string_list(self):
+ hosts = [
+ 'localhost:1234',
+ 'localhost',
+ '[localhost]',
+ '2001::1',
+ '[2001::1]:1234',
+ ]
+
+ results = collect_hosts(hosts)
+
+ self.assertEqual(set(results), set([
+ ('localhost', 1234, socket.AF_INET),
+ ('localhost', 9092, socket.AF_INET),
+ ('localhost', 9092, socket.AF_INET6),
+ ('2001::1', 9092, socket.AF_INET6),
+ ('2001::1', 1234, socket.AF_INET6),
+ ]))
+
+ def test_collect_hosts__with_spaces(self):
+ hosts = "localhost:1234, localhost"
+ results = collect_hosts(hosts)
+
+ self.assertEqual(set(results), set([
+ ('localhost', 1234, socket.AF_INET),
+ ('localhost', 9092, socket.AF_INET),
+ ]))
+
+
+ def test_send(self):
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ self.conn._sock.sendall.assert_called_with(self.config['payload'])
+
+ def test_init_creates_socket_connection(self):
+ KafkaConnection(self.config['host'], self.config['port'])
+ self.MockCreateConn.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS)
+
+ def test_init_failure_raises_connection_error(self):
+
+ def raise_error(*args):
+ raise socket.error
+
+ assert socket.create_connection is self.MockCreateConn
+ socket.create_connection.side_effect=raise_error
+ with self.assertRaises(ConnectionError):
+ KafkaConnection(self.config['host'], self.config['port'])
+
+ def test_send__reconnects_on_dirty_conn(self):
+
+ # Dirty the connection
+ try:
+ self.conn._raise_connection_error()
+ except ConnectionError:
+ pass
+
+ # Now test that sending attempts to reconnect
+ self.assertEqual(self.MockCreateConn.call_count, 0)
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ self.assertEqual(self.MockCreateConn.call_count, 1)
+
+ def test_send__failure_sets_dirty_connection(self):
+
+ def raise_error(*args):
+ raise socket.error
+
+ assert isinstance(self.conn._sock, mock.Mock)
+ self.conn._sock.sendall.side_effect=raise_error
+ try:
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ except ConnectionError:
+ self.assertIsNone(self.conn._sock)
+
+ def test_recv(self):
+
+ self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload'])
+
+ def test_recv__reconnects_on_dirty_conn(self):
+
+ # Dirty the connection
+ try:
+ self.conn._raise_connection_error()
+ except ConnectionError:
+ pass
+
+ # Now test that recv'ing attempts to reconnect
+ self.assertEqual(self.MockCreateConn.call_count, 0)
+ self.conn.recv(self.config['request_id'])
+ self.assertEqual(self.MockCreateConn.call_count, 1)
+
+ def test_recv__failure_sets_dirty_connection(self):
+
+ def raise_error(*args):
+ raise socket.error
+
+ # test that recv'ing attempts to reconnect
+ assert isinstance(self.conn._sock, mock.Mock)
+ self.conn._sock.recv.side_effect=raise_error
+ try:
+ self.conn.recv(self.config['request_id'])
+ except ConnectionError:
+ self.assertIsNone(self.conn._sock)
+
+ def test_recv__doesnt_consume_extra_data_in_stream(self):
+
+ # Here just test that each call to recv will return a single payload
+ self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload'])
+ self.assertEqual(self.conn.recv(self.config['request_id']), self.config['payload2'])
+
+ def test_get_connected_socket(self):
+ s = self.conn.get_connected_socket()
+
+ self.assertEqual(s, self.MockCreateConn())
+
+ def test_get_connected_socket_on_dirty_conn(self):
+ # Dirty the connection
+ try:
+ self.conn._raise_connection_error()
+ except ConnectionError:
+ pass
+
+ # Test that get_connected_socket tries to connect
+ self.assertEqual(self.MockCreateConn.call_count, 0)
+ self.conn.get_connected_socket()
+ self.assertEqual(self.MockCreateConn.call_count, 1)
+
+ def test_close__object_is_reusable(self):
+
+ # test that sending to a closed connection
+ # will re-connect and send data to the socket
+ self.conn.close()
+ self.conn.send(self.config['request_id'], self.config['payload'])
+ self.assertEqual(self.MockCreateConn.call_count, 1)
+ self.conn._sock.sendall.assert_called_with(self.config['payload'])
+
+
+class TestKafkaConnection(unittest.TestCase):
+ @mock.patch('socket.create_connection')
+ def test_copy(self, socket):
+ """KafkaConnection copies work as expected"""
+
+ conn = KafkaConnection('kafka', 9092)
+ self.assertEqual(socket.call_count, 1)
+
+ copy = conn.copy()
+ self.assertEqual(socket.call_count, 1)
+ self.assertEqual(copy.host, 'kafka')
+ self.assertEqual(copy.port, 9092)
+ self.assertEqual(copy._sock, None)
+
+ copy.reinit()
+ self.assertEqual(socket.call_count, 2)
+ self.assertNotEqual(copy._sock, None)
+
+ @mock.patch('socket.create_connection')
+ def test_copy_thread(self, socket):
+ """KafkaConnection copies work in other threads"""
+
+ err = []
+ copy = KafkaConnection('kafka', 9092).copy()
+
+ def thread_func(err, copy):
+ try:
+ self.assertEqual(copy.host, 'kafka')
+ self.assertEqual(copy.port, 9092)
+ self.assertNotEqual(copy._sock, None)
+ except Exception as e:
+ err.append(e)
+ else:
+ err.append(None)
+ thread = Thread(target=thread_func, args=(err, copy))
+ thread.start()
+ thread.join()
+
+ self.assertEqual(err, [None])
+ self.assertEqual(socket.call_count, 2)