updates the code path for network modules (#21193)

* replaces persistent connection digest with _create_control_path()
* adds _ansible_socket to _legal_inputs in basic.py
* adds connection_user to play_context
* maps remote_user to connection_user when connection is local
* maps ansible_socket in task_vars to module_args _ansible_socket if exists
pull/19707/merge
Peter Sprygada 8 years ago committed by GitHub
parent 25cb281b9b
commit 138051540e

@ -29,7 +29,6 @@ except Exception:
pass pass
import fcntl import fcntl
import hashlib
import os import os
import shlex import shlex
import signal import signal
@ -119,6 +118,8 @@ class Server():
self._start_time = datetime.datetime.now() self._start_time = datetime.datetime.now()
self.log("setup connection %s" % self.play_context.connection)
self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin) self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
self.conn._connect() self.conn._connect()
if not self.conn.connected: if not self.conn.connected:
@ -259,20 +260,15 @@ def main():
display.verbosity = pc.verbosity display.verbosity = pc.verbosity
# here we create a hash to use later when creating the socket file, ssh = connection_loader.get('ssh', class_only=True)
# so we can hide the info about the target host/user/etc. m = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user)
m = hashlib.sha256()
for attr in ('connection', 'remote_addr', 'port', 'remote_user'):
val = getattr(pc, attr, None)
if val:
m.update(to_bytes(val))
# create the persistent connection dir if need be and create the paths # create the persistent connection dir if need be and create the paths
# which we will be using later # which we will be using later
tmp_path = unfrackpath("$HOME/.ansible/pc") tmp_path = unfrackpath("$HOME/.ansible/pc")
makedirs_safe(tmp_path) makedirs_safe(tmp_path)
lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path) lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
sf_path = unfrackpath("%s/conn-%s" % (tmp_path, m.hexdigest()[0:12])) sf_path = unfrackpath(m % dict(directory=tmp_path))
# if the socket file doesn't exist, spin up the daemon process # if the socket file doesn't exist, spin up the daemon process
lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600) lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)

@ -680,6 +680,7 @@ class AnsibleModule(object):
self.cleanup_files = [] self.cleanup_files = []
self._debug = False self._debug = False
self._diff = False self._diff = False
self._socket_path = None
self._verbosity = 0 self._verbosity = 0
# May be used to set modifications to the environment for any # May be used to set modifications to the environment for any
# run_command invocation # run_command invocation
@ -689,7 +690,7 @@ class AnsibleModule(object):
self._passthrough = ['warnings', 'deprecations'] self._passthrough = ['warnings', 'deprecations']
self.aliases = {} self.aliases = {}
self._legal_inputs = ['_ansible_check_mode', '_ansible_no_log', '_ansible_debug', '_ansible_diff', '_ansible_verbosity', '_ansible_selinux_special_fs', '_ansible_module_name', '_ansible_version', '_ansible_syslog_facility'] self._legal_inputs = ['_ansible_check_mode', '_ansible_no_log', '_ansible_debug', '_ansible_diff', '_ansible_verbosity', '_ansible_selinux_special_fs', '_ansible_module_name', '_ansible_version', '_ansible_syslog_facility', '_ansible_socket']
if add_file_common_args: if add_file_common_args:
for k, v in FILE_COMMON_ARGUMENTS.items(): for k, v in FILE_COMMON_ARGUMENTS.items():
@ -1414,6 +1415,9 @@ class AnsibleModule(object):
elif k == '_ansible_module_name': elif k == '_ansible_module_name':
self._name = v self._name = v
elif k == '_ansible_socket':
self._socket_path = v
elif check_invalid_arguments and k not in self._legal_inputs: elif check_invalid_arguments and k not in self._legal_inputs:
unsupported_parameters.add(k) unsupported_parameters.add(k)

@ -168,6 +168,7 @@ class PlayContext(Base):
_timeout = FieldAttribute(isa='int', default=C.DEFAULT_TIMEOUT) _timeout = FieldAttribute(isa='int', default=C.DEFAULT_TIMEOUT)
_shell = FieldAttribute(isa='string') _shell = FieldAttribute(isa='string')
_network_os = FieldAttribute(isa='string') _network_os = FieldAttribute(isa='string')
_connection_user = FieldAttribute(isa='string')
_ssh_args = FieldAttribute(isa='string', default=C.ANSIBLE_SSH_ARGS) _ssh_args = FieldAttribute(isa='string', default=C.ANSIBLE_SSH_ARGS)
_ssh_common_args = FieldAttribute(isa='string') _ssh_common_args = FieldAttribute(isa='string')
_sftp_extra_args = FieldAttribute(isa='string') _sftp_extra_args = FieldAttribute(isa='string')
@ -442,6 +443,7 @@ class PlayContext(Base):
# additionally, we need to do this check after final connection has been # additionally, we need to do this check after final connection has been
# correctly set above ... # correctly set above ...
if new_info.connection == 'local': if new_info.connection == 'local':
new_info.connection_user = new_info.remote_user
new_info.remote_user = pwd.getpwuid(os.getuid()).pw_name new_info.remote_user = pwd.getpwuid(os.getuid()).pw_name
# set no_log to default if it was not previouslly set # set no_log to default if it was not previouslly set

@ -586,6 +586,9 @@ class ActionBase(with_metaclass(ABCMeta, object)):
# let module know about filesystems that selinux treats specially # let module know about filesystems that selinux treats specially
module_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS module_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS
# give the module the socket for persistent connections
module_args['_ansible_socket'] = task_vars.get('ansible_socket')
def _execute_module(self, module_name=None, module_args=None, tmp=None, task_vars=None, persist_files=False, delete_remote_tmp=True, wrap_async=False): def _execute_module(self, module_name=None, module_args=None, tmp=None, task_vars=None, persist_files=False, delete_remote_tmp=True, wrap_async=False):

Loading…
Cancel
Save