ssh/tailssh: fix exit-status ordering and improve signal/exit code handling

Fixes a race where CloseWrite() could be called before Exit(), causing
SSH clients (especially on macOS) to miss the exit status. Simplified
run() to use sync.WaitGroup and guarantee Exit() is sent before EOF per
RFC 4254 section 6.10.

Also:
- Send SIGHUP instead of SIGKILL on session termination
- Use exit code 127 for command not found
- Use exit code 255 for SSH permission/protocol errors
- Use exit code 254 for recording failures
- Complete TCP handlers only after I/O completes

Fixes tailscale/tailscale#18256

Signed-off-by: James Tucker <james@tailscale.com>
raggi/ssh-shutdown
James Tucker 7 days ago
parent d451cd54a7
commit e9a28ff0db
No known key found for this signature in database

@ -25,7 +25,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@ -804,11 +803,18 @@ func (ss *sshSession) killProcessOnContextDone() {
// We don't need to Process.Wait here, sshSession.run() does
// the waiting regardless of termination reason.
// TODO(maisem): should this be a SIGTERM followed by a SIGKILL?
ss.cmd.Process.Kill()
// Send SIGHUP like a real terminal disconnect would.
// The process may ignore it or exit cleanly.
ss.cmd.Process.Signal(syscall.SIGHUP)
})
}
// isNotFoundOrExecutable reports whether err is an error indicating
// the command could not be found or executed.
func isNotFoundOrExecutable(err error) bool {
return errors.Is(err, exec.ErrNotFound) || errors.Is(err, os.ErrNotExist)
}
// attachSession registers ss as an active session.
func (c *conn) attachSession(ss *sshSession) {
c.srv.sessionWaitGroup.Add(1)
@ -894,10 +900,11 @@ func (ss *sshSession) run() {
metricActiveSessions.Add(1)
defer metricActiveSessions.Add(-1)
defer ss.cancelCtx(errSessionDone)
defer ss.Close()
if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
ss.Exit(1)
ss.Exit(255)
return
}
defer ss.conn.detachSession(ss)
@ -919,7 +926,10 @@ func (ss *sshSession) run() {
if lu.Uid != fmt.Sprint(euid) {
ss.logf("can't switch to user %q from process euid %v", lu.Username, euid)
fmt.Fprintf(ss, "can't switch user\r\n")
ss.Exit(1)
// Exit code 255 indicates SSH protocol/permission error.
// This matches OpenSSH behavior for fatal errors that prevent
// the session from starting.
ss.Exit(255)
return
}
}
@ -948,7 +958,9 @@ func (ss *sshSession) run() {
fmt.Fprintf(ss, "can't start new recording\r\n")
}
ss.logf("startNewRecording: %v", err)
ss.Exit(1)
// Exit code 254 for recording infrastructure failure.
// Distinct from 255 (SSH protocol error) and 1 (general command failure).
ss.Exit(254)
return
}
ss.logf("startNewRecording: <nil>")
@ -961,95 +973,90 @@ func (ss *sshSession) run() {
err := ss.launchProcess()
if err != nil {
logf("start failed: %v", err.Error())
exitCode := 1
if errors.Is(err, context.Canceled) {
err := context.Cause(ss.ctx)
var uve userVisibleError
if errors.As(err, &uve) {
fmt.Fprintf(ss, "%s\r\n", uve)
}
} else if isNotFoundOrExecutable(err) {
// Use exit code 127 for "command not found" per shell convention.
// This matches standard SSH behavior.
exitCode = 127
}
ss.Exit(1)
ss.Exit(exitCode)
return
}
go ss.killProcessOnContextDone()
var processDone atomic.Bool
// Start goroutines to copy stdin/stdout/stderr.
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer ss.wrStdin.Close()
if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil {
logf("stdin copy: %v", err)
ss.cancelCtx(err)
}
}()
outputDone := make(chan struct{})
var openOutputStreams atomic.Int32
if ss.rdStderr != nil {
openOutputStreams.Store(2)
} else {
openOutputStreams.Store(1)
}
wg.Add(1)
go func() {
defer wg.Done()
defer ss.rdStdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.rdStdout)
if err != nil && !errors.Is(err, io.EOF) {
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
if !isErrBecauseProcessExited {
logf("stdout copy: %v", err)
ss.cancelCtx(err)
}
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
if _, err := io.Copy(rec.writer("o", ss), ss.rdStdout); err != nil && !errors.Is(err, io.EOF) {
logf("stdout copy: %v", err)
}
// Send EOF as soon as stdout copying completes. This allows sibling
// processes waiting for EOF to proceed, even if the main process hasn't
// exited yet. The channel remains open for sending exit-status later.
ss.CloseWrite()
}()
// rdStderr is nil for ptys.
if ss.rdStderr != nil {
wg.Add(1)
go func() {
defer wg.Done()
defer ss.rdStderr.Close()
_, err := io.Copy(ss.Stderr(), ss.rdStderr)
if err != nil {
if _, err := io.Copy(ss.Stderr(), ss.rdStderr); err != nil {
logf("stderr copy: %v", err)
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
}
}()
}
err = ss.cmd.Wait()
processDone.Store(true)
// This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the
// aforementioned goroutine.
ss.exitOnce.Do(func() {})
// Close the process-side of all pipes to signal the asynchronous
// io.Copy routines reading/writing from the pipes to terminate.
// Block for the io.Copy to finish before calling ss.Exit below.
closeAll(ss.childPipes...)
select {
case <-outputDone:
case <-ss.ctx.Done():
}
var exitCode int
if err == nil {
ss.logf("Session complete")
ss.Exit(0)
return
}
if ee, ok := err.(*exec.ExitError); ok {
code := ee.ProcessState.ExitCode()
ss.logf("Wait: code=%v", code)
ss.Exit(code)
return
exitCode = 0
} else if ee, ok := err.(*exec.ExitError); ok {
exitCode = ee.ProcessState.ExitCode()
ss.logf("Wait: code=%v", exitCode)
} else {
ss.logf("Wait: %v", err)
exitCode = 1
}
ss.logf("Wait: %v", err)
ss.Exit(1)
return
// Send exit-status immediately. Per RFC 4254 section 6.10, exit-status
// should be sent before channel close. EOF will be sent by the stdout
// goroutine when it finishes, and that's fine - CloseWrite only closes
// the data stream but keeps the channel open for exit-status.
ss.Exit(exitCode)
// Close process-side of pipes to signal io.Copy goroutines to finish.
closeAll(ss.childPipes...)
// Wait for all IO to complete.
wg.Wait()
}
// recordSSHToLocalDisk is a deprecated dev knob to allow recording SSH sessions

@ -188,7 +188,12 @@ func (sess *session) Exit(code int) error {
if err != nil {
return err
}
return sess.Close()
// Don't close the channel here. Per RFC 4254 section 6.10, the exit-status
// message should be sent before the channel is closed. By not closing immediately,
// we allow the session handler to complete any remaining I/O operations (like
// flushing output and sending EOF via CloseWrite) before the channel is closed
// by the request handler's cleanup code.
return nil
}
func (sess *session) User() string {
@ -273,6 +278,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
go func() {
sess.handler(sess)
sess.Exit(0)
sess.Close()
}()
case "subsystem":
if sess.handled {
@ -307,6 +313,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
go func() {
handler(sess)
sess.Exit(0)
sess.Close()
}()
case "env":
if sess.handled {

@ -53,16 +53,37 @@ func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh
}
go gossh.DiscardRequests(reqs)
defer ch.Close()
defer dconn.Close()
done := make(chan struct{}, 2)
go func() {
defer ch.Close()
defer dconn.Close()
defer ch.CloseWrite()
defer closeRead(dconn)
io.Copy(ch, dconn)
done <- struct{}{}
}()
go func() {
defer ch.Close()
defer dconn.Close()
defer closeWrite(dconn)
io.Copy(dconn, ch)
done <- struct{}{}
}()
<-done
<-done
}
func closeWrite(c net.Conn) error {
if cw, ok := c.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return c.Close()
}
func closeRead(c net.Conn) error {
if cr, ok := c.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return c.Close()
}
type remoteForwardRequest struct {

Loading…
Cancel
Save