diff --git a/lib/ansible/parsing/dataloader.py b/lib/ansible/parsing/dataloader.py index 6366886cefc..dd6191de7ed 100644 --- a/lib/ansible/parsing/dataloader.py +++ b/lib/ansible/parsing/dataloader.py @@ -12,6 +12,7 @@ import os.path import re import tempfile +from ansible import constants as C from ansible.errors import AnsibleFileNotFound, AnsibleParserError from ansible.module_utils.basic import is_executable from ansible.module_utils.six import binary_type, text_type @@ -393,3 +394,53 @@ class DataLoader: self.cleanup_tmp_file(f) except Exception as e: display.warning("Unable to cleanup temp files: %s" % to_native(e)) + + def find_vars_files(self, path, name, extensions=None, allow_dir=True): + """ + Find vars files in a given path with specified name. This will find + files in a dir named / or a file called ending in known + extensions. + """ + + b_path = to_bytes(os.path.join(path, name)) + found = [] + + if extensions is None: + # Look for file with no extension first to find dir before file + extensions = [''] + C.YAML_FILENAME_EXTENSIONS + # add valid extensions to name + for ext in extensions: + + if '.' in ext: + full_path = b_path + to_bytes(ext) + elif ext: + full_path = b'.'.join([b_path, to_bytes(ext)]) + else: + full_path = b_path + + if self.path_exists(full_path): + if self.is_directory(full_path): + if allow_dir: + found.extend(self._get_dir_vars_files(to_text(full_path), extensions)) + else: + next + else: + found.append(full_path) + break + return found + + def _get_dir_vars_files(self, path, extensions): + found = [] + for spath in sorted(self.list_directory(path)): + if not spath.startswith(u'.') and not spath.endswith(u'~'): # skip hidden and backups + + ext = os.path.splitext(spath)[-1] + full_spath = os.path.join(path, spath) + + if self.is_directory(full_spath) and not ext: # recursive search if dir + found.extend(self._get_dir_vars_files(full_spath, extensions)) + elif self.is_file(full_spath) and (not ext or to_text(ext) in extensions): + # only consider files with valid extensions or no extension + found.append(full_spath) + + return found diff --git a/lib/ansible/playbook/role/__init__.py b/lib/ansible/playbook/role/__init__.py index 90d6abb14fe..73bd96e7340 100644 --- a/lib/ansible/playbook/role/__init__.py +++ b/lib/ansible/playbook/role/__init__.py @@ -223,58 +223,46 @@ class Role(Base, Become, Conditional, Taggable): obj=handler_data, orig_exc=e) # vars and default vars are regular dictionaries - self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars')) + self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True) if self._role_vars is None: self._role_vars = dict() elif not isinstance(self._role_vars, dict): raise AnsibleParserError("The vars/main.yml file for role '%s' must contain a dictionary of variables" % self._role_name) - self._default_vars = self._load_role_yaml('defaults', main=self._from_files.get('defaults')) + self._default_vars = self._load_role_yaml('defaults', main=self._from_files.get('defaults'), allow_dir=True) if self._default_vars is None: self._default_vars = dict() elif not isinstance(self._default_vars, dict): raise AnsibleParserError("The defaults/main.yml file for role '%s' must contain a dictionary of variables" % self._role_name) - def _load_role_yaml(self, subdir, main=None): + def _load_role_yaml(self, subdir, main=None, allow_dir=False): file_path = os.path.join(self._role_path, subdir) if self._loader.path_exists(file_path) and self._loader.is_directory(file_path): - main_file = self._resolve_main(file_path, main) - if self._loader.path_exists(main_file): - return self._loader.load_from_file(main_file) + # Valid extensions and ordering for roles is hard-coded to maintain + # role portability + extensions = ['.yml', '.yaml', '.json'] + # If no
is specified by the user, look for files with + # extensions before bare name. Otherwise, look for bare name first. + if main is None: + _main = 'main' + extensions.append('') + else: + _main = main + extensions.insert(0, '') + found_files = self._loader.find_vars_files(file_path, _main, extensions, allow_dir) + if found_files: + data = {} + for found in found_files: + new_data = self._loader.load_from_file(found) + if new_data and allow_dir: + data = combine_vars(data, new_data) + else: + data = new_data + return data elif main is not None: raise AnsibleParserError("Could not find specified file in role: %s/%s" % (subdir, main)) return None - def _resolve_main(self, basepath, main=None): - ''' flexibly handle variations in main filenames ''' - - post = False - # allow override if set, otherwise use default - if main is None: - main = 'main' - post = True - - bare_main = os.path.join(basepath, main) - - possible_mains = ( - os.path.join(basepath, '%s.yml' % main), - os.path.join(basepath, '%s.yaml' % main), - os.path.join(basepath, '%s.json' % main), - ) - - if post: - possible_mains = possible_mains + (bare_main,) - else: - possible_mains = (bare_main,) + possible_mains - - if sum([self._loader.is_file(x) for x in possible_mains]) > 1: - raise AnsibleError("found multiple main files at %s, only one allowed" % (basepath)) - else: - for m in possible_mains: - if self._loader.is_file(m): - return m # exactly one main file - return possible_mains[0] # zero mains (we still need to return something) - def _load_dependencies(self): ''' Recursively loads role dependencies from the metadata list of diff --git a/lib/ansible/plugins/vars/host_group_vars.py b/lib/ansible/plugins/vars/host_group_vars.py index d27b337ff03..361752b6e90 100644 --- a/lib/ansible/plugins/vars/host_group_vars.py +++ b/lib/ansible/plugins/vars/host_group_vars.py @@ -89,7 +89,7 @@ class VarsModule(BaseVarsPlugin): if os.path.exists(b_opath): if os.path.isdir(b_opath): self._display.debug("\tprocessing dir %s" % opath) - found_files = self._find_vars_files(opath, entity.name) + found_files = loader.find_vars_files(opath, entity.name) FOUND[key] = found_files else: self._display.warning("Found %s that is not a directory, skipping: %s" % (subdir, opath)) @@ -102,48 +102,3 @@ class VarsModule(BaseVarsPlugin): except Exception as e: raise AnsibleParserError(to_native(e)) return data - - def _find_vars_files(self, path, name): - """ Find {group,host}_vars files """ - - b_path = to_bytes(os.path.join(path, name)) - found = [] - - # first look for w/o extensions - if os.path.exists(b_path): - if os.path.isdir(b_path): - found.extend(self._get_dir_files(to_text(b_path))) - else: - found.append(b_path) - else: - # add valid extensions to name - for ext in C.YAML_FILENAME_EXTENSIONS: - - if '.' in ext: - full_path = b_path + to_bytes(ext) - elif ext: - full_path = b'.'.join([b_path, to_bytes(ext)]) - else: - full_path = b_path - - if os.path.exists(full_path) and os.path.isfile(full_path): - found.append(full_path) - break - return found - - def _get_dir_files(self, path): - - found = [] - for spath in sorted(os.listdir(path)): - if not spath.startswith(u'.') and not spath.endswith(u'~'): # skip hidden and backups - - ext = os.path.splitext(spath)[-1] - full_spath = os.path.join(path, spath) - - if os.path.isdir(full_spath) and not ext: # recursive search if dir - found.extend(self._get_dir_files(full_spath)) - elif os.path.isfile(full_spath) and (not ext or to_text(ext) in C.YAML_FILENAME_EXTENSIONS): - # only consider files with valid extensions or no extension - found.append(full_spath) - - return found diff --git a/test/units/mock/loader.py b/test/units/mock/loader.py index 73817a46506..793a891536d 100644 --- a/test/units/mock/loader.py +++ b/test/units/mock/loader.py @@ -23,7 +23,7 @@ import os from ansible.errors import AnsibleParserError from ansible.parsing.dataloader import DataLoader -from ansible.module_utils._text import to_bytes +from ansible.module_utils._text import to_bytes, to_text class DictDataLoader(DataLoader): @@ -39,6 +39,7 @@ class DictDataLoader(DataLoader): self._vault_secrets = None def load_from_file(self, path, unsafe=False): + path = to_text(path) if path in self._file_mapping: return self.load(self._file_mapping[path], path) return None @@ -46,22 +47,32 @@ class DictDataLoader(DataLoader): # TODO: the real _get_file_contents returns a bytestring, so we actually convert the # unicode/text it's created with to utf-8 def _get_file_contents(self, path): + path = to_text(path) if path in self._file_mapping: return (to_bytes(self._file_mapping[path]), False) else: raise AnsibleParserError("file not found: %s" % path) def path_exists(self, path): + path = to_text(path) return path in self._file_mapping or path in self._known_directories def is_file(self, path): + path = to_text(path) return path in self._file_mapping def is_directory(self, path): + path = to_text(path) return path in self._known_directories def list_directory(self, path): - return [x for x in self._known_directories] + ret = [] + path = to_text(path) + for x in (list(self._file_mapping.keys()) + self._known_directories): + if x.startswith(path): + if os.path.dirname(x) == path: + ret.append(os.path.basename(x)) + return ret def is_executable(self, path): # FIXME: figure out a way to make paths return true for this diff --git a/test/units/playbook/role/test_role.py b/test/units/playbook/role/test_role.py index 5cce8f7aa1c..b3fb063cb09 100644 --- a/test/units/playbook/role/test_role.py +++ b/test/units/playbook/role/test_role.py @@ -187,6 +187,68 @@ class TestRole(unittest.TestCase): self.assertEqual(r._default_vars, dict(foo='bar')) self.assertEqual(r._role_vars, dict(foo='bam')) + @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop) + def test_load_role_with_vars_dirs(self): + + fake_loader = DictDataLoader({ + "/etc/ansible/roles/foo_vars/defaults/main/foo.yml": """ + foo: bar + """, + "/etc/ansible/roles/foo_vars/vars/main/bar.yml": """ + foo: bam + """, + }) + + mock_play = MagicMock() + mock_play.ROLE_CACHE = {} + + i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + r = Role.load(i, play=mock_play) + + self.assertEqual(r._default_vars, dict(foo='bar')) + self.assertEqual(r._role_vars, dict(foo='bam')) + + @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop) + def test_load_role_with_vars_nested_dirs(self): + + fake_loader = DictDataLoader({ + "/etc/ansible/roles/foo_vars/defaults/main/foo/bar.yml": """ + foo: bar + """, + "/etc/ansible/roles/foo_vars/vars/main/bar/foo.yml": """ + foo: bam + """, + }) + + mock_play = MagicMock() + mock_play.ROLE_CACHE = {} + + i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + r = Role.load(i, play=mock_play) + + self.assertEqual(r._default_vars, dict(foo='bar')) + self.assertEqual(r._role_vars, dict(foo='bam')) + + @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop) + def test_load_role_with_vars_dir_vs_file(self): + + fake_loader = DictDataLoader({ + "/etc/ansible/roles/foo_vars/vars/main/foo.yml": """ + foo: bar + """, + "/etc/ansible/roles/foo_vars/vars/main.yml": """ + foo: bam + """, + }) + + mock_play = MagicMock() + mock_play.ROLE_CACHE = {} + + i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + r = Role.load(i, play=mock_play) + + self.assertEqual(r._role_vars, dict(foo='bam')) + @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop) def test_load_role_with_metadata(self):