diff options
Diffstat (limited to 'tests/conftest.py')
-rw-r--r-- | tests/conftest.py | 101 |
1 files changed, 84 insertions, 17 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 31d3fbd..ddc0834 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,10 @@ from redis.retry import Retry import pytest import random import redis +import time from distutils.version import LooseVersion from redis.connection import parse_url +from redis.exceptions import RedisClusterException from unittest.mock import Mock from urllib.parse import urlparse @@ -13,6 +15,7 @@ REDIS_INFO = {} default_redis_url = "redis://localhost:6379/9" default_redismod_url = "redis://localhost:36379" +default_cluster_nodes = 6 def pytest_addoption(parser): @@ -27,11 +30,18 @@ def pytest_addoption(parser): " with loaded modules," " defaults to `%(default)s`") + parser.addoption('--redis-cluster-nodes', default=default_cluster_nodes, + action="store", + help="The number of cluster nodes that need to be " + "available before the test can start," + " defaults to `%(default)s`") + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) info = client.info() - if 'dping' in client.__commands__: + cmds = [command.upper() for command in client.command().keys()] + if 'dping' in cmds: info["enterprise"] = True else: info["enterprise"] = False @@ -44,8 +54,10 @@ def pytest_sessionstart(session): info = _get_info(redis_url) version = info["redis_version"] arch_bits = info["arch_bits"] + cluster_enabled = info["cluster_enabled"] REDIS_INFO["version"] = version REDIS_INFO["arch_bits"] = arch_bits + REDIS_INFO["cluster_enabled"] = cluster_enabled REDIS_INFO["enterprise"] = info["enterprise"] # module info, if the second redis is running @@ -58,6 +70,42 @@ def pytest_sessionstart(session): except KeyError: pass + if cluster_enabled: + cluster_nodes = session.config.getoption("--redis-cluster-nodes") + wait_for_cluster_creation(redis_url, cluster_nodes) + + +def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): + """ + Waits for the cluster creation to complete. + As soon as all :cluster_nodes: nodes become available, the cluster will be + considered ready. + :param redis_url: the cluster's url, e.g. redis://localhost:16379/0 + :param cluster_nodes: The number of nodes in the cluster + :param timeout: the amount of time to wait (in seconds) + """ + now = time.time() + end_time = now + timeout + client = None + print("Waiting for {0} cluster nodes to become available". + format(cluster_nodes)) + while now < end_time: + try: + client = redis.RedisCluster.from_url(redis_url) + if len(client.get_nodes()) == cluster_nodes: + print("All nodes are available!") + break + except RedisClusterException: + pass + time.sleep(1) + now = time.time() + if now >= end_time: + available_nodes = 0 if client is None else len(client.get_nodes()) + raise RedisClusterException( + "The cluster did not become available after {0} seconds. " + "Only {1} nodes out of {2} are available".format( + timeout, available_nodes, cluster_nodes)) + def skip_if_server_version_lt(min_version): redis_version = REDIS_INFO["version"] @@ -101,13 +149,12 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): def skip_if_redis_enterprise(func): check = REDIS_INFO["enterprise"] is True - return pytest.mark.skipif(check, reason="Redis enterprise" - ) + return pytest.mark.skipif(check, reason="Redis enterprise") def skip_ifnot_redis_enterprise(func): check = REDIS_INFO["enterprise"] is False - return pytest.mark.skipif(check, reason="Redis enterprise") + return pytest.mark.skipif(check, reason="Not running in redis enterprise") def _get_client(cls, request, single_connection_client=True, flushdb=True, @@ -124,27 +171,47 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True, redis_url = request.config.getoption("--redis-url") else: redis_url = from_url - url_options = parse_url(redis_url) - url_options.update(kwargs) - pool = redis.ConnectionPool(**url_options) - client = cls(connection_pool=pool) + cluster_mode = REDIS_INFO["cluster_enabled"] + if not cluster_mode: + url_options = parse_url(redis_url) + url_options.update(kwargs) + pool = redis.ConnectionPool(**url_options) + client = cls(connection_pool=pool) + else: + client = redis.RedisCluster.from_url(redis_url, **kwargs) + single_connection_client = False if single_connection_client: client = client.client() if request: def teardown(): - if flushdb: - try: - client.flushdb() - except redis.ConnectionError: - # handle cases where a test disconnected a client - # just manually retry the flushdb - client.flushdb() - client.close() - client.connection_pool.disconnect() + if not cluster_mode: + if flushdb: + try: + client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + client.flushdb() + client.close() + client.connection_pool.disconnect() + else: + cluster_teardown(client, flushdb) request.addfinalizer(teardown) return client +def cluster_teardown(client, flushdb): + if flushdb: + try: + client.flushdb(target_nodes='primaries') + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + client.flushdb(target_nodes='primaries') + client.close() + client.disconnect_connection_pools() + + # specifically set to the zero database, because creating # an index on db != 0 raises a ResponseError in redis @pytest.fixture() |