Changing the way workers are forked

pull/13529/head
James Cammarata 9 years ago
parent ae988ed753
commit 120b9a7ac6

@ -60,6 +60,7 @@ if __name__ == '__main__':
try: try:
display = Display() display = Display()
display.debug("starting run")
sub = None sub = None
try: try:

@ -59,14 +59,18 @@ class WorkerProcess(multiprocessing.Process):
for reading later. for reading later.
''' '''
def __init__(self, tqm, main_q, rslt_q, hostvars_manager, loader): def __init__(self, rslt_q, task_vars, host, task, play_context, loader, variable_manager, shared_loader_obj):
super(WorkerProcess, self).__init__() super(WorkerProcess, self).__init__()
# takes a task queue manager as the sole param: # takes a task queue manager as the sole param:
self._main_q = main_q
self._rslt_q = rslt_q self._rslt_q = rslt_q
self._hostvars = hostvars_manager self._task_vars = task_vars
self._host = host
self._task = task
self._play_context = play_context
self._loader = loader self._loader = loader
self._variable_manager = variable_manager
self._shared_loader_obj = shared_loader_obj
# dupe stdin, if we have one # dupe stdin, if we have one
self._new_stdin = sys.stdin self._new_stdin = sys.stdin
@ -97,73 +101,45 @@ class WorkerProcess(multiprocessing.Process):
if HAS_ATFORK: if HAS_ATFORK:
atfork() atfork()
while True:
task = None
try: try:
#debug("waiting for work")
(host, task, basedir, zip_vars, compressed_vars, play_context, shared_loader_obj) = self._main_q.get(block=False)
if compressed_vars:
job_vars = json.loads(zlib.decompress(zip_vars))
else:
job_vars = zip_vars
job_vars['hostvars'] = self._hostvars.hostvars()
debug("there's work to be done! got a task/handler to work on: %s" % task)
# because the task queue manager starts workers (forks) before the
# playbook is loaded, set the basedir of the loader inherted by
# this fork now so that we can find files correctly
self._loader.set_basedir(basedir)
# Serializing/deserializing tasks does not preserve the loader attribute,
# since it is passed to the worker during the forking of the process and
# would be wasteful to serialize. So we set it here on the task now, and
# the task handles updating parent/child objects as needed.
task.set_loader(self._loader)
# execute the task and build a TaskResult from the result # execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (host, task)) debug("running TaskExecutor() for %s/%s" % (self._host, self._task))
executor_result = TaskExecutor( executor_result = TaskExecutor(
host, self._host,
task, self._task,
job_vars, self._task_vars,
play_context, self._play_context,
self._new_stdin, self._new_stdin,
self._loader, self._loader,
shared_loader_obj, self._shared_loader_obj,
).run() ).run()
debug("done running TaskExecutor() for %s/%s" % (host, task))
task_result = TaskResult(host, task, executor_result) debug("done running TaskExecutor() for %s/%s" % (self._host, self._task))
self._host.vars = dict()
self._host.groups = []
task_result = TaskResult(self._host, self._task, executor_result)
# put the result on the result queue # put the result on the result queue
debug("sending task result") debug("sending task result")
self._rslt_q.put(task_result) self._rslt_q.put(task_result)
debug("done sending task result") debug("done sending task result")
except queue.Empty:
time.sleep(0.0001)
except AnsibleConnectionFailure: except AnsibleConnectionFailure:
try: self._host.vars = dict()
if task: self._host.groups = []
task_result = TaskResult(host, task, dict(unreachable=True)) task_result = TaskResult(self._host, self._task, dict(unreachable=True))
self._rslt_q.put(task_result, block=False) self._rslt_q.put(task_result, block=False)
except:
break
except Exception as e: except Exception as e:
if isinstance(e, (IOError, EOFError, KeyboardInterrupt)) and not isinstance(e, TemplateNotFound): if not isinstance(e, (IOError, EOFError, KeyboardInterrupt)) or isinstance(e, TemplateNotFound):
break
else:
try: try:
if task: self._host.vars = dict()
task_result = TaskResult(host, task, dict(failed=True, exception=traceback.format_exc(), stdout='')) self._host.groups = []
task_result = TaskResult(self._host, self._task, dict(failed=True, exception=traceback.format_exc(), stdout=''))
self._rslt_q.put(task_result, block=False) self._rslt_q.put(task_result, block=False)
except: except:
debug("WORKER EXCEPTION: %s" % e) debug("WORKER EXCEPTION: %s" % e)
debug("WORKER EXCEPTION: %s" % traceback.format_exc()) debug("WORKER EXCEPTION: %s" % traceback.format_exc())
break
debug("WORKER PROCESS EXITING") debug("WORKER PROCESS EXITING")

@ -102,11 +102,7 @@ class TaskQueueManager:
for i in xrange(num): for i in xrange(num):
main_q = multiprocessing.Queue() main_q = multiprocessing.Queue()
rslt_q = multiprocessing.Queue() rslt_q = multiprocessing.Queue()
self._workers.append([None, main_q, rslt_q])
prc = WorkerProcess(self, main_q, rslt_q, self._hostvars_manager, self._loader)
prc.start()
self._workers.append((prc, main_q, rslt_q))
self._result_prc = ResultProcess(self._final_q, self._workers) self._result_prc = ResultProcess(self._final_q, self._workers)
self._result_prc.start() self._result_prc.start()
@ -195,31 +191,12 @@ class TaskQueueManager:
new_play = play.copy() new_play = play.copy()
new_play.post_validate(templar) new_play.post_validate(templar)
class HostVarsManager(SyncManager): self.hostvars = HostVars(
pass
hostvars = HostVars(
inventory=self._inventory, inventory=self._inventory,
variable_manager=self._variable_manager, variable_manager=self._variable_manager,
loader=self._loader, loader=self._loader,
) )
HostVarsManager.register(
'hostvars',
callable=lambda: hostvars,
# FIXME: this is the list of exposed methods to the DictProxy object, plus our
# special ones (set_variable_manager/set_inventory). There's probably a better way
# to do this with a proper BaseProxy/DictProxy derivative
exposed=(
'set_variable_manager', 'set_inventory', '__contains__', '__delitem__',
'set_nonpersistent_facts', 'set_host_facts', 'set_host_variable',
'__getitem__', '__len__', '__setitem__', 'clear', 'copy', 'get', 'has_key',
'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values'
),
)
self._hostvars_manager = HostVarsManager()
self._hostvars_manager.start()
# Fork # of forks, # of hosts or serial, whichever is lowest # Fork # of forks, # of hosts or serial, whichever is lowest
contenders = [self._options.forks, play.serial, len(self._inventory.get_hosts(new_play.hosts))] contenders = [self._options.forks, play.serial, len(self._inventory.get_hosts(new_play.hosts))]
contenders = [ v for v in contenders if v is not None and v > 0 ] contenders = [ v for v in contenders if v is not None and v > 0 ]
@ -259,7 +236,6 @@ class TaskQueueManager:
# and run the play using the strategy and cleanup on way out # and run the play using the strategy and cleanup on way out
play_return = strategy.run(iterator, play_context) play_return = strategy.run(iterator, play_context)
self._cleanup_processes() self._cleanup_processes()
self._hostvars_manager.shutdown()
return play_return return play_return
def cleanup(self): def cleanup(self):
@ -275,6 +251,7 @@ class TaskQueueManager:
for (worker_prc, main_q, rslt_q) in self._workers: for (worker_prc, main_q, rslt_q) in self._workers:
rslt_q.close() rslt_q.close()
main_q.close() main_q.close()
if worker_prc and worker_prc.is_alive():
worker_prc.terminate() worker_prc.terminate()
def clear_failed_hosts(self): def clear_failed_hosts(self):

@ -31,6 +31,7 @@ from jinja2.exceptions import UndefinedError
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
from ansible.executor.play_iterator import PlayIterator from ansible.executor.play_iterator import PlayIterator
from ansible.executor.process.worker import WorkerProcess
from ansible.executor.task_result import TaskResult from ansible.executor.task_result import TaskResult
from ansible.inventory.host import Host from ansible.inventory.host import Host
from ansible.inventory.group import Group from ansible.inventory.group import Group
@ -138,38 +139,29 @@ class StrategyBase:
display.debug("entering _queue_task() for %s/%s" % (host, task)) display.debug("entering _queue_task() for %s/%s" % (host, task))
task_vars['hostvars'] = self._tqm.hostvars
# and then queue the new task # and then queue the new task
display.debug("%s - putting task (%s) in queue" % (host, task)) display.debug("%s - putting task (%s) in queue" % (host, task))
try: try:
display.debug("worker is %d (out of %d available)" % (self._cur_worker+1, len(self._workers))) display.debug("worker is %d (out of %d available)" % (self._cur_worker+1, len(self._workers)))
# create a dummy object with plugin loaders set as an easier
# way to share them with the forked processes
shared_loader_obj = SharedPluginLoaderObj()
while True:
(worker_prc, main_q, rslt_q) = self._workers[self._cur_worker] (worker_prc, main_q, rslt_q) = self._workers[self._cur_worker]
if worker_prc is None or not worker_prc.is_alive():
worker_prc = WorkerProcess(rslt_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj)
self._workers[self._cur_worker][0] = worker_prc
worker_prc.start()
break
self._cur_worker += 1 self._cur_worker += 1
if self._cur_worker >= len(self._workers): if self._cur_worker >= len(self._workers):
self._cur_worker = 0 self._cur_worker = 0
time.sleep(0.0001)
# create a dummy object with plugin loaders set as an easier
# way to share them with the forked processes
shared_loader_obj = SharedPluginLoaderObj()
# compress (and convert) the data if so configured, which can
# help a lot when the variable dictionary is huge. We pop the
# hostvars out of the task variables right now, due to the fact
# that they're not JSON serializable
compressed_vars = False
if C.DEFAULT_VAR_COMPRESSION_LEVEL > 0:
zip_vars = zlib.compress(json.dumps(task_vars), C.DEFAULT_VAR_COMPRESSION_LEVEL)
compressed_vars = True
# we're done with the original dict now, so delete it to
# try and reclaim some memory space, which is helpful if the
# data contained in the dict is very large
del task_vars del task_vars
else:
zip_vars = task_vars # noqa (pyflakes false positive because task_vars is deleted in the conditional above)
# and queue the task
main_q.put((host, task, self._loader.get_basedir(), zip_vars, compressed_vars, play_context, shared_loader_obj))
self._pending_results += 1 self._pending_results += 1
except (EOFError, IOError, AssertionError) as e: except (EOFError, IOError, AssertionError) as e:
# most likely an abort # most likely an abort
@ -177,7 +169,7 @@ class StrategyBase:
return return
display.debug("exiting _queue_task() for %s/%s" % (host, task)) display.debug("exiting _queue_task() for %s/%s" % (host, task))
def _process_pending_results(self, iterator): def _process_pending_results(self, iterator, one_pass=False):
''' '''
Reads results off the final queue and takes appropriate action Reads results off the final queue and takes appropriate action
based on the result (executing callbacks, updating state, etc.). based on the result (executing callbacks, updating state, etc.).
@ -247,13 +239,11 @@ class StrategyBase:
new_host_info = result_item.get('add_host', dict()) new_host_info = result_item.get('add_host', dict())
self._add_host(new_host_info, iterator) self._add_host(new_host_info, iterator)
self._tqm._hostvars_manager.hostvars().set_inventory(self._inventory)
elif result[0] == 'add_group': elif result[0] == 'add_group':
host = result[1] host = result[1]
result_item = result[2] result_item = result[2]
self._add_group(host, result_item) self._add_group(host, result_item)
self._tqm._hostvars_manager.hostvars().set_inventory(self._inventory)
elif result[0] == 'notify_handler': elif result[0] == 'notify_handler':
task_result = result[1] task_result = result[1]
@ -283,7 +273,6 @@ class StrategyBase:
for target_host in host_list: for target_host in host_list:
self._variable_manager.set_nonpersistent_facts(target_host, {var_name: var_value}) self._variable_manager.set_nonpersistent_facts(target_host, {var_name: var_value})
self._tqm._hostvars_manager.hostvars().set_nonpersistent_facts(target_host, {var_name: var_value})
elif result[0] in ('set_host_var', 'set_host_facts'): elif result[0] in ('set_host_var', 'set_host_facts'):
host = result[1] host = result[1]
@ -316,21 +305,22 @@ class StrategyBase:
for target_host in host_list: for target_host in host_list:
self._variable_manager.set_host_variable(target_host, var_name, var_value) self._variable_manager.set_host_variable(target_host, var_name, var_value)
self._tqm._hostvars_manager.hostvars().set_host_variable(target_host, var_name, var_value)
elif result[0] == 'set_host_facts': elif result[0] == 'set_host_facts':
facts = result[4] facts = result[4]
if task.action == 'set_fact': if task.action == 'set_fact':
self._variable_manager.set_nonpersistent_facts(actual_host, facts) self._variable_manager.set_nonpersistent_facts(actual_host, facts)
self._tqm._hostvars_manager.hostvars().set_nonpersistent_facts(actual_host, facts)
else: else:
self._variable_manager.set_host_facts(actual_host, facts) self._variable_manager.set_host_facts(actual_host, facts)
self._tqm._hostvars_manager.hostvars().set_host_facts(actual_host, facts)
else: else:
raise AnsibleError("unknown result message received: %s" % result[0]) raise AnsibleError("unknown result message received: %s" % result[0])
except Queue.Empty: except Queue.Empty:
time.sleep(0.0001) time.sleep(0.0001)
if one_pass:
break
return ret_results return ret_results
def _wait_on_pending_results(self, iterator): def _wait_on_pending_results(self, iterator):

@ -169,6 +169,7 @@ class StrategyModule(StrategyBase):
skip_rest = False skip_rest = False
choose_step = True choose_step = True
results = []
for (host, task) in host_tasks: for (host, task) in host_tasks:
if not task: if not task:
continue continue
@ -243,12 +244,14 @@ class StrategyModule(StrategyBase):
if run_once: if run_once:
break break
results += self._process_pending_results(iterator, one_pass=True)
# go to next host/task group # go to next host/task group
if skip_rest: if skip_rest:
continue continue
display.debug("done queuing things up, now waiting for results queue to drain") display.debug("done queuing things up, now waiting for results queue to drain")
results = self._wait_on_pending_results(iterator) results += self._wait_on_pending_results(iterator)
host_results.extend(results) host_results.extend(results)
if not work_to_do and len(iterator.get_failed_hosts()) > 0: if not work_to_do and len(iterator.get_failed_hosts()) > 0:

Loading…
Cancel
Save