Fix for persistent connection plugin on Python3 (#24431)

Fix for persistent connection plugin on Python3.  Note that fixes are also needed to each terminal plugin.  This PR only fixes the ios terminal (as proof that this approach is workable.)  Future PRs can address the other terminal types.

* On Python3, pickle needs to work with byte strings, not text strings.
* Set the pickle protocol version to 0 because we're using a pty to feed data to the connection plugin.  A pty can't have control characters.  So we have to send ascii only.  That means
only using protocol=0 for pickling the data.
* ansible-connection isn't being used with py3 in the bug but it needs
several changes to work with python3.
* In python3, closing the pty too early causes no data to be sent.  So
leave stdin open until after we finish with the ansible-connection
process.
* Fix typo using traceback.format_exc()
* Cleanup unnecessary StringIO, BytesIO, and to_bytes calls
* Modify the network_cli and terminal plugins for py3 compat.  Lots of mixing of text and byte strings that needs to be straightened out to be compatible with python3
* Documentation for the bytes<=>text strategy for terminal plugins
* Update unittests for more bytes-oriented internals

Fixes #24355
pull/24461/head
Toshio Kuratomi 7 years ago committed by GitHub
parent e539726543
commit d834412ead

@ -45,7 +45,8 @@ from io import BytesIO
from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six.moves import cPickle, StringIO
from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle
from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
@ -73,11 +74,11 @@ def do_fork():
sys.exit(0)
if C.DEFAULT_LOG_PATH != '':
out_file = file(C.DEFAULT_LOG_PATH, 'a+')
err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0)
out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
else:
out_file = file('/dev/null', 'a+')
err_file = file('/dev/null', 'a+', 0)
out_file = open('/dev/null', 'ab+')
err_file = open('/dev/null', 'ab+', 0)
os.dup2(out_file.fileno(), sys.stdout.fileno())
os.dup2(err_file.fileno(), sys.stderr.fileno())
@ -90,7 +91,7 @@ def do_fork():
sys.exit(1)
def send_data(s, data):
packed_len = struct.pack('!Q',len(data))
packed_len = struct.pack('!Q', len(data))
return s.sendall(packed_len + data)
def recv_data(s):
@ -101,7 +102,7 @@ def recv_data(s):
if not d:
return None
data += d
data_len = struct.unpack('!Q',data[:header_len])[0]
data_len = struct.unpack('!Q', data[:header_len])[0]
data = data[header_len:]
while len(data) < data_len:
d = s.recv(data_len - len(data))
@ -211,11 +212,9 @@ class Server():
pass
elif data.startswith(b'CONTEXT: '):
display.display("socket operation is CONTEXT", log_only=True)
pc_data = data.split(b'CONTEXT: ')[1]
pc_data = data.split(b'CONTEXT: ', 1)[1]
src = StringIO(pc_data)
pc_data = cPickle.load(src)
src.close()
pc_data = cPickle.loads(pc_data)
pc = PlayContext()
pc.deserialize(pc_data)
@ -234,12 +233,12 @@ class Server():
display.display("socket operation completed with rc %s" % rc, log_only=True)
send_data(s, to_bytes(str(rc)))
send_data(s, to_bytes(rc))
send_data(s, to_bytes(stdout))
send_data(s, to_bytes(stderr))
s.close()
except Exception as e:
display.display(traceback.format_exec(), log_only=True)
display.display(traceback.format_exc(), log_only=True)
finally:
# when done, close the connection properly and cleanup
# the socket file so it can be recreated
@ -254,21 +253,25 @@ class Server():
os.remove(self.path)
def main():
# Need stdin as a byte stream
if PY3:
stdin = sys.stdin.buffer
else:
stdin = sys.stdin
try:
# read the play context data via stdin, which means depickling it
# FIXME: as noted above, we will probably need to deserialize the
# connection loader here as well at some point, otherwise this
# won't find role- or playbook-based connection plugins
cur_line = sys.stdin.readline()
init_data = ''
while cur_line.strip() != '#END_INIT#':
if cur_line == '':
raise Exception("EOL found before init data was complete")
cur_line = stdin.readline()
init_data = b''
while cur_line.strip() != b'#END_INIT#':
if cur_line == b'':
raise Exception("EOF found before init data was complete")
init_data += cur_line
cur_line = sys.stdin.readline()
src = BytesIO(to_bytes(init_data))
pc_data = cPickle.load(src)
cur_line = stdin.readline()
pc_data = cPickle.loads(init_data)
pc = PlayContext()
pc.deserialize(pc_data)
@ -319,10 +322,10 @@ def main():
# the connection will timeout here. Need to make this more resilient.
rc = 0
while rc == 0:
data = sys.stdin.readline()
if data == '':
data = stdin.readline()
if data == b'':
break
if data.strip() == '':
if data.strip() == b'':
continue
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
attempts = 1
@ -342,11 +345,10 @@ def main():
# send the play_context back into the connection so the connection
# can handle any privilege escalation activities
pc_data = 'CONTEXT: %s' % src.getvalue()
send_data(sf, to_bytes(pc_data))
src.close()
pc_data = b'CONTEXT: %s' % init_data
send_data(sf, pc_data)
send_data(sf, to_bytes(data.strip()))
send_data(sf, data.strip())
rc = int(recv_data(sf), 10)
stdout = recv_data(sf)

@ -18,17 +18,18 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import re
import socket
import json
import logging
import re
import signal
import datetime
import socket
import traceback
import logging
from collections import Sequence
from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils.six.moves import StringIO
from ansible.module_utils.six import BytesIO, binary_type, text_type
from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins import terminal_loader
from ansible.plugins.connection import ensure_connect
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection
@ -113,7 +114,7 @@ class Connection(_Connection):
self._terminal.on_authorize(passwd=auth_pass)
display.display('shell successfully opened', log_only=True)
return (0, 'ok', '')
return (0, b'ok', b'')
def close(self):
display.display('closing connection', log_only=True)
@ -131,11 +132,11 @@ class Connection(_Connection):
self._shell.close()
self._shell = None
return (0, 'ok', '')
return (0, b'ok', b'')
def receive(self, obj=None):
"""Handles receiving of output from command"""
recv = StringIO()
recv = BytesIO()
handled = False
self._matched_prompt = None
@ -162,30 +163,30 @@ class Connection(_Connection):
try:
command = obj['command']
self._history.append(command)
self._shell.sendall('%s\r' % command)
self._shell.sendall(b'%s\r' % command)
if obj.get('sendonly'):
return
return self.receive(obj)
except (socket.timeout, AttributeError) as exc:
except (socket.timeout, AttributeError):
display.display(traceback.format_exc(), log_only=True)
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())
def _strip(self, data):
"""Removes ANSI codes from device response"""
for regex in self._terminal.ansi_re:
data = regex.sub('', data)
data = regex.sub(b'', data)
return data
def _handle_prompt(self, resp, obj):
"""Matches the command prompt and responds"""
if not isinstance(obj['prompt'], list):
if isinstance(obj, (binary_type, text_type)) or not isinstance(obj['prompt'], Sequence):
obj['prompt'] = [obj['prompt']]
prompts = [re.compile(r, re.I) for r in obj['prompt']]
answer = obj['answer']
for regex in prompts:
match = regex.search(resp)
if match:
self._shell.sendall('%s\r' % answer)
self._shell.sendall(b'%s\r' % answer)
return True
def _sanitize(self, resp, obj=None):
@ -196,7 +197,7 @@ class Connection(_Connection):
if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line:
continue
cleaned.append(line)
return str("\n".join(cleaned)).strip()
return b"\n".join(cleaned).strip()
def _find_prompt(self, response):
"""Searches the buffered response for a matching command prompt"""
@ -225,9 +226,9 @@ class Connection(_Connection):
def exec_command(self, cmd):
"""Executes the cmd on in the shell and returns the output
The method accepts two forms of cmd. The first form is as a
The method accepts two forms of cmd. The first form is as a byte
string that represents the command to be executed in the shell. The
second form is as a JSON string with additional keyword.
second form is as a utf8 JSON byte string with additional keywords.
Keywords supported for cmd:
* command - the command string to execute
@ -235,28 +236,30 @@ class Connection(_Connection):
* answer - the string to respond to the prompt with
* sendonly - bool to disable waiting for response
:arg cmd: the string that represents the command to be executed
which can be a single command or a json encoded string
:arg cmd: the byte string that represents the command to be executed
which can be a single command or a json encoded string.
:returns: a tuple of (return code, stdout, stderr). The return
code is an integer and stdout and stderr are strings
code is an integer and stdout and stderr are byte strings
"""
try:
obj = json.loads(cmd)
obj = json.loads(to_text(cmd, errors='surrogate_or_strict'))
obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items())
except (ValueError, TypeError):
obj = {'command': str(cmd).strip()}
obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')}
if obj['command'] == 'close_shell()':
if obj['command'] == b'close_shell()':
return self.close_shell()
elif obj['command'] == 'open_shell()':
elif obj['command'] == b'open_shell()':
return self.open_shell()
elif obj['command'] == 'prompt()':
return (0, self._matched_prompt, '')
elif obj['command'] == b'prompt()':
return (0, self._matched_prompt, b'')
try:
if self._shell is None:
self.open_shell()
except AnsibleConnectionFailure as exc:
return (1, '', str(exc))
# FIXME: Feels like we should raise this rather than return it
return (1, b'', to_bytes(exc))
try:
if not signal.getsignal(signal.SIGALRM):
@ -264,6 +267,7 @@ class Connection(_Connection):
signal.alarm(self._play_context.timeout)
out = self.send(obj)
signal.alarm(0)
return (0, out, '')
return (0, out, b'')
except (AnsibleConnectionFailure, ValueError) as exc:
return (1, '', str(exc))
# FIXME: Feels like we should raise this rather than return it
return (1, b'', to_bytes(exc))

@ -24,7 +24,7 @@ import subprocess
import sys
from ansible.module_utils._text import to_bytes
from ansible.module_utils.six.moves import cPickle, StringIO
from ansible.module_utils.six.moves import cPickle
from ansible.plugins.connection import ConnectionBase
try:
@ -52,16 +52,20 @@ class Connection(ConnectionBase):
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)
src = StringIO()
cPickle.dump(self._play_context.serialize(), src)
stdin.write(src.getvalue())
src.close()
# Need to force a protocol that is compatible with both py2 and py3.
# That would be protocol=2 or less.
# Also need to force a protocol that excludes certain control chars as
# stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
stdin.write(src)
stdin.write(b'\n#END_INIT#\n')
stdin.write(to_bytes(action))
stdin.write(b'\n\n')
stdin.close()
(stdout, stderr) = p.communicate()
stdin.close()
return (p.returncode, stdout, stderr)

@ -30,33 +30,54 @@ from ansible.module_utils.six import with_metaclass
class TerminalBase(with_metaclass(ABCMeta, object)):
'''
A base class for implementing cli connections
.. note:: Unlike most of Ansible, nearly all strings in
:class:`TerminalBase` plugins are byte strings. This is because of
how close to the underlying platform these plugins operate. Remember
to mark literal strings as byte string (``b"string"``) and to use
:func:`~ansible.module_utils._text.to_bytes` and
:func:`~ansible.module_utils._text.to_text` to avoid unexpected
problems.
'''
# compiled regular expression as stdout
#: compiled bytes regular expressions as stdout
terminal_stdout_re = []
# compiled regular expression as stderr
#: compiled bytes regular expressions as stderr
terminal_stderr_re = []
# copiled regular expression to remove ANSI codes
#: compiled bytes regular expressions to remove ANSI codes
ansi_re = [
re.compile(r'(\x1b\[\?1h\x1b=)'),
re.compile(r'\x08.')
re.compile(br'(\x1b\[\?1h\x1b=)'),
re.compile(br'\x08.')
]
def __init__(self, connection):
self._connection = connection
def _exec_cli_command(self, cmd, check_rc=True):
"""Executes a CLI command on the device"""
"""
Executes a CLI command on the device
:arg cmd: Byte string consisting of the command to execute
:kwarg check_rc: If True, the default, raise an
:exc:`AnsibleConnectionFailure` if the return code from the
command is nonzero
:returns: A tuple of return code, stdout, and stderr from running the
command. stdout and stderr are both byte strings.
"""
rc, out, err = self._connection.exec_command(cmd)
if check_rc and rc != 0:
raise AnsibleConnectionFailure(err)
return rc, out, err
def _get_prompt(self):
""" Returns the current prompt from the device"""
for cmd in ['\n', 'prompt()']:
"""
Returns the current prompt from the device
:returns: A byte string of the prompt
"""
for cmd in (b'\n', b'prompt()'):
rc, out, err = self._exec_cli_command(cmd)
return out
@ -82,6 +103,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
def on_authorize(self, passwd=None):
"""Called when privilege escalation is requested
:kwarg passwd: String containing the password
This method is called when the privilege is requested to be elevated
in the play context by setting become to True. It is the responsibility
of the terminal plugin to actually do the privilege escalation such
@ -94,6 +117,6 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
This method is called when the privilege changed from escalated
(become=True) to non escalated (become=False). It is the responsibility
of the this method to actually perform the deauthorization procedure
of this method to actually perform the deauthorization procedure
"""
pass

@ -19,49 +19,52 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import re
import json
import re
from ansible.plugins.terminal import TerminalBase
from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils._text import to_bytes
from ansible.plugins.terminal import TerminalBase
class TerminalModule(TerminalBase):
terminal_stdout_re = [
re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
re.compile(br"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
]
terminal_stderr_re = [
re.compile(r"% ?Error"),
#re.compile(r"^% \w+", re.M),
re.compile(r"% ?Bad secret"),
re.compile(r"invalid input", re.I),
re.compile(r"(?:incomplete|ambiguous) command", re.I),
re.compile(r"connection timed out", re.I),
re.compile(r"[^\r\n]+ not found", re.I),
re.compile(r"'[^']' +returned error code: ?\d+"),
re.compile(br"% ?Error"),
#re.compile(br"^% \w+", re.M),
re.compile(br"% ?Bad secret"),
re.compile(br"invalid input", re.I),
re.compile(br"(?:incomplete|ambiguous) command", re.I),
re.compile(br"connection timed out", re.I),
re.compile(br"[^\r\n]+ not found", re.I),
re.compile(br"'[^']' +returned error code: ?\d+"),
]
def on_open_shell(self):
try:
for cmd in ['terminal length 0', 'terminal width 512']:
for cmd in (b'terminal length 0', b'terminal width 512'):
self._exec_cli_command(cmd)
except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to set terminal parameters')
def on_authorize(self, passwd=None):
if self._get_prompt().endswith('#'):
if self._get_prompt().endswith(b'#'):
return
cmd = {'command': 'enable'}
cmd = {u'command': u'enable'}
if passwd:
cmd['prompt'] = r"[\r\n]?password: $"
cmd['answer'] = passwd
# Note: python-3.5 cannot combine u"" and r"" together. Thus make
# an r string and use to_text to ensure it's text on both py2 and py3.
cmd[u'prompt'] = to_text(r"[\r\n]?password: $", errors='surrogate_or_strict')
cmd[u'answer'] = passwd
try:
self._exec_cli_command(json.dumps(cmd))
self._exec_cli_command(to_bytes(json.dumps(cmd), errors='surrogate_or_strict'))
except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to elevate privilege to enable mode')
@ -71,11 +74,9 @@ class TerminalModule(TerminalBase):
# if prompt is None most likely the terminal is hung up at a prompt
return
if '(config' in prompt:
self._exec_cli_command('end')
self._exec_cli_command('disable')
elif prompt.endswith('#'):
self._exec_cli_command('disable')
if b'(config' in prompt:
self._exec_cli_command(b'end')
self._exec_cli_command(b'disable')
elif prompt.endswith(b'#'):
self._exec_cli_command(b'disable')

@ -117,21 +117,21 @@ class TestConnectionClass(unittest.TestCase):
mock_open_shell = MagicMock()
conn.open_shell = mock_open_shell
mock_send = MagicMock(return_value='command response')
mock_send = MagicMock(return_value=b'command response')
conn.send = mock_send
# test sending a single command and converting to dict
rc, out, err = conn.exec_command('command')
self.assertEqual(out, 'command response')
self.assertEqual(out, b'command response')
self.assertTrue(mock_open_shell.called)
mock_send.assert_called_with({'command': 'command'})
mock_send.assert_called_with({'command': b'command'})
mock_open_shell.reset_mock()
# test sending a json string
rc, out, err = conn.exec_command(json.dumps({'command': 'command'}))
self.assertEqual(out, 'command response')
mock_send.assert_called_with({'command': 'command'})
self.assertEqual(out, b'command response')
mock_send.assert_called_with({'command': b'command'})
self.assertTrue(mock_open_shell.called)
mock_open_shell.reset_mock()
@ -139,9 +139,9 @@ class TestConnectionClass(unittest.TestCase):
# test _shell already open
rc, out, err = conn.exec_command('command')
self.assertEqual(out, 'command response')
self.assertEqual(out, b'command response')
self.assertFalse(mock_open_shell.called)
mock_send.assert_called_with({'command': 'command'})
mock_send.assert_called_with({'command': b'command'})
def test_network_cli_send(self):
@ -150,14 +150,14 @@ class TestConnectionClass(unittest.TestCase):
conn = network_cli.Connection(pc, new_stdin)
mock__terminal = MagicMock()
mock__terminal.terminal_stdout_re = [re.compile('device#')]
mock__terminal.terminal_stderr_re = [re.compile('^ERROR')]
mock__terminal.terminal_stdout_re = [re.compile(b'device#')]
mock__terminal.terminal_stderr_re = [re.compile(b'^ERROR')]
conn._terminal = mock__terminal
mock__shell = MagicMock()
conn._shell = mock__shell
response = """device#command
response = b"""device#command
command response
device#
@ -165,15 +165,15 @@ class TestConnectionClass(unittest.TestCase):
mock__shell.recv.return_value = response
output = conn.send({'command': 'command'})
output = conn.send({'command': b'command'})
mock__shell.sendall.assert_called_with('command\r')
self.assertEqual(output, 'command response')
mock__shell.sendall.assert_called_with(b'command\r')
self.assertEqual(output, b'command response')
mock__shell.reset_mock()
mock__shell.recv.return_value = "ERROR: error message"
mock__shell.recv.return_value = b"ERROR: error message"
with self.assertRaises(AnsibleConnectionFailure) as exc:
conn.send({'command': 'command'})
conn.send({'command': b'command'})
self.assertEqual(str(exc.exception), 'ERROR: error message')

Loading…
Cancel
Save