diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 7d12ab45f..40f376da9 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -32,6 +32,7 @@ import ( gossh "golang.org/x/crypto/ssh" "tailscale.com/envknob" "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/sessionrecording" @@ -76,6 +77,7 @@ type ipnLocalBackend interface { Dialer() *tsdial.Dialer TailscaleVarRoot() string NodeKey() key.NodePublic + Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error) } type server struct { @@ -834,6 +836,7 @@ func (c *conn) detachSession(ss *sshSession) { } var errSessionDone = errors.New("session is done") +var errClientUnreachable = errors.New("client is unreachable") // handleSSHAgentForwarding starts a Unix socket listener and in the background // forwards agent connections between the listener and the ssh.Session. @@ -954,6 +957,57 @@ func (ss *sshSession) run() { ss.logf("startNewRecording: ") if rec != nil { defer rec.Close() + + ping := func() bool { + clientIP := ss.conn.info.src.Addr() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := ss.conn.srv.lb.Ping(ctx, clientIP, tailcfg.PingICMP, 0) + if err != nil { + ss.logf("pinging SSH client %s failed: %v", clientIP, err) + return false + } + + ss.logf("pinging SSH client %s successful", clientIP) + return true + } + + go func() { + ss.logf("starting connection monitor for session %s", ss.sharedID) + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + consecutiveFailures := 0 + const maxFailures = 3 + + for { + select { + case <-ss.ctx.Done(): + ss.logf("session terminated, closing recording: %v", context.Cause(ss.ctx)) + rec.Close() + return + + case <-ticker.C: + pong := ping() + if pong { + consecutiveFailures = 0 + ss.logf("connection test passed for session %s", ss.sharedID) + } else { + consecutiveFailures++ + ss.logf("connection test failed (%d/%d) for session %s", consecutiveFailures, maxFailures, ss.sharedID) + + if consecutiveFailures >= maxFailures { + ss.logf("connection lost (connection test failed %d times), closing recording", maxFailures) + ss.cancelCtx(errClientUnreachable) + rec.Close() + return + } + } + } + } + }() } } }