net/udprelay: implement Server.SetStaticAddrPorts (#17909)

Only used in tests for now.

Updates tailscale/corp#31489

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/17752/merge
Jordan Whited 3 weeks ago committed by GitHub
parent a96ef432cf
commit e1f0ad7a05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -8,14 +8,10 @@ package relayserver
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"net/netip"
"strings"
"sync" "sync"
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/envknob"
"tailscale.com/feature" "tailscale.com/feature"
"tailscale.com/ipn" "tailscale.com/ipn"
"tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnext"
@ -71,8 +67,8 @@ func servePeerRelayDebugSessions(h *localapi.Handler, w http.ResponseWriter, r *
// imported. // imported.
func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) {
e := &extension{ e := &extension{
newServerFn: func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { newServerFn: func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
return udprelay.NewServer(logf, port, overrideAddrs) return udprelay.NewServer(logf, port, onlyStaticAddrPorts)
}, },
logf: logger.WithPrefix(logf, featureName+": "), logf: logger.WithPrefix(logf, featureName+": "),
} }
@ -94,7 +90,7 @@ type relayServer interface {
// extension is an [ipnext.Extension] managing the relay server on platforms // extension is an [ipnext.Extension] managing the relay server on platforms
// that import this package. // that import this package.
type extension struct { type extension struct {
newServerFn func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) // swappable for tests newServerFn func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) // swappable for tests
logf logger.Logf logf logger.Logf
ec *eventbus.Client ec *eventbus.Client
respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp] respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp]
@ -170,7 +166,7 @@ func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) {
} }
func (e *extension) tryStartRelayServerLocked() { func (e *extension) tryStartRelayServerLocked() {
rs, err := e.newServerFn(e.logf, *e.port, overrideAddrs()) rs, err := e.newServerFn(e.logf, *e.port, false)
if err != nil { if err != nil {
e.logf("error initializing server: %v", err) e.logf("error initializing server: %v", err)
return return
@ -217,26 +213,6 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV
e.handleRelayServerLifetimeLocked() e.handleRelayServerLifetimeLocked()
} }
// overrideAddrs returns TS_DEBUG_RELAY_SERVER_ADDRS as []netip.Addr, if set. It
// can be between 0 and 3 comma-separated Addrs. TS_DEBUG_RELAY_SERVER_ADDRS is
// not a stable interface, and is subject to change.
var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) {
all := envknob.String("TS_DEBUG_RELAY_SERVER_ADDRS")
const max = 3
remain := all
for remain != "" && len(ret) < max {
var s string
s, remain, _ = strings.Cut(remain, ",")
addr, err := netip.ParseAddr(s)
if err != nil {
log.Printf("ignoring invalid Addr %q in TS_DEBUG_RELAY_SERVER_ADDRS %q: %v", s, all, err)
continue
}
ret = append(ret, addr)
}
return
})
func (e *extension) stopRelayServerLocked() { func (e *extension) stopRelayServerLocked() {
if e.rs != nil { if e.rs != nil {
e.rs.Close() e.rs.Close()

@ -5,7 +5,6 @@ package relayserver
import ( import (
"errors" "errors"
"net/netip"
"reflect" "reflect"
"testing" "testing"
@ -157,7 +156,7 @@ func Test_extension_profileStateChanged(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
e := ipne.(*extension) e := ipne.(*extension)
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
return &mockRelayServer{}, nil return &mockRelayServer{}, nil
} }
e.port = tt.fields.port e.port = tt.fields.port
@ -289,7 +288,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
e := ipne.(*extension) e := ipne.(*extension)
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) { e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
return &mockRelayServer{}, nil return &mockRelayServer{}, nil
} }
e.shutdown = tt.shutdown e.shutdown = tt.shutdown

@ -36,6 +36,7 @@ import (
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/types/views"
"tailscale.com/util/eventbus" "tailscale.com/util/eventbus"
"tailscale.com/util/set" "tailscale.com/util/set"
) )
@ -72,15 +73,16 @@ type Server struct {
closeCh chan struct{} closeCh chan struct{}
netChecker *netcheck.Client netChecker *netcheck.Client
mu sync.Mutex // guards the following fields mu sync.Mutex // guards the following fields
derpMap *tailcfg.DERPMap derpMap *tailcfg.DERPMap
addrDiscoveryOnce bool // addrDiscovery completed once (successfully or unsuccessfully) onlyStaticAddrPorts bool // no dynamic addr port discovery when set
addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts]
closed bool dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs
lamportID uint64 closed bool
nextVNI uint32 lamportID uint64
byVNI map[uint32]*serverEndpoint nextVNI uint32
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint byVNI map[uint32]*serverEndpoint
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
} }
const ( const (
@ -278,15 +280,17 @@ func (e *serverEndpoint) isBound() bool {
// NewServer constructs a [Server] listening on port. If port is zero, then // NewServer constructs a [Server] listening on port. If port is zero, then
// port selection is left up to the host networking stack. If // port selection is left up to the host networking stack. If
// len(overrideAddrs) > 0 these will be used in place of dynamic discovery, // onlyStaticAddrPorts is true, then dynamic addr:port discovery will be
// which is useful to override in tests. // disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) { // used.
func NewServer(logf logger.Logf, port int, onlyStaticAddrPorts bool) (s *Server, err error) {
s = &Server{ s = &Server{
logf: logf, logf: logf,
disco: key.NewDisco(), disco: key.NewDisco(),
bindLifetime: defaultBindLifetime, bindLifetime: defaultBindLifetime,
steadyStateLifetime: defaultSteadyStateLifetime, steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
onlyStaticAddrPorts: onlyStaticAddrPorts,
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
nextVNI: minVNI, nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint), byVNI: make(map[uint32]*serverEndpoint),
@ -321,19 +325,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
return nil, err return nil, err
} }
if len(overrideAddrs) > 0 { if !s.onlyStaticAddrPorts {
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
for _, addr := range overrideAddrs {
if addr.IsValid() {
if addr.Is4() {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
} else if s.uc6 != nil {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
}
}
}
s.addrPorts = addrPorts.Slice()
} else {
s.wg.Add(1) s.wg.Add(1)
go s.addrDiscoveryLoop() go s.addrDiscoveryLoop()
} }
@ -429,8 +421,7 @@ func (s *Server) addrDiscoveryLoop() {
s.logf("error discovering IP:port candidates: %v", err) s.logf("error discovering IP:port candidates: %v", err)
} }
s.mu.Lock() s.mu.Lock()
s.addrPorts = addrPorts s.dynamicAddrPorts = addrPorts
s.addrDiscoveryOnce = true
s.mu.Unlock() s.mu.Unlock()
case <-s.closeCh: case <-s.closeCh:
return return
@ -747,6 +738,15 @@ func (s *Server) getNextVNILocked() (uint32, error) {
return 0, errors.New("VNI pool exhausted") return 0, errors.New("VNI pool exhausted")
} }
// getAllAddrPortsCopyLocked returns a copy of the combined
// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices.
func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort {
addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len())
addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...)
addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...)
return addrPorts
}
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair // AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB // of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
// it is returned without modification/reallocation. AllocateEndpoint returns // it is returned without modification/reallocation. AllocateEndpoint returns
@ -760,11 +760,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
return endpoint.ServerEndpoint{}, ErrServerClosed return endpoint.ServerEndpoint{}, ErrServerClosed
} }
if len(s.addrPorts) == 0 { if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 {
if !s.addrDiscoveryOnce { return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
}
return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known")
} }
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 { if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
@ -787,7 +784,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
// consider storing them (maybe interning) in the [*serverEndpoint] // consider storing them (maybe interning) in the [*serverEndpoint]
// at allocation time. // at allocation time.
ClientDisco: pair.Get(), ClientDisco: pair.Get(),
AddrPorts: slices.Clone(s.addrPorts), AddrPorts: s.getAllAddrPortsCopyLocked(),
VNI: e.vni, VNI: e.vni,
LamportID: e.lamportID, LamportID: e.lamportID,
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
@ -817,7 +814,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
return endpoint.ServerEndpoint{ return endpoint.ServerEndpoint{
ServerDisco: s.discoPublic, ServerDisco: s.discoPublic,
ClientDisco: pair.Get(), ClientDisco: pair.Get(),
AddrPorts: slices.Clone(s.addrPorts), AddrPorts: s.getAllAddrPortsCopyLocked(),
VNI: e.vni, VNI: e.vni,
LamportID: e.lamportID, LamportID: e.lamportID,
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
@ -880,3 +877,13 @@ func (s *Server) getDERPMap() *tailcfg.DERPMap {
defer s.mu.Unlock() defer s.mu.Unlock()
return s.derpMap return s.derpMap
} }
// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise
// as candidates it is potentially reachable over, in combination with
// dynamically discovered pairs. This replaces any previously-provided static
// values.
func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) {
s.mu.Lock()
defer s.mu.Unlock()
s.staticAddrPorts = addrPorts
}

@ -17,6 +17,7 @@ import (
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/views"
) )
type testClient struct { type testClient struct {
@ -185,31 +186,40 @@ func TestServer(t *testing.T) {
cases := []struct { cases := []struct {
name string name string
overrideAddrs []netip.Addr staticAddrs []netip.Addr
forceClientsMixedAF bool forceClientsMixedAF bool
}{ }{
{ {
name: "over ipv4", name: "over ipv4",
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
}, },
{ {
name: "over ipv6", name: "over ipv6",
overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")}, staticAddrs: []netip.Addr{netip.MustParseAddr("::1")},
}, },
{ {
name: "mixed address families", name: "mixed address families",
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")},
forceClientsMixedAF: true, forceClientsMixedAF: true,
}, },
} }
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server, err := NewServer(t.Logf, 0, tt.overrideAddrs) server, err := NewServer(t.Logf, 0, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer server.Close() defer server.Close()
addrPorts := make([]netip.AddrPort, 0, len(tt.staticAddrs))
for _, addr := range tt.staticAddrs {
if addr.Is4() {
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc4Port))
} else if server.uc6Port != 0 {
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc6Port))
}
}
server.SetStaticAddrPorts(views.SliceOf(addrPorts))
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil { if err != nil {

Loading…
Cancel
Save