@ -32,6 +32,7 @@ import (
gossh "golang.org/x/crypto/ssh"
gossh "golang.org/x/crypto/ssh"
"tailscale.com/envknob"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tsdial"
"tailscale.com/net/tsdial"
"tailscale.com/sessionrecording"
"tailscale.com/sessionrecording"
@ -76,6 +77,7 @@ type ipnLocalBackend interface {
Dialer ( ) * tsdial . Dialer
Dialer ( ) * tsdial . Dialer
TailscaleVarRoot ( ) string
TailscaleVarRoot ( ) string
NodeKey ( ) key . NodePublic
NodeKey ( ) key . NodePublic
Ping ( ctx context . Context , ip netip . Addr , pingType tailcfg . PingType , size int ) ( * ipnstate . PingResult , error )
}
}
type server struct {
type server struct {
@ -834,6 +836,7 @@ func (c *conn) detachSession(ss *sshSession) {
}
}
var errSessionDone = errors . New ( "session is done" )
var errSessionDone = errors . New ( "session is done" )
var errClientUnreachable = errors . New ( "client is unreachable" )
// handleSSHAgentForwarding starts a Unix socket listener and in the background
// handleSSHAgentForwarding starts a Unix socket listener and in the background
// forwards agent connections between the listener and the ssh.Session.
// forwards agent connections between the listener and the ssh.Session.
@ -954,6 +957,57 @@ func (ss *sshSession) run() {
ss . logf ( "startNewRecording: <nil>" )
ss . logf ( "startNewRecording: <nil>" )
if rec != nil {
if rec != nil {
defer rec . Close ( )
defer rec . Close ( )
ping := func ( ) bool {
clientIP := ss . conn . info . src . Addr ( )
ctx , cancel := context . WithTimeout ( context . Background ( ) , 5 * time . Second )
defer cancel ( )
_ , err := ss . conn . srv . lb . Ping ( ctx , clientIP , tailcfg . PingICMP , 0 )
if err != nil {
ss . logf ( "pinging SSH client %s failed: %v" , clientIP , err )
return false
}
ss . logf ( "pinging SSH client %s successful" , clientIP )
return true
}
go func ( ) {
ss . logf ( "starting connection monitor for session %s" , ss . sharedID )
ticker := time . NewTicker ( 15 * time . Second )
defer ticker . Stop ( )
consecutiveFailures := 0
const maxFailures = 3
for {
select {
case <- ss . ctx . Done ( ) :
ss . logf ( "session terminated, closing recording: %v" , context . Cause ( ss . ctx ) )
rec . Close ( )
return
case <- ticker . C :
pong := ping ( )
if pong {
consecutiveFailures = 0
ss . logf ( "connection test passed for session %s" , ss . sharedID )
} else {
consecutiveFailures ++
ss . logf ( "connection test failed (%d/%d) for session %s" , consecutiveFailures , maxFailures , ss . sharedID )
if consecutiveFailures >= maxFailures {
ss . logf ( "connection lost (connection test failed %d times), closing recording" , maxFailures )
ss . cancelCtx ( errClientUnreachable )
rec . Close ( )
return
}
}
}
}
} ( )
}
}
}
}
}
}