Config continued (#31024)

* included inventory and callback in new config

allow inventory to be configurable
updated connection options settings
also updated winrm to work with new configs
removed now obsolete set_host_overrides
added notes for future bcoca, current one is just punting, it's future's problem
updated docs per feedback
added remove group/host methods to inv data
moved fact cache from data to constructed
cleaner/better options
fix when vars are added
extended ignore list to config dicts
updated paramiko connection docs
removed options from base that paramiko already handles
left the look option as it is used by other plugin types
resolve delegation
updated cache doc options
fixed test_script
better fragment merge for options
fixed proxy command
restore ini for proxy
normalized options
moved pipelining to class
updates for host_key_checking
restructured mixins

* fix typo
pull/32991/head
Brian Coca 7 years ago committed by GitHub
parent 46c4f6311a
commit 23b1dbacaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -65,6 +65,7 @@ class ConnectionProcess(object):
self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file) self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)
self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null') self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null')
self.connection.set_options()
self.connection._connect() self.connection._connect()
self.srv.register(self.connection) self.srv.register(self.connection)
messages.append('connection to remote device started successfully') messages.append('connection to remote device started successfully')
@ -143,7 +144,7 @@ class ConnectionProcess(object):
if self.connection: if self.connection:
self.connection.close() self.connection.close()
except: except Exception:
pass pass
finally: finally:
@ -271,7 +272,7 @@ def main():
wfd = os.fdopen(w, 'w') wfd = os.fdopen(w, 'w')
process = ConnectionProcess(wfd, play_context, socket_path, original_path) process = ConnectionProcess(wfd, play_context, socket_path, original_path)
process.start() process.start()
except Exception as exc: except Exception:
messages.append(traceback.format_exc()) messages.append(traceback.format_exc())
rc = 1 rc = 1

@ -44,6 +44,9 @@ class DocCLI(CLI):
provides a printout of their DOCUMENTATION strings, provides a printout of their DOCUMENTATION strings,
and it can create a short "snippet" which can be pasted into a playbook. ''' and it can create a short "snippet" which can be pasted into a playbook. '''
# default ignore list for detailed views
IGNORE = ('module', 'docuri', 'version_added', 'short_description', 'now_date', 'plainexamples', 'returndocs')
def __init__(self, args): def __init__(self, args):
super(DocCLI, self).__init__(args) super(DocCLI, self).__init__(args)
@ -394,6 +397,10 @@ class DocCLI(CLI):
for config in ('env', 'ini', 'yaml', 'vars'): for config in ('env', 'ini', 'yaml', 'vars'):
if config in opt and opt[config]: if config in opt and opt[config]:
conf[config] = opt.pop(config) conf[config] = opt.pop(config)
for ignore in self.IGNORE:
for item in conf[config]:
if ignore in item:
del item[ignore]
if conf: if conf:
text.append(self._dump_yaml({'set_via': conf}, opt_indent)) text.append(self._dump_yaml({'set_via': conf}, opt_indent))
@ -441,7 +448,7 @@ class DocCLI(CLI):
def get_man_text(self, doc): def get_man_text(self, doc):
IGNORE = frozenset(['module', 'docuri', 'version_added', 'short_description', 'now_date', 'plainexamples', 'returndocs', self.options.type]) self.IGNORE = self.IGNORE + (self.options.type,)
opt_indent = " " opt_indent = " "
text = [] text = []
pad = display.columns * 0.20 pad = display.columns * 0.20
@ -492,7 +499,7 @@ class DocCLI(CLI):
# Generic handler # Generic handler
for k in sorted(doc): for k in sorted(doc):
if k in IGNORE or not doc[k]: if k in self.IGNORE or not doc[k]:
continue continue
if isinstance(doc[k], string_types): if isinstance(doc[k], string_types):
text.append('%s: %s' % (k.upper(), textwrap.fill(CLI.tty_ify(doc[k]), limit - (len(k) + 2), subsequent_indent=opt_indent))) text.append('%s: %s' % (k.upper(), textwrap.fill(CLI.tty_ify(doc[k]), limit - (len(k) + 2), subsequent_indent=opt_indent)))

@ -1319,36 +1319,13 @@ PARAMIKO_HOST_KEY_AUTO_ADD:
- {key: host_key_auto_add, section: paramiko_connection} - {key: host_key_auto_add, section: paramiko_connection}
type: boolean type: boolean
PARAMIKO_LOOK_FOR_KEYS: PARAMIKO_LOOK_FOR_KEYS:
# TODO: move to plugin name: look for keys
default: True default: True
description: 'TODO: write it' description: 'TODO: write it'
env: [{name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS}] env: [{name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS}]
ini: ini:
- {key: look_for_keys, section: paramiko_connection} - {key: look_for_keys, section: paramiko_connection}
type: boolean type: boolean
PARAMIKO_PROXY_COMMAND:
# TODO: move to plugin
default:
description: 'TODO: write it'
env: [{name: ANSIBLE_PARAMIKO_PROXY_COMMAND}]
ini:
- {key: proxy_command, section: paramiko_connection}
PARAMIKO_PTY:
# TODO: move to plugin
default: True
description: 'TODO: write it'
env: [{name: ANSIBLE_PARAMIKO_PTY}]
ini:
- {key: pty, section: paramiko_connection}
type: boolean
PARAMIKO_RECORD_HOST_KEYS:
# TODO: move to plugin
default: True
description: 'TODO: write it'
env: [{name: ANSIBLE_PARAMIKO_RECORD_HOST_KEYS}]
ini:
- {key: record_host_keys, section: paramiko_connection}
type: boolean
PERSISTENT_CONTROL_PATH_DIR: PERSISTENT_CONTROL_PATH_DIR:
name: Persistence socket path name: Persistence socket path
default: ~/.ansible/pc default: ~/.ansible/pc

@ -128,7 +128,7 @@ def get_ini_config_value(p, entry):
if p is not None: if p is not None:
try: try:
value = p.get(entry.get('section', 'defaults'), entry.get('key', ''), raw=True) value = p.get(entry.get('section', 'defaults'), entry.get('key', ''), raw=True)
except: # FIXME: actually report issues here except Exception: # FIXME: actually report issues here
pass pass
return value return value
@ -224,15 +224,24 @@ class ConfigManager(object):
''' Load YAML Config Files in order, check merge flags, keep origin of settings''' ''' Load YAML Config Files in order, check merge flags, keep origin of settings'''
pass pass
def get_plugin_options(self, plugin_type, name, variables=None): def get_plugin_options(self, plugin_type, name, keys=None, variables=None):
options = {} options = {}
defs = self.get_configuration_definitions(plugin_type, name) defs = self.get_configuration_definitions(plugin_type, name)
for option in defs: for option in defs:
options[option] = self.get_config_value(option, plugin_type=plugin_type, plugin_name=name, variables=variables) options[option] = self.get_config_value(option, plugin_type=plugin_type, plugin_name=name, keys=keys, variables=variables)
return options return options
def get_plugin_vars(self, plugin_type, name):
pvars = []
for pdef in self.get_configuration_definitions(plugin_type, name).values():
if 'vars' in pdef and pdef['vars']:
for var_entry in pdef['vars']:
pvars.append(var_entry['name'])
return pvars
def get_configuration_definitions(self, plugin_type=None, name=None): def get_configuration_definitions(self, plugin_type=None, name=None):
''' just list the possible settings, either base or for specific plugins or plugin ''' ''' just list the possible settings, either base or for specific plugins or plugin '''
@ -264,12 +273,12 @@ class ConfigManager(object):
return value, origin return value, origin
def get_config_value(self, config, cfile=None, plugin_type=None, plugin_name=None, variables=None): def get_config_value(self, config, cfile=None, plugin_type=None, plugin_name=None, keys=None, variables=None):
''' wrapper ''' ''' wrapper '''
value, _drop = self.get_config_value_and_origin(config, cfile=cfile, plugin_type=plugin_type, plugin_name=plugin_name, variables=variables) value, _drop = self.get_config_value_and_origin(config, cfile=cfile, plugin_type=plugin_type, plugin_name=plugin_name, keys=keys, variables=variables)
return value return value
def get_config_value_and_origin(self, config, cfile=None, plugin_type=None, plugin_name=None, variables=None): def get_config_value_and_origin(self, config, cfile=None, plugin_type=None, plugin_name=None, keys=None, variables=None):
''' Given a config key figure out the actual value and report on the origin of the settings ''' ''' Given a config key figure out the actual value and report on the origin of the settings '''
if cfile is None: if cfile is None:
@ -290,10 +299,15 @@ class ConfigManager(object):
if config in defs: if config in defs:
# Use 'variable overrides' if present, highest precedence, but only present when querying running play # Use 'variable overrides' if present, highest precedence, but only present when querying running play
if variables: if variables and defs[config].get('vars'):
value, origin = self._loop_entries(variables, defs[config]['vars']) value, origin = self._loop_entries(variables, defs[config]['vars'])
origin = 'var: %s' % origin origin = 'var: %s' % origin
# use playbook keywords if you have em
if value is None and keys:
value, origin = self._loop_entries(keys, defs[config]['keywords'])
origin = 'keyword: %s' % origin
# env vars are next precedence # env vars are next precedence
if value is None and defs[config].get('env'): if value is None and defs[config].get('env'):
value, origin = self._loop_entries(os.environ, defs[config]['env']) value, origin = self._loop_entries(os.environ, defs[config]['env'])
@ -319,13 +333,6 @@ class ConfigManager(object):
# FIXME: implement, also , break down key from defs (. notation???) # FIXME: implement, also , break down key from defs (. notation???)
origin = cfile origin = cfile
'''
# for plugins, try using existing constants, this is for backwards compatiblity
if plugin_name and defs[config].get('constants'):
value, origin = self._loop_entries(self.data, defs[config]['constants'])
origin = 'constant: %s' % origin
'''
# set default if we got here w/o a value # set default if we got here w/o a value
if value is None: if value is None:
value = defs[config].get('default') value = defs[config].get('default')

@ -19,11 +19,11 @@ from ansible.module_utils.six.moves import cPickle
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.plugins.connection import ConnectionBase
from ansible.template import Templar from ansible.template import Templar
from ansible.utils.listify import listify_lookup_plugin_terms from ansible.utils.listify import listify_lookup_plugin_terms
from ansible.utils.unsafe_proxy import UnsafeProxy, wrap_var from ansible.utils.unsafe_proxy import UnsafeProxy, wrap_var
from ansible.vars.clean import namespace_facts, clean_facts from ansible.vars.clean import namespace_facts, clean_facts
from ansible.utils.vars import combine_vars
try: try:
from __main__ import display from __main__ import display
@ -480,18 +480,14 @@ class TaskExecutor:
not getattr(self._connection, 'connected', False) or not getattr(self._connection, 'connected', False) or
self._play_context.remote_addr != self._connection._play_context.remote_addr): self._play_context.remote_addr != self._connection._play_context.remote_addr):
self._connection = self._get_connection(variables=variables, templar=templar) self._connection = self._get_connection(variables=variables, templar=templar)
if getattr(self._connection, '_socket_path'):
variables['ansible_socket'] = self._connection._socket_path
# only template the vars if the connection actually implements set_host_overrides
# NB: this is expensive, and should be removed once connection-specific vars are being handled by play_context
sho_impl = getattr(type(self._connection), 'set_host_overrides', None)
if sho_impl and sho_impl != ConnectionBase.set_host_overrides:
self._connection.set_host_overrides(self._host, variables, templar)
else: else:
# if connection is reused, its _play_context is no longer valid and needs # if connection is reused, its _play_context is no longer valid and needs
# to be replaced with the one templated above, in case other data changed # to be replaced with the one templated above, in case other data changed
self._connection._play_context = self._play_context self._connection._play_context = self._play_context
self._set_connection_options(variables, templar)
# get handler
self._handler = self._get_action_handler(connection=self._connection, templar=templar) self._handler = self._get_action_handler(connection=self._connection, templar=templar)
# And filter out any fields which were set to default(omit), and got the omit token value # And filter out any fields which were set to default(omit), and got the omit token value
@ -734,6 +730,7 @@ class TaskExecutor:
if not connection: if not connection:
raise AnsibleError("the connection plugin '%s' was not found" % conn_type) raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
# FIXME: remove once all plugins pull all data from self._options
self._play_context.set_options_from_plugin(connection) self._play_context.set_options_from_plugin(connection)
if any(((connection.supports_persistence and C.USE_PERSISTENT_CONNECTIONS), connection.force_persistence)): if any(((connection.supports_persistence and C.USE_PERSISTENT_CONNECTIONS), connection.force_persistence)):
@ -745,6 +742,29 @@ class TaskExecutor:
return connection return connection
def _set_connection_options(self, variables, templar):
# create copy with delegation built in
final_vars = combine_vars(variables, variables.get('ansible_delegated_vars', dict()).get(self._task.delegate_to, dict()))
# grab list of usable vars for this plugin
option_vars = C.config.get_plugin_vars('connection', self._connection._load_name)
# create dict of 'templated vars'
options = {'_extras': {}}
for k in option_vars:
if k in final_vars:
options[k] = templar.template(final_vars[k])
# add extras if plugin supports them
if getattr(self._connection, 'allow_extras', False):
for k in final_vars:
if k.startswith('ansible_%s_' % self._connection._load_name) and k not in options:
options['_extras'][k] = templar.template(final_vars[k])
# set options with 'templated vars' specific to this plugin
self._connection.set_options(var_options=options)
def _get_action_handler(self, connection, templar): def _get_action_handler(self, connection, templar):
''' '''
Returns the correct action plugin to handle the requestion task action Returns the correct action plugin to handle the requestion task action

@ -178,7 +178,7 @@ class TaskQueueManager:
else: else:
self._stdout_callback = callback_loader.get(self._stdout_callback) self._stdout_callback = callback_loader.get(self._stdout_callback)
try: try:
self._stdout_callback.set_options(C.config.get_plugin_options('callback', self._stdout_callback._load_name)) self._stdout_callback.set_options()
except AttributeError: except AttributeError:
display.deprecated("%s stdout callback, does not support setting 'options', it will work for now, " display.deprecated("%s stdout callback, does not support setting 'options', it will work for now, "
" but this will be required in the future and should be updated," " but this will be required in the future and should be updated,"
@ -207,7 +207,7 @@ class TaskQueueManager:
callback_obj = callback_plugin() callback_obj = callback_plugin()
try: try:
callback_obj .set_options(C.config.get_plugin_options('callback', callback_plugin._load_name)) callback_obj.set_options()
except AttributeError: except AttributeError:
display.deprecated("%s callback, does not support setting 'options', it will work for now, " display.deprecated("%s callback, does not support setting 'options', it will work for now, "
" but this will be required in the future and should be updated, " " but this will be required in the future and should be updated, "

@ -173,6 +173,17 @@ class InventoryData(object):
else: else:
display.debug("group %s already in inventory" % group) display.debug("group %s already in inventory" % group)
def remove_group(self, group):
if group in self.groups:
del self.groups[group]
display.debug("Removed group %s from inventory" % group)
self._groups_dict_cache = {}
for host in self.hosts:
h = self.hosts[host]
h.remove_group(group)
def add_host(self, host, group=None, port=None): def add_host(self, host, group=None, port=None):
''' adds a host to inventory and possibly a group if not there already ''' ''' adds a host to inventory and possibly a group if not there already '''
@ -209,6 +220,15 @@ class InventoryData(object):
self._groups_dict_cache = {} self._groups_dict_cache = {}
display.debug("Added host %s to group %s" % (host, group)) display.debug("Added host %s to group %s" % (host, group))
def remove_host(self, host):
if host in self.hosts:
del self.hosts[host]
for group in self.groups:
g = self.groups[group]
g.remove_host(host)
def set_variable(self, entity, varname, value): def set_variable(self, entity, varname, value):
''' sets a varible for an inventory object ''' ''' sets a varible for an inventory object '''

@ -183,6 +183,7 @@ class InventoryManager(object):
for name in C.INVENTORY_ENABLED: for name in C.INVENTORY_ENABLED:
plugin = inventory_loader.get(name) plugin = inventory_loader.get(name)
if plugin: if plugin:
plugin.set_options()
self._inventory_plugins.append(plugin) self._inventory_plugins.append(plugin)
else: else:
display.warning('Failed to load inventory plugin, skipping %s' % name) display.warning('Failed to load inventory plugin, skipping %s' % name)
@ -282,7 +283,8 @@ class InventoryManager(object):
else: else:
for fail in failures: for fail in failures:
display.warning(u'\n* Failed to parse %s with %s plugin: %s' % (to_text(fail['src']), fail['plugin'], to_text(fail['exc']))) display.warning(u'\n* Failed to parse %s with %s plugin: %s' % (to_text(fail['src']), fail['plugin'], to_text(fail['exc'])))
display.vvv(to_text(fail['exc'].tb)) if hasattr(fail['exc'], 'tb'):
display.vvv(to_text(fail['exc'].tb))
if not parsed: if not parsed:
display.warning("Unable to parse %s as an inventory source" % to_text(source)) display.warning("Unable to parse %s as an inventory source" % to_text(source))

@ -47,6 +47,9 @@ def get_plugin_class(obj):
class AnsiblePlugin(with_metaclass(ABCMeta, object)): class AnsiblePlugin(with_metaclass(ABCMeta, object)):
# allow extra passthrough parameters
allow_extras = False
def __init__(self): def __init__(self):
self._options = {} self._options = {}
@ -59,8 +62,28 @@ class AnsiblePlugin(with_metaclass(ABCMeta, object)):
def set_option(self, option, value): def set_option(self, option, value):
self._options[option] = value self._options[option] = value
def set_options(self, options): def set_options(self, task_keys=None, var_options=None, direct=None):
self._options = options '''
Sets the _options attribute with the configuration/keyword information for this plugin
:arg task_keys: Dict with playbook keywords that affect this option
:arg var_options: Dict with either 'conneciton variables'
:arg direct: Dict with 'direct assignment'
'''
if not self._options:
# load config options if we have not done so already, if vars provided we should be mostly done
self._options = C.config.get_plugin_options(get_plugin_class(self), self._load_name, keys=task_keys, variables=var_options)
# they can be direct options overriding config
if direct:
for k in self._options:
if k in direct:
self.set_option(k, direct[k])
# allow extras/wildcards from vars that are not directly consumed in configuration
if self.allow_extras and var_options and '_extras' in var_options:
self.set_option('_extras', var_options['_extras'])
def _check_required(self): def _check_required(self):
# FIXME: standarize required check based on config # FIXME: standarize required check based on config

@ -612,6 +612,21 @@ class ActionBase(with_metaclass(ABCMeta, object)):
# make sure all commands use the designated shell executable # make sure all commands use the designated shell executable
module_args['_ansible_shell_executable'] = self._play_context.executable module_args['_ansible_shell_executable'] = self._play_context.executable
def _update_connection_options(self, options, variables=None):
''' ensures connections have the appropriate information '''
update = {}
if getattr(self.connection, 'glob_option_vars', False):
# if the connection allows for it, pass any variables matching it.
if variables is not None:
for varname in variables:
if varname.match('ansible_%s_' % self.connection._load_name):
update[varname] = variables[varname]
# always override existing with options
update.update(options)
self.connection.set_options(update)
def _execute_module(self, module_name=None, module_args=None, tmp=None, task_vars=None, persist_files=False, delete_remote_tmp=True, wrap_async=False): def _execute_module(self, module_name=None, module_args=None, tmp=None, task_vars=None, persist_files=False, delete_remote_tmp=True, wrap_async=False):
''' '''
Transfer and run a module along with its arguments. Transfer and run a module along with its arguments.

@ -246,6 +246,9 @@ class FactCache(MutableMapping):
# Backwards compat: self._display isn't really needed, just import the global display and use that. # Backwards compat: self._display isn't really needed, just import the global display and use that.
self._display = display self._display = display
# in memory cache so plugins don't expire keys mid run
self._cache = {}
def __getitem__(self, key): def __getitem__(self, key):
if not self._plugin.contains(key): if not self._plugin.contains(key):
raise KeyError raise KeyError

@ -26,7 +26,7 @@ import warnings
from copy import deepcopy from copy import deepcopy
from ansible import constants as C from ansible import constants as C
from ansible.plugins import AnsiblePlugin from ansible.plugins import AnsiblePlugin, get_plugin_class
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
from ansible.utils.color import stringc from ansible.utils.color import stringc
from ansible.vars.clean import strip_internal_keys from ansible.vars.clean import strip_internal_keys
@ -81,8 +81,22 @@ class CallbackBase(AnsiblePlugin):
''' helper for callbacks, so they don't all have to include deepcopy ''' ''' helper for callbacks, so they don't all have to include deepcopy '''
_copy_result = deepcopy _copy_result = deepcopy
def set_options(self, options): def set_option(self, k, v):
self._plugin_options = options self._plugin_options[k] = v
def set_options(self, task_keys=None, var_options=None, direct=None):
''' This is different than the normal plugin method as callbacks get called early and really don't accept keywords.
Also _options was already taken for CLI args and callbacks use _plugin_options instead.
'''
# load from config
self._plugin_options = C.config.get_plugin_options(get_plugin_class(self), self._load_name, keys=task_keys, variables=var_options)
# or parse specific options
if direct:
for k in direct:
if k in self._plugin_options:
self.set_option(k, direct[k])
def _dump_results(self, result, indent=None, sort_keys=True, keep_invocation=False): def _dump_results(self, result, indent=None, sort_keys=True, keep_invocation=False):

@ -249,9 +249,6 @@ class CallbackModule(CallbackBase):
# self.set_options({'api': 'data.logentries.com', 'port': 80, # self.set_options({'api': 'data.logentries.com', 'port': 80,
# 'tls_port': 10000, 'use_tls': True, 'flatten': False, 'token': 'ae693734-4c5b-4a44-8814-1d2feb5c8241'}) # 'tls_port': 10000, 'use_tls': True, 'flatten': False, 'token': 'ae693734-4c5b-4a44-8814-1d2feb5c8241'})
def set_option(self, name, value):
raise AnsibleError("The Logentries callabck plugin does not suport setting individual options.")
def set_options(self, options): def set_options(self, options):
super(CallbackModule, self).set_options(options) super(CallbackModule, self).set_options(options)

@ -116,17 +116,6 @@ class ConnectionBase(AnsiblePlugin):
raise AnsibleError("Internal Error: this connection module does not support running commands via %s" % self._play_context.become_method) raise AnsibleError("Internal Error: this connection module does not support running commands via %s" % self._play_context.become_method)
def set_host_overrides(self, host, hostvars=None):
'''
An optional method, which can be used to set connection plugin parameters
from variables set on the host (or groups to which the host belongs)
Any connection plugin using this should first initialize its attributes in
an overridden `def __init__(self):`, and then use `host.get_vars()` to find
variables which may be used to set those attributes in this method.
'''
pass
@staticmethod @staticmethod
def _split_ssh_args(argstring): def _split_ssh_args(argstring):
""" """

@ -49,10 +49,11 @@ except ImportError:
class Connection(object): class Connection(object):
''' Func-based connections ''' ''' Func-based connections '''
has_pipelining = False
def __init__(self, runner, host, port, *args, **kwargs): def __init__(self, runner, host, port, *args, **kwargs):
self.runner = runner self.runner = runner
self.host = host self.host = host
self.has_pipelining = False
# port is unused, this go on func # port is unused, this go on func
self.port = port self.port = port

@ -48,23 +48,18 @@ import json
import logging import logging
import re import re
import os import os
import signal
import socket import socket
import traceback import traceback
from collections import Sequence
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils.six import PY3, BytesIO, binary_type from ansible.module_utils.six import BytesIO, PY3
from ansible.module_utils.six.moves import cPickle from ansible.module_utils.six.moves import cPickle
from ansible.module_utils._text import to_bytes, to_text from ansible.module_utils._text import to_bytes, to_text
from ansible.playbook.play_context import PlayContext from ansible.playbook.play_context import PlayContext
from ansible.plugins.loader import cliconf_loader, terminal_loader, connection_loader from ansible.plugins.loader import cliconf_loader, terminal_loader, connection_loader
from ansible.plugins.connection import ConnectionBase from ansible.plugins.connection import ConnectionBase
from ansible.plugins.connection.local import Connection as LocalConnection from ansible.utils.path import unfrackpath
from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoSshConnection
from ansible.utils.path import unfrackpath, makedirs_safe
try: try:
from __main__ import display from __main__ import display
@ -91,7 +86,8 @@ class Connection(ConnectionBase):
self._last_response = None self._last_response = None
self._history = list() self._history = list()
self._local = LocalConnection(play_context, new_stdin, *args, **kwargs) self._local = connection_loader.get('local', play_context, '/dev/null')
self._local.set_options()
self._terminal = None self._terminal = None
self._cliconf = None self._cliconf = None
@ -166,10 +162,9 @@ class Connection(ConnectionBase):
if self.connected: if self.connected:
return return
if self._play_context.password and not self._play_context.private_key_file: p = connection_loader.get('paramiko', self._play_context, '/dev/null')
C.PARAMIKO_LOOK_FOR_KEYS = False p.set_options(direct={'look_for_keys': bool(self._play_context.password and not self._play_context.private_key_file)})
ssh = p._connect()
ssh = ParamikoSshConnection(self._play_context, '/dev/null')._connect()
self.ssh = ssh.ssh self.ssh = ssh.ssh
display.vvvv('ssh connection done, setting terminal', host=self._play_context.remote_addr) display.vvvv('ssh connection done, setting terminal', host=self._play_context.remote_addr)

@ -15,6 +15,7 @@ DOCUMENTATION = """
- This is needed on the Ansible control machine to be reasonably efficient with connections. - This is needed on the Ansible control machine to be reasonably efficient with connections.
Thus paramiko is faster for most users on these platforms. Thus paramiko is faster for most users on these platforms.
Users with ControlPersist capability can consider using -c ssh or configuring the transport in the configuration file. Users with ControlPersist capability can consider using -c ssh or configuring the transport in the configuration file.
- This plugin also borrows a lot of settings from the ssh plugin as they both cover the same protocol.
version_added: "0.1" version_added: "0.1"
options: options:
remote_addr: remote_addr:
@ -28,26 +29,94 @@ DOCUMENTATION = """
remote_user: remote_user:
description: description:
- User to login/authenticate as - User to login/authenticate as
- Can be set from the CLI via the ``--user`` or ``-u`` options.
vars: vars:
- name: ansible_user - name: ansible_user
- name: ansible_ssh_user - name: ansible_ssh_user
- name: ansible_paramiko_user - name: ansible_paramiko_user
env:
- name: ANSIBLE_REMOTE_USER
- name: ANSIBLE_PARAMIKO_REMOTE_USER
version_added: '2.5'
ini:
- section: defaults
key: remote_user
- section: paramiko_connection
key: remote_user
version_added: '2.5'
password:
description:
- Secret used to either login the ssh server or as a passphrase for ssh keys that require it
- Can be set from the CLI via the ``--ask-pass`` option.
vars:
- name: ansible_password
- name: ansible_ssh_pass
- name: ansible_paramiko_pass
version_added: '2.5'
host_key_auto_add:
description: 'TODO: write it'
env: [{name: ANSIBLE_PARAMIKO_HOST_KEY_AUTO_ADD}]
ini:
- {key: host_key_auto_add, section: paramiko_connection}
type: boolean
look_for_keys:
default: True
description: 'TODO: write it'
env: [{name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS}]
ini:
- {key: look_for_keys, section: paramiko_connection}
type: boolean
proxy_command:
default: ''
description:
- Proxy information for running the connection via a jumphost
- Also this plugin will scan 'ssh_args', 'ssh_extra_args' and 'ssh_common_args' from the 'ssh' plugin settings for proxy information if set.
env: [{name: ANSIBLE_PARAMIKO_PROXY_COMMAND}]
ini:
- {key: proxy_command, section: paramiko_connection}
pty:
default: True
description: 'TODO: write it'
env:
- name: ANSIBLE_PARAMIKO_PTY
ini:
- section: paramiko_connection
key: pty
type: boolean
record_host_keys:
default: True
description: 'TODO: write it'
env: [{name: ANSIBLE_PARAMIKO_RECORD_HOST_KEYS}]
ini:
- section: paramiko_connection
key: record_host_keys
type: boolean
host_key_checking:
description: 'Set this to "False" if you want to avoid host key checking by the underlying tools Ansible uses to connect to the host'
type: boolean
default: True
env:
- name: ANSIBLE_HOST_KEY_CHECKING
- name: ANSIBLE_SSH_HOST_KEY_CHECKING
version_added: '2.5'
- name: ANSIBLE_PARAMIKO_HOST_KEY_CHECKING
version_added: '2.5'
ini:
- section: defaults
key: host_key_checking
- section: paramiko_connection
key: host_key_checking
version_added: '2.5'
vars:
- name: ansible_host_key_checking
version_added: '2.5'
- name: ansible_ssh_host_key_checking
version_added: '2.5'
- name: ansible_paramiko_host_key_checking
version_added: '2.5'
# TODO: # TODO:
#getattr(self._play_context, 'ssh_extra_args', '') or '',
#getattr(self._play_context, 'ssh_common_args', '') or '',
#getattr(self._play_context, 'ssh_args', '') or '',
#C.HOST_KEY_CHECKING
#C.PARAMIKO_HOST_KEY_AUTO_ADD
#C.USE_PERSISTENT_CONNECTIONS: #C.USE_PERSISTENT_CONNECTIONS:
# ssh.connect( #timeout=self._play_context.timeout,
# look_for_keys=C.PARAMIKO_LOOK_FOR_KEYS,
# key_filename,
# password=self._play_context.password,
# timeout=self._play_context.timeout,
# port=port,
#proxy_command = proxy_command or C.PARAMIKO_PROXY_COMMAND
#C.PARAMIKO_PTY
#C.PARAMIKO_RECORD_HOST_KEYS
""" """
import warnings import warnings
@ -110,10 +179,11 @@ class MyAddPolicy(object):
def __init__(self, new_stdin, connection): def __init__(self, new_stdin, connection):
self._new_stdin = new_stdin self._new_stdin = new_stdin
self.connection = connection self.connection = connection
self._options = connection._options
def missing_host_key(self, client, hostname, key): def missing_host_key(self, client, hostname, key):
if all((C.HOST_KEY_CHECKING, not C.PARAMIKO_HOST_KEY_AUTO_ADD)): if all((self._options['host_key_checking'], not self._options['host_key_auto_add'])):
fingerprint = hexlify(key.get_fingerprint()) fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name() ktype = key.get_name()
@ -194,7 +264,7 @@ class Connection(ConnectionBase):
if proxy_command: if proxy_command:
break break
proxy_command = proxy_command or C.PARAMIKO_PROXY_COMMAND proxy_command = proxy_command or self._options['proxy_command']
sock_kwarg = {} sock_kwarg = {}
if proxy_command: if proxy_command:
@ -229,7 +299,7 @@ class Connection(ConnectionBase):
self.keyfile = os.path.expanduser("~/.ssh/known_hosts") self.keyfile = os.path.expanduser("~/.ssh/known_hosts")
if C.HOST_KEY_CHECKING: if self._options['host_key_checking']:
for ssh_known_hosts in ("/etc/ssh/ssh_known_hosts", "/etc/openssh/ssh_known_hosts"): for ssh_known_hosts in ("/etc/ssh/ssh_known_hosts", "/etc/openssh/ssh_known_hosts"):
try: try:
# TODO: check if we need to look at several possible locations, possible for loop # TODO: check if we need to look at several possible locations, possible for loop
@ -257,7 +327,7 @@ class Connection(ConnectionBase):
self._play_context.remote_addr, self._play_context.remote_addr,
username=self._play_context.remote_user, username=self._play_context.remote_user,
allow_agent=allow_agent, allow_agent=allow_agent,
look_for_keys=C.PARAMIKO_LOOK_FOR_KEYS, look_for_keys=self._options['look_for_keys'],
key_filename=key_filename, key_filename=key_filename,
password=self._play_context.password, password=self._play_context.password,
timeout=self._play_context.timeout, timeout=self._play_context.timeout,
@ -301,7 +371,7 @@ class Connection(ConnectionBase):
# sudo usually requires a PTY (cf. requiretty option), therefore # sudo usually requires a PTY (cf. requiretty option), therefore
# we give it one by default (pty=True in ansble.cfg), and we try # we give it one by default (pty=True in ansble.cfg), and we try
# to initialise from the calling environment when sudoable is enabled # to initialise from the calling environment when sudoable is enabled
if C.PARAMIKO_PTY and sudoable: if self._options['pty'] and sudoable:
chan.get_pty(term=os.getenv('TERM', 'vt100'), width=int(os.getenv('COLUMNS', 0)), height=int(os.getenv('LINES', 0))) chan.get_pty(term=os.getenv('TERM', 'vt100'), width=int(os.getenv('COLUMNS', 0)), height=int(os.getenv('LINES', 0)))
display.vvv("EXEC %s" % cmd, host=self._play_context.remote_addr) display.vvv("EXEC %s" % cmd, host=self._play_context.remote_addr)
@ -454,7 +524,7 @@ class Connection(ConnectionBase):
if self.sftp is not None: if self.sftp is not None:
self.sftp.close() self.sftp.close()
if C.HOST_KEY_CHECKING and C.PARAMIKO_RECORD_HOST_KEYS and self._any_keys_added(): if self._options['host_key_checking'] and self._options['record_host_keys'] and self._any_keys_added():
# add any new SSH host keys -- warning -- this could be slow # add any new SSH host keys -- warning -- this could be slow
# (This doesn't acquire the connection lock because it needs # (This doesn't acquire the connection lock because it needs

@ -14,13 +14,11 @@ DOCUMENTATION = """
version_added: "2.3" version_added: "2.3"
""" """
import os import os
import sys
import pty import pty
import json import json
import subprocess import subprocess
from ansible import constants as C from ansible import constants as C
from ansible.plugins.loader import connection_loader
from ansible.plugins.connection import ConnectionBase from ansible.plugins.connection import ConnectionBase
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
from ansible.module_utils.six.moves import cPickle from ansible.module_utils.six.moves import cPickle

@ -22,14 +22,23 @@ DOCUMENTATION = '''
- name: ansible_host - name: ansible_host
- name: ansible_ssh_host - name: ansible_ssh_host
host_key_checking: host_key_checking:
#constant: HOST_KEY_CHECKING
description: Determines if ssh should check host keys description: Determines if ssh should check host keys
type: boolean type: boolean
ini: ini:
- section: defaults - section: defaults
key: 'host_key_checking' key: 'host_key_checking'
- section: ssh_connection
key: 'host_key_checking'
version_added: '2.5'
env: env:
- name: ANSIBLE_HOST_KEY_CHECKING - name: ANSIBLE_HOST_KEY_CHECKING
- name: ANSIBLE_SSH_HOST_KEY_CHECKING
version_added: '2.5'
vars:
- name: ansible_host_key_checking
version_added: '2.5'
- name: ansible_ssh_host_key_checking
version_added: '2.5'
password: password:
description: Authentication password for the C(remote_user). Can be supplied as CLI option. description: Authentication password for the C(remote_user). Can be supplied as CLI option.
vars: vars:

@ -11,8 +11,13 @@ DOCUMENTATION = """
short_description: Run tasks over Microsoft's WinRM short_description: Run tasks over Microsoft's WinRM
description: description:
- Run commands or put/fetch on a target via WinRM - Run commands or put/fetch on a target via WinRM
- This plugin allows extra arguments to be passed that are supported by the protocol but not explicitly defined here.
They should take the form of variables declared with the following pattern `ansible_winrm_<option>`.
version_added: "2.0" version_added: "2.0"
requirements:
- pywinrm (python library)
options: options:
# figure out more elegant 'delegation'
remote_addr: remote_addr:
description: description:
- Address of the windows machine - Address of the windows machine
@ -21,11 +26,58 @@ DOCUMENTATION = """
- name: ansible_host - name: ansible_host
- name: ansible_winrm_host - name: ansible_winrm_host
remote_user: remote_user:
keywords:
- name: user
- name: remote_user
description: description:
- The user to log in as to the Windows machine - The user to log in as to the Windows machine
vars: vars:
- name: ansible_user - name: ansible_user
- name: ansible_winrm_user - name: ansible_winrm_user
port:
description:
- port for winrm to connect on remote target
- The default is the https (5896) port, if using http it should be 5895
vars:
- name: ansible_port
- name: ansible_winrm_port
default: 5986
keywords:
- name: port
type: integer
scheme:
description:
- URI scheme to use
choices: [http, https]
default: https
vars:
- name: ansible_winrm_scheme
path:
description: URI path to connect to
default: '/wsman'
vars:
- name: ansible_winrm_path
transport:
description:
- List of winrm transports to attempt to to use (ssl, plaintext, kerberos, etc)
- If None (the default) the plugin will try to automatically guess the correct list
- The choices avialable depend on your version of pywinrm
type: list
vars:
- name: ansible_winrm_transport
kerberos_command:
description: kerberos command to use to request a authentication ticket
default: kinit
vars:
- name: ansible_winrm_kinit_cmd
kerberos_mode:
description:
- kerberos usage mode.
- The managed option means Ansible will obtain kerberos ticket.
- While the manual one means a ticket must already have been obtained by the user.
choices: [managed, manual]
vars:
- name: ansible_winrm_kinit_mode
""" """
import base64 import base64
@ -84,43 +136,40 @@ class Connection(ConnectionBase):
module_implementation_preferences = ('.ps1', '.exe', '') module_implementation_preferences = ('.ps1', '.exe', '')
become_methods = ['runas'] become_methods = ['runas']
allow_executable = False allow_executable = False
has_pipelining = True
allow_extras = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.has_pipelining = True
self.always_pipeline_modules = True self.always_pipeline_modules = True
self.has_native_async = True self.has_native_async = True
self.protocol = None self.protocol = None
self.shell_id = None self.shell_id = None
self.delegate = None self.delegate = None
self._shell_type = 'powershell' self._shell_type = 'powershell'
# FUTURE: Add runas support
super(Connection, self).__init__(*args, **kwargs) super(Connection, self).__init__(*args, **kwargs)
def set_host_overrides(self, host, variables, templar): def set_options(self, task_keys=None, var_options=None, direct=None):
'''
Override WinRM-specific options from host variables.
'''
if not HAS_WINRM: if not HAS_WINRM:
return return
hostvars = {} super(Connection, self).set_options(task_keys=None, var_options=var_options, direct=direct)
for k in variables:
if k.startswith('ansible_winrm'):
hostvars[k] = templar.template(variables[k])
self._winrm_host = self._play_context.remote_addr self._winrm_host = self._play_context.remote_addr
self._winrm_port = int(self._play_context.port or 5986)
self._winrm_scheme = hostvars.get('ansible_winrm_scheme', 'http' if self._winrm_port == 5985 else 'https')
self._winrm_path = hostvars.get('ansible_winrm_path', '/wsman')
self._winrm_user = self._play_context.remote_user self._winrm_user = self._play_context.remote_user
self._winrm_pass = self._play_context.password self._winrm_pass = self._play_context.password
self._become_method = self._play_context.become_method self._become_method = self._play_context.become_method
self._become_user = self._play_context.become_user self._become_user = self._play_context.become_user
self._become_pass = self._play_context.become_pass self._become_pass = self._play_context.become_pass
self._kinit_cmd = hostvars.get('ansible_winrm_kinit_cmd', 'kinit') self._winrm_port = self._options['port']
self._winrm_scheme = self._options['scheme']
self._winrm_path = self._options['path']
self._kinit_cmd = self._options['kerberos_command']
self._winrm_transport = self._options['transport']
if hasattr(winrm, 'FEATURE_SUPPORTED_AUTHTYPES'): if hasattr(winrm, 'FEATURE_SUPPORTED_AUTHTYPES'):
self._winrm_supported_authtypes = set(winrm.FEATURE_SUPPORTED_AUTHTYPES) self._winrm_supported_authtypes = set(winrm.FEATURE_SUPPORTED_AUTHTYPES)
@ -128,16 +177,15 @@ class Connection(ConnectionBase):
# for legacy versions of pywinrm, use the values we know are supported # for legacy versions of pywinrm, use the values we know are supported
self._winrm_supported_authtypes = set(['plaintext', 'ssl', 'kerberos']) self._winrm_supported_authtypes = set(['plaintext', 'ssl', 'kerberos'])
# TODO: figure out what we want to do with auto-transport selection in the face of NTLM/Kerb/CredSSP/Cert/Basic # calculate transport if needed
transport_selector = 'ssl' if self._winrm_scheme == 'https' else 'plaintext' if self._winrm_transport is None or self._winrm_transport[0] is None:
# TODO: figure out what we want to do with auto-transport selection in the face of NTLM/Kerb/CredSSP/Cert/Basic
transport_selector = ['ssl'] if self._winrm_scheme == 'https' else ['plaintext']
if HAVE_KERBEROS and ((self._winrm_user and '@' in self._winrm_user)): if HAVE_KERBEROS and ((self._winrm_user and '@' in self._winrm_user)):
self._winrm_transport = 'kerberos,%s' % transport_selector self._winrm_transport = ['kerberos'] + transport_selector
else: else:
self._winrm_transport = transport_selector self._winrm_transport = transport_selector
self._winrm_transport = hostvars.get('ansible_winrm_transport', self._winrm_transport)
if isinstance(self._winrm_transport, string_types):
self._winrm_transport = [x.strip() for x in self._winrm_transport.split(',') if x.strip()]
unsupported_transports = set(self._winrm_transport).difference(self._winrm_supported_authtypes) unsupported_transports = set(self._winrm_transport).difference(self._winrm_supported_authtypes)
@ -145,16 +193,14 @@ class Connection(ConnectionBase):
raise AnsibleError('The installed version of WinRM does not support transport(s) %s' % list(unsupported_transports)) raise AnsibleError('The installed version of WinRM does not support transport(s) %s' % list(unsupported_transports))
# if kerberos is among our transports and there's a password specified, we're managing the tickets # if kerberos is among our transports and there's a password specified, we're managing the tickets
kinit_mode = to_text(hostvars.get('ansible_winrm_kinit_mode', '')).strip() kinit_mode = self._options['kerberos_mode']
if kinit_mode == "": if kinit_mode is None:
# HACK: ideally, remove multi-transport stuff # HACK: ideally, remove multi-transport stuff
self._kerb_managed = "kerberos" in self._winrm_transport and self._winrm_pass self._kerb_managed = "kerberos" in self._winrm_transport and self._winrm_pass
elif kinit_mode == "managed": elif kinit_mode == "managed":
self._kerb_managed = True self._kerb_managed = True
elif kinit_mode == "manual": elif kinit_mode == "manual":
self._kerb_managed = False self._kerb_managed = False
else:
raise AnsibleError('Unknown ansible_winrm_kinit_mode value: "%s" (must be "managed" or "manual")' % kinit_mode)
# arg names we're going passing directly # arg names we're going passing directly
internal_kwarg_mask = set(['self', 'endpoint', 'transport', 'username', 'password', 'scheme', 'path', 'kinit_mode', 'kinit_cmd']) internal_kwarg_mask = set(['self', 'endpoint', 'transport', 'username', 'password', 'scheme', 'path', 'kinit_mode', 'kinit_cmd'])
@ -163,16 +209,16 @@ class Connection(ConnectionBase):
argspec = inspect.getargspec(Protocol.__init__) argspec = inspect.getargspec(Protocol.__init__)
supported_winrm_args = set(argspec.args) supported_winrm_args = set(argspec.args)
supported_winrm_args.update(internal_kwarg_mask) supported_winrm_args.update(internal_kwarg_mask)
passed_winrm_args = set([v.replace('ansible_winrm_', '') for v in hostvars if v.startswith('ansible_winrm_')]) passed_winrm_args = set([v.replace('ansible_winrm_', '') for v in self._options['_extras']])
unsupported_args = passed_winrm_args.difference(supported_winrm_args) unsupported_args = passed_winrm_args.difference(supported_winrm_args)
# warn for kwargs unsupported by the installed version of pywinrm # warn for kwargs unsupported by the installed version of pywinrm
for arg in unsupported_args: for arg in unsupported_args:
display.warning("ansible_winrm_{0} unsupported by pywinrm (is an up-to-date version of pywinrm installed?)".format(arg)) display.warning("ansible_winrm_{0} unsupported by pywinrm (is an up-to-date version of pywinrm installed?)".format(arg))
# pass through matching kwargs, excluding the list we want to treat specially # pass through matching extras, excluding the list we want to treat specially
for arg in passed_winrm_args.difference(internal_kwarg_mask).intersection(supported_winrm_args): for arg in passed_winrm_args.difference(internal_kwarg_mask).intersection(supported_winrm_args):
self._winrm_kwargs[arg] = hostvars['ansible_winrm_%s' % arg] self._winrm_kwargs[arg] = self._options['_extras']['ansible_winrm_%s' % arg]
# Until pykerberos has enough goodies to implement a rudimentary kinit/klist, simplest way is to let each connection # Until pykerberos has enough goodies to implement a rudimentary kinit/klist, simplest way is to let each connection
# auth itself with a private CCACHE. # auth itself with a private CCACHE.

@ -19,13 +19,15 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from collections import MutableMapping
import hashlib import hashlib
import os import os
import re import re
import string import string
from collections import MutableMapping
from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError
from ansible.plugins import AnsiblePlugin
from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.parsing.convert_bool import boolean from ansible.module_utils.parsing.convert_bool import boolean
from ansible.module_utils.six import string_types from ansible.module_utils.six import string_types
@ -40,16 +42,106 @@ except ImportError:
_SAFE_GROUP = re.compile("[^A-Za-z0-9\_]") _SAFE_GROUP = re.compile("[^A-Za-z0-9\_]")
class BaseInventoryPlugin(object): # Helper methods
def to_safe_group_name(name):
''' Converts 'bad' characters in a string to underscores so they can be used as Ansible hosts or groups '''
return _SAFE_GROUP.sub("_", name)
def detect_range(line=None):
'''
A helper function that checks a given host line to see if it contains
a range pattern described in the docstring above.
Returns True if the given line contains a pattern, else False.
'''
return '[' in line
def expand_hostname_range(line=None):
'''
A helper function that expands a given line that contains a pattern
specified in top docstring, and returns a list that consists of the
expanded version.
The '[' and ']' characters are used to maintain the pseudo-code
appearance. They are replaced in this function with '|' to ease
string splitting.
References: http://ansible.github.com/patterns.html#hosts-and-groups
'''
all_hosts = []
if line:
# A hostname such as db[1:6]-node is considered to consists
# three parts:
# head: 'db'
# nrange: [1:6]; range() is a built-in. Can't use the name
# tail: '-node'
# Add support for multiple ranges in a host so:
# db[01:10:3]node-[01:10]
# - to do this we split off at the first [...] set, getting the list
# of hosts and then repeat until none left.
# - also add an optional third parameter which contains the step. (Default: 1)
# so range can be [01:10:2] -> 01 03 05 07 09
(head, nrange, tail) = line.replace('[', '|', 1).replace(']', '|', 1).split('|')
bounds = nrange.split(":")
if len(bounds) != 2 and len(bounds) != 3:
raise AnsibleError("host range must be begin:end or begin:end:step")
beg = bounds[0]
end = bounds[1]
if len(bounds) == 2:
step = 1
else:
step = bounds[2]
if not beg:
beg = "0"
if not end:
raise AnsibleError("host range must specify end value")
if beg[0] == '0' and len(beg) > 1:
rlen = len(beg) # range length formatting hint
if rlen != len(end):
raise AnsibleError("host range must specify equal-length begin and end formats")
def fill(x):
return str(x).zfill(rlen) # range sequence
else:
fill = str
try:
i_beg = string.ascii_letters.index(beg)
i_end = string.ascii_letters.index(end)
if i_beg > i_end:
raise AnsibleError("host range must have begin <= end")
seq = list(string.ascii_letters[i_beg:i_end + 1:int(step)])
except ValueError: # not an alpha range
seq = range(int(beg), int(end) + 1, int(step))
for rseq in seq:
hname = ''.join((head, fill(rseq), tail))
if detect_range(hname):
all_hosts.extend(expand_hostname_range(hname))
else:
all_hosts.append(hname)
return all_hosts
class BaseInventoryPlugin(AnsiblePlugin):
""" Parses an Inventory Source""" """ Parses an Inventory Source"""
TYPE = 'generator' TYPE = 'generator'
def __init__(self): def __init__(self):
super(BaseInventoryPlugin, self).__init__()
self._options = {}
self.inventory = None self.inventory = None
self.display = display self.display = display
self._cache = {}
def parse(self, inventory, loader, path, cache=True): def parse(self, inventory, loader, path, cache=True):
''' Populates self.groups from the given data. Raises an error on any parse failure. ''' ''' Populates self.groups from the given data. Raises an error on any parse failure. '''
@ -64,7 +156,64 @@ class BaseInventoryPlugin(object):
b_path = to_bytes(path, errors='surrogate_or_strict') b_path = to_bytes(path, errors='surrogate_or_strict')
return (os.path.exists(b_path) and os.access(b_path, os.R_OK)) return (os.path.exists(b_path) and os.access(b_path, os.R_OK))
def get_cache_prefix(self, path): def _populate_host_vars(self, hosts, variables, group=None, port=None):
if not isinstance(variables, MutableMapping):
raise AnsibleParserError("Invalid data from file, expected dictionary and got:\n\n%s" % to_native(variables))
for host in hosts:
self.inventory.add_host(host, group=group, port=port)
for k in variables:
self.inventory.set_variable(host, k, variables[k])
def _read_config_data(self, path):
''' validate config and set options as appropriate '''
config = {}
try:
config = self.loader.load_from_file(path)
except Exception as e:
raise AnsibleParserError(to_native(e))
if not config:
# no data
raise AnsibleParserError("%s is empty" % (to_native(path)))
elif config.get('plugin') != self.NAME:
# this is not my config file
raise AnsibleParserError("Incorrect plugin name in file: %s" % config.get('plugin', 'none found'))
elif not isinstance(config, MutableMapping):
# configs are dictionaries
raise AnsibleParserError('inventory source has invalid structure, it should be a dictionary, got: %s' % type(config))
self.set_options(direct=config)
return config
def _consume_options(self, data):
''' update existing options from file data'''
for k in self._options:
if k in data:
self._options[k] = data.pop(k)
def clear_cache(self):
pass
class BaseFileInventoryPlugin(BaseInventoryPlugin):
""" Parses a File based Inventory Source"""
TYPE = 'storage'
def __init__(self):
super(BaseFileInventoryPlugin, self).__init__()
class Cacheable(object):
_cache = {}
def _get_cache_prefix(self, path):
''' create predictable unique prefix for plugin/inventory ''' ''' create predictable unique prefix for plugin/inventory '''
m = hashlib.sha1() m = hashlib.sha1()
@ -78,16 +227,10 @@ class BaseInventoryPlugin(object):
return 's_'.join([d1[:5], d2[:5]]) return 's_'.join([d1[:5], d2[:5]])
def clear_cache(self): def clear_cache(self):
pass self._cache = {}
def populate_host_vars(self, hosts, variables, group=None, port=None):
if not isinstance(variables, MutableMapping):
raise AnsibleParserError("Invalid data from file, expected dictionary and got:\n\n%s" % to_native(variables))
for host in hosts: class Constructable(object):
self.inventory.add_host(host, group=group, port=port)
for k in variables:
self.inventory.set_variable(host, k, variables[k])
def _compose(self, template, variables): def _compose(self, template, variables):
''' helper method for pluigns to compose variables for Ansible based on jinja2 expression and inventory vars''' ''' helper method for pluigns to compose variables for Ansible based on jinja2 expression and inventory vars'''
@ -153,101 +296,3 @@ class BaseInventoryPlugin(object):
raise AnsibleOptionsError("No key supplied, invalid entry") raise AnsibleOptionsError("No key supplied, invalid entry")
else: else:
raise AnsibleOptionsError("Invalid keyed group entry, it must be a dictionary: %s " % keyed) raise AnsibleOptionsError("Invalid keyed group entry, it must be a dictionary: %s " % keyed)
class BaseFileInventoryPlugin(BaseInventoryPlugin):
""" Parses a File based Inventory Source"""
TYPE = 'storage'
def __init__(self):
super(BaseFileInventoryPlugin, self).__init__()
# Helper methods
def to_safe_group_name(name):
''' Converts 'bad' characters in a string to underscores so they can be used as Ansible hosts or groups '''
return _SAFE_GROUP.sub("_", name)
def detect_range(line=None):
'''
A helper function that checks a given host line to see if it contains
a range pattern described in the docstring above.
Returns True if the given line contains a pattern, else False.
'''
return '[' in line
def expand_hostname_range(line=None):
'''
A helper function that expands a given line that contains a pattern
specified in top docstring, and returns a list that consists of the
expanded version.
The '[' and ']' characters are used to maintain the pseudo-code
appearance. They are replaced in this function with '|' to ease
string splitting.
References: http://ansible.github.com/patterns.html#hosts-and-groups
'''
all_hosts = []
if line:
# A hostname such as db[1:6]-node is considered to consists
# three parts:
# head: 'db'
# nrange: [1:6]; range() is a built-in. Can't use the name
# tail: '-node'
# Add support for multiple ranges in a host so:
# db[01:10:3]node-[01:10]
# - to do this we split off at the first [...] set, getting the list
# of hosts and then repeat until none left.
# - also add an optional third parameter which contains the step. (Default: 1)
# so range can be [01:10:2] -> 01 03 05 07 09
(head, nrange, tail) = line.replace('[', '|', 1).replace(']', '|', 1).split('|')
bounds = nrange.split(":")
if len(bounds) != 2 and len(bounds) != 3:
raise AnsibleError("host range must be begin:end or begin:end:step")
beg = bounds[0]
end = bounds[1]
if len(bounds) == 2:
step = 1
else:
step = bounds[2]
if not beg:
beg = "0"
if not end:
raise AnsibleError("host range must specify end value")
if beg[0] == '0' and len(beg) > 1:
rlen = len(beg) # range length formatting hint
if rlen != len(end):
raise AnsibleError("host range must specify equal-length begin and end formats")
def fill(x):
return str(x).zfill(rlen) # range sequence
else:
fill = str
try:
i_beg = string.ascii_letters.index(beg)
i_end = string.ascii_letters.index(end)
if i_beg > i_end:
raise AnsibleError("host range must have begin <= end")
seq = list(string.ascii_letters[i_beg:i_end + 1:int(step)])
except ValueError: # not an alpha range
seq = range(int(beg), int(end) + 1, int(step))
for rseq in seq:
hname = ''.join((head, fill(rseq), tail))
if detect_range(hname):
all_hosts.extend(expand_hostname_range(hname))
else:
all_hosts.append(hname)
return all_hosts

@ -54,17 +54,15 @@ EXAMPLES = '''
import os import os
from collections import MutableMapping
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.plugins.cache import FactCache from ansible.plugins.cache import FactCache
from ansible.plugins.inventory import BaseInventoryPlugin from ansible.plugins.inventory import BaseInventoryPlugin, Constructable
from ansible.module_utils._text import to_native from ansible.module_utils._text import to_native
from ansible.utils.vars import combine_vars from ansible.utils.vars import combine_vars
class InventoryModule(BaseInventoryPlugin): class InventoryModule(BaseInventoryPlugin, Constructable):
""" constructs groups and vars using Jinaj2 template expressions """ """ constructs groups and vars using Jinaj2 template expressions """
NAME = 'constructed' NAME = 'constructed'
@ -91,30 +89,21 @@ class InventoryModule(BaseInventoryPlugin):
super(InventoryModule, self).parse(inventory, loader, path, cache=cache) super(InventoryModule, self).parse(inventory, loader, path, cache=cache)
try: self._read_config_data(path)
data = self.loader.load_from_file(path)
except Exception as e:
raise AnsibleParserError("Unable to parse %s: %s" % (to_native(path), to_native(e)))
if not data:
raise AnsibleParserError("%s is empty" % (to_native(path)))
elif not isinstance(data, MutableMapping):
raise AnsibleParserError('inventory source has invalid structure, it should be a dictionary, got: %s' % type(data))
elif data.get('plugin') != self.NAME:
raise AnsibleParserError("%s is not a constructed groups config file, plugin entry must be 'constructed'" % (to_native(path)))
strict = data.get('strict', False) strict = self._options['strict']
fact_cache = FactCache()
try: try:
# Go over hosts (less var copies) # Go over hosts (less var copies)
for host in inventory.hosts: for host in inventory.hosts:
# get available variables to templar # get available variables to templar
hostvars = inventory.hosts[host].get_vars() hostvars = inventory.hosts[host].get_vars()
if host in self._cache: # adds facts if cache is active if host in fact_cache: # adds facts if cache is active
hostvars = combine_vars(hostvars, self._cache[host]) hostvars = combine_vars(hostvars, fact_cache[host])
# create composite vars # create composite vars
self._set_composite_vars(data.get('compose'), hostvars, host, strict=strict) self._set_composite_vars(self._options['compose'], hostvars, host, strict=strict)
# refetch host vars in case new ones have been created above # refetch host vars in case new ones have been created above
hostvars = inventory.hosts[host].get_vars() hostvars = inventory.hosts[host].get_vars()
@ -122,10 +111,10 @@ class InventoryModule(BaseInventoryPlugin):
hostvars = combine_vars(hostvars, self._cache[host]) hostvars = combine_vars(hostvars, self._cache[host])
# constructed groups based on conditionals # constructed groups based on conditionals
self._add_host_to_composed_groups(data.get('groups'), hostvars, host, strict=strict) self._add_host_to_composed_groups(self._options['groups'], hostvars, host, strict=strict)
# constructed groups based variable values # constructed groups based variable values
self._add_host_to_keyed_groups(data.get('keyed_groups'), hostvars, host, strict=strict) self._add_host_to_keyed_groups(self._options['keyed_groups'], hostvars, host, strict=strict)
except Exception as e: except Exception as e:
raise AnsibleParserError("failed to parse %s: %s " % (to_native(path), to_native(e))) raise AnsibleParserError("failed to parse %s: %s " % (to_native(path), to_native(e)))

@ -208,7 +208,7 @@ class InventoryModule(BaseFileInventoryPlugin):
# the current group. # the current group.
if state == 'hosts': if state == 'hosts':
hosts, port, variables = self._parse_host_definition(line) hosts, port, variables = self._parse_host_definition(line)
self.populate_host_vars(hosts, variables, groupname, port) self._populate_host_vars(hosts, variables, groupname, port)
# [groupname:vars] contains variable definitions that must be # [groupname:vars] contains variable definitions that must be
# applied to the current group. # applied to the current group.

@ -104,7 +104,7 @@ simple_config_file:
import collections import collections
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.plugins.inventory import BaseInventoryPlugin from ansible.plugins.inventory import BaseInventoryPlugin, Constructable, Cacheable
try: try:
import os_client_config import os_client_config
@ -115,7 +115,7 @@ except ImportError:
HAS_SHADE = False HAS_SHADE = False
class InventoryModule(BaseInventoryPlugin): class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
''' Host inventory provider for ansible using OpenStack clouds. ''' ''' Host inventory provider for ansible using OpenStack clouds. '''
NAME = 'openstack' NAME = 'openstack'
@ -124,13 +124,10 @@ class InventoryModule(BaseInventoryPlugin):
super(InventoryModule, self).parse(inventory, loader, path) super(InventoryModule, self).parse(inventory, loader, path)
cache_key = self.get_cache_prefix(path) cache_key = self._get_cache_prefix(path)
# file is config file # file is config file
try: self._config_data = self._read_config_data(path)
self._config_data = self.loader.load_from_file(path)
except Exception as e:
raise AnsibleParserError(e)
msg = '' msg = ''
if not self._config_data: if not self._config_data:

@ -9,10 +9,21 @@ DOCUMENTATION = '''
inventory: script inventory: script
version_added: "2.4" version_added: "2.4"
short_description: Executes an inventory script that returns JSON short_description: Executes an inventory script that returns JSON
options:
cache:
description: Toggle the usage of the configured Cache plugin.
default: False
type: boolean
ini:
- section: inventory_plugin_script
key: cache
env:
- name: ANSIBLE_INVENTORY_PLUGIN_SCRIPT_CACHE
description: description:
- The source provided must an executable that returns Ansible inventory JSON - The source provided must an executable that returns Ansible inventory JSON
- The source must accept C(--list) and C(--host <hostname>) as arguments. - The source must accept C(--list) and C(--host <hostname>) as arguments.
C(--host) will only be used if no C(_meta) key is present (performance optimization) C(--host) will only be used if no C(_meta) key is present.
This is a performance optimization as the script would be called per host otherwise.
notes: notes:
- It takes the place of the previously hardcoded script inventory. - It takes the place of the previously hardcoded script inventory.
- To function it requires being whitelisted in configuration, which is true by default. - To function it requires being whitelisted in configuration, which is true by default.
@ -26,10 +37,10 @@ from ansible.errors import AnsibleError, AnsibleParserError
from ansible.module_utils.basic import json_dict_bytes_to_unicode from ansible.module_utils.basic import json_dict_bytes_to_unicode
from ansible.module_utils.six import iteritems from ansible.module_utils.six import iteritems
from ansible.module_utils._text import to_native, to_text from ansible.module_utils._text import to_native, to_text
from ansible.plugins.inventory import BaseInventoryPlugin from ansible.plugins.inventory import BaseInventoryPlugin, Cacheable
class InventoryModule(BaseInventoryPlugin): class InventoryModule(BaseInventoryPlugin, Cacheable):
''' Host inventory parser for ansible using external inventory scripts. ''' ''' Host inventory parser for ansible using external inventory scripts. '''
NAME = 'script' NAME = 'script'
@ -61,17 +72,20 @@ class InventoryModule(BaseInventoryPlugin):
return valid return valid
def parse(self, inventory, loader, path, cache=True): def parse(self, inventory, loader, path, cache=None):
super(InventoryModule, self).parse(inventory, loader, path) super(InventoryModule, self).parse(inventory, loader, path)
if cache is None:
cache = self._options['cache']
# Support inventory scripts that are not prefixed with some # Support inventory scripts that are not prefixed with some
# path information but happen to be in the current working # path information but happen to be in the current working
# directory when '.' is not in PATH. # directory when '.' is not in PATH.
cmd = [path, "--list"] cmd = [path, "--list"]
try: try:
cache_key = self.get_cache_prefix(path) cache_key = self._get_cache_prefix(path)
if not cache or cache_key not in self._cache: if not cache or cache_key not in self._cache:
try: try:
sp = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) sp = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@ -125,7 +139,7 @@ class InventoryModule(BaseInventoryPlugin):
except AttributeError as e: except AttributeError as e:
raise AnsibleError("Improperly formatted host information for %s: %s" % (host, to_native(e))) raise AnsibleError("Improperly formatted host information for %s: %s" % (host, to_native(e)))
self.populate_host_vars([host], got) self._populate_host_vars([host], got)
except Exception as e: except Exception as e:
raise AnsibleParserError(to_native(e)) raise AnsibleParserError(to_native(e))

@ -12,6 +12,8 @@ DOCUMENTATION = '''
- Get inventory hosts from the local virtualbox installation. - Get inventory hosts from the local virtualbox installation.
- Uses a <name>.vbox.yaml (or .vbox.yml) YAML configuration file. - Uses a <name>.vbox.yaml (or .vbox.yml) YAML configuration file.
- The inventory_hostname is always the 'Name' of the virtualbox instance. - The inventory_hostname is always the 'Name' of the virtualbox instance.
extends_documentation_fragment:
- constructed
options: options:
running_only: running_only:
description: toggles showing all vms vs only those currently running description: toggles showing all vms vs only those currently running
@ -26,14 +28,6 @@ DOCUMENTATION = '''
description: create vars from virtualbox properties description: create vars from virtualbox properties
type: dictionary type: dictionary
default: {} default: {}
compose:
description: create vars from jinja2 expressions, these are created AFTER the query block
type: dictionary
default: {}
groups:
description: add hosts to group based on Jinja2 conditionals, these also run after query block
type: dictionary
default: {}
''' '''
EXAMPLES = ''' EXAMPLES = '''
@ -54,10 +48,10 @@ from subprocess import Popen, PIPE
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.plugins.inventory import BaseInventoryPlugin from ansible.plugins.inventory import BaseInventoryPlugin, Constructable, Cacheable
class InventoryModule(BaseInventoryPlugin): class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
''' Host inventory parser for ansible using local virtualbox. ''' ''' Host inventory parser for ansible using local virtualbox. '''
NAME = 'virtualbox' NAME = 'virtualbox'
@ -76,33 +70,34 @@ class InventoryModule(BaseInventoryPlugin):
pass pass
return ret return ret
def _set_variables(self, hostvars, data): def _set_variables(self, hostvars):
# set vars in inventory from hostvars # set vars in inventory from hostvars
for host in hostvars: for host in hostvars:
query = self._options['query']
# create vars from vbox properties # create vars from vbox properties
if data.get('query') and isinstance(data['query'], MutableMapping): if query and isinstance(query, MutableMapping):
for varname in data['query']: for varname in query:
hostvars[host][varname] = self._query_vbox_data(host, data['query'][varname]) hostvars[host][varname] = self._query_vbox_data(host, query[varname])
# create composite vars # create composite vars
self._set_composite_vars(data.get('compose'), hostvars, host) self._set_composite_vars(self._options['compose'], hostvars, host)
# actually update inventory # actually update inventory
for key in hostvars[host]: for key in hostvars[host]:
self.inventory.set_variable(host, key, hostvars[host][key]) self.inventory.set_variable(host, key, hostvars[host][key])
# constructed groups based on conditionals # constructed groups based on conditionals
self._add_host_to_composed_groups(data.get('groups'), hostvars, host) self._add_host_to_composed_groups(self._options['groups'], hostvars, host)
def _populate_from_source(self, source_data, config_data): def _populate_from_source(self, source_data):
hostvars = {} hostvars = {}
prevkey = pref_k = '' prevkey = pref_k = ''
current_host = None current_host = None
# needed to possibly set ansible_host # needed to possibly set ansible_host
netinfo = config_data.get('network_info_path', "/VirtualBox/GuestInfo/Net/0/V4/IP") netinfo = self._options['network_info_path']
for line in source_data: for line in source_data:
try: try:
@ -149,7 +144,7 @@ class InventoryModule(BaseInventoryPlugin):
prevkey = pref_k prevkey = pref_k
self._set_variables(hostvars, config_data) self._set_variables(hostvars)
def verify_file(self, path): def verify_file(self, path):
@ -163,17 +158,12 @@ class InventoryModule(BaseInventoryPlugin):
super(InventoryModule, self).parse(inventory, loader, path) super(InventoryModule, self).parse(inventory, loader, path)
cache_key = self.get_cache_prefix(path) cache_key = self._get_cache_prefix(path)
# file is config file config_data = self._read_config_data(path)
try:
config_data = self.loader.load_from_file(path)
except Exception as e:
raise AnsibleParserError(to_native(e))
if not config_data or config_data.get('plugin') != self.NAME: # set _options from config data
# this is not my config file self._consume_options(config_data)
raise AnsibleParserError("Incorrect plugin name in file: %s" % config_data.get('plugin', 'none found'))
source_data = None source_data = None
if cache and cache_key in self._cache: if cache and cache_key in self._cache:
@ -183,8 +173,8 @@ class InventoryModule(BaseInventoryPlugin):
pass pass
if not source_data: if not source_data:
b_pwfile = to_bytes(config_data.get('settings_password_file'), errors='surrogate_or_strict') b_pwfile = to_bytes(self._options['settings_password_file'], errors='surrogate_or_strict')
running = config_data.get('running_only', False) running = self._options['running_only']
# start getting data # start getting data
cmd = [self.VBOX, b'list', b'-l'] cmd = [self.VBOX, b'list', b'-l']
@ -205,4 +195,4 @@ class InventoryModule(BaseInventoryPlugin):
source_data = p.stdout.read() source_data = p.stdout.read()
self._cache[cache_key] = to_text(source_data, errors='surrogate_or_strict') self._cache[cache_key] = to_text(source_data, errors='surrogate_or_strict')
self._populate_from_source(source_data.splitlines(), config_data) self._populate_from_source(source_data.splitlines())

@ -18,10 +18,19 @@ DOCUMENTATION = '''
- It takes the place of the previously hardcoded YAML inventory. - It takes the place of the previously hardcoded YAML inventory.
- To function it requires being whitelisted in configuration. - To function it requires being whitelisted in configuration.
options: options:
yaml_extensions: yaml_extensions:
description: list of 'valid' extensions for files containing YAML description: list of 'valid' extensions for files containing YAML
type: list type: list
default: ['.yaml', '.yml', '.json'] default: ['.yaml', '.yml', '.json']
env:
- name: ANSIBLE_YAML_FILENAME_EXT
- name: ANSIBLE_INVENTORY_PLUGIN_EXTS
ini:
- key: yaml_valid_extensions
section: defaults
- section: inventory_plugin_yaml
key: yaml_valid_extensions
''' '''
EXAMPLES = ''' EXAMPLES = '''
all: # keys must be unique, i.e. only one 'hosts' per group all: # keys must be unique, i.e. only one 'hosts' per group
@ -52,7 +61,6 @@ all: # keys must be unique, i.e. only one 'hosts' per group
import os import os
from collections import MutableMapping from collections import MutableMapping
from ansible import constants as C
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.module_utils.six import string_types from ansible.module_utils.six import string_types
from ansible.module_utils._text import to_native from ansible.module_utils._text import to_native
@ -73,7 +81,7 @@ class InventoryModule(BaseFileInventoryPlugin):
valid = False valid = False
if super(InventoryModule, self).verify_file(path): if super(InventoryModule, self).verify_file(path):
file_name, ext = os.path.splitext(path) file_name, ext = os.path.splitext(path)
if not ext or ext in C.YAML_FILENAME_EXTENSIONS: if not ext or ext in self._options['yaml_extensions']:
valid = True valid = True
return valid return valid
@ -131,7 +139,7 @@ class InventoryModule(BaseFileInventoryPlugin):
elif key == 'hosts': elif key == 'hosts':
for host_pattern in group_data['hosts']: for host_pattern in group_data['hosts']:
hosts, port = self._parse_host(host_pattern) hosts, port = self._parse_host(host_pattern)
self.populate_host_vars(hosts, group_data['hosts'][host_pattern] or {}, group, port) self._populate_host_vars(hosts, group_data['hosts'][host_pattern] or {}, group, port)
else: else:
self.display.warning('Skipping unexpected key (%s) in group (%s), only "vars", "children" and "hosts" are valid' % (key, group)) self.display.warning('Skipping unexpected key (%s) in group (%s), only "vars", "children" and "hosts" are valid' % (key, group))

@ -210,7 +210,7 @@ class PluginLoader:
type_name = get_plugin_class(self.class_name) type_name = get_plugin_class(self.class_name)
# FIXME: expand from just connection and callback # FIXME: expand from just connection and callback
if type_name in ('connection', 'callback'): if type_name in ('callback', 'connection', 'inventory'):
dstring = read_docstring(path, verbose=False, ignore_errors=False) dstring = read_docstring(path, verbose=False, ignore_errors=False)
if dstring.get('doc', False): if dstring.get('doc', False):

@ -0,0 +1,46 @@
# (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
class ModuleDocFragment(object):
# inventory cache
DOCUMENTATION = """
options:
cache:
description:
- Toggle to enable/disable the caching of the inventory's source data, requires a cache plugin setup to work.
type: boolean
default: False
env:
- name: ANSIBLE_INVENTORY_CACHE
ini:
- section: inventory
key: cache
cache_plugin:
description:
- Cache plugin to use for the inventory's source data.
env:
- name: ANSIBLE_INVENTORY_CACHE_PLUGIN
ini:
- section: inventory
key: cache_plugin
cache_timeout:
description:
- Cache duration in seconds
default: 3600
type: integer
env:
- name: ANSIBLE_INVENTORY_CACHE_TIMEOUT
ini:
- section: inventory
key: cache_timeout
cache_connection:
description:
- Cache connection data or path, read cache plugin documentation for specifics.
env:
- name: ANSIBLE_INVENTORY_CACHE_CONNECTION
ini:
- section: inventory
key: cache_connection
"""

@ -22,8 +22,9 @@ __metaclass__ = type
from collections import MutableMapping, MutableSet, MutableSequence from collections import MutableMapping, MutableSet, MutableSequence
from ansible.errors import AnsibleAssertionError from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import string_types from ansible.module_utils.six import string_types
from ansible.module_utils._text import to_native
from ansible.parsing.plugin_docs import read_docstring from ansible.parsing.plugin_docs import read_docstring
from ansible.parsing.yaml.loader import AnsibleLoader from ansible.parsing.yaml.loader import AnsibleLoader
from ansible.plugins.loader import fragment_loader from ansible.plugins.loader import fragment_loader
@ -42,6 +43,22 @@ BLACKLIST = {
} }
def merge_fragment(target, source):
for key, value in source.items():
if key in target:
# assumes both structures have same type
if isinstance(target[key], MutableMapping):
value.update(target[key])
elif isinstance(target[key], MutableSet):
value.add(target[key])
elif isinstance(target[key], MutableSequence):
value = sorted(frozenset(value + target[key]))
else:
raise Exception("Attempt to extend a documentation fragement, invalid type for %s" % key)
target[key] = value
def add_fragments(doc, filename): def add_fragments(doc, filename):
fragments = doc.get('extends_documentation_fragment', []) fragments = doc.get('extends_documentation_fragment', [])
@ -76,18 +93,18 @@ def add_fragments(doc, filename):
if 'options' not in fragment: if 'options' not in fragment:
raise Exception("missing options in fragment (%s), possibly misformatted?: %s" % (fragment_name, filename)) raise Exception("missing options in fragment (%s), possibly misformatted?: %s" % (fragment_name, filename))
for key, value in fragment.items(): # ensure options themselves are directly merged
if key in doc: if 'options' in doc:
# assumes both structures have same type try:
if isinstance(doc[key], MutableMapping): merge_fragment(doc['options'], fragment.pop('options'))
value.update(doc[key]) except Exception as e:
elif isinstance(doc[key], MutableSet): raise AnsibleError("%s options (%s) of unknown type: %s" % (to_native(e), fragment_name, filename))
value.add(doc[key])
elif isinstance(doc[key], MutableSequence): # merge rest of the sections
value = sorted(frozenset(value + doc[key])) try:
else: merge_fragment(doc, fragment)
raise Exception("Attempt to extend a documentation fragement (%s) of unknown type: %s" % (fragment_name, filename)) except Exception as e:
doc[key] = value raise AnsibleError("%s (%s) of unknown type: %s" % (to_native(e), fragment_name, filename))
def get_docstring(filename, verbose=False): def get_docstring(filename, verbose=False):

@ -35,7 +35,7 @@ from ansible.plugins.connection import network_cli
class TestConnectionClass(unittest.TestCase): class TestConnectionClass(unittest.TestCase):
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
def test_network_cli__connect_error(self, mocked_super): def test_network_cli__connect_error(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -47,7 +47,7 @@ class TestConnectionClass(unittest.TestCase):
pc.network_os = None pc.network_os = None
self.assertRaises(AnsibleConnectionFailure, conn._connect) self.assertRaises(AnsibleConnectionFailure, conn._connect)
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
def test_network_cli__invalid_os(self, mocked_super): def test_network_cli__invalid_os(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -60,7 +60,7 @@ class TestConnectionClass(unittest.TestCase):
self.assertRaises(AnsibleConnectionFailure, conn._connect) self.assertRaises(AnsibleConnectionFailure, conn._connect)
@patch("ansible.plugins.connection.network_cli.terminal_loader") @patch("ansible.plugins.connection.network_cli.terminal_loader")
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
def test_network_cli__connect(self, mocked_super, mocked_terminal_loader): def test_network_cli__connect(self, mocked_super, mocked_terminal_loader):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -84,7 +84,7 @@ class TestConnectionClass(unittest.TestCase):
conn._connect() conn._connect()
conn._terminal.on_authorize.assert_called_with(passwd='password') conn._terminal.on_authorize.assert_called_with(passwd='password')
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection.close") @patch("ansible.plugins.connection.paramiko_ssh.Connection.close")
def test_network_cli_close(self, mocked_super): def test_network_cli_close(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -99,7 +99,7 @@ class TestConnectionClass(unittest.TestCase):
self.assertTrue(terminal.on_close_shell.called) self.assertTrue(terminal.on_close_shell.called)
self.assertIsNone(conn._ssh_shell) self.assertIsNone(conn._ssh_shell)
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
def test_network_cli_exec_command(self, mocked_super): def test_network_cli_exec_command(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()

@ -23,17 +23,18 @@ __metaclass__ = type
import pytest import pytest
from ansible import constants as C
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.plugins.loader import PluginLoader
from ansible.compat.tests import mock from ansible.compat.tests import mock
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils._text import to_bytes, to_native
from ansible.plugins.inventory.script import InventoryModule
class TestInventoryModule(unittest.TestCase): class TestInventoryModule(unittest.TestCase):
def setUp(self): def setUp(self):
class Inventory(): class Inventory():
cache = dict() cache = dict()
@ -50,6 +51,10 @@ class TestInventoryModule(unittest.TestCase):
self.loader = mock.MagicMock() self.loader = mock.MagicMock()
self.loader.load = mock.MagicMock() self.loader.load = mock.MagicMock()
inv_loader = PluginLoader('InventoryModule', 'ansible.plugins.inventory', C.DEFAULT_INVENTORY_PLUGIN_PATH, 'inventory_plugins')
self.inventory_module = inv_loader.get('script')
self.inventory_module.set_options()
def register_patch(name): def register_patch(name):
patcher = mock.patch(name) patcher = mock.patch(name)
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
@ -64,9 +69,8 @@ class TestInventoryModule(unittest.TestCase):
def test_parse_subprocess_path_not_found_fail(self): def test_parse_subprocess_path_not_found_fail(self):
self.popen.side_effect = OSError("dummy text") self.popen.side_effect = OSError("dummy text")
inventory_module = InventoryModule()
with pytest.raises(AnsibleError) as e: with pytest.raises(AnsibleError) as e:
inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
assert e.value.message == "problem running /foo/bar/foobar.py --list (dummy text)" assert e.value.message == "problem running /foo/bar/foobar.py --list (dummy text)"
def test_parse_subprocess_err_code_fail(self): def test_parse_subprocess_err_code_fail(self):
@ -75,9 +79,8 @@ class TestInventoryModule(unittest.TestCase):
self.popen_result.returncode = 1 self.popen_result.returncode = 1
inventory_module = InventoryModule()
with pytest.raises(AnsibleError) as e: with pytest.raises(AnsibleError) as e:
inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
assert e.value.message == to_native("Inventory script (/foo/bar/foobar.py) had an execution error: " assert e.value.message == to_native("Inventory script (/foo/bar/foobar.py) had an execution error: "
"dummyédata\n ") "dummyédata\n ")
@ -86,9 +89,8 @@ class TestInventoryModule(unittest.TestCase):
self.popen_result.stderr = to_bytes("dummyédata") self.popen_result.stderr = to_bytes("dummyédata")
self.loader.load.side_effect = TypeError('obj must be string') self.loader.load.side_effect = TypeError('obj must be string')
inventory_module = InventoryModule()
with pytest.raises(AnsibleError) as e: with pytest.raises(AnsibleError) as e:
inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
assert e.value.message == to_native("failed to parse executable inventory script results from " assert e.value.message == to_native("failed to parse executable inventory script results from "
"/foo/bar/foobar.py: obj must be string\ndummyédata\n") "/foo/bar/foobar.py: obj must be string\ndummyédata\n")
@ -97,8 +99,7 @@ class TestInventoryModule(unittest.TestCase):
self.popen_result.stderr = to_bytes("dummyédata") self.popen_result.stderr = to_bytes("dummyédata")
self.loader.load.return_value = 'i am not a dict' self.loader.load.return_value = 'i am not a dict'
inventory_module = InventoryModule()
with pytest.raises(AnsibleError) as e: with pytest.raises(AnsibleError) as e:
inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py') self.inventory_module.parse(self.inventory, self.loader, '/foo/bar/foobar.py')
assert e.value.message == to_native("failed to parse executable inventory script results from " assert e.value.message == to_native("failed to parse executable inventory script results from "
"/foo/bar/foobar.py: needs to be a json dict\ndummyédata\n") "/foo/bar/foobar.py: needs to be a json dict\ndummyédata\n")

Loading…
Cancel
Save