Replace GetStream/SetStream with Disconnect()

pull/35/head
David Wilson 8 years ago
parent e042bfa954
commit 29f38d578c

@ -367,8 +367,8 @@ class Stream(BasicStream):
"""Close our associated file descriptor and tell registered callbacks
the connection has been destroyed."""
LOG.debug('%r.Disconnect()', self)
if self._context.GetStream() is self:
self._context.SetStream(None)
if self._context.stream is self:
self._context.Disconnect()
try:
os.close(self.read_side.fd)
@ -392,7 +392,7 @@ class Stream(BasicStream):
def Accept(self, rfd, wfd):
self.read_side = Side(self, os.dup(rfd))
self.write_side = Side(self, os.dup(wfd))
self._context.SetStream(self)
self._context.stream = self
def Connect(self):
"""Connect to a Broker at the address specified in our associated
@ -499,14 +499,17 @@ class Context(object):
"""
Represent a remote context regardless of connection method.
"""
stream = None
def __init__(self, broker, name=None, hostname=None, username=None,
key=None, parent_addr=None):
key=None, parent_addr=None, finalize_on_disconnect=False):
self.broker = broker
self.name = name
self.hostname = hostname
self.username = username
self.parent_addr = parent_addr
self.key = key or ('%016x' % random.getrandbits(128))
self.parent_addr = parent_addr
self.finalize_on_disconnect = finalize_on_disconnect
self._last_handle = 1000L
self._handle_map = {}
@ -515,12 +518,11 @@ class Context(object):
self.responder = MasterModuleResponder(self)
self.log_forwarder = LogForwarder(self)
def GetStream(self):
return self._stream
def SetStream(self, stream):
self._stream = stream
return stream
def Disconnect(self):
self.stream = None
if self.finalize_on_disconnect:
LOG.debug('Parent stream is gone, dying.')
self.broker.Finalize(wait=False)
def AllocHandle(self):
"""Allocate a handle."""
@ -556,12 +558,12 @@ class Context(object):
queue.put(data)
self.AddHandleCB(_Receive, reply_to, persist=False)
self._stream.Enqueue(handle, (reply_to,) + data)
self.stream.Enqueue(handle, (reply_to,) + data)
try:
data = queue.get(True, deadline)
except Queue.Empty:
self._stream.Disconnect()
self.stream.Disconnect()
raise TimeoutError('deadline exceeded.')
if data == _DEAD:
@ -594,14 +596,6 @@ class Context(object):
return 'Context(%s)' % ', '.join(bits)
class ParentContext(Context):
def SetStream(self, stream):
super(ParentContext, self).SetStream(stream)
if stream is None:
LOG.debug('Parent stream is gone, dying.')
self.broker.Finalize(wait=False)
class Waker(BasicStream):
def __init__(self, broker):
self._broker = broker
@ -716,9 +710,9 @@ class Broker(object):
def Register(self, context):
"""Put a context under control of this broker."""
LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context,
context.GetStream().read_side,
context.GetStream().write_side)
self.UpdateStream(context.GetStream())
context.stream.read_side,
context.stream.write_side)
self.UpdateStream(context.stream)
self._contexts[context.name] = context
return context
@ -726,7 +720,8 @@ class Broker(object):
"""Get the named context running on the local machine, creating it if
it does not exist."""
context = Context(self, name)
context.SetStream(LocalStream(context)).Connect()
context.stream = LocalStream(context)
context.stream.Connect()
return self.Register(context)
def GetRemote(self, hostname, username, name=None, python_path=None):
@ -736,11 +731,10 @@ class Broker(object):
(username, socket.gethostname(), os.getpid())
context = Context(self, name, hostname, username)
stream = SSHStream(context)
context.stream = SSHStream(context)
if python_path:
stream.python_path = python_path
context.SetStream(stream)
stream.Connect()
context.stream.python_path = python_path
context.stream.Connect()
return self.Register(context)
def _CallAndUpdate(self, stream, func):
@ -771,9 +765,8 @@ class Broker(object):
self._LoopOnce()
for context in self._contexts.itervalues():
stream = context.GetStream()
if stream:
stream.Disconnect()
if context.stream:
context.stream.Disconnect()
except Exception:
LOG.exception('Loop() crashed')
@ -810,7 +803,8 @@ class ExternalContext(object):
def _SetupMaster(self, key):
self.broker = Broker()
self.context = ParentContext(self.broker, 'parent', key=key)
self.context = Context(self.broker, 'parent', key=key,
finalize_on_disconnect=True)
self.channel = Channel(self.context, CALL_FUNCTION)
self.stream = Stream(self.context)
self.stream.Accept(0, 1)

Loading…
Cancel
Save