[stream-refactor] merge stdout+stderr when reporting EofError

Fixes sudo regression
pull/607/head
David Wilson 5 years ago
parent 1d2bfc28da
commit 4eecc08047

@ -1021,13 +1021,9 @@ class LineLoggingProtocolMixin(object):
self.logged_lines = [] self.logged_lines = []
self.logged_partial = None self.logged_partial = None
def get_history(self):
s = b('\n').join(self.logged_lines) + (self.logged_partial or b(''))
return mitogen.core.to_text(s)
def on_line_received(self, line): def on_line_received(self, line):
self.logged_partial = None self.logged_partial = None
self.logged_lines.append(line) self.logged_lines.append((time.time(), line))
self.logged_lines[:] = self.logged_lines[-100:] self.logged_lines[:] = self.logged_lines[-100:]
return super(LineLoggingProtocolMixin, self).on_line_received(line) return super(LineLoggingProtocolMixin, self).on_line_received(line)
@ -1035,8 +1031,25 @@ class LineLoggingProtocolMixin(object):
self.logged_partial = line self.logged_partial = line
return super(LineLoggingProtocolMixin, self).on_partial_line_received(line) return super(LineLoggingProtocolMixin, self).on_partial_line_received(line)
def on_disconnect(self, broker):
if self.logged_partial:
self.logged_lines.append((time.time(), self.logged_partial))
self.logged_partial = None
super(LineLoggingProtocolMixin, self).on_disconnect(broker)
def get_history(streams):
history = []
for stream in streams:
if stream:
history.extend(getattr(stream.protocol, 'logged_lines', []))
history.sort()
class RegexProtocol(mitogen.core.DelimitedProtocol): s = b('\n').join(h[1] for h in history)
return mitogen.core.to_text(s)
class RegexProtocol(LineLoggingProtocolMixin, mitogen.core.DelimitedProtocol):
""" """
Implement a delimited protocol where messages matching a set of regular Implement a delimited protocol where messages matching a set of regular
expressions are dispatched to individual handler methods. Input is expressions are dispatched to individual handler methods. Input is
@ -1055,6 +1068,7 @@ class RegexProtocol(mitogen.core.DelimitedProtocol):
PARTIAL_PATTERNS = [] PARTIAL_PATTERNS = []
def on_line_received(self, line): def on_line_received(self, line):
super(RegexProtocol, self).on_line_received(line)
for pattern, func in self.PATTERNS: for pattern, func in self.PATTERNS:
match = pattern.search(line) match = pattern.search(line)
if match is not None: if match is not None:
@ -1067,6 +1081,7 @@ class RegexProtocol(mitogen.core.DelimitedProtocol):
self.stream.name, line.decode('utf-8', 'replace')) self.stream.name, line.decode('utf-8', 'replace'))
def on_partial_line_received(self, line): def on_partial_line_received(self, line):
super(RegexProtocol, self).on_partial_line_received(line)
LOG.debug('%s: (partial): %s', LOG.debug('%s: (partial): %s',
self.stream.name, line.decode('utf-8', 'replace')) self.stream.name, line.decode('utf-8', 'replace'))
for pattern, func in self.PARTIAL_PATTERNS: for pattern, func in self.PARTIAL_PATTERNS:
@ -1081,7 +1096,7 @@ class RegexProtocol(mitogen.core.DelimitedProtocol):
self.stream.name, line.decode('utf-8', 'replace')) self.stream.name, line.decode('utf-8', 'replace'))
class BootstrapProtocol(LineLoggingProtocolMixin, RegexProtocol): class BootstrapProtocol(RegexProtocol):
""" """
Respond to stdout of a child during bootstrap. Wait for EC0_MARKER to be Respond to stdout of a child during bootstrap. Wait for EC0_MARKER to be
written by the first stage to indicate it can receive the bootstrap, then written by the first stage to indicate it can receive the bootstrap, then
@ -1124,7 +1139,7 @@ class BootstrapProtocol(LineLoggingProtocolMixin, RegexProtocol):
] ]
class LogProtocol(mitogen.core.DelimitedProtocol): class LogProtocol(LineLoggingProtocolMixin, mitogen.core.DelimitedProtocol):
""" """
For "hybrid TTY/socketpair" mode, after connection setup a spare TTY master For "hybrid TTY/socketpair" mode, after connection setup a spare TTY master
FD exists that cannot be closed, and to which SSH or sudo may continue FD exists that cannot be closed, and to which SSH or sudo may continue
@ -1136,6 +1151,7 @@ class LogProtocol(mitogen.core.DelimitedProtocol):
written to it. written to it.
""" """
def on_line_received(self, line): def on_line_received(self, line):
super(LogProtocol, self).on_line_received(line)
LOG.info(u'%s: %s', self.stream.name, line.decode('utf-8', 'replace')) LOG.info(u'%s: %s', self.stream.name, line.decode('utf-8', 'replace'))
@ -1425,7 +1441,9 @@ class Connection(object):
if not self.timer.cancelled: if not self.timer.cancelled:
self.timer.cancel() self.timer.cancel()
self._fail_connection(EofError( self._fail_connection(EofError(
self.eof_error_msg + self.stream.protocol.get_history() self.eof_error_msg + get_history(
[self.stream, self.stderr_stream]
)
)) ))
self.proc._async_reap(self, self._router) self.proc._async_reap(self, self._router)

@ -266,6 +266,4 @@ class Connection(mitogen.parent.Connection):
if self.options.selinux_type: if self.options.selinux_type:
bits += ['-t', self.options.selinux_type] bits += ['-t', self.options.selinux_type]
bits = bits + ['--'] + super(Connection, self).get_boot_command() return bits + ['--'] + super(Connection, self).get_boot_command()
LOG.debug('sudo command line: %r', bits)
return bits

Loading…
Cancel
Save