blob: 34622fb5d6442b100a4d7845928880932af17d2d [file] [log] [blame]
# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Task manager classes for work queues."""
from __future__ import print_function
import abc
import multiprocessing
from chromite.lib import cros_logging as logging
def _ExecuteTask(handler, request_data):
"""Wrapper for the task handler function."""
root_logger = logging.getLogger()
for h in list(root_logger.handlers):
root_logger.removeHandler(h)
try:
return handler(request_data)
except Exception as e:
return e
class TaskManager(object):
"""Abstract base class for task management.
`TaskManager` is responsible for managing individual work queue
requests from the time that they're scheduled to run, until they
complete or are aborted.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, handler, sample_interval):
self.sample_interval = sample_interval
self._handler = handler
@abc.abstractmethod
def StartTick(self):
"""Start the polling cycle in `WorkQueueService.ProcessRequests()`.
The work queue service's server polling loop will call this function
once per loop iteration, to mark the nominal start of the polling
cycle.
"""
@abc.abstractmethod
def HasCapacity(self):
"""Return whether there is capacity to start more tasks.
Returns:
A true value if there is enough capacity for at least one
additional call to `StartTask()`.
"""
return False
@abc.abstractmethod
def StartTask(self, request_id, request_data):
"""Start work on a new task request.
Args:
request_id: Identifier for the task, used by `TerminateTask()`
and `Reap()`.
request_data: Argument to be passed to the task handler.
"""
@abc.abstractmethod
def TerminateTask(self, request_id):
"""Terminate a running task.
A terminated task will be forgotten, and will never be returned
by `Reap()`.
Args:
request_id: Identifier of the task to be terminated.
"""
@abc.abstractmethod
def Reap(self):
"""Generator to return results of all completed tasks.
Yields:
A `(request_id, return_value)` tuple.
"""
pass
class ProcessPoolTaskManager(TaskManager):
"""A task manager implemented with `multiprocessing.Pool`."""
def __init__(self, max_tasks, handler, sample_interval):
super(ProcessPoolTaskManager, self).__init__(handler, sample_interval)
self._pool = multiprocessing.Pool(max_tasks)
self._max_tasks = max_tasks
self._pending_results = {}
def StartTick(self):
pass
def HasCapacity(self):
return len(self._pending_results) < self._max_tasks
def StartTask(self, request_id, request_data):
self._pending_results[request_id] = (
self._pool.apply_async(_ExecuteTask,
(self._handler, request_data)))
def TerminateTask(self, request_id):
del self._pending_results[request_id]
def Reap(self):
for request_id, result in self._pending_results.items():
if result.ready():
del self._pending_results[request_id]
yield request_id, result.get()
def Close(self):
self._pool.terminate()
self._pool.join()