diff options
-rw-r--r-- | kafka/conn.py | 14 | ||||
-rw-r--r-- | test/test_conn.py | 28 |
2 files changed, 37 insertions, 5 deletions
diff --git a/kafka/conn.py b/kafka/conn.py index 21607d9..a8751e9 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -157,6 +157,9 @@ class BrokerConnection(object): self.hostname = host self.port = port self.afi = afi + self._init_host = host + self._init_port = port + self._init_afi = afi self.in_flight_requests = collections.deque() self.config = copy.copy(self.DEFAULT_CONFIG) @@ -208,7 +211,7 @@ class BrokerConnection(object): log.debug('%s: creating new socket', str(self)) # if self.afi is set to AF_UNSPEC, then we need to do a name # resolution and try all available address families - if self.afi == socket.AF_UNSPEC: + if self._init_afi == socket.AF_UNSPEC: if self._gai is None: # XXX: all DNS functions in Python are blocking. If we really # want to be non-blocking here, we need to use a 3rd-party @@ -216,14 +219,15 @@ class BrokerConnection(object): # own thread. This will be subject to the default libc # name resolution timeout (5s on most Linux boxes) try: - self._gai = socket.getaddrinfo(self.host, self.port, + self._gai = socket.getaddrinfo(self._init_host, + self._init_port, socket.AF_UNSPEC, socket.SOCK_STREAM) except socket.gaierror as ex: raise socket.gaierror('getaddrinfo failed for {0}:{1}, ' 'exception was {2}. Is your advertised.listeners (called' 'advertised.host.name before Kafka 9) correct and resolvable?'.format( - self.host, self.port, ex + self._init_host, self._init_port, ex )) self._gai_index = 0 else: @@ -233,7 +237,7 @@ class BrokerConnection(object): while True: if self._gai_index >= len(self._gai): log.error('Unable to connect to any of the names for {0}:{1}'.format( - self.host, self.port + self._init_host, self._init_port )) self.close() return @@ -245,7 +249,7 @@ class BrokerConnection(object): self.host, self.port = sockaddr[:2] self._sock = socket.socket(afi, socket.SOCK_STREAM) else: - self._sock = socket.socket(self.afi, socket.SOCK_STREAM) + self._sock = socket.socket(self._init_afi, socket.SOCK_STREAM) for option in self.config['socket_options']: self._sock.setsockopt(*option) diff --git a/test/test_conn.py b/test/test_conn.py index 4f2b12f..c3e40c0 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -5,6 +5,7 @@ from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET import socket import time +import mock import pytest from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts @@ -264,3 +265,30 @@ def test_collect_hosts__with_spaces(): ('localhost', 1234, socket.AF_UNSPEC), ('localhost', 9092, socket.AF_UNSPEC), ]) + + +def test_lookup_on_connect(): + hostname = 'example.org' + port = 9092 + conn = BrokerConnection(hostname, port, socket.AF_UNSPEC) + assert conn.host == conn.hostname == hostname + ip1 = '127.0.0.1' + mock_return1 = [ + (2, 2, 17, '', (ip1, 9092)), + ] + with mock.patch("socket.getaddrinfo", return_value=mock_return1) as m: + conn.connect() + m.assert_called_once_with(hostname, port, 0, 1) + conn.close() + assert conn.host == ip1 + + ip2 = '127.0.0.2' + mock_return2 = [ + (2, 2, 17, '', (ip2, 9092)), + ] + + with mock.patch("socket.getaddrinfo", return_value=mock_return2) as m: + conn.connect() + m.assert_called_once_with(hostname, port, 0, 1) + conn.close() + assert conn.host == ip2 |