From e5c2c03dea0998872a6b16a18d6c187685a5fc7a Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Tue, 15 Dec 2015 09:39:13 -0500 Subject: [PATCH] Enable host_key checking at the strategy level Implements a new method in the ssh connection plugin (fetch_and_store_key) which is used to prefetch the key using ssh-keyscan. --- lib/ansible/executor/task_executor.py | 17 +- lib/ansible/inventory/host.py | 11 +- lib/ansible/plugins/connection/__init__.py | 5 +- lib/ansible/plugins/connection/ssh.py | 193 +++++++++++++++++++-- lib/ansible/plugins/strategy/__init__.py | 30 +++- lib/ansible/utils/connection.py | 50 ++++++ 6 files changed, 273 insertions(+), 33 deletions(-) create mode 100644 lib/ansible/utils/connection.py diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 5d7430fad25..2623bc775b2 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -32,6 +32,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar from ansible.playbook.conditional import Conditional from ansible.playbook.task import Task from ansible.template import Templar +from ansible.utils.connection import get_smart_connection_type from ansible.utils.encrypt import key_for_hostname from ansible.utils.listify import listify_lookup_plugin_terms from ansible.utils.unicode import to_unicode @@ -564,21 +565,7 @@ class TaskExecutor: conn_type = self._play_context.connection if conn_type == 'smart': - conn_type = 'ssh' - if sys.platform.startswith('darwin') and self._play_context.password: - # due to a current bug in sshpass on OSX, which can trigger - # a kernel panic even for non-privileged users, we revert to - # paramiko on that OS when a SSH password is specified - conn_type = "paramiko" - else: - # see if SSH can support ControlPersist if not use paramiko - try: - cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - (out, err) = cmd.communicate() - if "Bad configuration option" in err or "Usage:" in err: - conn_type = "paramiko" - except OSError: - conn_type = "paramiko" + conn_type = get_smart_connection_type(self._play_context) connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin) if not connection: diff --git a/lib/ansible/inventory/host.py b/lib/ansible/inventory/host.py index 6263dcbc80d..70f9f57b5f1 100644 --- a/lib/ansible/inventory/host.py +++ b/lib/ansible/inventory/host.py @@ -57,6 +57,7 @@ class Host: name=self.name, vars=self.vars.copy(), address=self.address, + has_hostkey=self.has_hostkey, uuid=self._uuid, gathered_facts=self._gathered_facts, groups=groups, @@ -65,10 +66,11 @@ class Host: def deserialize(self, data): self.__init__() - self.name = data.get('name') - self.vars = data.get('vars', dict()) - self.address = data.get('address', '') - self._uuid = data.get('uuid', uuid.uuid4()) + self.name = data.get('name') + self.vars = data.get('vars', dict()) + self.address = data.get('address', '') + self.has_hostkey = data.get('has_hostkey', False) + self._uuid = data.get('uuid', uuid.uuid4()) groups = data.get('groups', []) for group_data in groups: @@ -89,6 +91,7 @@ class Host: self._gathered_facts = False self._uuid = uuid.uuid4() + self.has_hostkey = False def __repr__(self): return self.get_name() diff --git a/lib/ansible/plugins/connection/__init__.py b/lib/ansible/plugins/connection/__init__.py index 06616bac4ca..7fc19c8c195 100644 --- a/lib/ansible/plugins/connection/__init__.py +++ b/lib/ansible/plugins/connection/__init__.py @@ -23,11 +23,11 @@ __metaclass__ = type import fcntl import gettext import os -from abc import ABCMeta, abstractmethod, abstractproperty +from abc import ABCMeta, abstractmethod, abstractproperty from functools import wraps -from ansible.compat.six import with_metaclass +from ansible.compat.six import with_metaclass from ansible import constants as C from ansible.errors import AnsibleError from ansible.plugins import shell_loader @@ -233,3 +233,4 @@ class ConnectionBase(with_metaclass(ABCMeta, object)): f = self._play_context.connection_lockfd fcntl.lockf(f, fcntl.LOCK_UN) display.vvvv('CONNECTION: pid %d released lock on %d' % (os.getpid(), f)) + diff --git a/lib/ansible/plugins/connection/ssh.py b/lib/ansible/plugins/connection/ssh.py index a2abcf20aee..cce29824e1a 100644 --- a/lib/ansible/plugins/connection/ssh.py +++ b/lib/ansible/plugins/connection/ssh.py @@ -19,7 +19,12 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from ansible.compat.six import text_type + +import base64 import fcntl +import hmac +import operator import os import pipes import pty @@ -28,9 +33,13 @@ import shlex import subprocess import time +from hashlib import md5, sha1, sha256 + from ansible import constants as C from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound from ansible.plugins.connection import ConnectionBase +from ansible.utils.boolean import boolean +from ansible.utils.connection import get_smart_connection_type from ansible.utils.path import unfrackpath, makedirs_safe from ansible.utils.unicode import to_bytes, to_unicode @@ -41,7 +50,128 @@ except ImportError: display = Display() SSHPASS_AVAILABLE = None +HASHED_KEY_MAGIC = "|1|" + +def split_args(argstring): + """ + Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a + list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to + the argument list. The list will not contain any empty elements. + """ + return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()] + +def get_ssh_opts(play_context): + # FIXME: caching may help here + opts_dict = dict() + try: + cmd = ['ssh', '-G', play_context.remote_addr] + res = subprocess.check_output(cmd) + for line in res.split('\n'): + if ' ' in line: + (key, val) = line.split(' ', 1) + else: + key = line + val = '' + opts_dict[key.lower()] = val + + # next, we manually override any options that are being + # set via ssh_args or due to the fact that `ssh -G` doesn't + # actually use the options set via -o + for opt in ['ssh_args', 'ssh_common_args', 'ssh_extra_args']: + attr = getattr(play_context, opt, None) + if attr is not None: + args = split_args(attr) + for arg in args: + if '=' in arg: + (key, val) = arg.split('=', 1) + opts_dict[key.lower()] = val + + return opts_dict + except subprocess.CalledProcessError: + return dict() + +def host_in_known_hosts(host, ssh_opts): + # the setting from the ssh_opts may actually be multiple files, so + # we use shlex.split and simply take the first one specified + user_host_file = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0]) + + host_file_list = [] + host_file_list.append(user_host_file) + host_file_list.append("/etc/ssh/ssh_known_hosts") + host_file_list.append("/etc/ssh/ssh_known_hosts2") + + hfiles_not_found = 0 + for hf in host_file_list: + if not os.path.exists(hf): + continue + try: + host_fh = open(hf) + except (OSError, IOError) as e: + continue + else: + data = host_fh.read() + host_fh.close() + + for line in data.split("\n"): + line = line.strip() + if line is None or " " not in line: + continue + tokens = line.split() + if not tokens: + continue + if tokens[0].find(HASHED_KEY_MAGIC) == 0: + # this is a hashed known host entry + try: + (kn_salt, kn_host) = tokens[0][len(HASHED_KEY_MAGIC):].split("|",2) + hash = hmac.new(kn_salt.decode('base64'), digestmod=sha1) + hash.update(host) + if hash.digest() == kn_host.decode('base64'): + return True + except: + # invalid hashed host key, skip it + continue + else: + # standard host file entry + if host in tokens[0]: + return True + + return False + +def fetch_ssh_host_key(play_context, ssh_opts): + keyscan_cmd = ['ssh-keyscan'] + + if play_context.port: + keyscan_cmd.extend(['-p', text_type(play_context.port)]) + + if boolean(ssh_opts.get('hashknownhosts', 'no')): + keyscan_cmd.append('-H') + keyscan_cmd.append(play_context.remote_addr) + + p = subprocess.Popen(keyscan_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True) + (stdout, stderr) = p.communicate() + if stdout == '': + raise AnsibleConnectionFailure("Failed to connect to the host to fetch the host key: %s." % stderr) + else: + return stdout + +def add_host_key(host_key, ssh_opts): + # the setting from the ssh_opts may actually be multiple files, so + # we use shlex.split and simply take the first one specified + user_known_hosts = os.path.expanduser(shlex.split(ssh_opts.get('userknownhostsfile', '~/.ssh/known_hosts'))[0]) + user_ssh_dir = os.path.dirname(user_known_hosts) + + if not os.path.exists(user_ssh_dir): + raise AnsibleError("the user ssh directory does not exist: %s" % user_ssh_dir) + elif not os.path.isdir(user_ssh_dir): + raise AnsibleError("%s is not a directory" % user_ssh_dir) + + try: + display.vv("adding to known_hosts file: %s" % user_known_hosts) + with open(user_known_hosts, 'a') as f: + f.write(host_key) + except (OSError, IOError) as e: + raise AnsibleError("error when trying to access the known hosts file: '%s', error was: %s" % (user_known_hosts, text_type(e))) class Connection(ConnectionBase): ''' ssh based connections ''' @@ -62,6 +192,56 @@ class Connection(ConnectionBase): def _connect(self): return self + @staticmethod + def fetch_and_store_key(host, play_context): + ssh_opts = get_ssh_opts(play_context) + if not host_in_known_hosts(play_context.remote_addr, ssh_opts): + display.debug("host %s does not have a known host key, fetching it" % host) + + # build the list of valid host key types, for use later as we scan for keys. + # we also use this to determine the most preferred key when multiple keys are available + valid_host_key_types = [x.lower() for x in ssh_opts.get('hostbasedkeytypes', '').split(',')] + + # attempt to fetch the key with ssh-keyscan. More than one key may be + # returned, so we save all and use the above list to determine which + host_key_data = fetch_ssh_host_key(play_context, ssh_opts).strip().split('\n') + host_keys = dict() + for host_key in host_key_data: + (host_info, key_type, key_hash) = host_key.strip().split(' ', 3) + key_type = key_type.lower() + if key_type in valid_host_key_types and key_type not in host_keys: + host_keys[key_type.lower()] = host_key + + if len(host_keys) == 0: + raise AnsibleConnectionFailure("none of the available host keys found were in the HostBasedKeyTypes configuration option") + + # now we determine the preferred key by sorting the above dict on the + # index of the key type in the valid keys list + preferred_key = sorted(host_keys.items(), cmp=lambda x,y: cmp(valid_host_key_types.index(x), valid_host_key_types.index(y)), key=operator.itemgetter(0))[0] + + # shamelessly copied from here: + # https://github.com/ojarva/python-sshpubkeys/blob/master/sshpubkeys/__init__.py#L39 + # (which shamelessly copied it from somewhere else...) + (host_info, key_type, key_hash) = preferred_key[1].strip().split(' ', 3) + decoded_key = key_hash.decode('base64') + fp_plain = md5(decoded_key).hexdigest() + key_data = ':'.join(a+b for a, b in zip(fp_plain[::2], fp_plain[1::2])) + + # prompt the user to add the key + # if yes, add it, otherwise raise AnsibleConnectionFailure + display.display("\nThe authenticity of host %s (%s) can't be established." % (host.name, play_context.remote_addr)) + display.display("%s key fingerprint is SHA256:%s." % (key_type.upper(), sha256(decoded_key).digest().encode('base64').strip())) + display.display("%s key fingerprint is MD5:%s." % (key_type.upper(), key_data)) + response = display.prompt("Are you sure you want to continue connecting (yes/no)? ") + display.display("") + if boolean(response): + add_host_key(host_key, ssh_opts) + return True + else: + raise AnsibleConnectionFailure("Host key validation failed.") + + return False + @staticmethod def _sshpass_available(): global SSHPASS_AVAILABLE @@ -100,15 +280,6 @@ class Connection(ConnectionBase): return controlpersist, controlpath - @staticmethod - def _split_args(argstring): - """ - Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a - list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to - the argument list. The list will not contain any empty elements. - """ - return [to_unicode(x.strip()) for x in shlex.split(to_bytes(argstring)) if x.strip()] - def _add_args(self, explanation, args): """ Adds the given args to self._command and displays a caller-supplied @@ -157,7 +328,7 @@ class Connection(ConnectionBase): # Next, we add [ssh_connection]ssh_args from ansible.cfg. if self._play_context.ssh_args: - args = self._split_args(self._play_context.ssh_args) + args = split_args(self._play_context.ssh_args) self._add_args("ansible.cfg set ssh_args", args) # Now we add various arguments controlled by configuration file settings @@ -210,7 +381,7 @@ class Connection(ConnectionBase): for opt in ['ssh_common_args', binary + '_extra_args']: attr = getattr(self._play_context, opt, None) if attr is not None: - args = self._split_args(attr) + args = split_args(attr) self._add_args("PlayContext set %s" % opt, args) # Check if ControlPersist is enabled and add a ControlPath if one hasn't diff --git a/lib/ansible/plugins/strategy/__init__.py b/lib/ansible/plugins/strategy/__init__.py index 7b2a3794efc..e460708f906 100644 --- a/lib/ansible/plugins/strategy/__init__.py +++ b/lib/ansible/plugins/strategy/__init__.py @@ -29,7 +29,7 @@ import zlib from jinja2.exceptions import UndefinedError from ansible import constants as C -from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable +from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure from ansible.executor.play_iterator import PlayIterator from ansible.executor.process.worker import WorkerProcess from ansible.executor.task_result import TaskResult @@ -39,6 +39,7 @@ from ansible.playbook.helpers import load_list_of_blocks from ansible.playbook.included_file import IncludedFile from ansible.plugins import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader from ansible.template import Templar +from ansible.utils.connection import get_smart_connection_type from ansible.vars.unsafe_proxy import wrap_var try: @@ -139,6 +140,33 @@ class StrategyBase: display.debug("entering _queue_task() for %s/%s" % (host, task)) + if C.HOST_KEY_CHECKING and not host.has_hostkey: + # caveat here, regarding with loops. It is assumed that none of the connection + # related variables would contain '{{item}}' as it would cause some really + # weird loops. As is, if someone did something odd like that they would need + # to disable host key checking + templar = Templar(loader=self._loader, variables=task_vars) + temp_pc = play_context.set_task_and_variable_override(task=task, variables=task_vars, templar=templar) + temp_pc.post_validate(templar) + if temp_pc.connection in ('smart', 'ssh') and get_smart_connection_type(temp_pc) == 'ssh': + try: + # get the ssh connection plugin's class, and use its builtin + # static method to fetch and save the key to the known_hosts file + ssh_conn = connection_loader.get('ssh', class_only=True) + ssh_conn.fetch_and_store_key(host, temp_pc) + except AnsibleConnectionFailure as e: + # if that fails, add the host to the list of unreachable + # hosts and send the appropriate callback + self._tqm._unreachable_hosts[host.name] = True + self._tqm._stats.increment('dark', host.name) + tr = TaskResult(host=host, task=task, return_data=dict(msg=text_type(e))) + self._tqm.send_callback('v2_runner_on_unreachable', tr) + return + + # finally, we set the has_hostkey flag to true for this + # host so we can skip it quickly in the future + host.has_hostkey = True + task_vars['hostvars'] = self._tqm.hostvars # and then queue the new task display.debug("%s - putting task (%s) in queue" % (host, task)) diff --git a/lib/ansible/utils/connection.py b/lib/ansible/utils/connection.py new file mode 100644 index 00000000000..6f6b405640e --- /dev/null +++ b/lib/ansible/utils/connection.py @@ -0,0 +1,50 @@ +# (c) 2015, Ansible, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import subprocess +import sys + + +__all__ = ['get_smart_connection_type'] + +def get_smart_connection_type(play_context): + ''' + Uses the ssh command with the ControlPersist option while checking + for an error to determine if we should use ssh or paramiko. Also + may take other factors into account. + ''' + + conn_type = 'ssh' + if sys.platform.startswith('darwin') and play_context.password: + # due to a current bug in sshpass on OSX, which can trigger + # a kernel panic even for non-privileged users, we revert to + # paramiko on that OS when a SSH password is specified + conn_type = "paramiko" + else: + # see if SSH can support ControlPersist if not use paramiko + try: + cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + (out, err) = cmd.communicate() + if "Bad configuration option" in err or "Usage:" in err: + conn_type = "paramiko" + except OSError: + conn_type = "paramiko" + + return conn_type