Changing the way workers are forked

pull/13567/head
James Cammarata 9 years ago
parent 3214ef8832
commit c8e6461dee

@ -60,6 +60,7 @@ if __name__ == '__main__':
try:
display = Display()
display.debug("starting run")
sub = None
try:

@ -59,14 +59,18 @@ class WorkerProcess(multiprocessing.Process):
for reading later.
'''
def __init__(self, tqm, main_q, rslt_q, hostvars_manager, loader):
def __init__(self, rslt_q, task_vars, host, task, play_context, loader, variable_manager, shared_loader_obj):
super(WorkerProcess, self).__init__()
# takes a task queue manager as the sole param:
self._main_q = main_q
self._rslt_q = rslt_q
self._hostvars = hostvars_manager
self._loader = loader
self._rslt_q = rslt_q
self._task_vars = task_vars
self._host = host
self._task = task
self._play_context = play_context
self._loader = loader
self._variable_manager = variable_manager
self._shared_loader_obj = shared_loader_obj
# dupe stdin, if we have one
self._new_stdin = sys.stdin
@ -97,73 +101,45 @@ class WorkerProcess(multiprocessing.Process):
if HAS_ATFORK:
atfork()
while True:
task = None
try:
#debug("waiting for work")
(host, task, basedir, zip_vars, compressed_vars, play_context, shared_loader_obj) = self._main_q.get(block=False)
if compressed_vars:
job_vars = json.loads(zlib.decompress(zip_vars))
else:
job_vars = zip_vars
job_vars['hostvars'] = self._hostvars.hostvars()
debug("there's work to be done! got a task/handler to work on: %s" % task)
# because the task queue manager starts workers (forks) before the
# playbook is loaded, set the basedir of the loader inherted by
# this fork now so that we can find files correctly
self._loader.set_basedir(basedir)
# Serializing/deserializing tasks does not preserve the loader attribute,
# since it is passed to the worker during the forking of the process and
# would be wasteful to serialize. So we set it here on the task now, and
# the task handles updating parent/child objects as needed.
task.set_loader(self._loader)
# execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (host, task))
executor_result = TaskExecutor(
host,
task,
job_vars,
play_context,
self._new_stdin,
self._loader,
shared_loader_obj,
).run()
debug("done running TaskExecutor() for %s/%s" % (host, task))
task_result = TaskResult(host, task, executor_result)
# put the result on the result queue
debug("sending task result")
self._rslt_q.put(task_result)
debug("done sending task result")
except queue.Empty:
time.sleep(0.0001)
except AnsibleConnectionFailure:
try:
# execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (self._host, self._task))
executor_result = TaskExecutor(
self._host,
self._task,
self._task_vars,
self._play_context,
self._new_stdin,
self._loader,
self._shared_loader_obj,
).run()
debug("done running TaskExecutor() for %s/%s" % (self._host, self._task))
self._host.vars = dict()
self._host.groups = []
task_result = TaskResult(self._host, self._task, executor_result)
# put the result on the result queue
debug("sending task result")
self._rslt_q.put(task_result)
debug("done sending task result")
except AnsibleConnectionFailure:
self._host.vars = dict()
self._host.groups = []
task_result = TaskResult(self._host, self._task, dict(unreachable=True))
self._rslt_q.put(task_result, block=False)
except Exception as e:
if not isinstance(e, (IOError, EOFError, KeyboardInterrupt)) or isinstance(e, TemplateNotFound):
try:
if task:
task_result = TaskResult(host, task, dict(unreachable=True))
self._rslt_q.put(task_result, block=False)
self._host.vars = dict()
self._host.groups = []
task_result = TaskResult(self._host, self._task, dict(failed=True, exception=traceback.format_exc(), stdout=''))
self._rslt_q.put(task_result, block=False)
except:
break
except Exception as e:
if isinstance(e, (IOError, EOFError, KeyboardInterrupt)) and not isinstance(e, TemplateNotFound):
break
else:
try:
if task:
task_result = TaskResult(host, task, dict(failed=True, exception=traceback.format_exc(), stdout=''))
self._rslt_q.put(task_result, block=False)
except:
debug("WORKER EXCEPTION: %s" % e)
debug("WORKER EXCEPTION: %s" % traceback.format_exc())
break
debug("WORKER EXCEPTION: %s" % e)
debug("WORKER EXCEPTION: %s" % traceback.format_exc())
debug("WORKER PROCESS EXITING")

@ -102,11 +102,7 @@ class TaskQueueManager:
for i in xrange(num):
main_q = multiprocessing.Queue()
rslt_q = multiprocessing.Queue()
prc = WorkerProcess(self, main_q, rslt_q, self._hostvars_manager, self._loader)
prc.start()
self._workers.append((prc, main_q, rslt_q))
self._workers.append([None, main_q, rslt_q])
self._result_prc = ResultProcess(self._final_q, self._workers)
self._result_prc.start()
@ -195,31 +191,12 @@ class TaskQueueManager:
new_play = play.copy()
new_play.post_validate(templar)
class HostVarsManager(SyncManager):
pass
hostvars = HostVars(
self.hostvars = HostVars(
inventory=self._inventory,
variable_manager=self._variable_manager,
loader=self._loader,
)
HostVarsManager.register(
'hostvars',
callable=lambda: hostvars,
# FIXME: this is the list of exposed methods to the DictProxy object, plus our
# special ones (set_variable_manager/set_inventory). There's probably a better way
# to do this with a proper BaseProxy/DictProxy derivative
exposed=(
'set_variable_manager', 'set_inventory', '__contains__', '__delitem__',
'set_nonpersistent_facts', 'set_host_facts', 'set_host_variable',
'__getitem__', '__len__', '__setitem__', 'clear', 'copy', 'get', 'has_key',
'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values'
),
)
self._hostvars_manager = HostVarsManager()
self._hostvars_manager.start()
# Fork # of forks, # of hosts or serial, whichever is lowest
contenders = [self._options.forks, play.serial, len(self._inventory.get_hosts(new_play.hosts))]
contenders = [ v for v in contenders if v is not None and v > 0 ]
@ -259,7 +236,6 @@ class TaskQueueManager:
# and run the play using the strategy and cleanup on way out
play_return = strategy.run(iterator, play_context)
self._cleanup_processes()
self._hostvars_manager.shutdown()
return play_return
def cleanup(self):
@ -275,7 +251,8 @@ class TaskQueueManager:
for (worker_prc, main_q, rslt_q) in self._workers:
rslt_q.close()
main_q.close()
worker_prc.terminate()
if worker_prc and worker_prc.is_alive():
worker_prc.terminate()
def clear_failed_hosts(self):
self._failed_hosts = dict()

@ -31,6 +31,7 @@ from jinja2.exceptions import UndefinedError
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
from ansible.executor.play_iterator import PlayIterator
from ansible.executor.process.worker import WorkerProcess
from ansible.executor.task_result import TaskResult
from ansible.inventory.host import Host
from ansible.inventory.group import Group
@ -138,38 +139,29 @@ class StrategyBase:
display.debug("entering _queue_task() for %s/%s" % (host, task))
task_vars['hostvars'] = self._tqm.hostvars
# and then queue the new task
display.debug("%s - putting task (%s) in queue" % (host, task))
try:
display.debug("worker is %d (out of %d available)" % (self._cur_worker+1, len(self._workers)))
(worker_prc, main_q, rslt_q) = self._workers[self._cur_worker]
self._cur_worker += 1
if self._cur_worker >= len(self._workers):
self._cur_worker = 0
# create a dummy object with plugin loaders set as an easier
# way to share them with the forked processes
shared_loader_obj = SharedPluginLoaderObj()
# compress (and convert) the data if so configured, which can
# help a lot when the variable dictionary is huge. We pop the
# hostvars out of the task variables right now, due to the fact
# that they're not JSON serializable
compressed_vars = False
if C.DEFAULT_VAR_COMPRESSION_LEVEL > 0:
zip_vars = zlib.compress(json.dumps(task_vars), C.DEFAULT_VAR_COMPRESSION_LEVEL)
compressed_vars = True
# we're done with the original dict now, so delete it to
# try and reclaim some memory space, which is helpful if the
# data contained in the dict is very large
del task_vars
else:
zip_vars = task_vars # noqa (pyflakes false positive because task_vars is deleted in the conditional above)
# and queue the task
main_q.put((host, task, self._loader.get_basedir(), zip_vars, compressed_vars, play_context, shared_loader_obj))
while True:
(worker_prc, main_q, rslt_q) = self._workers[self._cur_worker]
if worker_prc is None or not worker_prc.is_alive():
worker_prc = WorkerProcess(rslt_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj)
self._workers[self._cur_worker][0] = worker_prc
worker_prc.start()
break
self._cur_worker += 1
if self._cur_worker >= len(self._workers):
self._cur_worker = 0
time.sleep(0.0001)
del task_vars
self._pending_results += 1
except (EOFError, IOError, AssertionError) as e:
# most likely an abort
@ -177,7 +169,7 @@ class StrategyBase:
return
display.debug("exiting _queue_task() for %s/%s" % (host, task))
def _process_pending_results(self, iterator):
def _process_pending_results(self, iterator, one_pass=False):
'''
Reads results off the final queue and takes appropriate action
based on the result (executing callbacks, updating state, etc.).
@ -247,13 +239,11 @@ class StrategyBase:
new_host_info = result_item.get('add_host', dict())
self._add_host(new_host_info, iterator)
self._tqm._hostvars_manager.hostvars().set_inventory(self._inventory)
elif result[0] == 'add_group':
host = result[1]
result_item = result[2]
self._add_group(host, result_item)
self._tqm._hostvars_manager.hostvars().set_inventory(self._inventory)
elif result[0] == 'notify_handler':
task_result = result[1]
@ -283,7 +273,6 @@ class StrategyBase:
for target_host in host_list:
self._variable_manager.set_nonpersistent_facts(target_host, {var_name: var_value})
self._tqm._hostvars_manager.hostvars().set_nonpersistent_facts(target_host, {var_name: var_value})
elif result[0] in ('set_host_var', 'set_host_facts'):
host = result[1]
@ -316,21 +305,22 @@ class StrategyBase:
for target_host in host_list:
self._variable_manager.set_host_variable(target_host, var_name, var_value)
self._tqm._hostvars_manager.hostvars().set_host_variable(target_host, var_name, var_value)
elif result[0] == 'set_host_facts':
facts = result[4]
if task.action == 'set_fact':
self._variable_manager.set_nonpersistent_facts(actual_host, facts)
self._tqm._hostvars_manager.hostvars().set_nonpersistent_facts(actual_host, facts)
else:
self._variable_manager.set_host_facts(actual_host, facts)
self._tqm._hostvars_manager.hostvars().set_host_facts(actual_host, facts)
else:
raise AnsibleError("unknown result message received: %s" % result[0])
except Queue.Empty:
time.sleep(0.0001)
if one_pass:
break
return ret_results
def _wait_on_pending_results(self, iterator):

@ -169,6 +169,7 @@ class StrategyModule(StrategyBase):
skip_rest = False
choose_step = True
results = []
for (host, task) in host_tasks:
if not task:
continue
@ -243,12 +244,14 @@ class StrategyModule(StrategyBase):
if run_once:
break
results += self._process_pending_results(iterator, one_pass=True)
# go to next host/task group
if skip_rest:
continue
display.debug("done queuing things up, now waiting for results queue to drain")
results = self._wait_on_pending_results(iterator)
results += self._wait_on_pending_results(iterator)
host_results.extend(results)
if not work_to_do and len(iterator.get_failed_hosts()) > 0:

Loading…
Cancel
Save