summaryrefslogtreecommitdiff
path: root/taskflow/engines/worker_based/executor.py
diff options
context:
space:
mode:
Diffstat (limited to 'taskflow/engines/worker_based/executor.py')
-rw-r--r--taskflow/engines/worker_based/executor.py90
1 files changed, 23 insertions, 67 deletions
diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py
index cdc6361..cda3745 100644
--- a/taskflow/engines/worker_based/executor.py
+++ b/taskflow/engines/worker_based/executor.py
@@ -15,15 +15,13 @@
# under the License.
import functools
-import threading
-from oslo_utils import reflection
from oslo_utils import timeutils
from taskflow.engines.action_engine import executor
-from taskflow.engines.worker_based import cache
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
+from taskflow.engines.worker_based import types as wt
from taskflow import exceptions as exc
from taskflow import logging
from taskflow import task as task_atom
@@ -34,35 +32,6 @@ from taskflow.utils import threading_utils as tu
LOG = logging.getLogger(__name__)
-class PeriodicWorker(object):
- """Calls a set of functions when activated periodically.
-
- NOTE(harlowja): the provided timeout object determines the periodicity.
- """
- def __init__(self, timeout, functors):
- self._timeout = timeout
- self._functors = []
- for f in functors:
- self._functors.append((f, reflection.get_callable_name(f)))
-
- def start(self):
- while not self._timeout.is_stopped():
- for (f, f_name) in self._functors:
- LOG.debug("Calling periodic function '%s'", f_name)
- try:
- f()
- except Exception:
- LOG.warn("Failed to call periodic function '%s'", f_name,
- exc_info=True)
- self._timeout.wait()
-
- def stop(self):
- self._timeout.interrupt()
-
- def reset(self):
- self._timeout.reset()
-
-
class WorkerTaskExecutor(executor.TaskExecutor):
"""Executes tasks on remote workers."""
@@ -72,10 +41,9 @@ class WorkerTaskExecutor(executor.TaskExecutor):
retry_options=None):
self._uuid = uuid
self._topics = topics
- self._requests_cache = cache.RequestsCache()
+ self._requests_cache = wt.RequestsCache()
+ self._workers = wt.TopicWorkers()
self._transition_timeout = transition_timeout
- self._workers_cache = cache.WorkersCache()
- self._workers_arrival = threading.Condition()
type_handlers = {
pr.NOTIFY: [
self._process_notify,
@@ -92,8 +60,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
transport_options=transport_options,
retry_options=retry_options)
self._proxy_thread = None
- self._periodic = PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
- [self._notify_topics])
+ self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
+ [self._notify_topics])
self._periodic_thread = None
def _process_notify(self, notify, message):
@@ -104,16 +72,15 @@ class WorkerTaskExecutor(executor.TaskExecutor):
tasks = notify['tasks']
# Add worker info to the cache
- LOG.debug("Received that tasks %s can be processed by topic '%s'",
- tasks, topic)
- with self._workers_arrival:
- self._workers_cache[topic] = tasks
- self._workers_arrival.notify_all()
+ worker = self._workers.add(topic, tasks)
+ LOG.debug("Received notification about worker '%s' (%s"
+ " total workers are currently known)", worker,
+ len(self._workers))
# Publish waiting requests
- for request in self._requests_cache.get_waiting_requests(tasks):
+ for request in self._requests_cache.get_waiting_requests(worker):
if request.transition_and_log_error(pr.PENDING, logger=LOG):
- self._publish_request(request, topic)
+ self._publish_request(request, worker)
def _process_response(self, response, message):
"""Process response from remote side."""
@@ -147,7 +114,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
del self._requests_cache[request.uuid]
request.set_result(**response.data)
else:
- LOG.warning("Unexpected response status: '%s'",
+ LOG.warning("Unexpected response status '%s'",
response.state)
else:
LOG.debug("Request with id='%s' not found", task_uuid)
@@ -196,16 +163,16 @@ class WorkerTaskExecutor(executor.TaskExecutor):
progress_callback)
request.result.add_done_callback(lambda fut: cleaner())
- # Get task's topic and publish request if topic was found.
- topic = self._workers_cache.get_topic_by_task(request.task_cls)
- if topic is not None:
+ # Get task's worker and publish request if worker was found.
+ worker = self._workers.get_worker_for_task(task)
+ if worker is not None:
# NOTE(skudriashev): Make sure request is set to the PENDING state
# before putting it into the requests cache to prevent the notify
# processing thread get list of waiting requests and publish it
# before it is published here, so it wouldn't be published twice.
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._requests_cache[request.uuid] = request
- self._publish_request(request, topic)
+ self._publish_request(request, worker)
else:
LOG.debug("Delaying submission of '%s', no currently known"
" worker/s available to process it", request)
@@ -213,14 +180,14 @@ class WorkerTaskExecutor(executor.TaskExecutor):
return request.result
- def _publish_request(self, request, topic):
+ def _publish_request(self, request, worker):
"""Publish request to a given topic."""
- LOG.debug("Submitting execution of '%s' to topic '%s' (expecting"
+ LOG.debug("Submitting execution of '%s' to worker '%s' (expecting"
" response identified by reply_to=%s and"
- " correlation_id=%s)", request, topic, self._uuid,
+ " correlation_id=%s)", request, worker, self._uuid,
request.uuid)
try:
- self._proxy.publish(request, topic,
+ self._proxy.publish(request, worker.topic,
reply_to=self._uuid,
correlation_id=request.uuid)
except Exception:
@@ -255,20 +222,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
return how many workers are still needed, otherwise it will
return zero.
"""
- if workers <= 0:
- raise ValueError("Worker amount must be greater than zero")
- w = None
- if timeout is not None:
- w = tt.StopWatch(timeout).start()
- with self._workers_arrival:
- while len(self._workers_cache) < workers:
- if w is not None and w.expired():
- return workers - len(self._workers_cache)
- timeout = None
- if w is not None:
- timeout = w.leftover()
- self._workers_arrival.wait(timeout)
- return 0
+ return self._workers.wait_for_workers(workers=workers, timeout=timeout)
def start(self):
"""Starts proxy thread and associated topic notification thread."""
@@ -291,3 +245,5 @@ class WorkerTaskExecutor(executor.TaskExecutor):
self._proxy.stop()
self._proxy_thread.join()
self._proxy_thread = None
+ self._requests_cache.clear(self._handle_expired_request)
+ self._workers.clear()