diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 7d12ab45f..07b0aa57a 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -25,7 +25,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "syscall" "time" @@ -804,11 +803,18 @@ func (ss *sshSession) killProcessOnContextDone() { // We don't need to Process.Wait here, sshSession.run() does // the waiting regardless of termination reason. - // TODO(maisem): should this be a SIGTERM followed by a SIGKILL? - ss.cmd.Process.Kill() + // Send SIGHUP like a real terminal disconnect would. + // The process may ignore it or exit cleanly. + ss.cmd.Process.Signal(syscall.SIGHUP) }) } +// isNotFoundOrExecutable reports whether err is an error indicating +// the command could not be found or executed. +func isNotFoundOrExecutable(err error) bool { + return errors.Is(err, exec.ErrNotFound) || errors.Is(err, os.ErrNotExist) +} + // attachSession registers ss as an active session. func (c *conn) attachSession(ss *sshSession) { c.srv.sessionWaitGroup.Add(1) @@ -894,10 +900,11 @@ func (ss *sshSession) run() { metricActiveSessions.Add(1) defer metricActiveSessions.Add(-1) defer ss.cancelCtx(errSessionDone) + defer ss.Close() if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached { fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n") - ss.Exit(1) + ss.Exit(255) return } defer ss.conn.detachSession(ss) @@ -919,7 +926,10 @@ func (ss *sshSession) run() { if lu.Uid != fmt.Sprint(euid) { ss.logf("can't switch to user %q from process euid %v", lu.Username, euid) fmt.Fprintf(ss, "can't switch user\r\n") - ss.Exit(1) + // Exit code 255 indicates SSH protocol/permission error. + // This matches OpenSSH behavior for fatal errors that prevent + // the session from starting. + ss.Exit(255) return } } @@ -948,7 +958,9 @@ func (ss *sshSession) run() { fmt.Fprintf(ss, "can't start new recording\r\n") } ss.logf("startNewRecording: %v", err) - ss.Exit(1) + // Exit code 254 for recording infrastructure failure. + // Distinct from 255 (SSH protocol error) and 1 (general command failure). + ss.Exit(254) return } ss.logf("startNewRecording: ") @@ -961,95 +973,90 @@ func (ss *sshSession) run() { err := ss.launchProcess() if err != nil { logf("start failed: %v", err.Error()) + exitCode := 1 if errors.Is(err, context.Canceled) { err := context.Cause(ss.ctx) var uve userVisibleError if errors.As(err, &uve) { fmt.Fprintf(ss, "%s\r\n", uve) } + } else if isNotFoundOrExecutable(err) { + // Use exit code 127 for "command not found" per shell convention. + // This matches standard SSH behavior. + exitCode = 127 } - ss.Exit(1) + ss.Exit(exitCode) return } go ss.killProcessOnContextDone() - var processDone atomic.Bool + // Start goroutines to copy stdin/stdout/stderr. + var wg sync.WaitGroup + + wg.Add(1) go func() { + defer wg.Done() defer ss.wrStdin.Close() if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil { logf("stdin copy: %v", err) ss.cancelCtx(err) } }() - outputDone := make(chan struct{}) - var openOutputStreams atomic.Int32 - if ss.rdStderr != nil { - openOutputStreams.Store(2) - } else { - openOutputStreams.Store(1) - } + + wg.Add(1) go func() { + defer wg.Done() defer ss.rdStdout.Close() - _, err := io.Copy(rec.writer("o", ss), ss.rdStdout) - if err != nil && !errors.Is(err, io.EOF) { - isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) - if !isErrBecauseProcessExited { - logf("stdout copy: %v", err) - ss.cancelCtx(err) - } - } - if openOutputStreams.Add(-1) == 0 { - ss.CloseWrite() - close(outputDone) + if _, err := io.Copy(rec.writer("o", ss), ss.rdStdout); err != nil && !errors.Is(err, io.EOF) { + logf("stdout copy: %v", err) } + // Send EOF as soon as stdout copying completes. This allows sibling + // processes waiting for EOF to proceed, even if the main process hasn't + // exited yet. The channel remains open for sending exit-status later. + ss.CloseWrite() }() - // rdStderr is nil for ptys. + if ss.rdStderr != nil { + wg.Add(1) go func() { + defer wg.Done() defer ss.rdStderr.Close() - _, err := io.Copy(ss.Stderr(), ss.rdStderr) - if err != nil { + if _, err := io.Copy(ss.Stderr(), ss.rdStderr); err != nil { logf("stderr copy: %v", err) } - if openOutputStreams.Add(-1) == 0 { - ss.CloseWrite() - close(outputDone) - } }() } err = ss.cmd.Wait() - processDone.Store(true) // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine. ss.exitOnce.Do(func() {}) - // Close the process-side of all pipes to signal the asynchronous - // io.Copy routines reading/writing from the pipes to terminate. - // Block for the io.Copy to finish before calling ss.Exit below. - closeAll(ss.childPipes...) - select { - case <-outputDone: - case <-ss.ctx.Done(): - } - + var exitCode int if err == nil { ss.logf("Session complete") - ss.Exit(0) - return - } - if ee, ok := err.(*exec.ExitError); ok { - code := ee.ProcessState.ExitCode() - ss.logf("Wait: code=%v", code) - ss.Exit(code) - return + exitCode = 0 + } else if ee, ok := err.(*exec.ExitError); ok { + exitCode = ee.ProcessState.ExitCode() + ss.logf("Wait: code=%v", exitCode) + } else { + ss.logf("Wait: %v", err) + exitCode = 1 } - ss.logf("Wait: %v", err) - ss.Exit(1) - return + // Send exit-status immediately. Per RFC 4254 section 6.10, exit-status + // should be sent before channel close. EOF will be sent by the stdout + // goroutine when it finishes, and that's fine - CloseWrite only closes + // the data stream but keeps the channel open for exit-status. + ss.Exit(exitCode) + + // Close process-side of pipes to signal io.Copy goroutines to finish. + closeAll(ss.childPipes...) + + // Wait for all IO to complete. + wg.Wait() } // recordSSHToLocalDisk is a deprecated dev knob to allow recording SSH sessions diff --git a/tempfork/gliderlabs/ssh/session.go b/tempfork/gliderlabs/ssh/session.go index a7a9a3eeb..ef068355e 100644 --- a/tempfork/gliderlabs/ssh/session.go +++ b/tempfork/gliderlabs/ssh/session.go @@ -188,7 +188,12 @@ func (sess *session) Exit(code int) error { if err != nil { return err } - return sess.Close() + // Don't close the channel here. Per RFC 4254 section 6.10, the exit-status + // message should be sent before the channel is closed. By not closing immediately, + // we allow the session handler to complete any remaining I/O operations (like + // flushing output and sending EOF via CloseWrite) before the channel is closed + // by the request handler's cleanup code. + return nil } func (sess *session) User() string { @@ -273,6 +278,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { go func() { sess.handler(sess) sess.Exit(0) + sess.Close() }() case "subsystem": if sess.handled { @@ -307,6 +313,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { go func() { handler(sess) sess.Exit(0) + sess.Close() }() case "env": if sess.handled { diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go index 335fda657..307cc53cb 100644 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ b/tempfork/gliderlabs/ssh/tcpip.go @@ -53,16 +53,37 @@ func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh } go gossh.DiscardRequests(reqs) + defer ch.Close() + defer dconn.Close() + + done := make(chan struct{}, 2) go func() { - defer ch.Close() - defer dconn.Close() + defer ch.CloseWrite() + defer closeRead(dconn) io.Copy(ch, dconn) + done <- struct{}{} }() go func() { - defer ch.Close() - defer dconn.Close() + defer closeWrite(dconn) io.Copy(dconn, ch) + done <- struct{}{} }() + <-done + <-done +} + +func closeWrite(c net.Conn) error { + if cw, ok := c.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return c.Close() +} + +func closeRead(c net.Conn) error { + if cr, ok := c.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return c.Close() } type remoteForwardRequest struct {