Fix cli context check for network_cli connection (#64697)

* Fix cli context check for network_cli connection

Fixes #64575

*  Check cli context for network_cli connection
   at the start of new task run only.

* Pass task_uuid around to identify start of new task run

* Handle for local connection
pull/65278/head
Ganesh Nalawade 5 years ago committed by GitHub
parent bd68bcab95
commit ee3f8d28a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,10 +71,11 @@ class ConnectionProcess(object):
The connection process wraps around a Connection object that manages
the connection to a remote device that persists over the playbook
'''
def __init__(self, fd, play_context, socket_path, original_path, ansible_playbook_pid=None):
def __init__(self, fd, play_context, socket_path, original_path, task_uuid=None, ansible_playbook_pid=None):
self.play_context = play_context
self.socket_path = socket_path
self.original_path = original_path
self._task_uuid = task_uuid
self.fd = fd
self.exception = None
@ -98,7 +99,7 @@ class ConnectionProcess(object):
if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/':
self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)
self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null',
ansible_playbook_pid=self._ansible_playbook_pid)
task_uuid=self._task_uuid, ansible_playbook_pid=self._ansible_playbook_pid)
self.connection.set_options(var_options=variables)
self.connection._socket_path = self.socket_path
@ -257,8 +258,8 @@ def main():
if rc == 0:
ssh = connection_loader.get('ssh', class_only=True)
ansible_playbook_pid = sys.argv[1]
task_uuid = sys.argv[2]
cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection, ansible_playbook_pid)
# create the persistent connection dir if need be and create the paths
# which we will be using later
tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
@ -278,7 +279,7 @@ def main():
try:
os.close(r)
wfd = os.fdopen(w, 'w')
process = ConnectionProcess(wfd, play_context, socket_path, original_path, ansible_playbook_pid)
process = ConnectionProcess(wfd, play_context, socket_path, original_path, task_uuid, ansible_playbook_pid)
process.start(variables)
except Exception:
messages.append(('error', traceback.format_exc()))
@ -305,7 +306,7 @@ def main():
pc_data = to_text(init_data)
try:
conn.update_play_context(pc_data)
conn.set_cli_prompt_context()
conn.update_cli_prompt_context(task_uuid)
except Exception as exc:
# Only network_cli has update_play context and set_cli_prompt_context, so missing this is
# not fatal e.g. netconf

@ -926,7 +926,7 @@ class TaskExecutor:
display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr)
options = self._get_persistent_connection_options(connection, variables, templar)
socket_path = start_connection(self._play_context, options)
socket_path = start_connection(self._play_context, options, self._task._uuid)
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(connection, '_socket_path', socket_path)
@ -1046,7 +1046,7 @@ class TaskExecutor:
return handler
def start_connection(play_context, variables):
def start_connection(play_context, variables, task_uuid):
'''
Starts the persistent connection
'''
@ -1078,7 +1078,7 @@ def start_connection(play_context, variables):
python = sys.executable
master, slave = pty.openpty()
p = subprocess.Popen(
[python, ansible_connection, to_text(os.getppid())],
[python, ansible_connection, to_text(os.getppid()), to_text(task_uuid)],
stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
os.close(slave)

@ -57,7 +57,7 @@ class ActionModule(ActionNetworkModule):
command_timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -58,7 +58,7 @@ class ActionModule(ActionNetworkModule):
command_timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -55,7 +55,7 @@ class ActionModule(ActionNetworkModule):
pc.become_method = 'enable'
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -69,7 +69,7 @@ class ActionModule(ActionNetworkModule):
command_timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -66,7 +66,7 @@ class ActionModule(_ActionModule):
command_timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -58,7 +58,7 @@ class ActionModule(ActionNetworkModule):
if self._task.action in ['ce_netconf'] or self._task.action not in CLI_SUPPORTED_MODULES:
pc.connection = 'netconf'
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -52,7 +52,7 @@ class ActionModule(ActionNetworkModule):
pc.become_method = 'enable'
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -62,7 +62,7 @@ class ActionModule(ActionNetworkModule):
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -63,7 +63,7 @@ class ActionModule(ActionNetworkModule):
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -63,7 +63,7 @@ class ActionModule(ActionNetworkModule):
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -52,7 +52,7 @@ class ActionModule(ActionNetworkModule):
pc.become_method = 'enable'
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -68,7 +68,7 @@ class ActionModule(ActionNetworkModule):
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
command_timeout = int(provider['timeout']) if provider['timeout'] else connection.get_option('persistent_command_timeout')
connection.set_options(direct={'persistent_command_timeout': command_timeout})

@ -59,7 +59,7 @@ class ActionModule(ActionNetworkModule):
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
command_timeout = int(provider['timeout']) if provider['timeout'] else connection.get_option('persistent_command_timeout')
connection.set_options(direct={'persistent_command_timeout': command_timeout})

@ -58,7 +58,7 @@ class ActionModule(ActionNetworkModule):
pc.password = provider['password'] or self._play_context.password
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
command_timeout = int(provider['timeout']) if provider['timeout'] else connection.get_option('persistent_command_timeout')
connection.set_options(direct={'persistent_command_timeout': command_timeout})

@ -58,7 +58,7 @@ class ActionModule(ActionNetworkModule):
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
command_timeout = int(provider['timeout']) if provider['timeout'] else connection.get_option('persistent_command_timeout')
connection.set_options(direct={'persistent_command_timeout': command_timeout})

@ -63,7 +63,7 @@ class ActionModule(ActionNetworkModule):
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
command_timeout = int(provider['timeout']) if provider['timeout'] else connection.get_option('persistent_command_timeout')
connection.set_options(direct={'persistent_command_timeout': command_timeout})

@ -133,7 +133,7 @@ class ActionModule(ActionBase):
display.vvv('using connection plugin %s (was local)' % play_context.connection, play_context.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent',
play_context, sys.stdin)
play_context, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': play_context.timeout})

@ -54,7 +54,7 @@ class ActionModule(ActionNetworkModule):
pc.private_key_file = args.get('ssh_keyfile') or self._play_context.private_key_file
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
timeout = args.get('timeout')
command_timeout = int(timeout) if timeout else connection.get_option('persistent_command_timeout')

@ -101,7 +101,7 @@ class ActionModule(ActionNetworkModule):
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
command_timeout = int(provider['timeout']) if provider['timeout'] else connection.get_option('persistent_command_timeout')
connection.set_options(direct={'persistent_command_timeout': command_timeout})

@ -56,7 +56,7 @@ class ActionModule(ActionNetworkModule):
command_timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
connection.set_options(direct={'persistent_command_timeout': command_timeout})
socket_path = connection.run()

@ -55,7 +55,7 @@ class ActionModule(ActionNetworkModule):
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
display.vvv('using connection plugin %s (was local)' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin, task_uuid=self._task._uuid)
command_timeout = int(provider['timeout']) if provider['timeout'] else connection.get_option('persistent_command_timeout')
connection.set_options(direct={'persistent_command_timeout': command_timeout})

@ -319,6 +319,7 @@ class Connection(NetworkConnectionBase):
self._terminal = None
self.cliconf = None
self._paramiko_conn = None
self._task_uuid = to_text(kwargs.get('task_uuid', ''))
if self._play_context.verbosity > 3:
logging.getLogger('paramiko').setLevel(logging.DEBUG)
@ -408,6 +409,12 @@ class Connection(NetworkConnectionBase):
if hasattr(self, 'disable_response_logging'):
self.disable_response_logging()
def update_cli_prompt_context(self, task_uuid):
# set cli prompt context at the start of new task run only
if self._task_uuid != task_uuid:
self.set_cli_prompt_context()
self._task_uuid = task_uuid
def _connect(self):
'''
Connects to the remote device and starts the terminal

@ -31,6 +31,7 @@ options:
"""
from ansible.executor.task_executor import start_connection
from ansible.plugins.connection import ConnectionBase
from ansible.module_utils._text import to_text
from ansible.module_utils.connection import Connection as SocketConnection
from ansible.utils.display import Display
@ -43,6 +44,10 @@ class Connection(ConnectionBase):
transport = 'persistent'
has_pipelining = False
def __init__(self, play_context, new_stdin, *args, **kwargs):
super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs)
self._task_uuid = to_text(kwargs.get('task_uuid', ''))
def _connect(self):
self._connected = True
return self
@ -71,7 +76,7 @@ class Connection(ConnectionBase):
"""
display.vvvv('starting connection from persistent connection plugin', host=self._play_context.remote_addr)
variables = {'ansible_command_timeout': self.get_option('persistent_command_timeout')}
socket_path = start_connection(self._play_context, variables)
socket_path = start_connection(self._play_context, variables, self._task_uuid)
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(self, '_socket_path', socket_path)
return socket_path

@ -29,6 +29,36 @@
- assert:
that:
- "result.changed == false"
- cli_command:
command: "{{item}}"
prompt:
- "New password"
- "Retype new password"
answer:
- "Test1234"
- "Test1234"
check_all: True
loop:
- "configure"
- "rollback"
- "set system login user ansible_test class operator authentication plain-text-password"
- "commit"
register: result
ignore_errors: True
- assert:
that:
- "'failed' not in result"
- junos_netconf:
register: result
ignore_errors: True
- assert:
that:
- "result.failed == false"
when: ansible_connection == 'network_cli'
- debug: msg="END cli/cli_command.yaml on connection={{ ansible_connection }}"

Loading…
Cancel
Save