summaryrefslogtreecommitdiff
path: root/tests/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/conftest.py')
-rw-r--r--tests/conftest.py101
1 files changed, 84 insertions, 17 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 31d3fbd..ddc0834 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,8 +3,10 @@ from redis.retry import Retry
import pytest
import random
import redis
+import time
from distutils.version import LooseVersion
from redis.connection import parse_url
+from redis.exceptions import RedisClusterException
from unittest.mock import Mock
from urllib.parse import urlparse
@@ -13,6 +15,7 @@ REDIS_INFO = {}
default_redis_url = "redis://localhost:6379/9"
default_redismod_url = "redis://localhost:36379"
+default_cluster_nodes = 6
def pytest_addoption(parser):
@@ -27,11 +30,18 @@ def pytest_addoption(parser):
" with loaded modules,"
" defaults to `%(default)s`")
+ parser.addoption('--redis-cluster-nodes', default=default_cluster_nodes,
+ action="store",
+ help="The number of cluster nodes that need to be "
+ "available before the test can start,"
+ " defaults to `%(default)s`")
+
def _get_info(redis_url):
client = redis.Redis.from_url(redis_url)
info = client.info()
- if 'dping' in client.__commands__:
+ cmds = [command.upper() for command in client.command().keys()]
+ if 'dping' in cmds:
info["enterprise"] = True
else:
info["enterprise"] = False
@@ -44,8 +54,10 @@ def pytest_sessionstart(session):
info = _get_info(redis_url)
version = info["redis_version"]
arch_bits = info["arch_bits"]
+ cluster_enabled = info["cluster_enabled"]
REDIS_INFO["version"] = version
REDIS_INFO["arch_bits"] = arch_bits
+ REDIS_INFO["cluster_enabled"] = cluster_enabled
REDIS_INFO["enterprise"] = info["enterprise"]
# module info, if the second redis is running
@@ -58,6 +70,42 @@ def pytest_sessionstart(session):
except KeyError:
pass
+ if cluster_enabled:
+ cluster_nodes = session.config.getoption("--redis-cluster-nodes")
+ wait_for_cluster_creation(redis_url, cluster_nodes)
+
+
+def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20):
+ """
+ Waits for the cluster creation to complete.
+ As soon as all :cluster_nodes: nodes become available, the cluster will be
+ considered ready.
+ :param redis_url: the cluster's url, e.g. redis://localhost:16379/0
+ :param cluster_nodes: The number of nodes in the cluster
+ :param timeout: the amount of time to wait (in seconds)
+ """
+ now = time.time()
+ end_time = now + timeout
+ client = None
+ print("Waiting for {0} cluster nodes to become available".
+ format(cluster_nodes))
+ while now < end_time:
+ try:
+ client = redis.RedisCluster.from_url(redis_url)
+ if len(client.get_nodes()) == cluster_nodes:
+ print("All nodes are available!")
+ break
+ except RedisClusterException:
+ pass
+ time.sleep(1)
+ now = time.time()
+ if now >= end_time:
+ available_nodes = 0 if client is None else len(client.get_nodes())
+ raise RedisClusterException(
+ "The cluster did not become available after {0} seconds. "
+ "Only {1} nodes out of {2} are available".format(
+ timeout, available_nodes, cluster_nodes))
+
def skip_if_server_version_lt(min_version):
redis_version = REDIS_INFO["version"]
@@ -101,13 +149,12 @@ def skip_ifmodversion_lt(min_version: str, module_name: str):
def skip_if_redis_enterprise(func):
check = REDIS_INFO["enterprise"] is True
- return pytest.mark.skipif(check, reason="Redis enterprise"
- )
+ return pytest.mark.skipif(check, reason="Redis enterprise")
def skip_ifnot_redis_enterprise(func):
check = REDIS_INFO["enterprise"] is False
- return pytest.mark.skipif(check, reason="Redis enterprise")
+ return pytest.mark.skipif(check, reason="Not running in redis enterprise")
def _get_client(cls, request, single_connection_client=True, flushdb=True,
@@ -124,27 +171,47 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True,
redis_url = request.config.getoption("--redis-url")
else:
redis_url = from_url
- url_options = parse_url(redis_url)
- url_options.update(kwargs)
- pool = redis.ConnectionPool(**url_options)
- client = cls(connection_pool=pool)
+ cluster_mode = REDIS_INFO["cluster_enabled"]
+ if not cluster_mode:
+ url_options = parse_url(redis_url)
+ url_options.update(kwargs)
+ pool = redis.ConnectionPool(**url_options)
+ client = cls(connection_pool=pool)
+ else:
+ client = redis.RedisCluster.from_url(redis_url, **kwargs)
+ single_connection_client = False
if single_connection_client:
client = client.client()
if request:
def teardown():
- if flushdb:
- try:
- client.flushdb()
- except redis.ConnectionError:
- # handle cases where a test disconnected a client
- # just manually retry the flushdb
- client.flushdb()
- client.close()
- client.connection_pool.disconnect()
+ if not cluster_mode:
+ if flushdb:
+ try:
+ client.flushdb()
+ except redis.ConnectionError:
+ # handle cases where a test disconnected a client
+ # just manually retry the flushdb
+ client.flushdb()
+ client.close()
+ client.connection_pool.disconnect()
+ else:
+ cluster_teardown(client, flushdb)
request.addfinalizer(teardown)
return client
+def cluster_teardown(client, flushdb):
+ if flushdb:
+ try:
+ client.flushdb(target_nodes='primaries')
+ except redis.ConnectionError:
+ # handle cases where a test disconnected a client
+ # just manually retry the flushdb
+ client.flushdb(target_nodes='primaries')
+ client.close()
+ client.disconnect_connection_pools()
+
+
# specifically set to the zero database, because creating
# an index on db != 0 raises a ResponseError in redis
@pytest.fixture()