import os import pytest import re import redis import time from unittest import mock from threading import Thread from redis.connection import ssl_available, to_bool from .conftest import ( skip_if_server_version_lt, skip_if_redis_enterprise, _get_client ) from .test_pubsub import wait_for_message class DummyConnection: description_format = "DummyConnection<>" def __init__(self, **kwargs): self.kwargs = kwargs self.pid = os.getpid() def connect(self): pass def can_read(self): return False class TestConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=None, connection_class=redis.Connection): connection_kwargs = connection_kwargs or {} pool = redis.ConnectionPool( connection_class=connection_class, max_connections=max_connections, **connection_kwargs) return pool def test_connection_creation(self): connection_kwargs = {'foo': 'bar', 'biz': 'baz'} pool = self.get_pool(connection_kwargs=connection_kwargs, connection_class=DummyConnection) connection = pool.get_connection('_') assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs def test_multiple_connections(self, master_host): connection_kwargs = {'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') c2 = pool.get_connection('_') assert c1 != c2 def test_max_connections(self, master_host): connection_kwargs = {'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) pool.get_connection('_') pool.get_connection('_') with pytest.raises(redis.ConnectionError): pool.get_connection('_') def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') pool.release(c1) c2 = pool.get_connection('_') assert c1 == c2 def test_repr_contains_db_info_tcp(self): connection_kwargs = { 'host': 'localhost', 'port': 6379, 'db': 1, 'client_name': 'test-client' } pool = self.get_pool(connection_kwargs=connection_kwargs, connection_class=redis.Connection) expected = ('ConnectionPool>') assert repr(pool) == expected def test_repr_contains_db_info_unix(self): connection_kwargs = { 'path': '/abc', 'db': 1, 'client_name': 'test-client' } pool = self.get_pool(connection_kwargs=connection_kwargs, connection_class=redis.UnixDomainSocketConnection) expected = ('ConnectionPool>') assert repr(pool) == expected class TestBlockingConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): connection_kwargs = connection_kwargs or {} pool = redis.BlockingConnectionPool(connection_class=DummyConnection, max_connections=max_connections, timeout=timeout, **connection_kwargs) return pool def test_connection_creation(self, master_host): connection_kwargs = {'foo': 'bar', 'biz': 'baz', 'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) connection = pool.get_connection('_') assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs def test_multiple_connections(self, master_host): connection_kwargs = {'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') c2 = pool.get_connection('_') assert c1 != c2 def test_connection_pool_blocks_until_timeout(self, master_host): "When out of connections, block for timeout seconds, then raise" connection_kwargs = {'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs) pool.get_connection('_') start = time.time() with pytest.raises(redis.ConnectionError): pool.get_connection('_') # we should have waited at least 0.1 seconds assert time.time() - start >= 0.1 def test_connection_pool_blocks_until_conn_available(self, master_host): """ When out of connections, block until another connection is released to the pool """ connection_kwargs = {'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(max_connections=1, timeout=2, connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') def target(): time.sleep(0.1) pool.release(c1) start = time.time() Thread(target=target).start() pool.get_connection('_') assert time.time() - start >= 0.1 def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {'host': master_host[0], 'port': master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') pool.release(c1) c2 = pool.get_connection('_') assert c1 == c2 def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( host='localhost', port=6379, client_name='test-client' ) expected = ('ConnectionPool>') assert repr(pool) == expected def test_repr_contains_db_info_unix(self): pool = redis.ConnectionPool( connection_class=redis.UnixDomainSocketConnection, path='abc', client_name='test-client' ) expected = ('ConnectionPool>') assert repr(pool) == expected class TestConnectionPoolURLParsing: def test_hostname(self): pool = redis.ConnectionPool.from_url('redis://my.host') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'my.host', } def test_quoted_hostname(self): pool = redis.ConnectionPool.from_url('redis://my %2F host %2B%3D+') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'my / host +=+', } def test_port(self): pool = redis.ConnectionPool.from_url('redis://localhost:6380') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'port': 6380, } @skip_if_server_version_lt("6.0.0") def test_username(self): pool = redis.ConnectionPool.from_url('redis://myuser:@localhost') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'username': 'myuser', } @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): pool = redis.ConnectionPool.from_url( 'redis://%2Fmyuser%2F%2B name%3D%24+:@localhost') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'username': '/myuser/+ name=$+', } def test_password(self): pool = redis.ConnectionPool.from_url('redis://:mypassword@localhost') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'password': 'mypassword', } def test_quoted_password(self): pool = redis.ConnectionPool.from_url( 'redis://:%2Fmypass%2F%2B word%3D%24+@localhost') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'password': '/mypass/+ word=$+', } @skip_if_server_version_lt("6.0.0") def test_username_and_password(self): pool = redis.ConnectionPool.from_url('redis://myuser:mypass@localhost') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'username': 'myuser', 'password': 'mypass', } def test_db_as_argument(self): pool = redis.ConnectionPool.from_url('redis://localhost', db=1) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'db': 1, } def test_db_in_path(self): pool = redis.ConnectionPool.from_url('redis://localhost/2', db=1) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'db': 2, } def test_db_in_querystring(self): pool = redis.ConnectionPool.from_url('redis://localhost/2?db=3', db=1) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'db': 3, } def test_extra_typed_querystring_options(self): pool = redis.ConnectionPool.from_url( 'redis://localhost/2?socket_timeout=20&socket_connect_timeout=10' '&socket_keepalive=&retry_on_timeout=Yes&max_connections=10' ) assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'db': 2, 'socket_timeout': 20.0, 'socket_connect_timeout': 10.0, 'retry_on_timeout': True, } assert pool.max_connections == 10 def test_boolean_parsing(self): for expected, value in ( (None, None), (None, ''), (False, 0), (False, '0'), (False, 'f'), (False, 'F'), (False, 'False'), (False, 'n'), (False, 'N'), (False, 'No'), (True, 1), (True, '1'), (True, 'y'), (True, 'Y'), (True, 'Yes'), ): assert expected is to_bool(value) def test_client_name_in_querystring(self): pool = redis.ConnectionPool.from_url( 'redis://location?client_name=test-client' ) assert pool.connection_kwargs['client_name'] == 'test-client' def test_invalid_extra_typed_querystring_options(self): with pytest.raises(ValueError): redis.ConnectionPool.from_url( 'redis://localhost/2?socket_timeout=_&' 'socket_connect_timeout=abc' ) def test_extra_querystring_options(self): pool = redis.ConnectionPool.from_url('redis://localhost?a=1&b=2') assert pool.connection_class == redis.Connection assert pool.connection_kwargs == { 'host': 'localhost', 'a': '1', 'b': '2' } def test_calling_from_subclass_returns_correct_instance(self): pool = redis.BlockingConnectionPool.from_url('redis://localhost') assert isinstance(pool, redis.BlockingConnectionPool) def test_client_creates_connection_pool(self): r = redis.Redis.from_url('redis://myhost') assert r.connection_pool.connection_class == redis.Connection assert r.connection_pool.connection_kwargs == { 'host': 'myhost', } def test_invalid_scheme_raises_error(self): with pytest.raises(ValueError) as cm: redis.ConnectionPool.from_url('localhost') assert str(cm.value) == ( 'Redis URL must specify one of the following schemes ' '(redis://, rediss://, unix://)' ) class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): pool = redis.ConnectionPool.from_url('unix:///socket') assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', } @skip_if_server_version_lt("6.0.0") def test_username(self): pool = redis.ConnectionPool.from_url('unix://myuser:@/socket') assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', 'username': 'myuser', } @skip_if_server_version_lt("6.0.0") def test_quoted_username(self): pool = redis.ConnectionPool.from_url( 'unix://%2Fmyuser%2F%2B name%3D%24+:@/socket') assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', 'username': '/myuser/+ name=$+', } def test_password(self): pool = redis.ConnectionPool.from_url('unix://:mypassword@/socket') assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', 'password': 'mypassword', } def test_quoted_password(self): pool = redis.ConnectionPool.from_url( 'unix://:%2Fmypass%2F%2B word%3D%24+@/socket') assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', 'password': '/mypass/+ word=$+', } def test_quoted_path(self): pool = redis.ConnectionPool.from_url( 'unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket') assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/my/path/to/../+_+=$ocket', 'password': 'mypassword', } def test_db_as_argument(self): pool = redis.ConnectionPool.from_url('unix:///socket', db=1) assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', 'db': 1, } def test_db_in_querystring(self): pool = redis.ConnectionPool.from_url('unix:///socket?db=2', db=1) assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', 'db': 2, } def test_client_name_in_querystring(self): pool = redis.ConnectionPool.from_url( 'redis://location?client_name=test-client' ) assert pool.connection_kwargs['client_name'] == 'test-client' def test_extra_querystring_options(self): pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2') assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == { 'path': '/socket', 'a': '1', 'b': '2' } @pytest.mark.skipif(not ssl_available, reason="SSL not installed") class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url('rediss://my.host') assert pool.connection_class == redis.SSLConnection assert pool.connection_kwargs == { 'host': 'my.host', } def test_cert_reqs_options(self): import ssl class DummyConnectionPool(redis.ConnectionPool): def get_connection(self, *args, **kwargs): return self.make_connection() pool = DummyConnectionPool.from_url( 'rediss://?ssl_cert_reqs=none') assert pool.get_connection('_').cert_reqs == ssl.CERT_NONE pool = DummyConnectionPool.from_url( 'rediss://?ssl_cert_reqs=optional') assert pool.get_connection('_').cert_reqs == ssl.CERT_OPTIONAL pool = DummyConnectionPool.from_url( 'rediss://?ssl_cert_reqs=required') assert pool.get_connection('_').cert_reqs == ssl.CERT_REQUIRED pool = DummyConnectionPool.from_url( 'rediss://?ssl_check_hostname=False') assert pool.get_connection('_').check_hostname is False pool = DummyConnectionPool.from_url( 'rediss://?ssl_check_hostname=True') assert pool.get_connection('_').check_hostname is True class TestConnection: def test_on_connect_error(self): """ An error in Connection.on_connect should disconnect from the server see for details: https://github.com/andymccurdy/redis-py/issues/368 """ # this assumes the Redis server being tested against doesn't have # 9999 databases ;) bad_connection = redis.Redis(db=9999) # an error should be raised on connect with pytest.raises(redis.RedisError): bad_connection.info() pool = bad_connection.connection_pool assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.8') @skip_if_redis_enterprise def test_busy_loading_disconnects_socket(self, r): """ If Redis raises a LOADING error, the connection should be disconnected and a BusyLoadingError raised """ with pytest.raises(redis.BusyLoadingError): r.execute_command('DEBUG', 'ERROR', 'LOADING fake message') assert not r.connection._sock @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.8') @skip_if_redis_enterprise def test_busy_loading_from_pipeline_immediate_command(self, r): """ BusyLoadingErrors should raise from Pipelines that execute a command immediately, like WATCH does. """ pipe = r.pipeline() with pytest.raises(redis.BusyLoadingError): pipe.immediate_execute_command('DEBUG', 'ERROR', 'LOADING fake message') pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.8') @skip_if_redis_enterprise def test_busy_loading_from_pipeline(self, r): """ BusyLoadingErrors should be raised from a pipeline execution regardless of the raise_on_error flag. """ pipe = r.pipeline() pipe.execute_command('DEBUG', 'ERROR', 'LOADING fake message') with pytest.raises(redis.BusyLoadingError): pipe.execute() pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock @skip_if_server_version_lt('2.8.8') @skip_if_redis_enterprise def test_read_only_error(self, r): "READONLY errors get turned in ReadOnlyError exceptions" with pytest.raises(redis.ReadOnlyError): r.execute_command('DEBUG', 'ERROR', 'READONLY blah blah') def test_connect_from_url_tcp(self): connection = redis.Redis.from_url('redis://localhost') pool = connection.connection_pool assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == ( 'ConnectionPool', 'Connection', 'host=localhost,port=6379,db=0', ) def test_connect_from_url_unix(self): connection = redis.Redis.from_url('unix:///path/to/socket') pool = connection.connection_pool assert re.match('(.*)<(.*)<(.*)>>', repr(pool)).groups() == ( 'ConnectionPool', 'UnixDomainSocketConnection', 'path=/path/to/socket,db=0', ) @skip_if_redis_enterprise def test_connect_no_auth_supplied_when_required(self, r): """ AuthenticationError should be raised when the server requires a password but one isn't supplied. """ with pytest.raises(redis.AuthenticationError): r.execute_command('DEBUG', 'ERROR', 'ERR Client sent AUTH, but no password is set') @skip_if_redis_enterprise def test_connect_invalid_password_supplied(self, r): "AuthenticationError should be raised when sending the wrong password" with pytest.raises(redis.AuthenticationError): r.execute_command('DEBUG', 'ERROR', 'ERR invalid password') @pytest.mark.onlynoncluster class TestMultiConnectionClient: @pytest.fixture() def r(self, request): return _get_client(redis.Redis, request, single_connection_client=False) def test_multi_connection_command(self, r): assert not r.connection assert r.set('a', '123') assert r.get('a') == b'123' @pytest.mark.onlynoncluster class TestHealthCheck: interval = 60 @pytest.fixture() def r(self, request): return _get_client(redis.Redis, request, health_check_interval=self.interval) def assert_interval_advanced(self, connection): diff = connection.next_health_check - time.time() assert self.interval > diff > (self.interval - 1) def test_health_check_runs(self, r): r.connection.next_health_check = time.time() - 1 r.connection.check_health() self.assert_interval_advanced(r.connection) def test_arbitrary_command_invokes_health_check(self, r): # invoke a command to make sure the connection is entirely setup r.get('foo') r.connection.next_health_check = time.time() with mock.patch.object(r.connection, 'send_command', wraps=r.connection.send_command) as m: r.get('foo') m.assert_called_with('PING', check_health=False) self.assert_interval_advanced(r.connection) def test_arbitrary_command_advances_next_health_check(self, r): r.get('foo') next_health_check = r.connection.next_health_check r.get('foo') assert next_health_check < r.connection.next_health_check def test_health_check_not_invoked_within_interval(self, r): r.get('foo') with mock.patch.object(r.connection, 'send_command', wraps=r.connection.send_command) as m: r.get('foo') ping_call_spec = (('PING',), {'check_health': False}) assert ping_call_spec not in m.call_args_list def test_health_check_in_pipeline(self, r): with r.pipeline(transaction=False) as pipe: pipe.connection = pipe.connection_pool.get_connection('_') pipe.connection.next_health_check = 0 with mock.patch.object(pipe.connection, 'send_command', wraps=pipe.connection.send_command) as m: responses = pipe.set('foo', 'bar').get('foo').execute() m.assert_any_call('PING', check_health=False) assert responses == [True, b'bar'] def test_health_check_in_transaction(self, r): with r.pipeline(transaction=True) as pipe: pipe.connection = pipe.connection_pool.get_connection('_') pipe.connection.next_health_check = 0 with mock.patch.object(pipe.connection, 'send_command', wraps=pipe.connection.send_command) as m: responses = pipe.set('foo', 'bar').get('foo').execute() m.assert_any_call('PING', check_health=False) assert responses == [True, b'bar'] def test_health_check_in_watched_pipeline(self, r): r.set('foo', 'bar') with r.pipeline(transaction=False) as pipe: pipe.connection = pipe.connection_pool.get_connection('_') pipe.connection.next_health_check = 0 with mock.patch.object(pipe.connection, 'send_command', wraps=pipe.connection.send_command) as m: pipe.watch('foo') # the health check should be called when watching m.assert_called_with('PING', check_health=False) self.assert_interval_advanced(pipe.connection) assert pipe.get('foo') == b'bar' # reset the mock to clear the call list and schedule another # health check m.reset_mock() pipe.connection.next_health_check = 0 pipe.multi() responses = pipe.set('foo', 'not-bar').get('foo').execute() assert responses == [True, b'not-bar'] m.assert_any_call('PING', check_health=False) def test_health_check_in_pubsub_before_subscribe(self, r): "A health check happens before the first [p]subscribe" p = r.pubsub() p.connection = p.connection_pool.get_connection('_') p.connection.next_health_check = 0 with mock.patch.object(p.connection, 'send_command', wraps=p.connection.send_command) as m: assert not p.subscribed p.subscribe('foo') # the connection is not yet in pubsub mode, so the normal # ping/pong within connection.send_command should check # the health of the connection m.assert_any_call('PING', check_health=False) self.assert_interval_advanced(p.connection) subscribe_message = wait_for_message(p) assert subscribe_message['type'] == 'subscribe' def test_health_check_in_pubsub_after_subscribed(self, r): """ Pubsub can handle a new subscribe when it's time to check the connection health """ p = r.pubsub() p.connection = p.connection_pool.get_connection('_') p.connection.next_health_check = 0 with mock.patch.object(p.connection, 'send_command', wraps=p.connection.send_command) as m: p.subscribe('foo') subscribe_message = wait_for_message(p) assert subscribe_message['type'] == 'subscribe' self.assert_interval_advanced(p.connection) # because we weren't subscribed when sending the subscribe # message to 'foo', the connection's standard check_health ran # prior to subscribing. m.assert_any_call('PING', check_health=False) p.connection.next_health_check = 0 m.reset_mock() p.subscribe('bar') # the second subscribe issues exactly only command (the subscribe) # and the health check is not invoked m.assert_called_once_with('SUBSCRIBE', 'bar', check_health=False) # since no message has been read since the health check was # reset, it should still be 0 assert p.connection.next_health_check == 0 subscribe_message = wait_for_message(p) assert subscribe_message['type'] == 'subscribe' assert wait_for_message(p) is None # now that the connection is subscribed, the pubsub health # check should have taken over and include the HEALTH_CHECK_MESSAGE m.assert_any_call('PING', p.HEALTH_CHECK_MESSAGE, check_health=False) self.assert_interval_advanced(p.connection) def test_health_check_in_pubsub_poll(self, r): """ Polling a pubsub connection that's subscribed will regularly check the connection's health. """ p = r.pubsub() p.connection = p.connection_pool.get_connection('_') with mock.patch.object(p.connection, 'send_command', wraps=p.connection.send_command) as m: p.subscribe('foo') subscribe_message = wait_for_message(p) assert subscribe_message['type'] == 'subscribe' self.assert_interval_advanced(p.connection) # polling the connection before the health check interval # doesn't result in another health check m.reset_mock() next_health_check = p.connection.next_health_check assert wait_for_message(p) is None assert p.connection.next_health_check == next_health_check m.assert_not_called() # reset the health check and poll again # we should not receive a pong message, but the next_health_check # should be advanced p.connection.next_health_check = 0 assert wait_for_message(p) is None m.assert_called_with('PING', p.HEALTH_CHECK_MESSAGE, check_health=False) self.assert_interval_advanced(p.connection)