From 9f3da7ab26af7445dae0b3a74e4ba2ec46cbd52a Mon Sep 17 00:00:00 2001 From: Gesa Stupperich Date: Tue, 11 Nov 2025 13:12:00 +0000 Subject: [PATCH] ssh/tailssh: explore client connection monitoring Run a connection monitor that pings the SSH client when session is recorded. If the pings fail consecutively, close the recording and then cancel the connection. This is one way to ensure that session records get flushed promptly when using S3 multi-part upload. Timeouts and consecutive failure threshold are hardcoded because this is just an experiment. Fixes tailscale.com/corp#33968 Signed-off-by: Gesa Stupperich --- ssh/tailssh/tailssh.go | 54 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 7d12ab45f..40f376da9 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -32,6 +32,7 @@ import ( gossh "golang.org/x/crypto/ssh" "tailscale.com/envknob" "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/sessionrecording" @@ -76,6 +77,7 @@ type ipnLocalBackend interface { Dialer() *tsdial.Dialer TailscaleVarRoot() string NodeKey() key.NodePublic + Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error) } type server struct { @@ -834,6 +836,7 @@ func (c *conn) detachSession(ss *sshSession) { } var errSessionDone = errors.New("session is done") +var errClientUnreachable = errors.New("client is unreachable") // handleSSHAgentForwarding starts a Unix socket listener and in the background // forwards agent connections between the listener and the ssh.Session. @@ -954,6 +957,57 @@ func (ss *sshSession) run() { ss.logf("startNewRecording: ") if rec != nil { defer rec.Close() + + ping := func() bool { + clientIP := ss.conn.info.src.Addr() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := ss.conn.srv.lb.Ping(ctx, clientIP, tailcfg.PingICMP, 0) + if err != nil { + ss.logf("pinging SSH client %s failed: %v", clientIP, err) + return false + } + + ss.logf("pinging SSH client %s successful", clientIP) + return true + } + + go func() { + ss.logf("starting connection monitor for session %s", ss.sharedID) + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + consecutiveFailures := 0 + const maxFailures = 3 + + for { + select { + case <-ss.ctx.Done(): + ss.logf("session terminated, closing recording: %v", context.Cause(ss.ctx)) + rec.Close() + return + + case <-ticker.C: + pong := ping() + if pong { + consecutiveFailures = 0 + ss.logf("connection test passed for session %s", ss.sharedID) + } else { + consecutiveFailures++ + ss.logf("connection test failed (%d/%d) for session %s", consecutiveFailures, maxFailures, ss.sharedID) + + if consecutiveFailures >= maxFailures { + ss.logf("connection lost (connection test failed %d times), closing recording", maxFailures) + ss.cancelCtx(errClientUnreachable) + rec.Close() + return + } + } + } + } + }() } } }