diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 4de3e2b88..e52ffbfce 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -476,10 +476,10 @@ func (ss *sshSession) launchProcess() error { } go resizeWindow(ptyDup /* arbitrary fd */, winCh) - ss.tty = tty - ss.stdin = pty - ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name()) - ss.stderr = nil // not available for pty + ss.wrStdin = pty + ss.rdStdout = os.NewFile(uintptr(ptyDup), pty.Name()) + ss.rdStderr = nil // not available for pty + ss.childPipes = []io.Closer{tty} return nil } @@ -658,40 +658,29 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) { // startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. func (ss *sshSession) startWithStdPipes() (err error) { - var stdin io.WriteCloser - var stdout, stderr io.ReadCloser + var rdStdin, wrStdout, wrStderr io.ReadWriteCloser defer func() { if err != nil { - for _, c := range []io.Closer{stdin, stdout, stderr} { - if c != nil { - c.Close() - } - } + closeAll(rdStdin, ss.wrStdin, ss.rdStdout, wrStdout, ss.rdStderr, wrStderr) } }() - cmd := ss.cmd - if cmd == nil { + if ss.cmd == nil { return errors.New("nil cmd") } - stdin, err = cmd.StdinPipe() - if err != nil { + if rdStdin, ss.wrStdin, err = os.Pipe(); err != nil { return err } - stdout, err = cmd.StdoutPipe() - if err != nil { + if ss.rdStdout, wrStdout, err = os.Pipe(); err != nil { return err } - stderr, err = cmd.StderrPipe() - if err != nil { - return err - } - if err := cmd.Start(); err != nil { + if ss.rdStderr, wrStderr, err = os.Pipe(); err != nil { return err } - ss.stdin = stdin - ss.stdout = stdout - ss.stderr = stderr - return nil + ss.cmd.Stdin = rdStdin + ss.cmd.Stdout = wrStdout + ss.cmd.Stderr = wrStderr + ss.childPipes = []io.Closer{rdStdin, wrStdout, wrStderr} + return ss.cmd.Start() } func envForUser(u *userMeta) []string { diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 4253b2471..274f8cc70 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -823,12 +823,16 @@ type sshSession struct { agentListener net.Listener // non-nil if agent-forwarding requested+allowed // initialized by launchProcess: - cmd *exec.Cmd - stdin io.WriteCloser - stdout io.ReadCloser - stderr io.Reader // nil for pty sessions - ptyReq *ssh.Pty // non-nil for pty sessions - tty *os.File // non-nil for pty sessions, must be closed after process exits + cmd *exec.Cmd + wrStdin io.WriteCloser + rdStdout io.ReadCloser + rdStderr io.ReadCloser // rdStderr is nil for pty sessions + ptyReq *ssh.Pty // non-nil for pty sessions + + // childPipes is a list of pipes that need to be closed when the process exits. + // For pty sessions, this is the tty fd. + // For non-pty sessions, this is the stdin, stdout, stderr fds. + childPipes []io.Closer // We use this sync.Once to ensure that we only terminate the process once, // either it exits itself or is terminated @@ -1107,21 +1111,22 @@ func (ss *sshSession) run() { var processDone atomic.Bool go func() { - defer ss.stdin.Close() - if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil { + defer ss.wrStdin.Close() + if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil { logf("stdin copy: %v", err) ss.cancelCtx(err) } }() + outputDone := make(chan struct{}) var openOutputStreams atomic.Int32 - if ss.stderr != nil { + if ss.rdStderr != nil { openOutputStreams.Store(2) } else { openOutputStreams.Store(1) } go func() { - defer ss.stdout.Close() - _, err := io.Copy(rec.writer("o", ss), ss.stdout) + defer ss.rdStdout.Close() + _, err := io.Copy(rec.writer("o", ss), ss.rdStdout) if err != nil && !errors.Is(err, io.EOF) { isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) if !isErrBecauseProcessExited { @@ -1131,32 +1136,41 @@ func (ss *sshSession) run() { } if openOutputStreams.Add(-1) == 0 { ss.CloseWrite() + close(outputDone) } }() - // stderr is nil for ptys. - if ss.stderr != nil { + // rdStderr is nil for ptys. + if ss.rdStderr != nil { go func() { - _, err := io.Copy(ss.Stderr(), ss.stderr) + defer ss.rdStderr.Close() + _, err := io.Copy(ss.Stderr(), ss.rdStderr) if err != nil { logf("stderr copy: %v", err) } if openOutputStreams.Add(-1) == 0 { ss.CloseWrite() + close(outputDone) } }() } - if ss.tty != nil { - // If running a tty session, close the tty when the session is done. - defer ss.tty.Close() - } err = ss.cmd.Wait() processDone.Store(true) + // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine. ss.exitOnce.Do(func() {}) + // Close the process-side of all pipes to signal the asynchronous + // io.Copy routines reading/writing from the pipes to terminate. + // Block for the io.Copy to finish before calling ss.Exit below. + closeAll(ss.childPipes...) + select { + case <-outputDone: + case <-ss.ctx.Done(): + } + if err == nil { ss.logf("Session complete") ss.Exit(0) @@ -1894,3 +1908,11 @@ type SSHTerminationError interface { error SSHTerminationMessage() string } + +func closeAll(cs ...io.Closer) { + for _, c := range cs { + if c != nil { + c.Close() + } + } +} diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index ed08fa584..fac2c70e6 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -25,6 +25,7 @@ import ( "os/user" "reflect" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -947,6 +948,19 @@ func TestSSH(t *testing.T) { // "foo\n" and "bar\n", not "\n" and "bar\n". }) + t.Run("large_file", func(t *testing.T) { + const wantSize = 1e6 + var outBuf bytes.Buffer + cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero") + cmd.Stdout = &outBuf + if err := cmd.Run(); err != nil { + t.Fatal(err) + } + if gotSize := outBuf.Len(); gotSize != wantSize { + t.Fatalf("got %d, want %d", gotSize, int(wantSize)) + } + }) + t.Run("stdin", func(t *testing.T) { if cibuild.On() { t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")