Fix mixing of bytes and str in module replacer (caused traceback on python3)

pull/14688/head
Toshio Kuratomi 8 years ago
parent b0bed27211
commit c29f51804b

@ -21,7 +21,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
# from python and deps
from ansible.compat.six.moves import StringIO
from io import BytesIO
import json
import os
import shlex
@ -30,20 +30,20 @@ import shlex
from ansible import __version__
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.utils.unicode import to_bytes
from ansible.utils.unicode import to_bytes, to_unicode
REPLACER = "#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
REPLACER_ARGS = "\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
REPLACER_COMPLEX = "\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
REPLACER_WINDOWS = "# POWERSHELL_COMMON"
REPLACER_WINARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
REPLACER_JSONARGS = "<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
REPLACER_VERSION = "\"<<ANSIBLE_VERSION>>\""
REPLACER_SELINUX = "<<SELINUX_SPECIAL_FILESYSTEMS>>"
REPLACER = b"#<<INCLUDE_ANSIBLE_MODULE_COMMON>>"
REPLACER_ARGS = b"\"<<INCLUDE_ANSIBLE_MODULE_ARGS>>\""
REPLACER_COMPLEX = b"\"<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>\""
REPLACER_WINDOWS = b"# POWERSHELL_COMMON"
REPLACER_WINARGS = b"<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
REPLACER_JSONARGS = b"<<INCLUDE_ANSIBLE_MODULE_JSON_ARGS>>"
REPLACER_VERSION = b"\"<<ANSIBLE_VERSION>>\""
REPLACER_SELINUX = b"<<SELINUX_SPECIAL_FILESYSTEMS>>"
# We could end up writing out parameters with unicode characters so we need to
# specify an encoding for the python source file
ENCODING_STRING = '# -*- coding: utf-8 -*-'
ENCODING_STRING = b'# -*- coding: utf-8 -*-'
# we've moved the module_common relative to the snippets, so fix the path
_SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
@ -53,7 +53,7 @@ _SNIPPET_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils')
def _slurp(path):
if not os.path.exists(path):
raise AnsibleError("imported module support code does not exist at %s" % path)
fd = open(path)
fd = open(path, 'rb')
data = fd.read()
fd.close()
return data
@ -71,49 +71,49 @@ def _find_snippet_imports(module_data, module_path, strip_comments):
module_style = 'new'
elif REPLACER_JSONARGS in module_data:
module_style = 'new'
elif 'from ansible.module_utils.' in module_data:
elif b'from ansible.module_utils.' in module_data:
module_style = 'new'
elif 'WANT_JSON' in module_data:
elif b'WANT_JSON' in module_data:
module_style = 'non_native_want_json'
output = StringIO()
lines = module_data.split('\n')
output = BytesIO()
lines = module_data.split(b'\n')
snippet_names = []
for line in lines:
if REPLACER in line:
output.write(_slurp(os.path.join(_SNIPPET_PATH, "basic.py")))
snippet_names.append('basic')
snippet_names.append(b'basic')
if REPLACER_WINDOWS in line:
ps_data = _slurp(os.path.join(_SNIPPET_PATH, "powershell.ps1"))
output.write(ps_data)
snippet_names.append('powershell')
elif line.startswith('from ansible.module_utils.'):
tokens=line.split(".")
snippet_names.append(b'powershell')
elif line.startswith(b'from ansible.module_utils.'):
tokens=line.split(b".")
import_error = False
if len(tokens) != 3:
import_error = True
if " import *" not in line:
if b" import *" not in line:
import_error = True
if import_error:
raise AnsibleError("error importing module in %s, expecting format like 'from ansible.module_utils.<lib name> import *'" % module_path)
snippet_name = tokens[2].split()[0]
snippet_names.append(snippet_name)
output.write(_slurp(os.path.join(_SNIPPET_PATH, snippet_name + ".py")))
output.write(_slurp(os.path.join(_SNIPPET_PATH, to_unicode(snippet_name) + ".py")))
else:
if strip_comments and line.startswith("#") or line == '':
if strip_comments and line.startswith(b"#") or line == b'':
pass
output.write(line)
output.write("\n")
output.write(b"\n")
if not module_path.endswith(".ps1"):
# Unixy modules
if len(snippet_names) > 0 and not 'basic' in snippet_names:
if len(snippet_names) > 0 and not b'basic' in snippet_names:
raise AnsibleError("missing required import in %s: from ansible.module_utils.basic import *" % module_path)
else:
# Windows modules
if len(snippet_names) > 0 and not 'powershell' in snippet_names:
if len(snippet_names) > 0 and not b'powershell' in snippet_names:
raise AnsibleError("missing required import in %s: # POWERSHELL_COMMON" % module_path)
return (output.getvalue(), module_style)
@ -158,28 +158,28 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
# * Cache the modified module? If only the args are different and we do
# that as the last step we could cache all the work up to that point.
with open(module_path) as f:
with open(module_path, 'rb') as f:
# read in the module source
module_data = f.read()
(module_data, module_style) = _find_snippet_imports(module_data, module_path, strip_comments)
module_args_json = json.dumps(module_args).encode('utf-8')
python_repred_args = repr(module_args_json)
module_args_json = to_bytes(json.dumps(module_args))
python_repred_args = to_bytes(repr(module_args_json))
# these strings should be part of the 'basic' snippet which is required to be included
module_data = module_data.replace(REPLACER_VERSION, repr(__version__))
module_data = module_data.replace(REPLACER_VERSION, to_bytes(__version__, nonstring='repr'))
module_data = module_data.replace(REPLACER_COMPLEX, python_repred_args)
module_data = module_data.replace(REPLACER_WINARGS, module_args_json)
module_data = module_data.replace(REPLACER_JSONARGS, module_args_json)
module_data = module_data.replace(REPLACER_SELINUX, ','.join(C.DEFAULT_SELINUX_SPECIAL_FS))
module_data = module_data.replace(REPLACER_SELINUX, to_bytes(','.join(C.DEFAULT_SELINUX_SPECIAL_FS)))
if module_style == 'new':
facility = C.DEFAULT_SYSLOG_FACILITY
if 'ansible_syslog_facility' in task_vars:
facility = task_vars['ansible_syslog_facility']
module_data = module_data.replace('syslog.LOG_USER', "syslog.%s" % facility)
module_data = module_data.replace(b'syslog.LOG_USER', to_bytes("syslog.%s" % facility))
lines = module_data.split(b"\n", 1)
shebang = None
@ -188,12 +188,13 @@ def modify_module(module_path, module_args, task_vars=dict(), strip_comments=Fal
args = shlex.split(str(shebang[2:]))
interpreter = args[0]
interpreter_config = 'ansible_%s_interpreter' % os.path.basename(interpreter)
interpreter = to_bytes(interpreter)
if interpreter_config in task_vars:
interpreter = to_bytes(task_vars[interpreter_config], errors='strict')
lines[0] = shebang = b"#!{0} {1}".format(interpreter, b" ".join(args[1:]))
if os.path.basename(interpreter).startswith('python'):
if os.path.basename(interpreter).startswith(b'python'):
lines.insert(1, ENCODING_STRING)
else:
# No shebang, assume a binary module?

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# (c) 2015, Florian Apolloner <florian@apolloner.eu>
#
# This file is part of Ansible
@ -34,15 +35,17 @@ from ansible import constants as C
from ansible.compat.six import text_type
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock, mock_open
from ansible.errors import AnsibleError
from ansible.playbook.play_context import PlayContext
from ansible.plugins import PluginLoader
from ansible.plugins.action import ActionBase
from ansible.template import Templar
from ansible.utils.unicode import to_bytes
from units.mock.loader import DictDataLoader
python_module_replacers = """
python_module_replacers = b"""
#!/usr/bin/python
#ANSIBLE_VERSION = "<<ANSIBLE_VERSION>>"
@ -50,14 +53,95 @@ python_module_replacers = """
#MODULE_COMPLEX_ARGS = "<<INCLUDE_ANSIBLE_MODULE_COMPLEX_ARGS>>"
#SELINUX_SPECIAL_FS="<<SELINUX_SPECIAL_FILESYSTEMS>>"
test = u'Toshio \u304f\u3089\u3068\u307f'
from ansible.module_utils.basic import *
"""
powershell_module_replacers = """
powershell_module_replacers = b"""
WINDOWS_ARGS = "<<INCLUDE_ANSIBLE_MODULE_WINDOWS_ARGS>>"
# POWERSHELL_COMMON
"""
# Prior to 3.4.4, mock_open cannot handle binary read_data
if version_info >= (3,) and version_info < (3, 4, 4):
file_spec = None
def _iterate_read_data(read_data):
# Helper for mock_open:
# Retrieve lines from read_data via a generator so that separate calls to
# readline, read, and readlines are properly interleaved
sep = b'\n' if isinstance(read_data, bytes) else '\n'
data_as_list = [l + sep for l in read_data.split(sep)]
if data_as_list[-1] == sep:
# If the last line ended in a newline, the list comprehension will have an
# extra entry that's just a newline. Remove this.
data_as_list = data_as_list[:-1]
else:
# If there wasn't an extra newline by itself, then the file being
# emulated doesn't have a newline to end the last line remove the
# newline that our naive format() added
data_as_list[-1] = data_as_list[-1][:-1]
for line in data_as_list:
yield line
def mock_open(mock=None, read_data=''):
"""
A helper function to create a mock to replace the use of `open`. It works
for `open` called directly or used as a context manager.
The `mock` argument is the mock object to configure. If `None` (the
default) then a `MagicMock` will be created for you, with the API limited
to methods or attributes available on standard file handles.
`read_data` is a string for the `read` methoddline`, and `readlines` of the
file handle to return. This is an empty string by default.
"""
def _readlines_side_effect(*args, **kwargs):
if handle.readlines.return_value is not None:
return handle.readlines.return_value
return list(_data)
def _read_side_effect(*args, **kwargs):
if handle.read.return_value is not None:
return handle.read.return_value
return type(read_data)().join(_data)
def _readline_side_effect():
if handle.readline.return_value is not None:
while True:
yield handle.readline.return_value
for line in _data:
yield line
global file_spec
if file_spec is None:
import _io
file_spec = list(set(dir(_io.TextIOWrapper)).union(set(dir(_io.BytesIO))))
if mock is None:
mock = MagicMock(name='open', spec=open)
handle = MagicMock(spec=file_spec)
handle.__enter__.return_value = handle
_data = _iterate_read_data(read_data)
handle.write.return_value = None
handle.read.return_value = None
handle.readline.return_value = None
handle.readlines.return_value = None
handle.read.side_effect = _read_side_effect
handle.readline.side_effect = _readline_side_effect()
handle.readlines.side_effect = _readlines_side_effect
mock.return_value = handle
return mock
class DerivedActionBase(ActionBase):
def run(self, tmp=None, task_vars=None):
# We're not testing the plugin run() method, just the helper
@ -124,18 +208,18 @@ class TestActionBase(unittest.TestCase):
)
# test python module formatting
with patch.object(builtins, 'open', mock_open(read_data=text_type(python_module_replacers.strip()))) as m:
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(python_module_replacers.strip(), encoding='utf-8'))) as m:
mock_task.args = dict(a=1)
mock_connection.module_implementation_preferences = ('',)
(style, shebang, data) = action_base._configure_module(mock_task.action, mock_task.args)
self.assertEqual(style, "new")
self.assertEqual(shebang, "#!/usr/bin/python")
self.assertEqual(shebang, b"#!/usr/bin/python")
# test module not found
self.assertRaises(AnsibleError, action_base._configure_module, 'badmodule', mock_task.args)
# test powershell module formatting
with patch.object(builtins, 'open', mock_open(read_data=text_type(powershell_module_replacers.strip()))) as m:
with patch.object(builtins, 'open', mock_open(read_data=to_bytes(powershell_module_replacers.strip(), encoding='utf-8'))) as m:
mock_task.action = 'win_copy'
mock_task.args = dict(b=2)
mock_connection.module_implementation_preferences = ('.ps1',)

Loading…
Cancel
Save