diff --git a/lib/ansible/compat/selectors/__init__.py b/lib/ansible/compat/selectors/__init__.py new file mode 100644 index 00000000000..149656766fe --- /dev/null +++ b/lib/ansible/compat/selectors/__init__.py @@ -0,0 +1,47 @@ +# (c) 2014, 2017 Toshio Kuratomi +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +''' +Compat selectors library. Python-3.5 has this builtin. The selectors2 +package exists on pypi to backport the functionality as far as python-2.6. +''' +# The following makes it easier for us to script updates of the bundled code +_BUNDLED_METADATA = { "pypi_name": "selectors2", "version": "1.1.0" } + +import os.path +import sys + +try: + # Python 3.4+ + import selectors as _system_selectors +except ImportError: + try: + # backport package installed in the system + import selectors2 as _system_selectors + except ImportError: + _system_selectors = None + +if _system_selectors: + selectors = _system_selectors +else: + # Our bundled copy + from . import _selectors2 as selectors +sys.modules['ansible.compat.selectors'] = selectors diff --git a/lib/ansible/compat/selectors/_selectors2.py b/lib/ansible/compat/selectors/_selectors2.py new file mode 100644 index 00000000000..8b1dc866bb3 --- /dev/null +++ b/lib/ansible/compat/selectors/_selectors2.py @@ -0,0 +1,667 @@ +# This file is from the selectors2.py package. It backports the PSF Licensed +# selectors module from the Python-3.5 stdlib to older versions of Python. +# The author, Seth Michael Larson, dual licenses his modifications under the +# PSF License and MIT License: +# https://github.com/SethMichaelLarson/selectors2#license +# +# Seth's copy of the MIT license is reproduced below +# +# MIT License +# +# Copyright (c) 2016 Seth Michael Larson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# Backport of selectors.py from Python 3.5+ to support Python < 3.4 +# Also has the behavior specified in PEP 475 which is to retry syscalls +# in the case of an EINTR error. This module is required because selectors34 +# does not follow this behavior and instead returns that no dile descriptor +# events have occurred rather than retry the syscall. The decision to drop +# support for select.devpoll is made to maintain 100% test coverage. + +import errno +import math +import select +import socket +import sys +import time +from collections import namedtuple, Mapping + +try: + monotonic = time.monotonic +except (AttributeError, ImportError): # Python 3.3< + monotonic = time.time + +__author__ = 'Seth Michael Larson' +__email__ = 'sethmichaellarson@protonmail.com' +__version__ = '1.1.0' +__license__ = 'MIT' + +__all__ = [ + 'EVENT_READ', + 'EVENT_WRITE', + 'SelectorError', + 'SelectorKey', + 'DefaultSelector' +] + +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + +HAS_SELECT = True # Variable that shows whether the platform has a selector. +_SYSCALL_SENTINEL = object() # Sentinel in case a system call returns None. + + +class SelectorError(Exception): + def __init__(self, errcode): + super(SelectorError, self).__init__() + self.errno = errcode + + def __repr__(self): + return "".format(self.errno) + + def __str__(self): + return self.__repr__() + + +def _fileobj_to_fd(fileobj): + """ Return a file descriptor from a file object. If + given an integer will simply return that integer back. """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: {0!r}".format(fileobj)) + if fd < 0: + raise ValueError("Invalid file descriptor: {0}".format(fd)) + return fd + +# Python 3.5 uses a more direct route to wrap system calls to increase speed. +if sys.version_info >= (3, 5): + def _syscall_wrapper(func, _, *args, **kwargs): + """ This is the short-circuit version of the below logic + because in Python 3.5+ all selectors restart system calls. """ + try: + return func(*args, **kwargs) + except (OSError, IOError, select.error) as e: + errcode = None + if hasattr(e, "errno"): + errcode = e.errno + elif hasattr(e, "args"): + errcode = e.args[0] + raise SelectorError(errcode) +else: + def _syscall_wrapper(func, recalc_timeout, *args, **kwargs): + """ Wrapper function for syscalls that could fail due to EINTR. + All functions should be retried if there is time left in the timeout + in accordance with PEP 475. """ + timeout = kwargs.get("timeout", None) + if timeout is None: + expires = None + recalc_timeout = False + else: + timeout = float(timeout) + if timeout < 0.0: # Timeout less than 0 treated as no timeout. + expires = None + else: + expires = monotonic() + timeout + + args = list(args) + if recalc_timeout and "timeout" not in kwargs: + raise ValueError( + "Timeout must be in args or kwargs to be recalculated") + + result = _SYSCALL_SENTINEL + while result is _SYSCALL_SENTINEL: + try: + result = func(*args, **kwargs) + # OSError is thrown by select.select + # IOError is thrown by select.epoll.poll + # select.error is thrown by select.poll.poll + # Aren't we thankful for Python 3.x rework for exceptions? + except (OSError, IOError, select.error) as e: + # select.error wasn't a subclass of OSError in the past. + errcode = None + if hasattr(e, "errno"): + errcode = e.errno + elif hasattr(e, "args"): + errcode = e.args[0] + + # Also test for the Windows equivalent of EINTR. + is_interrupt = (errcode == errno.EINTR or (hasattr(errno, "WSAEINTR") and + errcode == errno.WSAEINTR)) + + if is_interrupt: + if expires is not None: + current_time = monotonic() + if current_time > expires: + raise OSError(errno=errno.ETIMEDOUT) + if recalc_timeout: + if "timeout" in kwargs: + kwargs["timeout"] = expires - current_time + continue + if errcode: + raise SelectorError(errcode) + else: + raise + return result + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) + + +class _SelectorMapping(Mapping): + """ Mapping of file objects to selector keys """ + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{0!r} is not registered.".format(fileobj)) + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(object): + """ Abstract Selector class + + A selector supports registering file objects to be monitored + for specific I/O events. + + A file object is a file descriptor or any object with a + `fileno()` method. An arbitrary object can be attached to the + file object which can be used for example to store context info, + a callback, etc. + + A selector can use various implementations (select(), poll(), epoll(), + and kqueue()) depending on the platform. The 'DefaultSelector' class uses + the most efficient implementation for the current platform. + """ + def __init__(self): + # Maps file descriptors to keys. + self._fd_to_key = {} + + # Read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """ Return a file descriptor from a file object. + This wraps _fileobj_to_fd() to do an exhaustive + search in case the object is invalid but we still + have it in our map. Used by unregister() so we can + unregister an object that was previously registered + even if it is closed. It is also used by _SelectorMapping + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + + # Search through all our mapped keys. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + """ Register a file object for a set of events to monitor. """ + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {0!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{0!r} (FD {1}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + """ Unregister a file object from being monitored. """ + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + # Getting the fileno of a closed socket on Windows errors with EBADF. + except socket.error as err: + if err.errno != errno.EBADF: + raise + else: + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + self._fd_to_key.pop(key.fd) + break + else: + raise KeyError("{0!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """ Change a registered file object monitored events and data. """ + # NOTE: Some subclasses optimize this operation even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + + return key + + def select(self, timeout=None): + """ Perform the actual selection until some monitored file objects + are ready or the timeout expires. """ + raise NotImplementedError() + + def close(self): + """ Close the selector. This must be called to ensure that all + underlying resources are freed. """ + self._fd_to_key.clear() + self._map = None + + def get_key(self, fileobj): + """ Return the key associated with a registered file object. """ + mapping = self.get_map() + if mapping is None: + raise RuntimeError("Selector is closed") + try: + return mapping[fileobj] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + def get_map(self): + """ Return a mapping of file objects to selector keys """ + return self._map + + def _key_from_fd(self, fd): + """ Return the key associated to a given file descriptor + Return None if it is not found. """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + +# Almost all platforms have select.select() +if hasattr(select, "select"): + class SelectSelector(BaseSelector): + """ Select-based selector. """ + def __init__(self): + super(SelectSelector, self).__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super(SelectSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super(SelectSelector, self).unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def _select(self, r, w, timeout=None): + """ Wrapper for select.select because timeout is a positional arg """ + return select.select(r, w, [], timeout) + + def select(self, timeout=None): + # Selecting on empty lists on Windows errors out. + if not len(self._readers) and not len(self._writers): + return [] + + timeout = None if timeout is None else max(timeout, 0.0) + ready = [] + r, w, _ = _syscall_wrapper(self._select, True, self._readers, + self._writers, timeout) + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + __all__.append('SelectSelector') + + +if hasattr(select, "poll"): + class PollSelector(BaseSelector): + """ Poll-based selector """ + def __init__(self): + super(PollSelector, self).__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super(PollSelector, self).register(fileobj, events, data) + event_mask = 0 + if events & EVENT_READ: + event_mask |= select.POLLIN + if events & EVENT_WRITE: + event_mask |= select.POLLOUT + self._poll.register(key.fd, event_mask) + return key + + def unregister(self, fileobj): + key = super(PollSelector, self).unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def _wrap_poll(self, timeout=None): + """ Wrapper function for select.poll.poll() so that + _syscall_wrapper can work with only seconds. """ + if timeout is not None: + if timeout <= 0: + timeout = 0 + else: + # select.poll.poll() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + + result = self._poll.poll(timeout) + return result + + def select(self, timeout=None): + ready = [] + fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout) + for fd, event_mask in fd_events: + events = 0 + if event_mask & ~select.POLLIN: + events |= EVENT_WRITE + if event_mask & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + + return ready + + __all__.append('PollSelector') + +if hasattr(select, "epoll"): + class EpollSelector(BaseSelector): + """ Epoll-based selector """ + def __init__(self): + super(EpollSelector, self).__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(EpollSelector, self).register(fileobj, events, data) + events_mask = 0 + if events & EVENT_READ: + events_mask |= select.EPOLLIN + if events & EVENT_WRITE: + events_mask |= select.EPOLLOUT + _syscall_wrapper(self._epoll.register, False, key.fd, events_mask) + return key + + def unregister(self, fileobj): + key = super(EpollSelector, self).unregister(fileobj) + try: + _syscall_wrapper(self._epoll.unregister, False, key.fd) + except SelectorError: + # This can occur when the fd was closed since registry. + pass + return key + + def select(self, timeout=None): + if timeout is not None: + if timeout <= 0: + timeout = 0.0 + else: + # select.epoll.poll() has a resolution of 1 millisecond + # but luckily takes seconds so we don't need a wrapper + # like PollSelector. Just for better rounding. + timeout = math.ceil(timeout * 1e3) * 1e-3 + timeout = float(timeout) + else: + timeout = -1.0 # epoll.poll() must have a float. + + # We always want at least 1 to ensure that select can be called + # with no file descriptors registered. Otherwise will fail. + max_events = max(len(self._fd_to_key), 1) + + ready = [] + fd_events = _syscall_wrapper(self._epoll.poll, True, + timeout=timeout, + maxevents=max_events) + for fd, event_mask in fd_events: + events = 0 + if event_mask & ~select.EPOLLIN: + events |= EVENT_WRITE + if event_mask & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._epoll.close() + super(EpollSelector, self).close() + + __all__.append('EpollSelector') + + +if hasattr(select, "devpoll"): + class DevpollSelector(BaseSelector): + """Solaris /dev/poll selector.""" + + def __init__(self): + super(DevpollSelector, self).__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(DevpollSelector, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(DevpollSelector, self).unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def _wrap_poll(self, timeout=None): + """ Wrapper function for select.poll.poll() so that + _syscall_wrapper can work with only seconds. """ + if timeout is not None: + if timeout <= 0: + timeout = 0 + else: + # select.devpoll.poll() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + + result = self._devpoll.poll(timeout) + return result + + def select(self, timeout=None): + ready = [] + fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout) + for fd, event_mask in fd_events: + events = 0 + if event_mask & ~select.POLLIN: + events |= EVENT_WRITE + if event_mask & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + + return ready + + def close(self): + self._devpoll.close() + super(DevpollSelector, self).close() + + __all__.append('DevpollSelector') + + +if hasattr(select, "kqueue"): + class KqueueSelector(BaseSelector): + """ Kqueue / Kevent-based selector """ + def __init__(self): + super(KqueueSelector, self).__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super(KqueueSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + kevent = select.kevent(key.fd, + select.KQ_FILTER_READ, + select.KQ_EV_ADD) + + _syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0) + + if events & EVENT_WRITE: + kevent = select.kevent(key.fd, + select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + + _syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0) + + return key + + def unregister(self, fileobj): + key = super(KqueueSelector, self).unregister(fileobj) + if key.events & EVENT_READ: + kevent = select.kevent(key.fd, + select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + _syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0) + except SelectorError: + pass + if key.events & EVENT_WRITE: + kevent = select.kevent(key.fd, + select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + _syscall_wrapper(self._kqueue.control, False, [kevent], 0, 0) + except SelectorError: + pass + + return key + + def select(self, timeout=None): + if timeout is not None: + timeout = max(timeout, 0) + + max_events = len(self._fd_to_key) * 2 + ready_fds = {} + + kevent_list = _syscall_wrapper(self._kqueue.control, True, + None, max_events, timeout) + + for kevent in kevent_list: + fd = kevent.ident + event_mask = kevent.filter + events = 0 + if event_mask == select.KQ_FILTER_READ: + events |= EVENT_READ + if event_mask == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + if key.fd not in ready_fds: + ready_fds[key.fd] = (key, events & key.events) + else: + old_events = ready_fds[key.fd][1] + ready_fds[key.fd] = (key, (events | old_events) & key.events) + + return list(ready_fds.values()) + + def close(self): + self._kqueue.close() + super(KqueueSelector, self).close() + + __all__.append('KqueueSelector') + + +# Choose the best implementation, roughly: +# kqueue == epoll == devpoll > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): # Platform-specific: Mac OS and BSD + DefaultSelector = KqueueSelector +if 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'EpollSelector' in globals(): # Platform-specific: Linux + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): # Platform-specific: Linux + DefaultSelector = PollSelector +elif 'SelectSelector' in globals(): # Platform-specific: Windows + DefaultSelector = SelectSelector +else: # Platform-specific: AppEngine + def no_selector(_): + raise ValueError("Platform does not have a selector") + DefaultSelector = no_selector + HAS_SELECT = False diff --git a/lib/ansible/plugins/connection/local.py b/lib/ansible/plugins/connection/local.py index da25f1e2306..133401581d7 100644 --- a/lib/ansible/plugins/connection/local.py +++ b/lib/ansible/plugins/connection/local.py @@ -1,5 +1,5 @@ # (c) 2012, Michael DeHaan -# (c) 2015 Toshio Kuratomi +# (c) 2015, 2017 Toshio Kuratomi # # This file is part of Ansible # @@ -19,16 +19,14 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type import os -import select import shutil import subprocess import fcntl import getpass -from ansible.compat.six import text_type, binary_type - import ansible.constants as C - +from ansible.compat import selectors +from ansible.compat.six import text_type, binary_type from ansible.errors import AnsibleError, AnsibleFileNotFound from ansible.module_utils._text import to_bytes, to_native from ansible.plugins.connection import ConnectionBase @@ -90,21 +88,31 @@ class Connection(ConnectionBase): if self._play_context.prompt and sudoable: fcntl.fcntl(p.stdout, fcntl.F_SETFL, fcntl.fcntl(p.stdout, fcntl.F_GETFL) | os.O_NONBLOCK) fcntl.fcntl(p.stderr, fcntl.F_SETFL, fcntl.fcntl(p.stderr, fcntl.F_GETFL) | os.O_NONBLOCK) + selector = selectors.DefaultSelector() + selector.register(p.stdout, selectors.EVENT_READ) + selector.register(p.stderr, selectors.EVENT_READ) + become_output = b'' - while not self.check_become_success(become_output) and not self.check_password_prompt(become_output): - - rfd, wfd, efd = select.select([p.stdout, p.stderr], [], [p.stdout, p.stderr], self._play_context.timeout) - if p.stdout in rfd: - chunk = p.stdout.read() - elif p.stderr in rfd: - chunk = p.stderr.read() - else: - stdout, stderr = p.communicate() - raise AnsibleError('timeout waiting for privilege escalation password prompt:\n' + to_native(become_output)) - if not chunk: - stdout, stderr = p.communicate() - raise AnsibleError('privilege output closed while waiting for password prompt:\n' + to_native(become_output)) - become_output += chunk + try: + while not self.check_become_success(become_output) and not self.check_password_prompt(become_output): + events = selector.select(self._play_context.timeout) + if not events: + stdout, stderr = p.communicate() + raise AnsibleError('timeout waiting for privilege escalation password prompt:\n' + to_native(become_output)) + + for key, event in events: + if key.fileobj == p.stdout: + chunk = p.stdout.read() + elif key.fileobj == p.stderr: + chunk = p.stderr.read() + + if not chunk: + stdout, stderr = p.communicate() + raise AnsibleError('privilege output closed while waiting for password prompt:\n' + to_native(become_output)) + become_output += chunk + finally: + selector.close() + if not self.check_become_success(become_output): p.stdin.write(to_bytes(self._play_context.become_pass, errors='surrogate_or_strict') + b'\n') fcntl.fcntl(p.stdout, fcntl.F_SETFL, fcntl.fcntl(p.stdout, fcntl.F_GETFL) & ~os.O_NONBLOCK) diff --git a/lib/ansible/plugins/connection/ssh.py b/lib/ansible/plugins/connection/ssh.py index d1c61fd8018..7a26d158335 100644 --- a/lib/ansible/plugins/connection/ssh.py +++ b/lib/ansible/plugins/connection/ssh.py @@ -1,5 +1,6 @@ # (c) 2012, Michael DeHaan # Copyright 2015 Abhijit Menon-Sen +# Copyright 2017 Toshio Kuratomi # # This file is part of Ansible # @@ -24,11 +25,11 @@ import fcntl import hashlib import os import pty -import select import subprocess import time from ansible import constants as C +from ansible.compat import selectors from ansible.compat.six import PY3, text_type, binary_type from ansible.compat.six.moves import shlex_quote from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound @@ -443,148 +444,158 @@ class Connection(ConnectionBase): # they will race each other when we can't connect, and the connect # timeout usually fails timeout = 2 + self._play_context.timeout - rpipes = [p.stdout, p.stderr] - for fd in rpipes: + for fd in (p.stdout, p.stderr): fcntl.fcntl(fd, fcntl.F_SETFL, fcntl.fcntl(fd, fcntl.F_GETFL) | os.O_NONBLOCK) - # If we can send initial data without waiting for anything, we do so - # before we call select. + ### TODO: bcoca would like to use SelectSelector() when open + # filehandles is low, then switch to more efficient ones when higher. + # select is faster when filehandles is low. + selector = selectors.DefaultSelector() + selector.register(p.stdout, selectors.EVENT_READ) + selector.register(p.stderr, selectors.EVENT_READ) + # If we can send initial data without waiting for anything, we do so + # before we start polling if states[state] == 'ready_to_send' and in_data: self._send_initial_data(stdin, in_data) state += 1 - while True: - rfd, wfd, efd = select.select(rpipes, [], [], timeout) + try: + while True: + events = selector.select(timeout) + + # We pay attention to timeouts only while negotiating a prompt. + + if not events: + # We timed out + if state <= states.index('awaiting_escalation'): + # If the process has already exited, then it's not really a + # timeout; we'll let the normal error handling deal with it. + if p.poll() is not None: + break + self._terminate_process(p) + raise AnsibleError('Timeout (%ds) waiting for privilege escalation prompt: %s' % (timeout, to_native(b_stdout))) + + # Read whatever output is available on stdout and stderr, and stop + # listening to the pipe if it's been closed. + + for key, event in events: + if key.fileobj == p.stdout: + b_chunk = p.stdout.read() + if b_chunk == b'': + # stdout has been closed, stop watching it + selector.unregister(p.stdout) + # When ssh has ControlMaster (+ControlPath/Persist) enabled, the + # first connection goes into the background and we never see EOF + # on stderr. If we see EOF on stdout, lower the select timeout + # to reduce the time wasted selecting on stderr if we observe + # that the process has not yet existed after this EOF. Otherwise + # we may spend a long timeout period waiting for an EOF that is + # not going to arrive until the persisted connection closes. + timeout = 1 + b_tmp_stdout += b_chunk + display.debug("stdout chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk))) + elif key.fileobj == p.stderr: + b_chunk = p.stderr.read() + if b_chunk == b'': + # stderr has been closed, stop watching it + selector.unregister(p.stderr) + b_tmp_stderr += b_chunk + display.debug("stderr chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk))) + + # We examine the output line-by-line until we have negotiated any + # privilege escalation prompt and subsequent success/error message. + # Afterwards, we can accumulate output without looking at it. + + if state < states.index('ready_to_send'): + if b_tmp_stdout: + b_output, b_unprocessed = self._examine_output('stdout', states[state], b_tmp_stdout, sudoable) + b_stdout += b_output + b_tmp_stdout = b_unprocessed + + if b_tmp_stderr: + b_output, b_unprocessed = self._examine_output('stderr', states[state], b_tmp_stderr, sudoable) + b_stderr += b_output + b_tmp_stderr = b_unprocessed + else: + b_stdout += b_tmp_stdout + b_stderr += b_tmp_stderr + b_tmp_stdout = b_tmp_stderr = b'' + + # If we see a privilege escalation prompt, we send the password. + # (If we're expecting a prompt but the escalation succeeds, we + # didn't need the password and can carry on regardless.) + + if states[state] == 'awaiting_prompt': + if self._flags['become_prompt']: + display.debug('Sending become_pass in response to prompt') + stdin.write(to_bytes(self._play_context.become_pass) + b'\n') + self._flags['become_prompt'] = False + state += 1 + elif self._flags['become_success']: + state += 1 + + # We've requested escalation (with or without a password), now we + # wait for an error message or a successful escalation. + + if states[state] == 'awaiting_escalation': + if self._flags['become_success']: + display.debug('Escalation succeeded') + self._flags['become_success'] = False + state += 1 + elif self._flags['become_error']: + display.debug('Escalation failed') + self._terminate_process(p) + self._flags['become_error'] = False + raise AnsibleError('Incorrect %s password' % self._play_context.become_method) + elif self._flags['become_nopasswd_error']: + display.debug('Escalation requires password') + self._terminate_process(p) + self._flags['become_nopasswd_error'] = False + raise AnsibleError('Missing %s password' % self._play_context.become_method) + elif self._flags['become_prompt']: + # This shouldn't happen, because we should see the "Sorry, + # try again" message first. + display.debug('Escalation prompt repeated') + self._terminate_process(p) + self._flags['become_prompt'] = False + raise AnsibleError('Incorrect %s password' % self._play_context.become_method) + + # Once we're sure that the privilege escalation prompt, if any, has + # been dealt with, we can send any initial data and start waiting + # for output. + + if states[state] == 'ready_to_send': + if in_data: + self._send_initial_data(stdin, in_data) + state += 1 - # We pay attention to timeouts only while negotiating a prompt. + # Now we're awaiting_exit: has the child process exited? If it has, + # and we've read all available output from it, we're done. - if not rfd: - if state <= states.index('awaiting_escalation'): - # If the process has already exited, then it's not really a - # timeout; we'll let the normal error handling deal with it. - if p.poll() is not None: + if p.poll() is not None: + if not selector.get_map() or not events: break - self._terminate_process(p) - raise AnsibleError('Timeout (%ds) waiting for privilege escalation prompt: %s' % (timeout, to_native(b_stdout))) - - # Read whatever output is available on stdout and stderr, and stop - # listening to the pipe if it's been closed. - - if p.stdout in rfd: - b_chunk = p.stdout.read() - if b_chunk == b'': - rpipes.remove(p.stdout) - # When ssh has ControlMaster (+ControlPath/Persist) enabled, the - # first connection goes into the background and we never see EOF - # on stderr. If we see EOF on stdout, lower the select timeout - # to reduce the time wasted selecting on stderr if we observe - # that the process has not yet existed after this EOF. Otherwise - # we may spend a long timeout period waiting for an EOF that is - # not going to arrive until the persisted connection closes. - timeout = 1 - b_tmp_stdout += b_chunk - display.debug("stdout chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk))) - - if p.stderr in rfd: - b_chunk = p.stderr.read() - if b_chunk == b'': - rpipes.remove(p.stderr) - b_tmp_stderr += b_chunk - display.debug("stderr chunk (state=%s):\n>>>%s<<<\n" % (state, to_text(b_chunk))) - - # We examine the output line-by-line until we have negotiated any - # privilege escalation prompt and subsequent success/error message. - # Afterwards, we can accumulate output without looking at it. - - if state < states.index('ready_to_send'): - if b_tmp_stdout: - b_output, b_unprocessed = self._examine_output('stdout', states[state], b_tmp_stdout, sudoable) - b_stdout += b_output - b_tmp_stdout = b_unprocessed - - if b_tmp_stderr: - b_output, b_unprocessed = self._examine_output('stderr', states[state], b_tmp_stderr, sudoable) - b_stderr += b_output - b_tmp_stderr = b_unprocessed - else: - b_stdout += b_tmp_stdout - b_stderr += b_tmp_stderr - b_tmp_stdout = b_tmp_stderr = b'' - - # If we see a privilege escalation prompt, we send the password. - # (If we're expecting a prompt but the escalation succeeds, we - # didn't need the password and can carry on regardless.) - - if states[state] == 'awaiting_prompt': - if self._flags['become_prompt']: - display.debug('Sending become_pass in response to prompt') - stdin.write(to_bytes(self._play_context.become_pass) + b'\n') - self._flags['become_prompt'] = False - state += 1 - elif self._flags['become_success']: - state += 1 + # We should not see further writes to the stdout/stderr file + # descriptors after the process has closed, set the select + # timeout to gather any last writes we may have missed. + timeout = 0 + continue - # We've requested escalation (with or without a password), now we - # wait for an error message or a successful escalation. + # If the process has not yet exited, but we've already read EOF from + # its stdout and stderr (and thus no longer watching any file + # descriptors), we can just wait for it to exit. - if states[state] == 'awaiting_escalation': - if self._flags['become_success']: - display.debug('Escalation succeeded') - self._flags['become_success'] = False - state += 1 - elif self._flags['become_error']: - display.debug('Escalation failed') - self._terminate_process(p) - self._flags['become_error'] = False - raise AnsibleError('Incorrect %s password' % self._play_context.become_method) - elif self._flags['become_nopasswd_error']: - display.debug('Escalation requires password') - self._terminate_process(p) - self._flags['become_nopasswd_error'] = False - raise AnsibleError('Missing %s password' % self._play_context.become_method) - elif self._flags['become_prompt']: - # This shouldn't happen, because we should see the "Sorry, - # try again" message first. - display.debug('Escalation prompt repeated') - self._terminate_process(p) - self._flags['become_prompt'] = False - raise AnsibleError('Incorrect %s password' % self._play_context.become_method) - - # Once we're sure that the privilege escalation prompt, if any, has - # been dealt with, we can send any initial data and start waiting - # for output. - - if states[state] == 'ready_to_send': - if in_data: - self._send_initial_data(stdin, in_data) - state += 1 - - # Now we're awaiting_exit: has the child process exited? If it has, - # and we've read all available output from it, we're done. - - if p.poll() is not None: - if not rpipes or not rfd: + elif not selector.get_map(): + p.wait() break - # We should not see further writes to the stdout/stderr file - # descriptors after the process has closed, set the select - # timeout to gather any last writes we may have missed. - timeout = 0 - continue - - # If the process has not yet exited, but we've already read EOF from - # its stdout and stderr (and thus removed both from rpipes), we can - # just wait for it to exit. - - elif not rpipes: - p.wait() - break - - # Otherwise there may still be outstanding data to read. - # close stdin after process is terminated and stdout/stderr are read - # completely (see also issue #848) - stdin.close() + # Otherwise there may still be outstanding data to read. + finally: + selector.close() + # close stdin after process is terminated and stdout/stderr are read + # completely (see also issue #848) + stdin.close() if C.HOST_KEY_CHECKING: if cmd[0] == b"sshpass" and p.returncode == 6: diff --git a/setup.py b/setup.py index ad1bd661a93..1bb0ee2f00e 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,8 @@ setup(name='ansible', author_email='info@ansible.com', url='http://ansible.com/', license='GPLv3', - # Ansible will also make use of a system copy of python-six if installed but use a - # Bundled copy if it's not. + # Ansible will also make use of a system copy of python-six and + # python-selectors2 if installed but use a Bundled copy if it's not. install_requires=['paramiko', 'jinja2', "PyYAML", 'setuptools', 'pycrypto >= 2.6'], package_dir={ '': 'lib' }, packages=find_packages('lib'), diff --git a/test/sanity/code-smell/boilerplate.sh b/test/sanity/code-smell/boilerplate.sh index b9d67ed4ec8..5ed3bac57b8 100755 --- a/test/sanity/code-smell/boilerplate.sh +++ b/test/sanity/code-smell/boilerplate.sh @@ -7,6 +7,7 @@ metaclass2=$(find ./lib/ansible -path ./lib/ansible/modules -prune \ -o -path ./lib/ansible/modules/__init__.py \ -o -path ./lib/ansible/module_utils -prune \ -o -path ./lib/ansible/compat/six/_six.py -prune \ + -o -path ./lib/ansible/compat/selectors/_selectors2.py -prune \ -o -path ./lib/ansible/utils/module_docs_fragments -prune \ -o -name '*.py' -exec grep -HL '__metaclass__ = type' '{}' '+') @@ -14,6 +15,7 @@ future2=$(find ./lib/ansible -path ./lib/ansible/modules -prune \ -o -path ./lib/ansible/modules/__init__.py \ -o -path ./lib/ansible/module_utils -prune \ -o -path ./lib/ansible/compat/six/_six.py -prune \ + -o -path ./lib/ansible/compat/selectors/_selectors2.py -prune \ -o -path ./lib/ansible/utils/module_docs_fragments -prune \ -o -name '*.py' -exec grep -HL 'from __future__ import (absolute_import, division, print_function)' '{}' '+') diff --git a/test/units/plugins/connection/test_ssh.py b/test/units/plugins/connection/test_ssh.py index 73dc0685efd..9ca4373c752 100644 --- a/test/units/plugins/connection/test_ssh.py +++ b/test/units/plugins/connection/test_ssh.py @@ -23,10 +23,13 @@ __metaclass__ = type from io import StringIO +import pytest + from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, MagicMock from ansible import constants as C +from ansible.compat.selectors import SelectorKey, EVENT_READ from ansible.compat.six.moves import shlex_quote from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound from ansible.playbook.play_context import PlayContext @@ -83,82 +86,6 @@ class TestConnectionBaseClass(unittest.TestCase): res, stdout, stderr = conn._exec_command('ssh') res, stdout, stderr = conn._exec_command('ssh', 'this is some data') - @patch('select.select') - @patch('fcntl.fcntl') - @patch('os.write') - @patch('os.close') - @patch('pty.openpty') - @patch('subprocess.Popen') - def test_plugins_connection_ssh__run(self, mock_Popen, mock_openpty, mock_osclose, mock_oswrite, mock_fcntl, mock_select): - pc = PlayContext() - new_stdin = StringIO() - - conn = ssh.Connection(pc, new_stdin) - conn._send_initial_data = MagicMock() - conn._examine_output = MagicMock() - conn._terminate_process = MagicMock() - conn.sshpass_pipe = [MagicMock(), MagicMock()] - - mock_popen_res = MagicMock() - mock_popen_res.poll = MagicMock() - mock_popen_res.wait = MagicMock() - mock_popen_res.stdin = MagicMock() - mock_popen_res.stdin.fileno.return_value = 1000 - mock_popen_res.stdout = MagicMock() - mock_popen_res.stdout.fileno.return_value = 1001 - mock_popen_res.stderr = MagicMock() - mock_popen_res.stderr.fileno.return_value = 1002 - mock_popen_res.return_code = 0 - mock_Popen.return_value = mock_popen_res - - def _mock_select(rlist, wlist, elist, timeout=None): - rvals = [] - if mock_popen_res.stdin in rlist: - rvals.append(mock_popen_res.stdin) - if mock_popen_res.stderr in rlist: - rvals.append(mock_popen_res.stderr) - return (rvals, [], []) - - mock_select.side_effect = _mock_select - - mock_popen_res.stdout.read.side_effect = [b"some data", b""] - mock_popen_res.stderr.read.side_effect = [b""] - conn._run("ssh", "this is input data") - - # test with a password set to trigger the sshpass write - pc.password = '12345' - mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] - mock_popen_res.stderr.read.side_effect = [b""] - conn._run(["ssh", "is", "a", "cmd"], "this is more data") - - # test with password prompting enabled - pc.password = None - pc.prompt = True - mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] - mock_popen_res.stderr.read.side_effect = [b""] - conn._run("ssh", "this is input data") - - # test with some become settings - pc.prompt = False - pc.become = True - pc.success_key = 'BECOME-SUCCESS-abcdefg' - mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] - mock_popen_res.stderr.read.side_effect = [b""] - conn._run("ssh", "this is input data") - - # simulate no data input - mock_openpty.return_value = (98, 99) - mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] - mock_popen_res.stderr.read.side_effect = [b""] - conn._run("ssh", "") - - # simulate no data input but Popen using new pty's fails - mock_Popen.return_value = None - mock_Popen.side_effect = [OSError(), mock_popen_res] - mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] - mock_popen_res.stderr.read.side_effect = [b""] - conn._run("ssh", "") - def test_plugins_connection_ssh__examine_output(self): pc = PlayContext() new_stdin = StringIO() @@ -341,7 +268,6 @@ class TestConnectionBaseClass(unittest.TestCase): conn.put_file(u'/path/to/in/file/with/unicode-fö〩', u'/path/to/dest/file/with/unicode-fö〩') conn._run.assert_called_with('some command to run', expected_in_data, checkrc=False) - # test that a non-zero rc raises an error conn._run.return_value = (1, 'stdout', 'some errors') self.assertRaises(AnsibleError, conn.put_file, '/path/to/bad/file', '/remote/path/to/file') @@ -398,3 +324,215 @@ class TestConnectionBaseClass(unittest.TestCase): # test that a non-zero rc raises an error conn._run.return_value = (1, 'stdout', 'some errors') self.assertRaises(AnsibleError, conn.fetch_file, '/path/to/bad/file', '/remote/path/to/file') + + +class MockSelector(object): + def __init__(self): + self.files_watched = 0 + self.register = MagicMock(side_effect=self._register) + self.unregister = MagicMock(side_effect=self._unregister) + self.close = MagicMock() + self.get_map = MagicMock(side_effect=self._get_map) + self.select = MagicMock() + + def _register(self, *args, **kwargs): + self.files_watched += 1 + + def _unregister(self, *args, **kwargs): + self.files_watched -= 1 + + def _get_map(self, *args, **kwargs): + return self.files_watched + + +@pytest.fixture +def mock_run_env(request, mocker): + pc = PlayContext() + new_stdin = StringIO() + + conn = ssh.Connection(pc, new_stdin) + conn._send_initial_data = MagicMock() + conn._examine_output = MagicMock() + conn._terminate_process = MagicMock() + conn.sshpass_pipe = [MagicMock(), MagicMock()] + + request.cls.pc = pc + request.cls.conn = conn + + mock_popen_res = MagicMock() + mock_popen_res.poll = MagicMock() + mock_popen_res.wait = MagicMock() + mock_popen_res.stdin = MagicMock() + mock_popen_res.stdin.fileno.return_value = 1000 + mock_popen_res.stdout = MagicMock() + mock_popen_res.stdout.fileno.return_value = 1001 + mock_popen_res.stderr = MagicMock() + mock_popen_res.stderr.fileno.return_value = 1002 + mock_popen_res.returncode = 0 + request.cls.mock_popen_res = mock_popen_res + + mock_popen = mocker.patch('subprocess.Popen', return_value=mock_popen_res) + request.cls.mock_popen = mock_popen + + request.cls.mock_selector = MockSelector() + mocker.patch('ansible.compat.selectors.DefaultSelector', lambda: request.cls.mock_selector) + + request.cls.mock_openpty = mocker.patch('pty.openpty') + + mocker.patch('fcntl.fcntl') + mocker.patch('os.write') + mocker.patch('os.close') + + +@pytest.mark.usefixtures('mock_run_env') +class TestSSHConnectionRun(object): + # FIXME: + # These tests are little more than a smoketest. Need to enhance them + # a bit to check that they're calling the relevant functions and making + # complete coverage of the code paths + def test_no_escalation(self): + self.mock_popen_res.stdout.read.side_effect = [b"my_stdout\n", b"second_line"] + self.mock_popen_res.stderr.read.side_effect = [b"my_stderr"] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data") + assert return_code == 0 + assert b_stdout == b'my_stdout\nsecond_line' + assert b_stderr == b'my_stderr' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is input data' + + def test_with_password(self): + # test with a password set to trigger the sshpass write + self.pc.password = '12345' + self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] + self.mock_popen_res.stderr.read.side_effect = [b""] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run(["ssh", "is", "a", "cmd"], "this is more data") + assert return_code == 0 + assert b_stdout == b'some data' + assert b_stderr == b'' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is more data' + + def _password_with_prompt_examine_output(self, sourice, state, b_chunk, sudoable): + if state == 'awaiting_prompt': + self.conn._flags['become_prompt'] = True + elif state == 'awaiting_escalation': + self.conn._flags['become_success'] = True + return (b'', b'') + + def test_pasword_with_prompt(self): + # test with password prompting enabled + self.pc.password = None + self.pc.prompt = b'Password:' + self.conn._examine_output.side_effect = self._password_with_prompt_examine_output + self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"Success", b""] + self.mock_popen_res.stderr.read.side_effect = [b""] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ), + (SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data") + assert return_code == 0 + assert b_stdout == b'' + assert b_stderr == b'' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is input data' + + def test_pasword_with_become(self): + # test with some become settings + self.pc.prompt = b'Password:' + self.pc.become = True + self.pc.success_key = 'BECOME-SUCCESS-abcdefg' + self.conn._examine_output.side_effect = self._password_with_prompt_examine_output + self.mock_popen_res.stdout.read.side_effect = [b"Password:", b"BECOME-SUCCESS-abcdefg", b"abc"] + self.mock_popen_res.stderr.read.side_effect = [b"123"] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "this is input data") + assert return_code == 0 + assert b_stdout == b'abc' + assert b_stderr == b'123' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is True + assert self.conn._send_initial_data.call_count == 1 + assert self.conn._send_initial_data.call_args[0][1] == 'this is input data' + + def test_pasword_without_data(self): + # simulate no data input + self.mock_openpty.return_value = (98, 99) + self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] + self.mock_popen_res.stderr.read.side_effect = [b""] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "") + assert return_code == 0 + assert b_stdout == b'some data' + assert b_stderr == b'' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is False + + def test_pasword_without_data(self): + # simulate no data input but Popen using new pty's fails + self.mock_popen.return_value = None + self.mock_popen.side_effect = [OSError(), self.mock_popen_res] + + # simulate no data input + self.mock_openpty.return_value = (98, 99) + self.mock_popen_res.stdout.read.side_effect = [b"some data", b"", b""] + self.mock_popen_res.stderr.read.side_effect = [b""] + self.mock_selector.select.side_effect = [ + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)], + [(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)], + []] + self.mock_selector.get_map.side_effect = lambda: True + + return_code, b_stdout, b_stderr = self.conn._run("ssh", "") + assert return_code == 0 + assert b_stdout == b'some data' + assert b_stderr == b'' + assert self.mock_selector.register.called is True + assert self.mock_selector.register.call_count == 2 + assert self.conn._send_initial_data.called is False