Changing the way workers are forked

pull/13567/head
James Cammarata 9 years ago
parent 3214ef8832
commit c8e6461dee

@ -60,6 +60,7 @@ if __name__ == '__main__':
try: try:
display = Display() display = Display()
display.debug("starting run")
sub = None sub = None
try: try:

@ -59,14 +59,18 @@ class WorkerProcess(multiprocessing.Process):
for reading later. for reading later.
''' '''
def __init__(self, tqm, main_q, rslt_q, hostvars_manager, loader): def __init__(self, rslt_q, task_vars, host, task, play_context, loader, variable_manager, shared_loader_obj):
super(WorkerProcess, self).__init__() super(WorkerProcess, self).__init__()
# takes a task queue manager as the sole param: # takes a task queue manager as the sole param:
self._main_q = main_q self._rslt_q = rslt_q
self._rslt_q = rslt_q self._task_vars = task_vars
self._hostvars = hostvars_manager self._host = host
self._loader = loader self._task = task
self._play_context = play_context
self._loader = loader
self._variable_manager = variable_manager
self._shared_loader_obj = shared_loader_obj
# dupe stdin, if we have one # dupe stdin, if we have one
self._new_stdin = sys.stdin self._new_stdin = sys.stdin
@ -97,73 +101,45 @@ class WorkerProcess(multiprocessing.Process):
if HAS_ATFORK: if HAS_ATFORK:
atfork() atfork()
while True: try:
task = None # execute the task and build a TaskResult from the result
try: debug("running TaskExecutor() for %s/%s" % (self._host, self._task))
#debug("waiting for work") executor_result = TaskExecutor(
(host, task, basedir, zip_vars, compressed_vars, play_context, shared_loader_obj) = self._main_q.get(block=False) self._host,
self._task,
if compressed_vars: self._task_vars,
job_vars = json.loads(zlib.decompress(zip_vars)) self._play_context,
else: self._new_stdin,
job_vars = zip_vars self._loader,
self._shared_loader_obj,
job_vars['hostvars'] = self._hostvars.hostvars() ).run()
debug("there's work to be done! got a task/handler to work on: %s" % task) debug("done running TaskExecutor() for %s/%s" % (self._host, self._task))
self._host.vars = dict()
# because the task queue manager starts workers (forks) before the self._host.groups = []
# playbook is loaded, set the basedir of the loader inherted by task_result = TaskResult(self._host, self._task, executor_result)
# this fork now so that we can find files correctly
self._loader.set_basedir(basedir) # put the result on the result queue
debug("sending task result")
# Serializing/deserializing tasks does not preserve the loader attribute, self._rslt_q.put(task_result)
# since it is passed to the worker during the forking of the process and debug("done sending task result")
# would be wasteful to serialize. So we set it here on the task now, and
# the task handles updating parent/child objects as needed. except AnsibleConnectionFailure:
task.set_loader(self._loader) self._host.vars = dict()
self._host.groups = []
# execute the task and build a TaskResult from the result task_result = TaskResult(self._host, self._task, dict(unreachable=True))
debug("running TaskExecutor() for %s/%s" % (host, task)) self._rslt_q.put(task_result, block=False)
executor_result = TaskExecutor(
host, except Exception as e:
task, if not isinstance(e, (IOError, EOFError, KeyboardInterrupt)) or isinstance(e, TemplateNotFound):
job_vars,
play_context,
self._new_stdin,
self._loader,
shared_loader_obj,
).run()
debug("done running TaskExecutor() for %s/%s" % (host, task))
task_result = TaskResult(host, task, executor_result)
# put the result on the result queue
debug("sending task result")
self._rslt_q.put(task_result)
debug("done sending task result")
except queue.Empty:
time.sleep(0.0001)
except AnsibleConnectionFailure:
try: try:
if task: self._host.vars = dict()
task_result = TaskResult(host, task, dict(unreachable=True)) self._host.groups = []
self._rslt_q.put(task_result, block=False) task_result = TaskResult(self._host, self._task, dict(failed=True, exception=traceback.format_exc(), stdout=''))
self._rslt_q.put(task_result, block=False)
except: except:
break debug("WORKER EXCEPTION: %s" % e)
except Exception as e: debug("WORKER EXCEPTION: %s" % traceback.format_exc())
if isinstance(e, (IOError, EOFError, KeyboardInterrupt)) and not isinstance(e, TemplateNotFound):
break
else:
try:
if task:
task_result = TaskResult(host, task, dict(failed=True, exception=traceback.format_exc(), stdout=''))
self._rslt_q.put(task_result, block=False)
except:
debug("WORKER EXCEPTION: %s" % e)
debug("WORKER EXCEPTION: %s" % traceback.format_exc())
break
debug("WORKER PROCESS EXITING") debug("WORKER PROCESS EXITING")

@ -102,11 +102,7 @@ class TaskQueueManager:
for i in xrange(num): for i in xrange(num):
main_q = multiprocessing.Queue() main_q = multiprocessing.Queue()
rslt_q = multiprocessing.Queue() rslt_q = multiprocessing.Queue()
self._workers.append([None, main_q, rslt_q])
prc = WorkerProcess(self, main_q, rslt_q, self._hostvars_manager, self._loader)
prc.start()
self._workers.append((prc, main_q, rslt_q))
self._result_prc = ResultProcess(self._final_q, self._workers) self._result_prc = ResultProcess(self._final_q, self._workers)
self._result_prc.start() self._result_prc.start()
@ -195,31 +191,12 @@ class TaskQueueManager:
new_play = play.copy() new_play = play.copy()
new_play.post_validate(templar) new_play.post_validate(templar)
class HostVarsManager(SyncManager): self.hostvars = HostVars(
pass
hostvars = HostVars(
inventory=self._inventory, inventory=self._inventory,
variable_manager=self._variable_manager, variable_manager=self._variable_manager,
loader=self._loader, loader=self._loader,
) )
HostVarsManager.register(
'hostvars',
callable=lambda: hostvars,
# FIXME: this is the list of exposed methods to the DictProxy object, plus our
# special ones (set_variable_manager/set_inventory). There's probably a better way
# to do this with a proper BaseProxy/DictProxy derivative
exposed=(
'set_variable_manager', 'set_inventory', '__contains__', '__delitem__',
'set_nonpersistent_facts', 'set_host_facts', 'set_host_variable',
'__getitem__', '__len__', '__setitem__', 'clear', 'copy', 'get', 'has_key',
'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values'
),
)
self._hostvars_manager = HostVarsManager()
self._hostvars_manager.start()
# Fork # of forks, # of hosts or serial, whichever is lowest # Fork # of forks, # of hosts or serial, whichever is lowest
contenders = [self._options.forks, play.serial, len(self._inventory.get_hosts(new_play.hosts))] contenders = [self._options.forks, play.serial, len(self._inventory.get_hosts(new_play.hosts))]
contenders = [ v for v in contenders if v is not None and v > 0 ] contenders = [ v for v in contenders if v is not None and v > 0 ]
@ -259,7 +236,6 @@ class TaskQueueManager:
# and run the play using the strategy and cleanup on way out # and run the play using the strategy and cleanup on way out
play_return = strategy.run(iterator, play_context) play_return = strategy.run(iterator, play_context)
self._cleanup_processes() self._cleanup_processes()
self._hostvars_manager.shutdown()
return play_return return play_return
def cleanup(self): def cleanup(self):
@ -275,7 +251,8 @@ class TaskQueueManager:
for (worker_prc, main_q, rslt_q) in self._workers: for (worker_prc, main_q, rslt_q) in self._workers:
rslt_q.close() rslt_q.close()
main_q.close() main_q.close()
worker_prc.terminate() if worker_prc and worker_prc.is_alive():
worker_prc.terminate()
def clear_failed_hosts(self): def clear_failed_hosts(self):
self._failed_hosts = dict() self._failed_hosts = dict()

@ -31,6 +31,7 @@ from jinja2.exceptions import UndefinedError
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
from ansible.executor.play_iterator import PlayIterator from ansible.executor.play_iterator import PlayIterator
from ansible.executor.process.worker import WorkerProcess
from ansible.executor.task_result import TaskResult from ansible.executor.task_result import TaskResult
from ansible.inventory.host import Host from ansible.inventory.host import Host
from ansible.inventory.group import Group from ansible.inventory.group import Group
@ -138,38 +139,29 @@ class StrategyBase:
display.debug("entering _queue_task() for %s/%s" % (host, task)) display.debug("entering _queue_task() for %s/%s" % (host, task))
task_vars['hostvars'] = self._tqm.hostvars
# and then queue the new task # and then queue the new task
display.debug("%s - putting task (%s) in queue" % (host, task)) display.debug("%s - putting task (%s) in queue" % (host, task))
try: try:
display.debug("worker is %d (out of %d available)" % (self._cur_worker+1, len(self._workers))) display.debug("worker is %d (out of %d available)" % (self._cur_worker+1, len(self._workers)))
(worker_prc, main_q, rslt_q) = self._workers[self._cur_worker]
self._cur_worker += 1
if self._cur_worker >= len(self._workers):
self._cur_worker = 0
# create a dummy object with plugin loaders set as an easier # create a dummy object with plugin loaders set as an easier
# way to share them with the forked processes # way to share them with the forked processes
shared_loader_obj = SharedPluginLoaderObj() shared_loader_obj = SharedPluginLoaderObj()
# compress (and convert) the data if so configured, which can while True:
# help a lot when the variable dictionary is huge. We pop the (worker_prc, main_q, rslt_q) = self._workers[self._cur_worker]
# hostvars out of the task variables right now, due to the fact if worker_prc is None or not worker_prc.is_alive():
# that they're not JSON serializable worker_prc = WorkerProcess(rslt_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj)
compressed_vars = False self._workers[self._cur_worker][0] = worker_prc
if C.DEFAULT_VAR_COMPRESSION_LEVEL > 0: worker_prc.start()
zip_vars = zlib.compress(json.dumps(task_vars), C.DEFAULT_VAR_COMPRESSION_LEVEL) break
compressed_vars = True self._cur_worker += 1
# we're done with the original dict now, so delete it to if self._cur_worker >= len(self._workers):
# try and reclaim some memory space, which is helpful if the self._cur_worker = 0
# data contained in the dict is very large time.sleep(0.0001)
del task_vars
else:
zip_vars = task_vars # noqa (pyflakes false positive because task_vars is deleted in the conditional above)
# and queue the task
main_q.put((host, task, self._loader.get_basedir(), zip_vars, compressed_vars, play_context, shared_loader_obj))
del task_vars
self._pending_results += 1 self._pending_results += 1
except (EOFError, IOError, AssertionError) as e: except (EOFError, IOError, AssertionError) as e:
# most likely an abort # most likely an abort
@ -177,7 +169,7 @@ class StrategyBase:
return return
display.debug("exiting _queue_task() for %s/%s" % (host, task)) display.debug("exiting _queue_task() for %s/%s" % (host, task))
def _process_pending_results(self, iterator): def _process_pending_results(self, iterator, one_pass=False):
''' '''
Reads results off the final queue and takes appropriate action Reads results off the final queue and takes appropriate action
based on the result (executing callbacks, updating state, etc.). based on the result (executing callbacks, updating state, etc.).
@ -247,13 +239,11 @@ class StrategyBase:
new_host_info = result_item.get('add_host', dict()) new_host_info = result_item.get('add_host', dict())
self._add_host(new_host_info, iterator) self._add_host(new_host_info, iterator)
self._tqm._hostvars_manager.hostvars().set_inventory(self._inventory)
elif result[0] == 'add_group': elif result[0] == 'add_group':
host = result[1] host = result[1]
result_item = result[2] result_item = result[2]
self._add_group(host, result_item) self._add_group(host, result_item)
self._tqm._hostvars_manager.hostvars().set_inventory(self._inventory)
elif result[0] == 'notify_handler': elif result[0] == 'notify_handler':
task_result = result[1] task_result = result[1]
@ -283,7 +273,6 @@ class StrategyBase:
for target_host in host_list: for target_host in host_list:
self._variable_manager.set_nonpersistent_facts(target_host, {var_name: var_value}) self._variable_manager.set_nonpersistent_facts(target_host, {var_name: var_value})
self._tqm._hostvars_manager.hostvars().set_nonpersistent_facts(target_host, {var_name: var_value})
elif result[0] in ('set_host_var', 'set_host_facts'): elif result[0] in ('set_host_var', 'set_host_facts'):
host = result[1] host = result[1]
@ -316,21 +305,22 @@ class StrategyBase:
for target_host in host_list: for target_host in host_list:
self._variable_manager.set_host_variable(target_host, var_name, var_value) self._variable_manager.set_host_variable(target_host, var_name, var_value)
self._tqm._hostvars_manager.hostvars().set_host_variable(target_host, var_name, var_value)
elif result[0] == 'set_host_facts': elif result[0] == 'set_host_facts':
facts = result[4] facts = result[4]
if task.action == 'set_fact': if task.action == 'set_fact':
self._variable_manager.set_nonpersistent_facts(actual_host, facts) self._variable_manager.set_nonpersistent_facts(actual_host, facts)
self._tqm._hostvars_manager.hostvars().set_nonpersistent_facts(actual_host, facts)
else: else:
self._variable_manager.set_host_facts(actual_host, facts) self._variable_manager.set_host_facts(actual_host, facts)
self._tqm._hostvars_manager.hostvars().set_host_facts(actual_host, facts)
else: else:
raise AnsibleError("unknown result message received: %s" % result[0]) raise AnsibleError("unknown result message received: %s" % result[0])
except Queue.Empty: except Queue.Empty:
time.sleep(0.0001) time.sleep(0.0001)
if one_pass:
break
return ret_results return ret_results
def _wait_on_pending_results(self, iterator): def _wait_on_pending_results(self, iterator):

@ -169,6 +169,7 @@ class StrategyModule(StrategyBase):
skip_rest = False skip_rest = False
choose_step = True choose_step = True
results = []
for (host, task) in host_tasks: for (host, task) in host_tasks:
if not task: if not task:
continue continue
@ -243,12 +244,14 @@ class StrategyModule(StrategyBase):
if run_once: if run_once:
break break
results += self._process_pending_results(iterator, one_pass=True)
# go to next host/task group # go to next host/task group
if skip_rest: if skip_rest:
continue continue
display.debug("done queuing things up, now waiting for results queue to drain") display.debug("done queuing things up, now waiting for results queue to drain")
results = self._wait_on_pending_results(iterator) results += self._wait_on_pending_results(iterator)
host_results.extend(results) host_results.extend(results)
if not work_to_do and len(iterator.get_failed_hosts()) > 0: if not work_to_do and len(iterator.get_failed_hosts()) > 0:

Loading…
Cancel
Save