mirror of https://github.com/ansible/ansible.git
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
346 lines
11 KiB
Python
346 lines
11 KiB
Python
#!/usr/bin/env python
|
|
|
|
# (c) 2016, Ansible, Inc. <support@ansible.com>
|
|
#
|
|
# This file is part of Ansible
|
|
#
|
|
# Ansible is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# Ansible is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
########################################################
|
|
from __future__ import (absolute_import, division, print_function)
|
|
|
|
__metaclass__ = type
|
|
__requires__ = ['ansible']
|
|
|
|
try:
|
|
import pkg_resources
|
|
except Exception:
|
|
pass
|
|
|
|
import fcntl
|
|
import hashlib
|
|
import os
|
|
import shlex
|
|
import signal
|
|
import socket
|
|
import struct
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import syslog
|
|
import datetime
|
|
|
|
from io import BytesIO
|
|
|
|
from ansible import constants as C
|
|
from ansible.module_utils._text import to_bytes, to_native
|
|
from ansible.module_utils.six.moves import cPickle, StringIO
|
|
from ansible.playbook.play_context import PlayContext
|
|
from ansible.plugins import connection_loader
|
|
from ansible.utils.path import unfrackpath, makedirs_safe
|
|
from ansible.errors import AnsibleConnectionFailure
|
|
from ansible.utils.display import Display
|
|
|
|
|
|
def do_fork():
|
|
'''
|
|
Does the required double fork for a daemon process. Based on
|
|
http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
|
|
'''
|
|
try:
|
|
pid = os.fork()
|
|
if pid > 0:
|
|
return pid
|
|
|
|
os.chdir("/")
|
|
os.setsid()
|
|
os.umask(0)
|
|
|
|
try:
|
|
pid = os.fork()
|
|
if pid > 0:
|
|
sys.exit(0)
|
|
|
|
os.close(sys.stdin.fileno())
|
|
os.close(sys.stdout.fileno())
|
|
os.close(sys.stderr.fileno())
|
|
|
|
return pid
|
|
except OSError as e:
|
|
sys.exit(1)
|
|
except OSError as e:
|
|
sys.exit(1)
|
|
|
|
def send_data(s, data):
|
|
packed_len = struct.pack('!Q',len(data))
|
|
return s.sendall(packed_len + data)
|
|
|
|
def recv_data(s):
|
|
header_len = 8 # size of a packed unsigned long long
|
|
data = b""
|
|
while len(data) < header_len:
|
|
d = s.recv(header_len - len(data))
|
|
if not d:
|
|
return None
|
|
data += d
|
|
data_len = struct.unpack('!Q',data[:header_len])[0]
|
|
data = data[header_len:]
|
|
while len(data) < data_len:
|
|
d = s.recv(data_len - len(data))
|
|
if not d:
|
|
return None
|
|
data += d
|
|
return data
|
|
|
|
def log(msg, host=None, **kwargs):
|
|
if host:
|
|
msg = '<%s> %s' % (host, msg)
|
|
facility = getattr(syslog, C.DEFAULT_SYSLOG_FACILITY, syslog.LOG_USER)
|
|
syslog.openlog('ansible-connection', 0, facility)
|
|
syslog.syslog(syslog.LOG_INFO, str(msg))
|
|
|
|
|
|
class Server():
|
|
|
|
def __init__(self, path, play_context):
|
|
self.path = path
|
|
self.play_context = play_context
|
|
|
|
self._start_time = datetime.datetime.now()
|
|
|
|
self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
|
|
rc, out, err = self.conn._connect()
|
|
if rc != 0:
|
|
raise AnsibleConnectionFailure(err)
|
|
|
|
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
self.socket.bind(path)
|
|
self.socket.listen(1)
|
|
|
|
signal.signal(signal.SIGALRM, self.alarm_handler)
|
|
|
|
def dispatch(self, obj, name, *args, **kwargs):
|
|
meth = getattr(obj, name, None)
|
|
if meth:
|
|
return meth(*args, **kwargs)
|
|
|
|
def log(self, msg):
|
|
log(msg, host=self.play_context.remote_addr)
|
|
|
|
def alarm_handler(self, signum, frame):
|
|
'''
|
|
Alarm handler
|
|
'''
|
|
# FIXME: this should also set internal flags for other
|
|
# areas of code to check, so they can terminate
|
|
# earlier than the socket going back to the accept
|
|
# call and failing there.
|
|
#
|
|
# hooks the connection plugin to handle any cleanup
|
|
self.dispatch(self.conn, 'alarm_handler', signum, frame)
|
|
self.socket.close()
|
|
|
|
def run(self):
|
|
try:
|
|
while True:
|
|
# set the alarm, if we don't get an accept before it
|
|
# goes off we exit (via an exception caused by the socket
|
|
# getting closed while waiting on accept())
|
|
# FIXME: is this the best way to exit? as noted above in the
|
|
# handler we should probably be setting a flag to check
|
|
# here and in other parts of the code
|
|
signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
|
|
try:
|
|
(s, addr) = self.socket.accept()
|
|
# clear the alarm
|
|
# FIXME: potential race condition here between the accept and
|
|
# time to this call.
|
|
signal.alarm(0)
|
|
except:
|
|
break
|
|
|
|
while True:
|
|
data = recv_data(s)
|
|
if not data:
|
|
break
|
|
|
|
signal.alarm(C.DEFAULT_TIMEOUT)
|
|
|
|
rc = 255
|
|
try:
|
|
if data.startswith(b'EXEC: '):
|
|
cmd = data.split(b'EXEC: ')[1]
|
|
(rc, stdout, stderr) = self.conn.exec_command(cmd)
|
|
elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '):
|
|
(op, src, dst) = shlex.split(to_native(data))
|
|
stdout = stderr = ''
|
|
try:
|
|
if op == 'FETCH:':
|
|
self.conn.fetch_file(src, dst)
|
|
elif op == 'PUT:':
|
|
self.conn.put_file(src, dst)
|
|
rc = 0
|
|
except:
|
|
pass
|
|
elif data.startswith(b'CONTEXT: '):
|
|
pc_data = data.split(b'CONTEXT: ')[1]
|
|
|
|
src = StringIO(pc_data)
|
|
pc_data = cPickle.load(src)
|
|
src.close()
|
|
|
|
pc = PlayContext()
|
|
pc.deserialize(pc_data)
|
|
|
|
self.dispatch(self.conn, 'update_play_context', pc)
|
|
continue
|
|
else:
|
|
stdout = ''
|
|
stderr = 'Invalid action specified'
|
|
except:
|
|
stdout = ''
|
|
stderr = traceback.format_exc()
|
|
|
|
signal.alarm(0)
|
|
|
|
send_data(s, to_bytes(str(rc)))
|
|
send_data(s, to_bytes(stdout))
|
|
send_data(s, to_bytes(stderr))
|
|
s.close()
|
|
except Exception as e:
|
|
log(traceback.format_exc())
|
|
finally:
|
|
# when done, close the connection properly and cleanup
|
|
# the socket file so it can be recreated
|
|
end_time = datetime.datetime.now()
|
|
delta = end_time - self._start_time
|
|
log('shutting down connection, connection was active for %s secs' % delta, self.play_context.remote_addr)
|
|
try:
|
|
self.conn.close()
|
|
self.socket.close()
|
|
except Exception as e:
|
|
pass
|
|
os.remove(self.path)
|
|
|
|
def main():
|
|
|
|
try:
|
|
# read the play context data via stdin, which means depickling it
|
|
# FIXME: as noted above, we will probably need to deserialize the
|
|
# connection loader here as well at some point, otherwise this
|
|
# won't find role- or playbook-based connection plugins
|
|
cur_line = sys.stdin.readline()
|
|
init_data = ''
|
|
while cur_line.strip() != '#END_INIT#':
|
|
if cur_line == '':
|
|
raise Exception("EOL found before init data was complete")
|
|
init_data += cur_line
|
|
cur_line = sys.stdin.readline()
|
|
src = BytesIO(to_bytes(init_data))
|
|
pc_data = cPickle.load(src)
|
|
|
|
pc = PlayContext()
|
|
pc.deserialize(pc_data)
|
|
except Exception as e:
|
|
# FIXME: better error message/handling/logging
|
|
print("FAIL: %s" % e)
|
|
print(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
display.verbosity = pc.verbosity
|
|
|
|
# here we create a hash to use later when creating the socket file,
|
|
# so we can hide the info about the target host/user/etc.
|
|
m = hashlib.sha256()
|
|
for attr in ('connection', 'remote_addr', 'port', 'remote_user'):
|
|
val = getattr(pc, attr, None)
|
|
if val:
|
|
m.update(to_bytes(val))
|
|
|
|
# create the persistent connection dir if need be and create the paths
|
|
# which we will be using later
|
|
tmp_path = unfrackpath("$HOME/.ansible/pc")
|
|
makedirs_safe(tmp_path)
|
|
lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
|
|
sf_path = unfrackpath("%s/conn-%s" % (tmp_path, m.hexdigest()[0:12]))
|
|
|
|
# if the socket file doesn't exist, spin up the daemon process
|
|
lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)
|
|
fcntl.lockf(lock_fd, fcntl.LOCK_EX)
|
|
if not os.path.exists(sf_path):
|
|
pid = do_fork()
|
|
if pid == 0:
|
|
try:
|
|
server = Server(sf_path, pc)
|
|
except AnsibleConnectionFailure as exc:
|
|
log(str(exc), pc.remote_addr)
|
|
except Exception as exc:
|
|
log(traceback.format_exc(), pc.remote_addr)
|
|
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
|
|
os.close(lock_fd)
|
|
server.run()
|
|
sys.exit(0)
|
|
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
|
|
os.close(lock_fd)
|
|
|
|
# now connect to the daemon process
|
|
# FIXME: if the socket file existed but the daemonized process was killed,
|
|
# the connection will timeout here. Need to make this more resilient.
|
|
rc = 0
|
|
while rc == 0:
|
|
data = sys.stdin.readline()
|
|
if data == '':
|
|
break
|
|
if data.strip() == '':
|
|
continue
|
|
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
attempts = 1
|
|
while True:
|
|
try:
|
|
sf.connect(sf_path)
|
|
break
|
|
except socket.error:
|
|
# FIXME: better error handling/logging/message here
|
|
time.sleep(C.PERSISTENT_CONNECT_INTERVAL)
|
|
attempts += 1
|
|
if attempts > C.PERSISTENT_CONNECT_RETRIES:
|
|
sys.stderr.write('failed to connect to the listener socket, '
|
|
'connection timeout. See syslog for more details')
|
|
sys.exit(255)
|
|
|
|
# send the play_context back into the connection so the connection
|
|
# can handle any privilege escalation activities
|
|
pc_data = 'CONTEXT: %s' % src.getvalue()
|
|
send_data(sf, to_bytes(pc_data))
|
|
src.close()
|
|
|
|
send_data(sf, to_bytes(data.strip()))
|
|
|
|
rc = int(recv_data(sf), 10)
|
|
stdout = recv_data(sf)
|
|
stderr = recv_data(sf)
|
|
|
|
sys.stdout.write(to_native(stdout))
|
|
sys.stderr.write(to_native(stderr))
|
|
|
|
sf.close()
|
|
break
|
|
|
|
sys.exit(rc)
|
|
|
|
if __name__ == '__main__':
|
|
display = Display()
|
|
display.display = log
|
|
main()
|