Detect remote_user change in accelerate daemon and allow a restart

Fixes #5812
reviewable/pr18780/r1
James Cammarata 11 years ago
parent 9620af55b7
commit 0dff07b53e

@ -75,6 +75,7 @@ import getpass
import json import json
import os import os
import os.path import os.path
import pwd
import signal import signal
import socket import socket
import struct import struct
@ -280,6 +281,9 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
elif mode == 'fetch': elif mode == 'fetch':
vvvv("received a fetch request, getting it") vvvv("received a fetch request, getting it")
response = self.fetch(data) response = self.fetch(data)
elif mode == 'validate_user':
vvvv("received a request to validate the user id")
response = self.validate_user(data)
vvvv("response result is %s" % str(response)) vvvv("response result is %s" % str(response))
data2 = json.dumps(response) data2 = json.dumps(response)
@ -287,6 +291,10 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
vvvv("sending the response back to the controller") vvvv("sending the response back to the controller")
self.send_data(data2) self.send_data(data2)
vvvv("done sending the response") vvvv("done sending the response")
if mode == 'validate_user' and response.get('rc') == 1:
vvvv("detected a uid mismatch, shutting down")
self.server.shutdown()
except: except:
tb = traceback.format_exc() tb = traceback.format_exc()
log("encountered an unhandled exception in the handle() function") log("encountered an unhandled exception in the handle() function")
@ -295,6 +303,27 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
data2 = self.server.key.Encrypt(data2) data2 = self.server.key.Encrypt(data2)
self.send_data(data2) self.send_data(data2)
def validate_user(self, data):
if 'username' not in data:
return dict(failed=True, msg='No username specified')
vvvv("validating we're running as %s" % data['username'])
# get the current uid
c_uid = os.getuid()
try:
# the target uid
t_uid = pwd.getpwnam(data['username']).pw_uid
except:
vvvv("could not find user %s" % data['username'])
return dict(failed=True, msg='could not find user %s' % data['username'])
# and return rc=0 for success, rc=1 for failure
if c_uid == t_uid:
return dict(rc=0)
else:
return dict(rc=1)
def command(self, data): def command(self, data):
if 'cmd' not in data: if 'cmd' not in data:
return dict(failed=True, msg='internal error: cmd is required') return dict(failed=True, msg='internal error: cmd is required')
@ -409,14 +438,26 @@ def daemonize(module, password, port, timeout, minutes, ipv6):
signal.signal(signal.SIGALRM, catcher) signal.signal(signal.SIGALRM, catcher)
signal.setitimer(signal.ITIMER_REAL, 60 * minutes) signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
if ipv6: tries = 5
server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout) while tries > 0:
else: try:
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout) if ipv6:
server.allow_reuse_address = True server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout)
else:
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout)
server.allow_reuse_address = True
break
except:
vv("Failed to create the TCP server (tries left = %d)" % tries)
tries -= 1
time.sleep(0.2)
if tries == 0:
vv("Maximum number of attempts to create the TCP server reached, bailing out")
raise Exception("max # of attempts to serve reached")
vv("serving!") vv("serving!")
server.serve_forever(poll_interval=1.0) server.serve_forever(poll_interval=0.1)
except Exception, e: except Exception, e:
tb = traceback.format_exc() tb = traceback.format_exc()
log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb)) log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb))

Loading…
Cancel
Save