@ -15,19 +15,20 @@ import (
"net/http/httputil"
"net/netip"
"net/url"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"testing"
"testing/synctest"
"time"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpcommon"
"tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/health"
"tailscale.com/net/memnet"
"tailscale.com/net/netmon"
"tailscale.com/net/netx"
"tailscale.com/net/socks5"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
@ -36,6 +37,7 @@ import (
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus/eventbustest"
"tailscale.com/util/must"
)
type httpTestParam struct {
@ -532,6 +534,28 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
}
}
// slowListener wraps a memnet listener to delay accept operations
type slowListener struct {
net . Listener
delay time . Duration
}
func ( sl * slowListener ) Accept ( ) ( net . Conn , error ) {
// Add delay before accepting connections
timer := time . NewTimer ( sl . delay )
defer timer . Stop ( )
<- timer . C
return sl . Listener . Accept ( )
}
func newSlowListener ( inner net . Listener , delay time . Duration ) net . Listener {
return & slowListener {
Listener : inner ,
delay : delay ,
}
}
func brokenMITMHandler ( clock tstime . Clock ) http . HandlerFunc {
return func ( w http . ResponseWriter , r * http . Request ) {
w . Header ( ) . Set ( "Upgrade" , controlhttpcommon . UpgradeHeaderValue )
@ -545,33 +569,102 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
}
func TestDialPlan ( t * testing . T ) {
if runtime . GOOS != "linux" {
t . Skip ( "only works on Linux due to multiple localhost addresses" )
testCases := [ ] struct {
name string
plan * tailcfg . ControlDialPlan
want [ ] netip . Addr
allowFallback bool
maxDuration time . Duration
} {
{
name : "single" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
{ IP : netip . MustParseAddr ( "10.0.0.2" ) , DialTimeoutSec : 10 } ,
} } ,
want : [ ] netip . Addr { netip . MustParseAddr ( "10.0.0.2" ) } ,
} ,
{
name : "broken-then-good" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
{ IP : netip . MustParseAddr ( "10.0.0.10" ) , DialTimeoutSec : 10 } ,
{ IP : netip . MustParseAddr ( "10.0.0.2" ) , DialTimeoutSec : 10 , DialStartDelaySec : 1 } ,
} } ,
want : [ ] netip . Addr { netip . MustParseAddr ( "10.0.0.2" ) } ,
} ,
{
name : "multiple-candidates-with-broken" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
// Multiple good IPs plus a broken one
// Should succeed with any of the good ones
{ IP : netip . MustParseAddr ( "10.0.0.10" ) , DialTimeoutSec : 10 } ,
{ IP : netip . MustParseAddr ( "10.0.0.2" ) , DialTimeoutSec : 10 } ,
{ IP : netip . MustParseAddr ( "10.0.0.4" ) , DialTimeoutSec : 10 } ,
{ IP : netip . MustParseAddr ( "10.0.0.3" ) , DialTimeoutSec : 10 } ,
} } ,
want : [ ] netip . Addr { netip . MustParseAddr ( "10.0.0.2" ) , netip . MustParseAddr ( "10.0.0.4" ) , netip . MustParseAddr ( "10.0.0.3" ) } ,
} ,
{
name : "multiple-candidates-race" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
{ IP : netip . MustParseAddr ( "10.0.0.10" ) , DialTimeoutSec : 10 } ,
{ IP : netip . MustParseAddr ( "10.0.0.3" ) , DialTimeoutSec : 10 } ,
{ IP : netip . MustParseAddr ( "10.0.0.2" ) , DialTimeoutSec : 10 } ,
} } ,
want : [ ] netip . Addr { netip . MustParseAddr ( "10.0.0.3" ) , netip . MustParseAddr ( "10.0.0.2" ) } ,
} ,
{
name : "fallback" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
{ IP : netip . MustParseAddr ( "10.0.0.10" ) , DialTimeoutSec : 1 } ,
} } ,
want : [ ] netip . Addr { netip . MustParseAddr ( "10.0.0.1" ) } ,
allowFallback : true ,
} ,
{
// In tailscale/corp#32534 we discovered that a prior implementation
// of the dial race was waiting for all dials to complete when the
// top priority dial was failing. This delay was long enough that in
// real scenarios the server will close the connection due to
// inactivity, because the client does not send the first inside of
// noise request soon enough. This test is a regression guard
// against that behavior - proving that the dial returns promptly
// even if there is some cause of a slow race.
name : "slow-endpoint-doesnt-block" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
{ IP : netip . MustParseAddr ( "10.0.0.12" ) , Priority : 5 , DialTimeoutSec : 10 } ,
{ IP : netip . MustParseAddr ( "10.0.0.2" ) , Priority : 1 , DialTimeoutSec : 10 } ,
} } ,
want : [ ] netip . Addr { netip . MustParseAddr ( "10.0.0.2" ) } ,
maxDuration : 2 * time . Second , // Must complete quickly, not wait for slow endpoint
} ,
}
for _ , tt := range testCases {
t . Run ( tt . name , func ( t * testing . T ) {
synctest . Test ( t , func ( t * testing . T ) {
runDialPlanTest ( t , tt . plan , tt . want , tt . allowFallback , tt . maxDuration )
} )
} )
}
}
func runDialPlanTest ( t * testing . T , plan * tailcfg . ControlDialPlan , want [ ] netip . Addr , allowFallback bool , maxDuration time . Duration ) {
client , server := key . NewMachine ( ) , key . NewMachine ( )
const (
testProtocolVersion = 1
httpPort = "80"
httpsPort = "443"
)
getRandomPort := func ( ) string {
ln , err := net . Listen ( "tcp" , ":0" )
if err != nil {
t . Fatalf ( "net.Listen: %v" , err )
}
defer ln . Close ( )
_ , port , err := net . SplitHostPort ( ln . Addr ( ) . String ( ) )
if err != nil {
t . Fatal ( err )
}
return port
}
memNetwork := & memnet . Network { }
// We need consistent ports for each address; these are chosen
// randomly and we hope that they won't conflict during this test.
httpPort := getRandomPort ( )
httpsPort := getRandomPort ( )
fallbackAddr := netip . MustParseAddr ( "10.0.0.1" )
goodAddr := netip . MustParseAddr ( "10.0.0.2" )
otherAddr := netip . MustParseAddr ( "10.0.0.3" )
other2Addr := netip . MustParseAddr ( "10.0.0.4" )
brokenAddr := netip . MustParseAddr ( "10.0.0.10" )
slowAddr := netip . MustParseAddr ( "10.0.0.12" )
makeHandler := func ( t * testing . T , name string , host netip . Addr , wrap func ( http . Handler ) http . Handler ) {
done := make ( chan struct { } )
@ -592,17 +685,66 @@ func TestDialPlan(t *testing.T) {
handler = wrap ( handler )
}
httpLn , err := net . Listen ( "tcp" , host . String ( ) + ":" + httpPort )
httpLn := must . Get ( memNetwork . Listen ( "tcp" , host . String ( ) + ":" + httpPort ) )
httpsLn := must . Get ( memNetwork . Listen ( "tcp" , host . String ( ) + ":" + httpsPort ) )
httpServer := & http . Server { Handler : handler }
go httpServer . Serve ( httpLn )
t . Cleanup ( func ( ) {
httpServer . Close ( )
} )
httpsServer := & http . Server {
Handler : handler ,
TLSConfig : tlsConfig ( t ) ,
ErrorLog : logger . StdLogger ( logger . WithPrefix ( t . Logf , "http.Server.ErrorLog: " ) ) ,
}
go httpsServer . ServeTLS ( httpsLn , "" , "" )
t . Cleanup ( func ( ) {
httpsServer . Close ( )
} )
}
// Use synctest's controlled time
clock := tstime . StdClock { }
makeHandler ( t , "fallback" , fallbackAddr , nil )
makeHandler ( t , "good" , goodAddr , nil )
makeHandler ( t , "other" , otherAddr , nil )
makeHandler ( t , "other2" , other2Addr , nil )
makeHandler ( t , "broken" , brokenAddr , func ( h http . Handler ) http . Handler {
return brokenMITMHandler ( clock )
} )
// Create slow listener that delays accept by 5 seconds
makeSlowHandler := func ( t * testing . T , name string , host netip . Addr , delay time . Duration ) {
done := make ( chan struct { } )
t . Cleanup ( func ( ) {
close ( done )
} )
handler := http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
conn , err := controlhttpserver . AcceptHTTP ( context . Background ( ) , w , r , server , nil )
if err != nil {
log . Print ( err )
} else {
defer conn . Close ( )
}
w . Header ( ) . Set ( "X-Handler-Name" , name )
<- done
} )
httpLn , err := memNetwork . Listen ( "tcp" , host . String ( ) + ":" + httpPort )
if err != nil {
t . Fatalf ( "HTTP listen: %v" , err )
}
httpsLn , err := net . Listen ( "tcp" , host . String ( ) + ":" + httpsPort )
httpsLn , err := memNetwork . Listen ( "tcp" , host . String ( ) + ":" + httpsPort )
if err != nil {
t . Fatalf ( "HTTPS listen: %v" , err )
}
slowHttpLn := newSlowListener ( httpLn , delay )
slowHttpsLn := newSlowListener ( httpsLn , delay )
httpServer := & http . Server { Handler : handler }
go httpServer . Serve ( httpLn )
go httpServer . Serve ( slowH ttpLn)
t . Cleanup ( func ( ) {
httpServer . Close ( )
} )
@ -612,213 +754,148 @@ func TestDialPlan(t *testing.T) {
TLSConfig : tlsConfig ( t ) ,
ErrorLog : logger . StdLogger ( logger . WithPrefix ( t . Logf , "http.Server.ErrorLog: " ) ) ,
}
go httpsServer . ServeTLS ( h ttpsLn, "" , "" )
go httpsServer . ServeTLS ( slowH ttpsLn, "" , "" )
t . Cleanup ( func ( ) {
httpsServer . Close ( )
} )
return
}
makeSlowHandler ( t , "slow" , slowAddr , 5 * time . Second )
fallbackAddr := netip . MustParseAddr ( "127.0.0.1" )
goodAddr := netip . MustParseAddr ( "127.0.0.2" )
otherAddr := netip . MustParseAddr ( "127.0.0.3" )
other2Addr := netip . MustParseAddr ( "127.0.0.4" )
brokenAddr := netip . MustParseAddr ( "127.0.0.10" )
// memnetDialer with connection tracking, so we can catch connection leaks.
dialer := & memnetDialer {
inner : memNetwork . Dial ,
t : t ,
}
defer dialer . waitForAllClosedSynctest ( )
testCases := [ ] struct {
name string
plan * tailcfg . ControlDialPlan
wrap func ( http . Handler ) http . Handler
want netip . Addr
ctx , cancel := context . WithTimeout ( context . Background ( ) , 10 * time . Second )
defer cancel ( )
allowFallback bool
} {
{
name : "single" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
{ IP : goodAddr , Priority : 1 , DialTimeoutSec : 10 } ,
} } ,
want : goodAddr ,
} ,
{
name : "broken-then-good" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
// Dials the broken one, which fails, and then
// eventually dials the good one and succeeds
{ IP : brokenAddr , Priority : 2 , DialTimeoutSec : 10 } ,
{ IP : goodAddr , Priority : 1 , DialTimeoutSec : 10 , DialStartDelaySec : 1 } ,
} } ,
want : goodAddr ,
} ,
// TODO(#8442): fix this test
// {
// name: "multiple-priority-fast-path",
// plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
// // Dials some good IPs and our bad one (which
// // hangs forever), which then hits the fast
// // path where we bail without waiting.
// {IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
// {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
// {IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
// {IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
// }},
// want: otherAddr,
// },
{
name : "multiple-priority-slow-path" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
// Our broken address is the highest priority,
// so we don't hit our fast path.
{ IP : brokenAddr , Priority : 10 , DialTimeoutSec : 10 } ,
{ IP : otherAddr , Priority : 2 , DialTimeoutSec : 10 } ,
{ IP : goodAddr , Priority : 1 , DialTimeoutSec : 10 } ,
} } ,
want : otherAddr ,
} ,
{
name : "fallback" ,
plan : & tailcfg . ControlDialPlan { Candidates : [ ] tailcfg . ControlIPCandidate {
{ IP : brokenAddr , Priority : 1 , DialTimeoutSec : 1 } ,
} } ,
want : fallbackAddr ,
allowFallback : true ,
} ,
host := "example.com"
if allowFallback {
host = fallbackAddr . String ( )
}
for _ , tt := range testCases {
t . Run ( tt . name , func ( t * testing . T ) {
// TODO(awly): replace this with tstest.NewClock and update the
// test to advance the clock correctly.
clock := tstime . StdClock { }
makeHandler ( t , "fallback" , fallbackAddr , nil )
makeHandler ( t , "good" , goodAddr , nil )
makeHandler ( t , "other" , otherAddr , nil )
makeHandler ( t , "other2" , other2Addr , nil )
makeHandler ( t , "broken" , brokenAddr , func ( h http . Handler ) http . Handler {
return brokenMITMHandler ( clock )
} )
dialer := closeTrackDialer {
t : t ,
inner : tsdial . NewDialer ( netmon . NewStatic ( ) ) . SystemDial ,
conns : make ( map [ * closeTrackConn ] bool ) ,
}
defer dialer . Done ( )
a := & Dialer {
Hostname : host ,
HTTPPort : httpPort ,
HTTPSPort : httpsPort ,
MachineKey : client ,
ControlKey : server . Public ( ) ,
ProtocolVersion : testProtocolVersion ,
Dialer : dialer . Dial ,
Logf : t . Logf ,
DialPlan : plan ,
proxyFunc : func ( * http . Request ) ( * url . URL , error ) { return nil , nil } ,
omitCertErrorLogging : true ,
testFallbackDelay : 50 * time . Millisecond ,
Clock : clock ,
HealthTracker : health . NewTracker ( eventbustest . NewBus ( t ) ) ,
}
ctx , cancel := context . WithTimeout ( context . Background ( ) , 10 * time . Second )
defer cancel ( )
start := time . Now ( )
conn , err := a . dial ( ctx )
duration := time . Since ( start )
// By default, we intentionally point to something that
// we know won't connect, since we want a fallback to
// DNS to be an error.
host := "example.com"
if tt . allowFallback {
host = "localhost"
}
if err != nil {
t . Fatalf ( "dialing controlhttp: %v" , err )
}
defer conn . Close ( )
drained := make ( chan struct { } )
a := & Dialer {
Hostname : host ,
HTTPPort : httpPort ,
HTTPSPort : httpsPort ,
MachineKey : client ,
ControlKey : server . Public ( ) ,
ProtocolVersion : testProtocolVersion ,
Dialer : dialer . Dial ,
Logf : t . Logf ,
DialPlan : tt . plan ,
proxyFunc : func ( * http . Request ) ( * url . URL , error ) { return nil , nil } ,
drainFinished : drained ,
omitCertErrorLogging : true ,
testFallbackDelay : 50 * time . Millisecond ,
Clock : clock ,
HealthTracker : health . NewTracker ( eventbustest . NewBus ( t ) ) ,
}
if maxDuration > 0 && duration > maxDuration {
t . Errorf ( "dial took %v, expected < %v (should not wait for slow endpoints)" , duration , maxDuration )
}
conn , err := a . dial ( ctx )
if err != nil {
t . Fatalf ( "dialing controlhttp: %v" , err )
}
defer conn . Close ( )
raddr := conn . RemoteAddr ( )
raddrStr := raddr . String ( )
raddr := conn . RemoteAddr ( ) . ( * net . TCPAddr )
// split on "|" first to remove memnet pipe suffix
addrPart := raddrStr
if idx := strings . Index ( raddrStr , "|" ) ; idx >= 0 {
addrPart = raddrStr [ : idx ]
}
got , ok := netip . AddrFromSlice ( raddr . IP )
if ! ok {
t . Errorf ( "invalid remote IP: %v" , raddr . IP )
} else if got != tt . want {
t . Errorf ( "got connection from %q; want %q" , got , tt . want )
} else {
t . Logf ( "successfully connected to %q" , raddr . String ( ) )
}
host , _ , err2 := net . SplitHostPort ( addrPart )
if err2 != nil {
t . Fatalf ( "failed to parse remote address %q: %v" , addrPart , err2 )
}
// Wait until our dialer drains so we can verify that
// all connections are closed.
<- drained
} )
got , err3 := netip . ParseAddr ( host )
if err3 != nil {
t . Errorf ( "invalid remote IP: %v" , host )
} else {
found := slices . Contains ( want , got )
if ! found {
t . Errorf ( "got connection from %q; want one of %v" , got , want )
} else {
t . Logf ( "successfully connected to %q" , raddr . String ( ) )
}
}
}
type closeTrackDialer struct {
t testing . TB
inner netx . DialFunc
// memnetDialer wraps memnet.Network.Dial to track connections for testing
type memnetDialer struct {
inner func ( ctx context . Context , network , addr string ) ( net . Conn , error )
t * testing . T
mu sync . Mutex
conns map [ * closeTrackConn ] bool
conns map [ net . Conn ] string // conn -> remote address for debugging
}
func ( d * closeTrack Dialer) Dial ( ctx context . Context , network , addr string ) ( net . Conn , error ) {
c , err := d . inner ( ctx , network , addr )
func ( d * memnet Dialer) Dial ( ctx context . Context , network , addr string ) ( net . Conn , error ) {
c onn , err := d . inner ( ctx , network , addr )
if err != nil {
return nil , err
}
ct := & closeTrackConn { Conn : c , d : d }
d . mu . Lock ( )
d . conns [ ct ] = true
if d . conns == nil {
d . conns = make ( map [ net . Conn ] string )
}
d . conns [ conn ] = conn . RemoteAddr ( ) . String ( )
d . t . Logf ( "tracked connection opened to %s" , conn . RemoteAddr ( ) )
d . mu . Unlock ( )
return ct , nil
return & memnetTrackedConn { Conn : conn , dialer : d } , nil
}
func ( d * closeTrackDialer ) Done ( ) {
// Unfortunately, tsdial.Dialer.SystemDial closes connections
// asynchronously in a goroutine, so we can't assume that everything is
// closed by the time we get here.
//
// Sleep/wait a few times on the assumption that things will close
// "eventually".
const iters = 100
for i := range iters {
func ( d * memnetDialer ) waitForAllClosedSynctest ( ) {
const maxWait = 15 * time . Second
const checkInterval = 100 * time . Millisecond
for range int ( maxWait / checkInterval ) {
d . mu . Lock ( )
if len ( d . conns ) == 0 {
remaining := len ( d . conns )
if remaining == 0 {
d . mu . Unlock ( )
return
}
d . mu . Unlock ( )
// Only error on last iteration
if i != iters - 1 {
d . mu . Unlock ( )
time . Sleep ( 100 * time . Millisecond )
continue
}
time . Sleep ( checkInterval )
}
for conn := range d . conns {
d . t . Errorf ( "expected close of conn %p; RemoteAddr=%q" , conn , conn . RemoteAddr ( ) . String ( ) )
}
d . mu. Unlock ( )
d . mu . Lock ( )
defer d . mu . Unlock ( )
for _ , addr := range d . conns {
d . t. Errorf ( "connection to %s was not closed after %v" , addr , maxWait )
}
}
func ( d * closeTrackDialer) noteClose ( c * closeTrack Conn) {
func ( d * memnetDialer) noteClose ( conn net . Conn) {
d . mu . Lock ( )
delete ( d . conns , c ) // safe if already deleted
if addr , exists := d . conns [ conn ] ; exists {
d . t . Logf ( "tracked connection closed to %s" , addr )
delete ( d . conns , conn )
}
d . mu . Unlock ( )
}
type closeTrack Conn struct {
type memnetTracked Conn struct {
net . Conn
d * closeTrack Dialer
d ialer * memnet Dialer
}
func ( c * closeTrack Conn) Close ( ) error {
c . d . noteClose ( c )
func ( c * memnetTracked Conn) Close ( ) error {
c . d ialer . noteClose ( c . Conn )
return c . Conn . Close ( )
}