diff options
Diffstat (limited to 'taskflow/engines/worker_based/executor.py')
| -rw-r--r-- | taskflow/engines/worker_based/executor.py | 90 |
1 files changed, 23 insertions, 67 deletions
diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index cdc6361..cda3745 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -15,15 +15,13 @@ # under the License. import functools -import threading -from oslo_utils import reflection from oslo_utils import timeutils from taskflow.engines.action_engine import executor -from taskflow.engines.worker_based import cache from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import proxy +from taskflow.engines.worker_based import types as wt from taskflow import exceptions as exc from taskflow import logging from taskflow import task as task_atom @@ -34,35 +32,6 @@ from taskflow.utils import threading_utils as tu LOG = logging.getLogger(__name__) -class PeriodicWorker(object): - """Calls a set of functions when activated periodically. - - NOTE(harlowja): the provided timeout object determines the periodicity. - """ - def __init__(self, timeout, functors): - self._timeout = timeout - self._functors = [] - for f in functors: - self._functors.append((f, reflection.get_callable_name(f))) - - def start(self): - while not self._timeout.is_stopped(): - for (f, f_name) in self._functors: - LOG.debug("Calling periodic function '%s'", f_name) - try: - f() - except Exception: - LOG.warn("Failed to call periodic function '%s'", f_name, - exc_info=True) - self._timeout.wait() - - def stop(self): - self._timeout.interrupt() - - def reset(self): - self._timeout.reset() - - class WorkerTaskExecutor(executor.TaskExecutor): """Executes tasks on remote workers.""" @@ -72,10 +41,9 @@ class WorkerTaskExecutor(executor.TaskExecutor): retry_options=None): self._uuid = uuid self._topics = topics - self._requests_cache = cache.RequestsCache() + self._requests_cache = wt.RequestsCache() + self._workers = wt.TopicWorkers() self._transition_timeout = transition_timeout - self._workers_cache = cache.WorkersCache() - self._workers_arrival = threading.Condition() type_handlers = { pr.NOTIFY: [ self._process_notify, @@ -92,8 +60,8 @@ class WorkerTaskExecutor(executor.TaskExecutor): transport_options=transport_options, retry_options=retry_options) self._proxy_thread = None - self._periodic = PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD), - [self._notify_topics]) + self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD), + [self._notify_topics]) self._periodic_thread = None def _process_notify(self, notify, message): @@ -104,16 +72,15 @@ class WorkerTaskExecutor(executor.TaskExecutor): tasks = notify['tasks'] # Add worker info to the cache - LOG.debug("Received that tasks %s can be processed by topic '%s'", - tasks, topic) - with self._workers_arrival: - self._workers_cache[topic] = tasks - self._workers_arrival.notify_all() + worker = self._workers.add(topic, tasks) + LOG.debug("Received notification about worker '%s' (%s" + " total workers are currently known)", worker, + len(self._workers)) # Publish waiting requests - for request in self._requests_cache.get_waiting_requests(tasks): + for request in self._requests_cache.get_waiting_requests(worker): if request.transition_and_log_error(pr.PENDING, logger=LOG): - self._publish_request(request, topic) + self._publish_request(request, worker) def _process_response(self, response, message): """Process response from remote side.""" @@ -147,7 +114,7 @@ class WorkerTaskExecutor(executor.TaskExecutor): del self._requests_cache[request.uuid] request.set_result(**response.data) else: - LOG.warning("Unexpected response status: '%s'", + LOG.warning("Unexpected response status '%s'", response.state) else: LOG.debug("Request with id='%s' not found", task_uuid) @@ -196,16 +163,16 @@ class WorkerTaskExecutor(executor.TaskExecutor): progress_callback) request.result.add_done_callback(lambda fut: cleaner()) - # Get task's topic and publish request if topic was found. - topic = self._workers_cache.get_topic_by_task(request.task_cls) - if topic is not None: + # Get task's worker and publish request if worker was found. + worker = self._workers.get_worker_for_task(task) + if worker is not None: # NOTE(skudriashev): Make sure request is set to the PENDING state # before putting it into the requests cache to prevent the notify # processing thread get list of waiting requests and publish it # before it is published here, so it wouldn't be published twice. if request.transition_and_log_error(pr.PENDING, logger=LOG): self._requests_cache[request.uuid] = request - self._publish_request(request, topic) + self._publish_request(request, worker) else: LOG.debug("Delaying submission of '%s', no currently known" " worker/s available to process it", request) @@ -213,14 +180,14 @@ class WorkerTaskExecutor(executor.TaskExecutor): return request.result - def _publish_request(self, request, topic): + def _publish_request(self, request, worker): """Publish request to a given topic.""" - LOG.debug("Submitting execution of '%s' to topic '%s' (expecting" + LOG.debug("Submitting execution of '%s' to worker '%s' (expecting" " response identified by reply_to=%s and" - " correlation_id=%s)", request, topic, self._uuid, + " correlation_id=%s)", request, worker, self._uuid, request.uuid) try: - self._proxy.publish(request, topic, + self._proxy.publish(request, worker.topic, reply_to=self._uuid, correlation_id=request.uuid) except Exception: @@ -255,20 +222,7 @@ class WorkerTaskExecutor(executor.TaskExecutor): return how many workers are still needed, otherwise it will return zero. """ - if workers <= 0: - raise ValueError("Worker amount must be greater than zero") - w = None - if timeout is not None: - w = tt.StopWatch(timeout).start() - with self._workers_arrival: - while len(self._workers_cache) < workers: - if w is not None and w.expired(): - return workers - len(self._workers_cache) - timeout = None - if w is not None: - timeout = w.leftover() - self._workers_arrival.wait(timeout) - return 0 + return self._workers.wait_for_workers(workers=workers, timeout=timeout) def start(self): """Starts proxy thread and associated topic notification thread.""" @@ -291,3 +245,5 @@ class WorkerTaskExecutor(executor.TaskExecutor): self._proxy.stop() self._proxy_thread.join() self._proxy_thread = None + self._requests_cache.clear(self._handle_expired_request) + self._workers.clear() |
