Replace GetStream/SetStream with Disconnect()

pull/35/head
David Wilson 8 years ago
parent e042bfa954
commit 29f38d578c

@ -367,8 +367,8 @@ class Stream(BasicStream):
"""Close our associated file descriptor and tell registered callbacks """Close our associated file descriptor and tell registered callbacks
the connection has been destroyed.""" the connection has been destroyed."""
LOG.debug('%r.Disconnect()', self) LOG.debug('%r.Disconnect()', self)
if self._context.GetStream() is self: if self._context.stream is self:
self._context.SetStream(None) self._context.Disconnect()
try: try:
os.close(self.read_side.fd) os.close(self.read_side.fd)
@ -392,7 +392,7 @@ class Stream(BasicStream):
def Accept(self, rfd, wfd): def Accept(self, rfd, wfd):
self.read_side = Side(self, os.dup(rfd)) self.read_side = Side(self, os.dup(rfd))
self.write_side = Side(self, os.dup(wfd)) self.write_side = Side(self, os.dup(wfd))
self._context.SetStream(self) self._context.stream = self
def Connect(self): def Connect(self):
"""Connect to a Broker at the address specified in our associated """Connect to a Broker at the address specified in our associated
@ -499,14 +499,17 @@ class Context(object):
""" """
Represent a remote context regardless of connection method. Represent a remote context regardless of connection method.
""" """
stream = None
def __init__(self, broker, name=None, hostname=None, username=None, def __init__(self, broker, name=None, hostname=None, username=None,
key=None, parent_addr=None): key=None, parent_addr=None, finalize_on_disconnect=False):
self.broker = broker self.broker = broker
self.name = name self.name = name
self.hostname = hostname self.hostname = hostname
self.username = username self.username = username
self.parent_addr = parent_addr
self.key = key or ('%016x' % random.getrandbits(128)) self.key = key or ('%016x' % random.getrandbits(128))
self.parent_addr = parent_addr
self.finalize_on_disconnect = finalize_on_disconnect
self._last_handle = 1000L self._last_handle = 1000L
self._handle_map = {} self._handle_map = {}
@ -515,12 +518,11 @@ class Context(object):
self.responder = MasterModuleResponder(self) self.responder = MasterModuleResponder(self)
self.log_forwarder = LogForwarder(self) self.log_forwarder = LogForwarder(self)
def GetStream(self): def Disconnect(self):
return self._stream self.stream = None
if self.finalize_on_disconnect:
def SetStream(self, stream): LOG.debug('Parent stream is gone, dying.')
self._stream = stream self.broker.Finalize(wait=False)
return stream
def AllocHandle(self): def AllocHandle(self):
"""Allocate a handle.""" """Allocate a handle."""
@ -556,12 +558,12 @@ class Context(object):
queue.put(data) queue.put(data)
self.AddHandleCB(_Receive, reply_to, persist=False) self.AddHandleCB(_Receive, reply_to, persist=False)
self._stream.Enqueue(handle, (reply_to,) + data) self.stream.Enqueue(handle, (reply_to,) + data)
try: try:
data = queue.get(True, deadline) data = queue.get(True, deadline)
except Queue.Empty: except Queue.Empty:
self._stream.Disconnect() self.stream.Disconnect()
raise TimeoutError('deadline exceeded.') raise TimeoutError('deadline exceeded.')
if data == _DEAD: if data == _DEAD:
@ -594,14 +596,6 @@ class Context(object):
return 'Context(%s)' % ', '.join(bits) return 'Context(%s)' % ', '.join(bits)
class ParentContext(Context):
def SetStream(self, stream):
super(ParentContext, self).SetStream(stream)
if stream is None:
LOG.debug('Parent stream is gone, dying.')
self.broker.Finalize(wait=False)
class Waker(BasicStream): class Waker(BasicStream):
def __init__(self, broker): def __init__(self, broker):
self._broker = broker self._broker = broker
@ -716,9 +710,9 @@ class Broker(object):
def Register(self, context): def Register(self, context):
"""Put a context under control of this broker.""" """Put a context under control of this broker."""
LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context, LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context,
context.GetStream().read_side, context.stream.read_side,
context.GetStream().write_side) context.stream.write_side)
self.UpdateStream(context.GetStream()) self.UpdateStream(context.stream)
self._contexts[context.name] = context self._contexts[context.name] = context
return context return context
@ -726,7 +720,8 @@ class Broker(object):
"""Get the named context running on the local machine, creating it if """Get the named context running on the local machine, creating it if
it does not exist.""" it does not exist."""
context = Context(self, name) context = Context(self, name)
context.SetStream(LocalStream(context)).Connect() context.stream = LocalStream(context)
context.stream.Connect()
return self.Register(context) return self.Register(context)
def GetRemote(self, hostname, username, name=None, python_path=None): def GetRemote(self, hostname, username, name=None, python_path=None):
@ -736,11 +731,10 @@ class Broker(object):
(username, socket.gethostname(), os.getpid()) (username, socket.gethostname(), os.getpid())
context = Context(self, name, hostname, username) context = Context(self, name, hostname, username)
stream = SSHStream(context) context.stream = SSHStream(context)
if python_path: if python_path:
stream.python_path = python_path context.stream.python_path = python_path
context.SetStream(stream) context.stream.Connect()
stream.Connect()
return self.Register(context) return self.Register(context)
def _CallAndUpdate(self, stream, func): def _CallAndUpdate(self, stream, func):
@ -771,9 +765,8 @@ class Broker(object):
self._LoopOnce() self._LoopOnce()
for context in self._contexts.itervalues(): for context in self._contexts.itervalues():
stream = context.GetStream() if context.stream:
if stream: context.stream.Disconnect()
stream.Disconnect()
except Exception: except Exception:
LOG.exception('Loop() crashed') LOG.exception('Loop() crashed')
@ -810,7 +803,8 @@ class ExternalContext(object):
def _SetupMaster(self, key): def _SetupMaster(self, key):
self.broker = Broker() self.broker = Broker()
self.context = ParentContext(self.broker, 'parent', key=key) self.context = Context(self.broker, 'parent', key=key,
finalize_on_disconnect=True)
self.channel = Channel(self.context, CALL_FUNCTION) self.channel = Channel(self.context, CALL_FUNCTION)
self.stream = Stream(self.context) self.stream = Stream(self.context)
self.stream.Accept(0, 1) self.stream.Accept(0, 1)

Loading…
Cancel
Save