Merge branch 'accelerate_improvements' into devel

Conflicts:
	library/utilities/accelerate
reviewable/pr18780/r1
James Cammarata 11 years ago
commit 88b9bc5de8

@ -35,6 +35,12 @@ options:
required: false
default: 5099
aliases: []
timeout:
description:
- The number of seconds the socket will wait for data. If none is received when the timeout value is reached, the connection will be closed.
required: false
default: 300
aliases: []
minutes:
description:
- The I(accelerate) listener daemon is started on nodes and will stay around for
@ -58,24 +64,25 @@ EXAMPLES = '''
- command: /usr/bin/anything
'''
import base64
import getpass
import os
import os.path
import tempfile
import sys
import shutil
import signal
import socket
import struct
import time
import base64
import getpass
import sys
import syslog
import signal
import tempfile
import time
import signal
import traceback
import SocketServer
from datetime import datetime
from threading import Thread
syslog.openlog('ansible-%s' % os.path.basename(__file__))
PIDFILE = os.path.expanduser("~/.accelerate.pid")
@ -85,8 +92,22 @@ PIDFILE = os.path.expanduser("~/.accelerate.pid")
# which leaves room for the TCP/IP header
CHUNK_SIZE=10240
def log(msg):
syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg)
# FIXME: this all should be moved to module_common, as it's
# pretty much a copy from the callbacks/util code
DEBUG_LEVEL=0
def log(msg, cap=0):
global DEBUG_LEVEL
if cap >= DEBUG_LEVEL:
syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg)
def vv(msg):
log(msg, cap=2)
def vvv(msg):
log(msg, cap=3)
def vvvv(msg):
log(msg, cap=4)
if os.path.exists(PIDFILE):
try:
@ -114,7 +135,7 @@ def daemonize_self(module, password, port, minutes):
try:
pid = os.fork()
if pid > 0:
log("exiting pid %s" % pid)
vvv("exiting pid %s" % pid)
# exit first parent
module.exit_json(msg="daemonized accelerate on port %s for %s minutes" % (port, minutes))
except OSError, e:
@ -134,7 +155,7 @@ def daemonize_self(module, password, port, minutes):
pid_file = open(PIDFILE, "w")
pid_file.write("%s" % pid)
pid_file.close()
log("pidfile written")
vvv("pidfile written")
sys.exit(0)
except OSError, e:
log("fork #2 failed: %d (%s)" % (e.errno, e.strerror))
@ -146,12 +167,25 @@ def daemonize_self(module, password, port, minutes):
os.dup2(dev_null.fileno(), sys.stderr.fileno())
log("daemonizing successful")
class ThreadWithReturnValue(Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None):
Thread.__init__(self, group, target, name, args, kwargs, Verbose)
self._return = None
def run(self):
if self._Thread__target is not None:
self._return = self._Thread__target(*self._Thread__args,
**self._Thread__kwargs)
def join(self,timeout=None):
Thread.join(self, timeout=timeout)
return self._return
class ThreadedTCPServer(SocketServer.ThreadingTCPServer):
def __init__(self, server_address, RequestHandlerClass, module, password):
def __init__(self, server_address, RequestHandlerClass, module, password, timeout):
self.module = module
self.key = AesKey.Read(password)
self.allow_reuse_address = True
self.timeout = None
self.timeout = timeout
SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass)
class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
@ -162,50 +196,85 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
def recv_data(self):
header_len = 8 # size of a packed unsigned long long
data = b""
vvvv("in recv_data(), waiting for the header")
while len(data) < header_len:
d = self.request.recv(1024)
d = self.request.recv(header_len - len(data))
if not d:
vvv("received nothing, bailing out")
return None
data += d
vvvv("in recv_data(), got the header, unpacking")
data_len = struct.unpack('Q',data[:header_len])[0]
data = data[header_len:]
vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
while len(data) < data_len:
d = self.request.recv(1024)
d = self.request.recv(data_len - len(data))
if not d:
vvv("received nothing, bailing out")
return None
data += d
vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
vvvv("received all of the data, returning")
return data
def handle(self):
while True:
#log("waiting for data")
data = self.recv_data()
if not data:
break
try:
#log("got data, decrypting")
data = self.server.key.Decrypt(data)
#log("decryption done")
except:
log("bad decrypt, skipping...")
data2 = json.dumps(dict(rc=1))
try:
while True:
vvvv("waiting for data")
data = self.recv_data()
if not data:
vvvv("received nothing back from recv_data(), breaking out")
break
try:
vvvv("got data, decrypting")
data = self.server.key.Decrypt(data)
vvvv("decryption done")
except:
vv("bad decrypt, skipping...")
data2 = json.dumps(dict(rc=1))
data2 = self.server.key.Encrypt(data2)
send_data(client, data2)
return
vvvv("loading json from the data")
data = json.loads(data)
mode = data['mode']
response = {}
last_pong = datetime.now()
if mode == 'command':
vvvv("received a command request, running it")
twrv = ThreadWithReturnValue(target=self.command, args=(data,))
twrv.start()
response = None
while twrv.is_alive():
if (datetime.now() - last_pong).seconds >= 15:
last_pong = datetime.now()
vvvv("command still running, sending keepalive packet")
data2 = json.dumps(dict(pong=True))
data2 = self.server.key.Encrypt(data2)
self.send_data(data2)
time.sleep(0.1)
response = twrv._return
vvvv("thread is done, response from join was %s" % response)
elif mode == 'put':
vvvv("received a put request, putting it")
response = self.put(data)
elif mode == 'fetch':
vvvv("received a fetch request, getting it")
response = self.fetch(data)
vvvv("response result is %s" % str(response))
data2 = json.dumps(response)
data2 = self.server.key.Encrypt(data2)
send_data(client, data2)
return
#log("loading json from the data")
data = json.loads(data)
mode = data['mode']
response = {}
if mode == 'command':
response = self.command(data)
elif mode == 'put':
response = self.put(data)
elif mode == 'fetch':
response = self.fetch(data)
data2 = json.dumps(response)
vvvv("sending the response back to the controller")
self.send_data(data2)
vvvv("done sending the response")
except:
tb = traceback.format_exc()
log("encountered an unhandled exception in the handle() function")
log("error was:\n%s" % tb)
data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function"))
data2 = self.server.key.Encrypt(data2)
self.send_data(data2)
@ -217,14 +286,14 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
if 'executable' not in data:
return dict(failed=True, msg='internal error: executable is required')
#log("executing: %s" % data['cmd'])
vvvv("executing: %s" % data['cmd'])
rc, stdout, stderr = self.server.module.run_command(data['cmd'], executable=data['executable'], close_fds=True)
if stdout is None:
stdout = ''
if stderr is None:
stderr = ''
#log("got stdout: %s" % stdout)
#log("got stderr: %s" % stderr)
vvvv("got stdout: %s" % stdout)
vvvv("got stderr: %s" % stderr)
return dict(rc=rc, stdout=stdout, stderr=stderr)
@ -235,7 +304,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
try:
fd = file(data['in_path'], 'rb')
fstat = os.stat(data['in_path'])
log("FETCH file is %d bytes" % fstat.st_size)
vvv("FETCH file is %d bytes" % fstat.st_size)
while fd.tell() < fstat.st_size:
data = fd.read(CHUNK_SIZE)
last = False
@ -276,7 +345,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
final_path = None
final_user = None
if 'user' in data and data.get('user') != getpass.getuser():
log("the target user doesn't match this user, we'll move the file into place via sudo")
vv("the target user doesn't match this user, we'll move the file into place via sudo")
(fd,out_path) = tempfile.mkstemp(prefix='ansible.', dir=os.path.expanduser('~/.ansible/tmp/'))
out_fd = os.fdopen(fd, 'w', 0)
final_path = data['out_path']
@ -306,15 +375,15 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
log("failed to put the file: %s" % tb)
return dict(failed=True, stdout="Could not write the file")
finally:
#log("wrote %d bytes" % bytes)
vvvv("wrote %d bytes" % bytes)
out_fd.close()
if final_path:
log("moving %s to %s" % (out_path, final_path))
vvv("moving %s to %s" % (out_path, final_path))
self.server.module.atomic_move(out_path, final_path)
return dict()
def daemonize(module, password, port, minutes):
def daemonize(module, password, port, timeout, minutes):
try:
daemonize_self(module, password, port, minutes)
@ -324,10 +393,10 @@ def daemonize(module, password, port, minutes):
signal.signal(signal.SIGALRM, catcher)
signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password)
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout)
server.allow_reuse_address = True
log("serving!")
vv("serving!")
server.serve_forever(poll_interval=1.0)
except Exception, e:
tb = traceback.format_exc()
@ -335,24 +404,30 @@ def daemonize(module, password, port, minutes):
sys.exit(0)
def main():
global DEBUG_LEVEL
module = AnsibleModule(
argument_spec = dict(
port=dict(required=False, default=5099),
timeout=dict(required=False, default=300),
password=dict(required=True),
minutes=dict(required=False, default=30),
debug=dict(required=False, default=0, type='int')
),
supports_check_mode=True
)
password = base64.b64decode(module.params['password'])
port = int(module.params['port'])
timeout = int(module.params['timeout'])
minutes = int(module.params['minutes'])
debug = int(module.params['debug'])
if not HAS_KEYCZAR:
module.fail_json(msg="keyczar is not installed")
daemonize(module, password, port, minutes)
DEBUG_LEVEL=debug
daemonize(module, password, port, timeout, minutes)
# this is magic, see lib/ansible/module_common.py
#<<INCLUDE_ANSIBLE_MODULE_COMMON>>

Loading…
Cancel
Save