SSH working

* Get rid of persistent functions for now.
* Split select into read/write sides for unidirectional SSH IO.
* Put more of Loop in a try/except.
pull/35/head
David Wilson 8 years ago
parent 1a30570057
commit e62b891b9a

@ -90,7 +90,6 @@ def CreateChild(*args):
if not pid:
os.dup2(childfp.fileno(), 0)
os.dup2(childfp.fileno(), 1)
sys.stderr = open('milf2', 'w', 1)
childfp.close()
parentfp.close()
os.execvp(args[0], args)
@ -121,21 +120,6 @@ class Formatter(logging.Formatter):
return p + ('{%s} %s' % (os.getpid(), s))
class PartialFunction(object):
'''
Partial function implementation.
'''
def __init__(self, fn, *partial_args):
self.fn = fn
self.partial_args = partial_args
def __call__(self, *args, **kwargs):
return self.fn(*(self.partial_args+args), **kwargs)
def __repr__(self):
return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args)
class Channel(object):
def __init__(self, stream, handle):
self._stream = stream
@ -268,11 +252,12 @@ class MasterModuleResponder(object):
def __init__(self, stream):
self._stream = stream
def GetModule(self, killed, (_, (reply_handle, fullname))):
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname)
def GetModule(self, killed, data):
if killed:
return
_, (reply_handle, fullname) = data
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname)
mod = sys.modules.get(fullname)
if mod:
source = zlib.compress(inspect.getsource(mod))
@ -285,12 +270,24 @@ class MasterModuleResponder(object):
#
class BasicStream(object):
class Side(object):
def __init__(self, stream, fd):
self.stream = stream
self.fd = fd
def __repr__(self):
return '<fd %r of %r>' % (self.fd, self.stream)
def fileno(self):
return self._fd
return self.fd
class BasicStream(object):
read_side = None
write_side = None
def Disconnect(self):
LOG.debug('%r: disconnect on %r fd %d', self._broker, self, self._fd)
LOG.debug('%r: disconnect on %r', self._broker, self)
self._broker.RemoveStream(self)
def ReadMore(self):
@ -325,11 +322,9 @@ class Stream(BasicStream):
self._pickler_file = cStringIO.StringIO()
self._pickler = cPickle.Pickler(self._pickler_file, protocol=2)
self._pickler.persistent_id = self._CheckFunctionPerID
self._unpickler_file = cStringIO.StringIO()
self._unpickler = cPickle.Unpickler(self._unpickler_file)
self._unpickler.persistent_load = self._LoadFunctionFromPerID
def Pickle(self, obj):
'''
@ -365,37 +360,6 @@ class Stream(BasicStream):
self._unpickler_file.truncate(0)
return data
def _CheckFunctionPerID(self, obj):
'''
Return None or a persistent ID for an object.
Please see the cPickle documentation.
Args:
obj: object
Returns:
str
'''
if isinstance(obj, (types.FunctionType, types.MethodType)):
pid = 'FUNC:' + repr(obj)
self._func_refs[per_id] = obj
return pid
def _LoadFunctionFromPerID(self, pid):
'''
Load an object from a persistent ID.
Please see the cPickle documentation.
Args:
pid: str
Returns:
object
'''
if not pid.startswith('FUNC:'):
raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
return PartialFunction(self._CallPersistentWhatsit, pid)
def AllocHandle(self):
'''
Allocate a unique handle for this stream.
@ -434,7 +398,7 @@ class Stream(BasicStream):
'''
LOG.debug('%r.Receive()', self)
buf = os.read(self._fd, 4096)
buf = os.read(self.read_side.fd, 4096)
if not buf:
return self.Disconnect()
@ -484,7 +448,7 @@ class Stream(BasicStream):
IOError
'''
LOG.debug('%r.Transmit()', self)
written = os.write(self._fd, self._output_buf[:4096])
written = os.write(self.write_side.fd, self._output_buf[:4096])
self._output_buf = self._output_buf[written:]
def WriteMore(self):
@ -513,32 +477,41 @@ class Stream(BasicStream):
def Disconnect(self):
'''
Close our associated file descriptor and tell any registered callbacks
that the connection has been destroyed.
Close our associated file descriptor and tell registered callbacks the
connection has been destroyed.
'''
LOG.debug('%r.Disconnect()', self)
if self._context.GetStream() is self:
self._context.SetStream(None)
try:
os.close(self._fd)
os.close(self.read_side.fd)
except OSError, e:
LOG.debug('%r.Disconnect(): did not close fd %s: %s',
self, self._fd, e)
self, self.read_side.fd, e)
self._fd = None
if self._context.GetStream() is self:
self._context.SetStream(None)
if self.read_side.fd != self.write_side.fd:
try:
os.close(self.write_side.fd)
except OSError, e:
LOG.debug('%r.Disconnect(): did not close fd %s: %s',
self, self.write_side.fd, e)
self.read_side.fd = None
self.write_side.fd = None
for handle, (persist, fn) in self._handle_map.iteritems():
LOG.debug('%r.Disconnect(): stale callback handle=%r; fn=%r',
self, handle, fn)
fn(True, None)
@classmethod
def Accept(cls, context, fd):
def Accept(cls, context, rfd, wfd):
'''
'''
stream = cls(context)
stream._fd = os.dup(fd)
stream.read_side = Side(stream, os.dup(rfd))
stream.write_side = Side(stream, os.dup(wfd))
context.SetStream(stream)
context.broker.Register(context)
return stream
@ -550,7 +523,8 @@ class Stream(BasicStream):
LOG.debug('%r.Connect()', self)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._fd = sock.fileno()
self.read_side = Side(self, sock.fileno())
self.write_side = Side(self, sock.fileno())
sock.connect(self._context.parent_addr)
self.Enqueue(0, self._context.name)
@ -635,9 +609,11 @@ class LocalStream(Stream):
def Connect(self):
LOG.debug('%r.Connect()', self)
pid, sock = CreateChild(*self.GetBootCommand())
self._fd = os.dup(sock.fileno())
self.read_side = Side(self, os.dup(sock.fileno()))
self.write_side = self.read_side
sock.close()
LOG.debug('%r.Connect(): child process stdin/stdout=%r', self, self._fd)
LOG.debug('%r.Connect(): child process stdin/stdout=%r',
self, self.read_side.fd)
source = inspect.getsource(sys.modules[__name__])
source += '\nExternalContextMain(%r, %r, %r)\n' % (
@ -648,8 +624,8 @@ class LocalStream(Stream):
compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed
write_all(self._fd, preamble)
assert os.read(self._fd, 3) == 'OK\n'
write_all(self.write_side.fd, preamble)
assert os.read(self.read_side.fd, 3) == 'OK\n'
class SSHStream(LocalStream):
@ -751,16 +727,17 @@ class Context(object):
class Waker(BasicStream):
def __init__(self, broker):
self._broker = broker
self._rfd, self._wfd = os.pipe()
self._fd = self._rfd
rfd, wfd = os.pipe()
self.read_side = Side(self, rfd)
self.write_side = Side(self, wfd)
broker.AddStream(self)
def Wake(self):
os.write(self._wfd, ' ')
os.write(self.write_side.fd, ' ')
def Receive(self):
LOG.debug('%r: waking %r', self, self._broker)
os.read(self._rfd, 1)
os.read(self.read_side.fd, 1)
class Listener(BasicStream):
@ -770,7 +747,7 @@ class Listener(BasicStream):
self._sock.bind(address or ('0.0.0.0', 0))
self._sock.listen(backlog)
self._listen_addr = self._sock.getsockname()
self._fd = self._sock.fileno()
self.read_side = Side(self, self._sock.fileno())
broker.AddStream(self)
def Receive(self):
@ -795,7 +772,6 @@ class Broker(object):
self._waker = Waker(self)
self._thread = threading.Thread(target=self._Loop, name='Broker')
self._thread.setDaemon(True)
self._thread.start()
def CreateListener(self, address=None, backlog=30):
@ -809,16 +785,15 @@ class Broker(object):
def UpdateStream(self, stream, wake=False):
LOG.debug('UpdateStream(%r, wake=%s)', stream, wake)
fileno = stream.fileno()
if fileno is not None and stream.ReadMore():
self._readers.add(stream)
if stream.ReadMore() and stream.read_side.fileno():
self._readers.add(stream.read_side)
else:
self._readers.discard(stream)
self._readers.discard(stream.read_side)
if fileno is not None and stream.WriteMore():
self._writers.add(stream)
if stream.WriteMore() and stream.write_side.fileno():
self._writers.add(stream.write_side)
else:
self._writers.discard(stream)
self._writers.discard(stream.write_side)
if wake:
self._waker.Wake()
@ -836,8 +811,9 @@ class Broker(object):
'''
Put a context under control of this broker.
'''
LOG.debug('%r.Register(%r) -> fd=%r', self, context,
context.GetStream().fileno())
LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context,
context.GetStream().read_side,
context.GetStream().write_side)
self.AddStream(context.GetStream())
self._contexts[context.name] = context
return context
@ -855,7 +831,7 @@ class Broker(object):
context.SetStream(LocalStream(context)).Connect()
return self.Register(context)
def GetRemote(self, hostname, username, name=None):
def GetRemote(self, hostname, username, name=None, python_path=None):
'''
Return the named remote context, or create it if it doesn't exist.
'''
@ -864,51 +840,56 @@ class Broker(object):
(username, os.getenv('HOSTNAME'), os.getpid())
context = Context(self, name, hostname, username)
context.SetStream(SSHStream(context)).Connect()
stream = SSHStream(context)
if python_path:
stream.python_path = python_path
context.SetStream(stream)
stream.Connect()
return self.Register(context)
def _Loop(self):
try:
self.Loop()
except Exception:
LOG.exception('Loop() crashed')
def _LoopOnce(self):
LOG.debug('%r.Loop()', self)
self._lock.acquire()
self._lock.release()
#LOG.debug('readers = %r', self._readers)
#LOG.debug('rfds = %r', [r.fileno() for r in self._readers])
#LOG.debug('writers = %r', self._writers)
#LOG.debug('wfds = %r', [w.fileno() for w in self._writers])
rsides, wsides, _ = select.select(self._readers, self._writers, ())
for side in rsides:
LOG.debug('%r: POLLIN for %r', self, side.stream)
side.stream.Receive()
self.UpdateStream(side.stream)
for side in wsides:
LOG.debug('%r: POLLOUT for %r', self, side.stream)
side.stream.Transmit()
self.UpdateStream(side.stream)
def Loop(self):
def _Loop(self):
'''
Handle stream events until Finalize() is called.
'''
while not self._dead:
LOG.debug('%r.Loop()', self)
self._lock.acquire()
self._lock.release()
#LOG.debug('readers = %r', self._readers)
#LOG.debug('rfds = %r', [r.fileno() for r in self._readers])
#LOG.debug('writers = %r', self._writers)
rstrms, wstrms, _ = select.select(self._readers, self._writers, ())
for stream in rstrms:
LOG.debug('%r: POLLIN for %r', self, stream)
stream.Receive()
self.UpdateStream(stream)
try:
while not self._dead:
self._LoopOnce()
for stream in wstrms:
LOG.debug('%r: POLLOUT for %r', self, stream)
stream.Transmit()
self.UpdateStream(stream)
for context in self._contexts.itervalues():
stream = context.GetStream()
if stream:
stream.Disconnect()
except Exception:
LOG.exception('Loop() crashed')
def Finalize(self):
'''
Tell all active streams to disconnect.
'''
self._dead = True
self._waker.Wake()
self._lock.acquire()
try:
for name, context in self._contexts.iteritems():
stream = context.GetStream()
if stream:
stream.Disconnect()
finally:
self._lock.release()
self._lock.release()
def __repr__(self):
return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
@ -918,19 +899,20 @@ def ExternalContextMain(context_name, parent_addr, key):
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
logging.getLogger('').handlers[0].formatter = Formatter(False)
LOG.debug('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
# os.wait() # Reap the first stage.
os.wait() # Reap the first stage.
os.dup2(100, 0)
os.close(100)
broker = Broker()
context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
stream = Stream.Accept(context, 0)
stream = Stream.Accept(context, 0, 1)
os.close(0)
os.close(1)
# stream = context.SetStream(Stream(context))
# stream.
@ -942,7 +924,7 @@ def ExternalContextMain(context_name, parent_addr, key):
for call_info in Channel(stream, CALL_FUNCTION):
LOG.debug('ExternalContextMain(): CALL_FUNCTION %r', call_info)
(reply_handle, mod_name, class_name, func_name, args, kwargs) = call_info
reply_handle, mod_name, class_name, func_name, args, kwargs = call_info
try:
fn = getattr(__import__(mod_name), func_name)

Loading…
Cancel
Save