ssh/tailssh: explore client connection monitoring

Run a connection monitor that pings the SSH client when session is
recorded. If the pings fail consecutively, close the recording and
then cancel the connection. This is one way to ensure that session
records get flushed promptly when using S3 multi-part upload.

Timeouts and consecutive failure threshold are hardcoded because
this is just an experiment.

Fixes tailscale.com/corp#33968

Signed-off-by: Gesa Stupperich <gesa@tailscale.com>
gesa/ssh-client-session-monitoring
Gesa Stupperich 3 weeks ago
parent 1eba5b0cbd
commit 9f3da7ab26

@ -32,6 +32,7 @@ import (
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/sessionrecording" "tailscale.com/sessionrecording"
@ -76,6 +77,7 @@ type ipnLocalBackend interface {
Dialer() *tsdial.Dialer Dialer() *tsdial.Dialer
TailscaleVarRoot() string TailscaleVarRoot() string
NodeKey() key.NodePublic NodeKey() key.NodePublic
Ping(ctx context.Context, ip netip.Addr, pingType tailcfg.PingType, size int) (*ipnstate.PingResult, error)
} }
type server struct { type server struct {
@ -834,6 +836,7 @@ func (c *conn) detachSession(ss *sshSession) {
} }
var errSessionDone = errors.New("session is done") var errSessionDone = errors.New("session is done")
var errClientUnreachable = errors.New("client is unreachable")
// handleSSHAgentForwarding starts a Unix socket listener and in the background // handleSSHAgentForwarding starts a Unix socket listener and in the background
// forwards agent connections between the listener and the ssh.Session. // forwards agent connections between the listener and the ssh.Session.
@ -954,6 +957,57 @@ func (ss *sshSession) run() {
ss.logf("startNewRecording: <nil>") ss.logf("startNewRecording: <nil>")
if rec != nil { if rec != nil {
defer rec.Close() defer rec.Close()
ping := func() bool {
clientIP := ss.conn.info.src.Addr()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := ss.conn.srv.lb.Ping(ctx, clientIP, tailcfg.PingICMP, 0)
if err != nil {
ss.logf("pinging SSH client %s failed: %v", clientIP, err)
return false
}
ss.logf("pinging SSH client %s successful", clientIP)
return true
}
go func() {
ss.logf("starting connection monitor for session %s", ss.sharedID)
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
consecutiveFailures := 0
const maxFailures = 3
for {
select {
case <-ss.ctx.Done():
ss.logf("session terminated, closing recording: %v", context.Cause(ss.ctx))
rec.Close()
return
case <-ticker.C:
pong := ping()
if pong {
consecutiveFailures = 0
ss.logf("connection test passed for session %s", ss.sharedID)
} else {
consecutiveFailures++
ss.logf("connection test failed (%d/%d) for session %s", consecutiveFailures, maxFailures, ss.sharedID)
if consecutiveFailures >= maxFailures {
ss.logf("connection lost (connection test failed %d times), closing recording", maxFailures)
ss.cancelCtx(errClientUnreachable)
rec.Close()
return
}
}
}
}
}()
} }
} }
} }

Loading…
Cancel
Save