From 61886e031e383f3218eeacce57d4661f0c13c454 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Wed, 21 Jun 2023 19:57:45 -0700 Subject: [PATCH] ssh/tailssh: fix double race condition with non-pty command (#8405) There are two race conditions in output handling. The first race condition is due to a misuse of exec.Cmd.StdoutPipe. The documentation explicitly forbids concurrent use of StdoutPipe with exec.Cmd.Wait (see golang/go#60908) because Wait will close both sides of the pipe once the process ends without any guarantees that all data has been read from the pipe. To fix this, we allocate the os.Pipes ourselves and manage cleanup ourselves when the process has ended. The second race condition is because sshSession.run waits upon exec.Cmd to finish and then immediately proceeds to call ss.Exit, which will close all output streams going to the SSH client. This may interrupt any asynchronous io.Copy still copying data. To fix this, we close the write-side of the os.Pipes after the process has finished (and before calling ss.Exit) and synchronously wait for the io.Copy routines to finish. Fixes #7601 Signed-off-by: Joe Tsai Co-authored-by: Maisem Ali --- ssh/tailssh/incubator.go | 41 ++++++++++---------------- ssh/tailssh/tailssh.go | 58 +++++++++++++++++++++++++------------ ssh/tailssh/tailssh_test.go | 14 +++++++++ 3 files changed, 69 insertions(+), 44 deletions(-) diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 4de3e2b88..e52ffbfce 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -476,10 +476,10 @@ func (ss *sshSession) launchProcess() error { } go resizeWindow(ptyDup /* arbitrary fd */, winCh) - ss.tty = tty - ss.stdin = pty - ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name()) - ss.stderr = nil // not available for pty + ss.wrStdin = pty + ss.rdStdout = os.NewFile(uintptr(ptyDup), pty.Name()) + ss.rdStderr = nil // not available for pty + ss.childPipes = []io.Closer{tty} return nil } @@ -658,40 +658,29 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) { // startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. func (ss *sshSession) startWithStdPipes() (err error) { - var stdin io.WriteCloser - var stdout, stderr io.ReadCloser + var rdStdin, wrStdout, wrStderr io.ReadWriteCloser defer func() { if err != nil { - for _, c := range []io.Closer{stdin, stdout, stderr} { - if c != nil { - c.Close() - } - } + closeAll(rdStdin, ss.wrStdin, ss.rdStdout, wrStdout, ss.rdStderr, wrStderr) } }() - cmd := ss.cmd - if cmd == nil { + if ss.cmd == nil { return errors.New("nil cmd") } - stdin, err = cmd.StdinPipe() - if err != nil { + if rdStdin, ss.wrStdin, err = os.Pipe(); err != nil { return err } - stdout, err = cmd.StdoutPipe() - if err != nil { + if ss.rdStdout, wrStdout, err = os.Pipe(); err != nil { return err } - stderr, err = cmd.StderrPipe() - if err != nil { - return err - } - if err := cmd.Start(); err != nil { + if ss.rdStderr, wrStderr, err = os.Pipe(); err != nil { return err } - ss.stdin = stdin - ss.stdout = stdout - ss.stderr = stderr - return nil + ss.cmd.Stdin = rdStdin + ss.cmd.Stdout = wrStdout + ss.cmd.Stderr = wrStderr + ss.childPipes = []io.Closer{rdStdin, wrStdout, wrStderr} + return ss.cmd.Start() } func envForUser(u *userMeta) []string { diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 4253b2471..274f8cc70 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -823,12 +823,16 @@ type sshSession struct { agentListener net.Listener // non-nil if agent-forwarding requested+allowed // initialized by launchProcess: - cmd *exec.Cmd - stdin io.WriteCloser - stdout io.ReadCloser - stderr io.Reader // nil for pty sessions - ptyReq *ssh.Pty // non-nil for pty sessions - tty *os.File // non-nil for pty sessions, must be closed after process exits + cmd *exec.Cmd + wrStdin io.WriteCloser + rdStdout io.ReadCloser + rdStderr io.ReadCloser // rdStderr is nil for pty sessions + ptyReq *ssh.Pty // non-nil for pty sessions + + // childPipes is a list of pipes that need to be closed when the process exits. + // For pty sessions, this is the tty fd. + // For non-pty sessions, this is the stdin, stdout, stderr fds. + childPipes []io.Closer // We use this sync.Once to ensure that we only terminate the process once, // either it exits itself or is terminated @@ -1107,21 +1111,22 @@ func (ss *sshSession) run() { var processDone atomic.Bool go func() { - defer ss.stdin.Close() - if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil { + defer ss.wrStdin.Close() + if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil { logf("stdin copy: %v", err) ss.cancelCtx(err) } }() + outputDone := make(chan struct{}) var openOutputStreams atomic.Int32 - if ss.stderr != nil { + if ss.rdStderr != nil { openOutputStreams.Store(2) } else { openOutputStreams.Store(1) } go func() { - defer ss.stdout.Close() - _, err := io.Copy(rec.writer("o", ss), ss.stdout) + defer ss.rdStdout.Close() + _, err := io.Copy(rec.writer("o", ss), ss.rdStdout) if err != nil && !errors.Is(err, io.EOF) { isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) if !isErrBecauseProcessExited { @@ -1131,32 +1136,41 @@ func (ss *sshSession) run() { } if openOutputStreams.Add(-1) == 0 { ss.CloseWrite() + close(outputDone) } }() - // stderr is nil for ptys. - if ss.stderr != nil { + // rdStderr is nil for ptys. + if ss.rdStderr != nil { go func() { - _, err := io.Copy(ss.Stderr(), ss.stderr) + defer ss.rdStderr.Close() + _, err := io.Copy(ss.Stderr(), ss.rdStderr) if err != nil { logf("stderr copy: %v", err) } if openOutputStreams.Add(-1) == 0 { ss.CloseWrite() + close(outputDone) } }() } - if ss.tty != nil { - // If running a tty session, close the tty when the session is done. - defer ss.tty.Close() - } err = ss.cmd.Wait() processDone.Store(true) + // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine. ss.exitOnce.Do(func() {}) + // Close the process-side of all pipes to signal the asynchronous + // io.Copy routines reading/writing from the pipes to terminate. + // Block for the io.Copy to finish before calling ss.Exit below. + closeAll(ss.childPipes...) + select { + case <-outputDone: + case <-ss.ctx.Done(): + } + if err == nil { ss.logf("Session complete") ss.Exit(0) @@ -1894,3 +1908,11 @@ type SSHTerminationError interface { error SSHTerminationMessage() string } + +func closeAll(cs ...io.Closer) { + for _, c := range cs { + if c != nil { + c.Close() + } + } +} diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index ed08fa584..fac2c70e6 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -25,6 +25,7 @@ import ( "os/user" "reflect" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -947,6 +948,19 @@ func TestSSH(t *testing.T) { // "foo\n" and "bar\n", not "\n" and "bar\n". }) + t.Run("large_file", func(t *testing.T) { + const wantSize = 1e6 + var outBuf bytes.Buffer + cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero") + cmd.Stdout = &outBuf + if err := cmd.Run(); err != nil { + t.Fatal(err) + } + if gotSize := outBuf.Len(); gotSize != wantSize { + t.Fatalf("got %d, want %d", gotSize, int(wantSize)) + } + }) + t.Run("stdin", func(t *testing.T) { if cibuild.On() { t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")