diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 1a277d63e..a0e79011e 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -452,7 +452,7 @@ func (ss *sshSession) launchProcess() error { return ss.startWithStdPipes() } ss.ptyReq = &ptyReq - pty, err := ss.startWithPTY() + pty, tty, err := ss.startWithPTY() if err != nil { return err } @@ -461,10 +461,13 @@ func (ss *sshSession) launchProcess() error { // dup. ptyDup, err := syscall.Dup(int(pty.Fd())) if err != nil { + pty.Close() + tty.Close() return err } go resizeWindow(ptyDup /* arbitrary fd */, winCh) + ss.tty = tty ss.stdin = pty ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name()) ss.stderr = nil // not available for pty @@ -544,17 +547,16 @@ var opcodeShortName = map[uint8]string{ } // startWithPTY starts cmd with a pseudo-terminal attached to Stdin, Stdout and Stderr. -func (ss *sshSession) startWithPTY() (ptyFile *os.File, err error) { +func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) { ptyReq := ss.ptyReq cmd := ss.cmd if cmd == nil { - return nil, errors.New("nil ss.cmd") + return nil, nil, errors.New("nil ss.cmd") } if ptyReq == nil { - return nil, errors.New("nil ss.ptyReq") + return nil, nil, errors.New("nil ss.ptyReq") } - var tty *os.File ptyFile, tty, err = pty.Open() if err != nil { err = fmt.Errorf("pty.Open: %w", err) @@ -568,7 +570,7 @@ func (ss *sshSession) startWithPTY() (ptyFile *os.File, err error) { }() ptyRawConn, err := tty.SyscallConn() if err != nil { - return nil, fmt.Errorf("SyscallConn: %w", err) + return nil, nil, fmt.Errorf("SyscallConn: %w", err) } var ctlErr error if err := ptyRawConn.Control(func(fd uintptr) { @@ -615,10 +617,10 @@ func (ss *sshSession) startWithPTY() (ptyFile *os.File, err error) { return } }); err != nil { - return nil, fmt.Errorf("ptyRawConn.Control: %w", err) + return nil, nil, fmt.Errorf("ptyRawConn.Control: %w", err) } if ctlErr != nil { - return nil, fmt.Errorf("ptyRawConn.Control func: %w", ctlErr) + return nil, nil, fmt.Errorf("ptyRawConn.Control func: %w", ctlErr) } cmd.SysProcAttr = &syscall.SysProcAttr{ Setctty: true, @@ -642,7 +644,7 @@ func (ss *sshSession) startWithPTY() (ptyFile *os.File, err error) { if err = cmd.Start(); err != nil { return } - return ptyFile, nil + return ptyFile, tty, nil } // startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index e15d4d991..18a2f5a7a 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -28,6 +28,7 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "time" gossh "github.com/tailscale/golang-x-crypto/ssh" @@ -811,6 +812,7 @@ type sshSession struct { stdout io.ReadCloser stderr io.Reader // nil for pty sessions ptyReq *ssh.Pty // non-nil for pty sessions + tty *os.File // non-nil for pty sessions, must be closed after process exits // We use this sync.Once to ensure that we only terminate the process once, // either it exits itself or is terminated @@ -1087,6 +1089,7 @@ func (ss *sshSession) run() { } go ss.killProcessOnContextDone() + var processDone atomic.Bool go func() { defer ss.stdin.Close() if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil { @@ -1104,8 +1107,11 @@ func (ss *sshSession) run() { defer ss.stdout.Close() _, err := io.Copy(rec.writer("o", ss), ss.stdout) if err != nil && !errors.Is(err, io.EOF) { - logf("stdout copy: %v", err) - ss.cancelCtx(err) + isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) + if !isErrBecauseProcessExited { + logf("stdout copy: %v, %T", err) + ss.cancelCtx(err) + } } if openOutputStreams.Add(-1) == 0 { ss.CloseWrite() @@ -1124,7 +1130,12 @@ func (ss *sshSession) run() { }() } + if ss.tty != nil { + // If running a tty session, close the tty when the session is done. + defer ss.tty.Close() + } err = ss.cmd.Wait() + processDone.Store(true) // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine.