summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSelwin Ong <selwin.ong@gmail.com>2023-04-25 19:38:44 +0700
committerGitHub <noreply@github.com>2023-04-25 19:38:44 +0700
commit77e926c4248590933e80ecd527e9c2f487261d2c (patch)
tree1c48313a1f27887fddec84beccc6425a3a8f6762
parent95983cfcacd7178004fa7c4325cd340ef360d5c0 (diff)
downloadrq-77e926c4248590933e80ecd527e9c2f487261d2c.tar.gz
Added parse_connection function (#1884)
* Added parse_connection function * feat: allow custom connection pool class (#1885) * Added test for SSL --------- Co-authored-by: Cyril Chapellier <tchapi@users.noreply.github.com>
-rw-r--r--rq/connections.py32
-rw-r--r--rq/scheduler.py49
-rw-r--r--tests/test_connection.py12
-rw-r--r--tests/test_scheduler.py44
4 files changed, 101 insertions, 36 deletions
diff --git a/rq/connections.py b/rq/connections.py
index 0d39e43..dfb590a 100644
--- a/rq/connections.py
+++ b/rq/connections.py
@@ -1,8 +1,8 @@
import warnings
from contextlib import contextmanager
-from typing import Optional
+from typing import Any, Optional, Tuple, Type
-from redis import Redis
+from redis import Connection as RedisConnection, Redis, SSLConnection, UnixDomainSocketConnection
from .local import LocalStack
@@ -42,10 +42,9 @@ def Connection(connection: Optional['Redis'] = None): # noqa
yield
finally:
popped = pop_connection()
- assert popped == connection, (
- 'Unexpected Redis connection was popped off the stack. '
- 'Check your Redis connection setup.'
- )
+ assert (
+ popped == connection
+ ), 'Unexpected Redis connection was popped off the stack. Check your Redis connection setup.'
def push_connection(redis: 'Redis'):
@@ -118,8 +117,27 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
return connection
+def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
+ connection_kwargs = connection.connection_pool.connection_kwargs.copy()
+ # Redis does not accept parser_class argument which is sometimes present
+ # on connection_pool kwargs, for example when hiredis is used
+ connection_kwargs.pop('parser_class', None)
+ connection_pool_class = connection.connection_pool.connection_class
+ if issubclass(connection_pool_class, SSLConnection):
+ connection_kwargs['ssl'] = True
+ if issubclass(connection_pool_class, UnixDomainSocketConnection):
+ # The connection keyword arguments are obtained from
+ # `UnixDomainSocketConnection`, which expects `path`, but passed to
+ # `redis.client.Redis`, which expects `unix_socket_path`, renaming
+ # the key is necessary.
+ # `path` is not left in the dictionary as that keyword argument is
+ # not expected by `redis.client.Redis` and would raise an exception.
+ connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
+
+ return connection.__class__, connection_pool_class, connection_kwargs
+
+
_connection_stack = LocalStack()
__all__ = ['Connection', 'get_current_connection', 'push_connection', 'pop_connection']
-
diff --git a/rq/scheduler.py b/rq/scheduler.py
index 86566ba..ec54c1a 100644
--- a/rq/scheduler.py
+++ b/rq/scheduler.py
@@ -7,8 +7,9 @@ from datetime import datetime
from enum import Enum
from multiprocessing import Process
-from redis import SSLConnection, UnixDomainSocketConnection
+from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
+from .connections import parse_connection
from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, DEFAULT_SCHEDULER_FALLBACK_PERIOD
from .job import Job
from .logutils import setup_loghandlers
@@ -35,37 +36,24 @@ class RQScheduler:
Status = SchedulerStatus
def __init__(
- self,
- queues,
- connection,
- interval=1,
- logging_level=logging.INFO,
- date_format=DEFAULT_LOGGING_DATE_FORMAT,
- log_format=DEFAULT_LOGGING_FORMAT,
- serializer=None,
+ self,
+ queues,
+ connection: Redis,
+ interval=1,
+ logging_level=logging.INFO,
+ date_format=DEFAULT_LOGGING_DATE_FORMAT,
+ log_format=DEFAULT_LOGGING_FORMAT,
+ serializer=None,
):
self._queue_names = set(parse_names(queues))
self._acquired_locks = set()
self._scheduled_job_registries = []
self.lock_acquisition_time = None
- # Copy the connection kwargs before mutating them in order to not change the arguments
- # used by the current connection pool to create new connections
- self._connection_kwargs = connection.connection_pool.connection_kwargs.copy()
- # Redis does not accept parser_class argument which is sometimes present
- # on connection_pool kwargs, for example when hiredis is used
- self._connection_kwargs.pop('parser_class', None)
- self._connection_class = connection.__class__ # client
- connection_class = connection.connection_pool.connection_class
- if issubclass(connection_class, SSLConnection):
- self._connection_kwargs['ssl'] = True
- if issubclass(connection_class, UnixDomainSocketConnection):
- # The connection keyword arguments are obtained from
- # `UnixDomainSocketConnection`, which expects `path`, but passed to
- # `redis.client.Redis`, which expects `unix_socket_path`, renaming
- # the key is necessary.
- # `path` is not left in the dictionary as that keyword argument is
- # not expected by `redis.client.Redis` and would raise an exception.
- self._connection_kwargs['unix_socket_path'] = self._connection_kwargs.pop('path')
+ (
+ self._connection_class,
+ self._connection_pool_class,
+ self._connection_kwargs,
+ ) = parse_connection(connection)
self.serializer = resolve_serializer(serializer)
self._connection = None
@@ -85,7 +73,12 @@ class RQScheduler:
def connection(self):
if self._connection:
return self._connection
- self._connection = self._connection_class(**self._connection_kwargs)
+ self._connection = self._connection_class(
+ connection_pool=ConnectionPool(
+ connection_class=self._connection_pool_class,
+ **self._connection_kwargs
+ )
+ )
return self._connection
@property
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 393c20d..4b4ba8e 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -1,6 +1,7 @@
-from redis import Redis
+from redis import ConnectionPool, Redis, UnixDomainSocketConnection
from rq import Connection, Queue
+from rq.connections import parse_connection
from tests import RQTestCase, find_empty_redis_database
from tests.fixtures import do_nothing
@@ -35,3 +36,12 @@ class TestConnectionInheritance(RQTestCase):
job2 = q2.enqueue(do_nothing)
self.assertEqual(q1.connection, job1.connection)
self.assertEqual(q2.connection, job2.connection)
+
+ def test_parse_connection(self):
+ """Test parsing `ssl` and UnixDomainSocketConnection"""
+ _, _, kwargs = parse_connection(Redis(ssl=True))
+ self.assertTrue(kwargs['ssl'])
+ path = '/tmp/redis.sock'
+ pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
+ _, _, kwargs = parse_connection(Redis(connection_pool=pool))
+ self.assertTrue(kwargs['unix_socket_path'], path)
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
index a907ff5..c417554 100644
--- a/tests/test_scheduler.py
+++ b/tests/test_scheduler.py
@@ -1,4 +1,6 @@
import os
+import redis
+
from datetime import datetime, timedelta, timezone
from multiprocessing import Process
from unittest import mock
@@ -16,6 +18,17 @@ from tests import RQTestCase, find_empty_redis_database, ssl_test
from .fixtures import kill_worker, say_hello
+class CustomRedisConnection(redis.Connection):
+ """Custom redis connection with a custom arg, used in test_custom_connection_pool"""
+
+ def __init__(self, *args, custom_arg=None, **kwargs):
+ self.custom_arg = custom_arg
+ super().__init__(*args, **kwargs)
+
+ def get_custom_arg(self):
+ return self.custom_arg
+
+
class TestScheduledJobRegistry(RQTestCase):
def test_get_jobs_to_enqueue(self):
@@ -431,3 +444,34 @@ class TestQueue(RQTestCase):
job = queue.enqueue_in(timedelta(seconds=30), say_hello, retry=Retry(3, [2]))
self.assertEqual(job.retries_left, 3)
self.assertEqual(job.retry_intervals, [2])
+
+ def test_custom_connection_pool(self):
+ """Connection pool customizing. Ensure that we can properly set a
+ custom connection pool class and pass extra arguments"""
+ custom_conn = redis.Redis(
+ connection_pool=redis.ConnectionPool(
+ connection_class=CustomRedisConnection,
+ db=4,
+ custom_arg="foo",
+ )
+ )
+
+ queue = Queue(connection=custom_conn)
+ scheduler = RQScheduler([queue], connection=custom_conn)
+
+ scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
+
+ self.assertEqual(scheduler_connection.__class__, CustomRedisConnection)
+ self.assertEqual(scheduler_connection.get_custom_arg(), "foo")
+
+ def test_no_custom_connection_pool(self):
+ """Connection pool customizing must not interfere if we're using a standard
+ connection (non-pooled)"""
+ standard_conn = redis.Redis(db=5)
+
+ queue = Queue(connection=standard_conn)
+ scheduler = RQScheduler([queue], connection=standard_conn)
+
+ scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
+
+ self.assertEqual(scheduler_connection.__class__, redis.Connection)