@ -18,6 +18,7 @@ import (
"net/http/httptest"
"net/http/httptest"
"net/netip"
"net/netip"
"os"
"os"
"reflect"
"runtime"
"runtime"
"strconv"
"strconv"
"strings"
"strings"
@ -31,6 +32,7 @@ import (
"github.com/tailscale/wireguard-go/tun/tuntest"
"github.com/tailscale/wireguard-go/tun/tuntest"
"go4.org/mem"
"go4.org/mem"
"golang.org/x/exp/maps"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/net/ipv6"
"golang.org/x/net/ipv6"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/derp"
"tailscale.com/derp"
@ -390,6 +392,7 @@ collectEndpoints:
for {
for {
select {
select {
case ep := <- epCh :
case ep := <- epCh :
t . Logf ( "TestNewConn: got endpoint: %v" , ep )
endpoints = append ( endpoints , ep )
endpoints = append ( endpoints , ep )
if strings . HasSuffix ( ep , suffix ) {
if strings . HasSuffix ( ep , suffix ) {
break collectEndpoints
break collectEndpoints
@ -2280,3 +2283,113 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) {
t . Fatal ( "no packet after 1s" )
t . Fatal ( "no packet after 1s" )
}
}
}
}
func TestEndpointTracker ( t * testing . T ) {
local := tailcfg . Endpoint {
Addr : netip . MustParseAddrPort ( "192.168.1.1:12345" ) ,
Type : tailcfg . EndpointLocal ,
}
stun4_1 := tailcfg . Endpoint {
Addr : netip . MustParseAddrPort ( "1.2.3.4:12345" ) ,
Type : tailcfg . EndpointSTUN ,
}
stun4_2 := tailcfg . Endpoint {
Addr : netip . MustParseAddrPort ( "5.6.7.8:12345" ) ,
Type : tailcfg . EndpointSTUN ,
}
stun6_1 := tailcfg . Endpoint {
Addr : netip . MustParseAddrPort ( "[2a09:8280:1::1111]:12345" ) ,
Type : tailcfg . EndpointSTUN ,
}
stun6_2 := tailcfg . Endpoint {
Addr : netip . MustParseAddrPort ( "[2a09:8280:1::2222]:12345" ) ,
Type : tailcfg . EndpointSTUN ,
}
start := time . Unix ( 1681503440 , 0 )
steps := [ ] struct {
name string
now time . Time
eps [ ] tailcfg . Endpoint
want [ ] tailcfg . Endpoint
} {
{
name : "initial endpoints" ,
now : start ,
eps : [ ] tailcfg . Endpoint { local , stun4_1 , stun6_1 } ,
want : [ ] tailcfg . Endpoint { local , stun4_1 , stun6_1 } ,
} ,
{
name : "no change" ,
now : start . Add ( 1 * time . Minute ) ,
eps : [ ] tailcfg . Endpoint { local , stun4_1 , stun6_1 } ,
want : [ ] tailcfg . Endpoint { local , stun4_1 , stun6_1 } ,
} ,
{
name : "missing stun4" ,
now : start . Add ( 2 * time . Minute ) ,
eps : [ ] tailcfg . Endpoint { local , stun6_1 } ,
want : [ ] tailcfg . Endpoint { local , stun4_1 , stun6_1 } ,
} ,
{
name : "missing stun6" ,
now : start . Add ( 3 * time . Minute ) ,
eps : [ ] tailcfg . Endpoint { local , stun4_1 } ,
want : [ ] tailcfg . Endpoint { local , stun4_1 , stun6_1 } ,
} ,
{
name : "multiple STUN addresses within timeout" ,
now : start . Add ( 4 * time . Minute ) ,
eps : [ ] tailcfg . Endpoint { local , stun4_2 , stun6_2 } ,
want : [ ] tailcfg . Endpoint { local , stun4_1 , stun4_2 , stun6_1 , stun6_2 } ,
} ,
{
name : "endpoint extended" ,
now : start . Add ( 3 * time . Minute + endpointTrackerLifetime - 1 ) ,
eps : [ ] tailcfg . Endpoint { local } ,
want : [ ] tailcfg . Endpoint {
local , stun4_2 , stun6_2 ,
// stun4_1 had its lifetime extended by the
// "missing stun6" test above to that start
// time plus the lifetime, while stun6 should
// have expired a minute sooner. It should thus
// be in this returned list.
stun4_1 ,
} ,
} ,
{
name : "after timeout" ,
now : start . Add ( 4 * time . Minute + endpointTrackerLifetime + 1 ) ,
eps : [ ] tailcfg . Endpoint { local , stun4_2 , stun6_2 } ,
want : [ ] tailcfg . Endpoint { local , stun4_2 , stun6_2 } ,
} ,
{
name : "after timeout still caches" ,
now : start . Add ( 4 * time . Minute + endpointTrackerLifetime + time . Minute ) ,
eps : [ ] tailcfg . Endpoint { local } ,
want : [ ] tailcfg . Endpoint { local , stun4_2 , stun6_2 } ,
} ,
}
var et endpointTracker
for _ , tt := range steps {
t . Logf ( "STEP: %s" , tt . name )
got := et . update ( tt . now , tt . eps )
// Sort both arrays for comparison
slices . SortFunc ( got , func ( a , b tailcfg . Endpoint ) bool {
return a . Addr . String ( ) < b . Addr . String ( )
} )
slices . SortFunc ( tt . want , func ( a , b tailcfg . Endpoint ) bool {
return a . Addr . String ( ) < b . Addr . String ( )
} )
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "endpoints mismatch\ngot: %+v\nwant: %+v" , got , tt . want )
}
}
}