diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 69cbb63f47c..8de8f7027ab 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -210,7 +210,6 @@ class TaskExecutor: # get the connection and the handler for this execution self._connection = self._get_connection(variables) self._connection.set_host_overrides(host=self._host) - self._connection._connect() self._handler = self._get_action_handler(connection=self._connection, templar=templar) diff --git a/lib/ansible/plugins/connections/__init__.py b/lib/ansible/plugins/connections/__init__.py index 897bc58982b..da0775530d6 100644 --- a/lib/ansible/plugins/connections/__init__.py +++ b/lib/ansible/plugins/connections/__init__.py @@ -22,6 +22,7 @@ __metaclass__ = type from abc import ABCMeta, abstractmethod, abstractproperty +from functools import wraps from six import with_metaclass from ansible import constants as C @@ -32,7 +33,16 @@ from ansible.errors import AnsibleError # which may want to output display/logs too from ansible.utils.display import Display -__all__ = ['ConnectionBase'] +__all__ = ['ConnectionBase', 'ensure_connect'] + + +def ensure_connect(func): + @wraps(func) + def wrapped(self, *args, **kwargs): + self._connect() + return func(self, *args, **kwargs) + return wrapped + class ConnectionBase(with_metaclass(ABCMeta, object)): ''' diff --git a/lib/ansible/plugins/connections/paramiko_ssh.py b/lib/ansible/plugins/connections/paramiko_ssh.py index 0d7a82c34b5..8beaecf4928 100644 --- a/lib/ansible/plugins/connections/paramiko_ssh.py +++ b/lib/ansible/plugins/connections/paramiko_ssh.py @@ -41,7 +41,7 @@ from binascii import hexlify from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase +from ansible.plugins.connections import ConnectionBase, ensure_connect from ansible.utils.path import makedirs_safe AUTHENTICITY_MSG=""" @@ -61,6 +61,7 @@ with warnings.catch_warnings(): except ImportError: pass + class MyAddPolicy(object): """ Based on AutoAddPolicy in paramiko so we can determine when keys are added @@ -188,6 +189,7 @@ class Connection(ConnectionBase): return ssh + @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): ''' run a command on the remote host ''' @@ -248,6 +250,7 @@ class Connection(ConnectionBase): return (chan.recv_exit_status(), '', no_prompt_out + stdout, no_prompt_out + stderr) + @ensure_connect def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' @@ -272,9 +275,10 @@ class Connection(ConnectionBase): if cache_key in SFTP_CONNECTION_CACHE: return SFTP_CONNECTION_CACHE[cache_key] else: - result = SFTP_CONNECTION_CACHE[cache_key] = self.connect().ssh.open_sftp() + result = SFTP_CONNECTION_CACHE[cache_key] = self._connect().ssh.open_sftp() return result + @ensure_connect def fetch_file(self, in_path, out_path): ''' save a remote file to the specified path ''' diff --git a/lib/ansible/plugins/connections/ssh.py b/lib/ansible/plugins/connections/ssh.py index b3ada343c04..5a435093d00 100644 --- a/lib/ansible/plugins/connections/ssh.py +++ b/lib/ansible/plugins/connections/ssh.py @@ -34,7 +34,8 @@ from hashlib import sha1 from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase +from ansible.plugins.connections import ConnectionBase, ensure_connect + class Connection(ConnectionBase): ''' ssh based connections ''' @@ -269,6 +270,7 @@ class Connection(ConnectionBase): self._display.vvv("EXEC previous known host file not found for {0}".format(host)) return True + @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): ''' run a command on the remote host ''' @@ -390,6 +392,7 @@ class Connection(ConnectionBase): return (p.returncode, '', no_prompt_out + stdout, no_prompt_err + stderr) + @ensure_connect def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' self._display.vvv("PUT {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr) @@ -425,6 +428,7 @@ class Connection(ConnectionBase): if returncode != 0: raise AnsibleError("failed to transfer file to {0}:\n{1}\n{2}".format(out_path, stdout, stderr)) + @ensure_connect def fetch_file(self, in_path, out_path): ''' fetch a file from remote to local ''' self._display.vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr) diff --git a/lib/ansible/plugins/connections/winrm.py b/lib/ansible/plugins/connections/winrm.py index f16da0f6e63..ee287491897 100644 --- a/lib/ansible/plugins/connections/winrm.py +++ b/lib/ansible/plugins/connections/winrm.py @@ -42,10 +42,11 @@ except ImportError: from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound -from ansible.plugins.connections import ConnectionBase +from ansible.plugins.connections import ConnectionBase, ensure_connect from ansible.plugins import shell_loader from ansible.utils.path import makedirs_safe + class Connection(ConnectionBase): '''WinRM connections over HTTP/HTTPS.''' @@ -151,6 +152,7 @@ class Connection(ConnectionBase): self.protocol = self._winrm_connect() return self + @ensure_connect def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None): cmd = cmd.encode('utf-8') @@ -172,6 +174,7 @@ class Connection(ConnectionBase): raise AnsibleError("failed to exec cmd %s" % cmd) return (result.status_code, '', result.std_out.encode('utf-8'), result.std_err.encode('utf-8')) + @ensure_connect def put_file(self, in_path, out_path): self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr) if not os.path.exists(in_path): @@ -210,6 +213,7 @@ class Connection(ConnectionBase): traceback.print_exc() raise AnsibleError("failed to transfer file to %s" % out_path) + @ensure_connect def fetch_file(self, in_path, out_path): out_path = out_path.replace('\\', '/') self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)