diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 868d5f61a..cfa372bd7 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -8,14 +8,10 @@ package relayserver import ( "encoding/json" "fmt" - "log" "net/http" - "net/netip" - "strings" "sync" "tailscale.com/disco" - "tailscale.com/envknob" "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnext" @@ -71,8 +67,8 @@ func servePeerRelayDebugSessions(h *localapi.Handler, w http.ResponseWriter, r * // imported. func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { e := &extension{ - newServerFn: func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { - return udprelay.NewServer(logf, port, overrideAddrs) + newServerFn: func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) { + return udprelay.NewServer(logf, port, onlyStaticAddrPorts) }, logf: logger.WithPrefix(logf, featureName+": "), } @@ -94,7 +90,7 @@ type relayServer interface { // extension is an [ipnext.Extension] managing the relay server on platforms // that import this package. type extension struct { - newServerFn func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) // swappable for tests + newServerFn func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) // swappable for tests logf logger.Logf ec *eventbus.Client respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp] @@ -170,7 +166,7 @@ func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) { } func (e *extension) tryStartRelayServerLocked() { - rs, err := e.newServerFn(e.logf, *e.port, overrideAddrs()) + rs, err := e.newServerFn(e.logf, *e.port, false) if err != nil { e.logf("error initializing server: %v", err) return @@ -217,26 +213,6 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV e.handleRelayServerLifetimeLocked() } -// overrideAddrs returns TS_DEBUG_RELAY_SERVER_ADDRS as []netip.Addr, if set. It -// can be between 0 and 3 comma-separated Addrs. TS_DEBUG_RELAY_SERVER_ADDRS is -// not a stable interface, and is subject to change. -var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) { - all := envknob.String("TS_DEBUG_RELAY_SERVER_ADDRS") - const max = 3 - remain := all - for remain != "" && len(ret) < max { - var s string - s, remain, _ = strings.Cut(remain, ",") - addr, err := netip.ParseAddr(s) - if err != nil { - log.Printf("ignoring invalid Addr %q in TS_DEBUG_RELAY_SERVER_ADDRS %q: %v", s, all, err) - continue - } - ret = append(ret, addr) - } - return -}) - func (e *extension) stopRelayServerLocked() { if e.rs != nil { e.rs.Close() diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go index 2184b5175..3d71c55d7 100644 --- a/feature/relayserver/relayserver_test.go +++ b/feature/relayserver/relayserver_test.go @@ -5,7 +5,6 @@ package relayserver import ( "errors" - "net/netip" "reflect" "testing" @@ -157,7 +156,7 @@ func Test_extension_profileStateChanged(t *testing.T) { t.Fatal(err) } e := ipne.(*extension) - e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { + e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) { return &mockRelayServer{}, nil } e.port = tt.fields.port @@ -289,7 +288,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { t.Fatal(err) } e := ipne.(*extension) - e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { + e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) { return &mockRelayServer{}, nil } e.shutdown = tt.shutdown diff --git a/net/udprelay/server.go b/net/udprelay/server.go index c050c9416..7138cec7a 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -36,6 +36,7 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/nettype" + "tailscale.com/types/views" "tailscale.com/util/eventbus" "tailscale.com/util/set" ) @@ -72,15 +73,16 @@ type Server struct { closeCh chan struct{} netChecker *netcheck.Client - mu sync.Mutex // guards the following fields - derpMap *tailcfg.DERPMap - addrDiscoveryOnce bool // addrDiscovery completed once (successfully or unsuccessfully) - addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints - closed bool - lamportID uint64 - nextVNI uint32 - byVNI map[uint32]*serverEndpoint - byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint + mu sync.Mutex // guards the following fields + derpMap *tailcfg.DERPMap + onlyStaticAddrPorts bool // no dynamic addr port discovery when set + staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts] + dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs + closed bool + lamportID uint64 + nextVNI uint32 + byVNI map[uint32]*serverEndpoint + byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint } const ( @@ -278,15 +280,17 @@ func (e *serverEndpoint) isBound() bool { // NewServer constructs a [Server] listening on port. If port is zero, then // port selection is left up to the host networking stack. If -// len(overrideAddrs) > 0 these will be used in place of dynamic discovery, -// which is useful to override in tests. -func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) { +// onlyStaticAddrPorts is true, then dynamic addr:port discovery will be +// disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be +// used. +func NewServer(logf logger.Logf, port int, onlyStaticAddrPorts bool) (s *Server, err error) { s = &Server{ logf: logf, disco: key.NewDisco(), bindLifetime: defaultBindLifetime, steadyStateLifetime: defaultSteadyStateLifetime, closeCh: make(chan struct{}), + onlyStaticAddrPorts: onlyStaticAddrPorts, byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), nextVNI: minVNI, byVNI: make(map[uint32]*serverEndpoint), @@ -321,19 +325,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve return nil, err } - if len(overrideAddrs) > 0 { - addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs)) - for _, addr := range overrideAddrs { - if addr.IsValid() { - if addr.Is4() { - addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port)) - } else if s.uc6 != nil { - addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port)) - } - } - } - s.addrPorts = addrPorts.Slice() - } else { + if !s.onlyStaticAddrPorts { s.wg.Add(1) go s.addrDiscoveryLoop() } @@ -429,8 +421,7 @@ func (s *Server) addrDiscoveryLoop() { s.logf("error discovering IP:port candidates: %v", err) } s.mu.Lock() - s.addrPorts = addrPorts - s.addrDiscoveryOnce = true + s.dynamicAddrPorts = addrPorts s.mu.Unlock() case <-s.closeCh: return @@ -747,6 +738,15 @@ func (s *Server) getNextVNILocked() (uint32, error) { return 0, errors.New("VNI pool exhausted") } +// getAllAddrPortsCopyLocked returns a copy of the combined +// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices. +func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort { + addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len()) + addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...) + addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...) + return addrPorts +} + // AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair // of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB // it is returned without modification/reallocation. AllocateEndpoint returns @@ -760,11 +760,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv return endpoint.ServerEndpoint{}, ErrServerClosed } - if len(s.addrPorts) == 0 { - if !s.addrDiscoveryOnce { - return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter} - } - return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known") + if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 { + return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter} } if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 { @@ -787,7 +784,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv // consider storing them (maybe interning) in the [*serverEndpoint] // at allocation time. ClientDisco: pair.Get(), - AddrPorts: slices.Clone(s.addrPorts), + AddrPorts: s.getAllAddrPortsCopyLocked(), VNI: e.vni, LamportID: e.lamportID, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, @@ -817,7 +814,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv return endpoint.ServerEndpoint{ ServerDisco: s.discoPublic, ClientDisco: pair.Get(), - AddrPorts: slices.Clone(s.addrPorts), + AddrPorts: s.getAllAddrPortsCopyLocked(), VNI: e.vni, LamportID: e.lamportID, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, @@ -880,3 +877,13 @@ func (s *Server) getDERPMap() *tailcfg.DERPMap { defer s.mu.Unlock() return s.derpMap } + +// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise +// as candidates it is potentially reachable over, in combination with +// dynamically discovered pairs. This replaces any previously-provided static +// values. +func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) { + s.mu.Lock() + defer s.mu.Unlock() + s.staticAddrPorts = addrPorts +} diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index bf7f0a9b5..6c3d61658 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -17,6 +17,7 @@ import ( "tailscale.com/disco" "tailscale.com/net/packet" "tailscale.com/types/key" + "tailscale.com/types/views" ) type testClient struct { @@ -185,31 +186,40 @@ func TestServer(t *testing.T) { cases := []struct { name string - overrideAddrs []netip.Addr + staticAddrs []netip.Addr forceClientsMixedAF bool }{ { - name: "over ipv4", - overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + name: "over ipv4", + staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, }, { - name: "over ipv6", - overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")}, + name: "over ipv6", + staticAddrs: []netip.Addr{netip.MustParseAddr("::1")}, }, { name: "mixed address families", - overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, + staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, forceClientsMixedAF: true, }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - server, err := NewServer(t.Logf, 0, tt.overrideAddrs) + server, err := NewServer(t.Logf, 0, true) if err != nil { t.Fatal(err) } defer server.Close() + addrPorts := make([]netip.AddrPort, 0, len(tt.staticAddrs)) + for _, addr := range tt.staticAddrs { + if addr.Is4() { + addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc4Port)) + } else if server.uc6Port != 0 { + addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc6Port)) + } + } + server.SetStaticAddrPorts(views.SliceOf(addrPorts)) endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) if err != nil {