Allow the use of _paramiko_conn even if the connection hasn't been started. (#61570)

* Allow the use of _paramiko_conn even if the connection hasn't been started.

I'm not sure what the benefit is of Noneing paramiko_conn on close, but will keep for now

* Fix test

* Try to fix up net_put & net_get

* Add changelog
pull/62068/head
Nathaniel Case 6 years ago committed by GitHub
parent 6e8d430872
commit 50e09be14f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
bugfixes:
- "fixed issues when using net_get & net_put before the persistent connection has been started"

@ -25,7 +25,7 @@ import hashlib
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.module_utils._text import to_text, to_bytes from ansible.module_utils._text import to_text, to_bytes
from ansible.module_utils.connection import Connection from ansible.module_utils.connection import Connection, ConnectionError
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.module_utils.six.moves.urllib.parse import urlsplit from ansible.module_utils.six.moves.urllib.parse import urlsplit
from ansible.utils.display import Display from ansible.utils.display import Display
@ -66,33 +66,31 @@ class ActionModule(ActionBase):
if proto is None: if proto is None:
proto = 'scp' proto = 'scp'
sock_timeout = play_context.timeout
if socket_path is None: if socket_path is None:
socket_path = self._connection.socket_path socket_path = self._connection.socket_path
conn = Connection(socket_path) conn = Connection(socket_path)
sock_timeout = conn.get_option('persistent_command_timeout')
try: try:
changed = self._handle_existing_file(conn, src, dest, proto, sock_timeout) changed = self._handle_existing_file(conn, src, dest, proto, sock_timeout)
if changed is False: if changed is False:
result['changed'] = False result['changed'] = changed
result['destination'] = dest result['destination'] = dest
return result return result
except Exception as exc: except Exception as exc:
result['msg'] = ('Warning: exception %s idempotency check failed. Check ' result['msg'] = ('Warning: %s idempotency check failed. Check dest' % exc)
'dest' % exc)
try: try:
out = conn.get_file( conn.get_file(
source=src, destination=dest, source=src, destination=dest,
proto=proto, timeout=sock_timeout proto=proto, timeout=sock_timeout
) )
except Exception as exc: except Exception as exc:
result['failed'] = True result['failed'] = True
result['msg'] = ('Exception received : %s' % exc) result['msg'] = 'Exception received: %s' % exc
result['changed'] = True result['changed'] = changed
result['destination'] = dest result['destination'] = dest
return result return result
@ -117,27 +115,37 @@ class ActionModule(ActionBase):
return filename return filename
def _handle_existing_file(self, conn, source, dest, proto, timeout): def _handle_existing_file(self, conn, source, dest, proto, timeout):
"""
Determines whether the source and destination file match.
:return: False if source and dest both exist and have matching sha1 sums, True otherwise.
"""
if not os.path.exists(dest): if not os.path.exists(dest):
return True return True
cwd = self._loader.get_basedir() cwd = self._loader.get_basedir()
filename = str(uuid.uuid4()) filename = str(uuid.uuid4())
tmp_dest_file = os.path.join(cwd, filename) tmp_dest_file = os.path.join(cwd, filename)
try: try:
out = conn.get_file( conn.get_file(
source=source, destination=tmp_dest_file, source=source, destination=tmp_dest_file,
proto=proto, timeout=timeout proto=proto, timeout=timeout
) )
except Exception as exc: except ConnectionError as exc:
error = to_text(exc)
if error.endswith("No such file or directory"):
if os.path.exists(tmp_dest_file):
os.remove(tmp_dest_file) os.remove(tmp_dest_file)
raise Exception(exc) return True
try: try:
with open(tmp_dest_file, 'r') as f: with open(tmp_dest_file, 'r') as f:
new_content = f.read() new_content = f.read()
with open(dest, 'r') as f: with open(dest, 'r') as f:
old_content = f.read() old_content = f.read()
except (IOError, OSError) as ioexc: except (IOError, OSError):
raise IOError(ioexc) os.remove(tmp_dest_file)
raise
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
old_content_b = to_bytes(old_content, errors='surrogate_or_strict') old_content_b = to_bytes(old_content, errors='surrogate_or_strict')
@ -151,7 +159,6 @@ class ActionModule(ActionBase):
os.remove(tmp_dest_file) os.remove(tmp_dest_file)
if checksum_old == checksum_new: if checksum_old == checksum_new:
return False return False
else:
return True return True
def _get_working_path(self): def _get_working_path(self):

@ -19,15 +19,12 @@ __metaclass__ = type
import copy import copy
import os import os
import time
import uuid import uuid
import hashlib import hashlib
import sys
import re
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.module_utils._text import to_text, to_bytes from ansible.module_utils._text import to_text, to_bytes
from ansible.module_utils.connection import Connection from ansible.module_utils.connection import Connection, ConnectionError
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.module_utils.six.moves.urllib.parse import urlsplit from ansible.module_utils.six.moves.urllib.parse import urlsplit
from ansible.utils.display import Display from ansible.utils.display import Display
@ -38,7 +35,6 @@ display = Display()
class ActionModule(ActionBase): class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None):
changed = True
socket_path = None socket_path = None
play_context = copy.deepcopy(self._play_context) play_context = copy.deepcopy(self._play_context)
play_context.network_os = self._get_network_os(task_vars) play_context.network_os = self._get_network_os(task_vars)
@ -52,7 +48,7 @@ class ActionModule(ActionBase):
return result return result
try: try:
src = self._task.args.get('src') src = self._task.args['src']
except KeyError as exc: except KeyError as exc:
return {'failed': True, 'msg': 'missing required argument: %s' % exc} return {'failed': True, 'msg': 'missing required argument: %s' % exc}
@ -106,15 +102,14 @@ class ActionModule(ActionBase):
try: try:
changed = self._handle_existing_file(conn, output_file, dest, proto, sock_timeout) changed = self._handle_existing_file(conn, output_file, dest, proto, sock_timeout)
if changed is False: if changed is False:
result['changed'] = False result['changed'] = changed
result['destination'] = dest result['destination'] = dest
return result return result
except Exception as exc: except Exception as exc:
result['msg'] = ('Warning: Exc %s idempotency check failed. Check' result['msg'] = ('Warning: %s idempotency check failed. Check dest' % exc)
'dest' % exc)
try: try:
out = conn.copy_file( conn.copy_file(
source=output_file, destination=dest, source=output_file, destination=dest,
proto=proto, timeout=sock_timeout proto=proto, timeout=sock_timeout
) )
@ -126,7 +121,7 @@ class ActionModule(ActionBase):
result['msg'] = 'Warning: iosxr scp server pre close issue. Please check dest' result['msg'] = 'Warning: iosxr scp server pre close issue. Please check dest'
else: else:
result['failed'] = True result['failed'] = True
result['msg'] = ('Exception received : %s' % exc) result['msg'] = 'Exception received: %s' % exc
if mode == 'text': if mode == 'text':
# Cleanup tmp file expanded wih ansible vars # Cleanup tmp file expanded wih ansible vars
@ -137,35 +132,34 @@ class ActionModule(ActionBase):
return result return result
def _handle_existing_file(self, conn, source, dest, proto, timeout): def _handle_existing_file(self, conn, source, dest, proto, timeout):
"""
Determines whether the source and destination file match.
:return: False if source and dest both exist and have matching sha1 sums, True otherwise.
"""
cwd = self._loader.get_basedir() cwd = self._loader.get_basedir()
filename = str(uuid.uuid4()) filename = str(uuid.uuid4())
source_file = os.path.join(cwd, filename) tmp_source_file = os.path.join(cwd, filename)
try: try:
out = conn.get_file( conn.get_file(
source=dest, destination=source_file, source=dest, destination=tmp_source_file,
proto=proto, timeout=timeout proto=proto, timeout=timeout
) )
except Exception as exc: except ConnectionError as exc:
pattern = to_text(exc) error = to_text(exc)
not_found_exc = "No such file or directory" if error.endswith("No such file or directory"):
if re.search(not_found_exc, pattern, re.I): if os.path.exists(tmp_source_file):
if os.path.exists(source_file): os.remove(tmp_source_file)
os.remove(source_file)
return True return True
else:
try:
os.remove(source_file)
except OSError as osex:
raise Exception(osex)
try: try:
with open(source, 'r') as f: with open(source, 'r') as f:
new_content = f.read() new_content = f.read()
with open(source_file, 'r') as f: with open(tmp_source_file, 'r') as f:
old_content = f.read() old_content = f.read()
except (IOError, OSError) as ioexc: except (IOError, OSError):
os.remove(source_file) os.remove(tmp_source_file)
raise IOError(ioexc) raise
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
old_content_b = to_bytes(old_content, errors='surrogate_or_strict') old_content_b = to_bytes(old_content, errors='surrogate_or_strict')
@ -176,10 +170,9 @@ class ActionModule(ActionBase):
new_content_b = to_bytes(new_content, errors='surrogate_or_strict') new_content_b = to_bytes(new_content, errors='surrogate_or_strict')
sha1.update(new_content_b) sha1.update(new_content_b)
checksum_new = sha1.digest() checksum_new = sha1.digest()
os.remove(source_file) os.remove(tmp_source_file)
if checksum_old == checksum_new: if checksum_old == checksum_new:
return False return False
else:
return True return True
def _get_binary_src_file(self, src): def _get_binary_src_file(self, src):

@ -365,8 +365,12 @@ class CliconfBase(AnsiblePlugin):
if proto == 'scp': if proto == 'scp':
if not HAS_SCP: if not HAS_SCP:
raise AnsibleError("Required library scp is not installed. Please install it using `pip install scp`") raise AnsibleError("Required library scp is not installed. Please install it using `pip install scp`")
try:
with SCPClient(ssh.get_transport(), socket_timeout=timeout) as scp: with SCPClient(ssh.get_transport(), socket_timeout=timeout) as scp:
scp.get(source, destination) scp.get(source, destination)
except EOFError:
# This appears to be benign.
pass
elif proto == 'sftp': elif proto == 'sftp':
with ssh.open_sftp() as sftp: with ssh.open_sftp() as sftp:
sftp.get(source, destination) sftp.get(source, destination)

@ -318,7 +318,7 @@ class Connection(NetworkConnectionBase):
self._terminal = None self._terminal = None
self.cliconf = None self.cliconf = None
self.paramiko_conn = None self._paramiko_conn = None
if self._play_context.verbosity > 3: if self._play_context.verbosity > 3:
logging.getLogger('paramiko').setLevel(logging.DEBUG) logging.getLogger('paramiko').setLevel(logging.DEBUG)
@ -341,6 +341,13 @@ class Connection(NetworkConnectionBase):
) )
self.queue_message('log', 'network_os is set to %s' % self._network_os) self.queue_message('log', 'network_os is set to %s' % self._network_os)
@property
def paramiko_conn(self):
if self._paramiko_conn is None:
self._paramiko_conn = connection_loader.get('paramiko', self._play_context, '/dev/null')
self._paramiko_conn.set_options(direct={'look_for_keys': not bool(self._play_context.password and not self._play_context.private_key_file)})
return self._paramiko_conn
def _get_log_channel(self): def _get_log_channel(self):
name = "p=%s u=%s | " % (os.getpid(), getpass.getuser()) name = "p=%s u=%s | " % (os.getpid(), getpass.getuser())
name += "paramiko [%s]" % self._play_context.remote_addr name += "paramiko [%s]" % self._play_context.remote_addr
@ -405,9 +412,7 @@ class Connection(NetworkConnectionBase):
Connects to the remote device and starts the terminal Connects to the remote device and starts the terminal
''' '''
if not self.connected: if not self.connected:
self.paramiko_conn = connection_loader.get('paramiko', self._play_context, '/dev/null')
self.paramiko_conn._set_log_channel(self._get_log_channel()) self.paramiko_conn._set_log_channel(self._get_log_channel())
self.paramiko_conn.set_options(direct={'look_for_keys': not bool(self._play_context.password and not self._play_context.private_key_file)})
self.paramiko_conn.force_persistence = self.force_persistence self.paramiko_conn.force_persistence = self.force_persistence
command_timeout = self.get_option('persistent_command_timeout') command_timeout = self.get_option('persistent_command_timeout')
@ -474,7 +479,7 @@ class Connection(NetworkConnectionBase):
self.queue_message('debug', "cli session is now closed") self.queue_message('debug', "cli session is now closed")
self.paramiko_conn.close() self.paramiko_conn.close()
self.paramiko_conn = None self._paramiko_conn = None
self.queue_message('debug', "ssh connection has been closed successfully") self.queue_message('debug', "ssh connection has been closed successfully")
super(Connection, self).close() super(Connection, self).close()

@ -77,13 +77,13 @@ class TestConnectionClass(unittest.TestCase):
terminal = MagicMock(supports_multiplexing=False) terminal = MagicMock(supports_multiplexing=False)
conn._terminal = terminal conn._terminal = terminal
conn._ssh_shell = MagicMock() conn._ssh_shell = MagicMock()
conn.paramiko_conn = MagicMock() conn._paramiko_conn = MagicMock()
conn._connected = True conn._connected = True
conn.close() conn.close()
self.assertTrue(terminal.on_close_shell.called) self.assertTrue(terminal.on_close_shell.called)
self.assertIsNone(conn._ssh_shell) self.assertIsNone(conn._ssh_shell)
self.assertIsNone(conn.paramiko_conn) self.assertIsNone(conn._paramiko_conn)
@patch("ansible.plugins.connection.paramiko_ssh.Connection._connect") @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
def test_network_cli_exec_command(self, mocked_super): def test_network_cli_exec_command(self, mocked_super):

Loading…
Cancel
Save