@ -8,6 +8,7 @@ package udprelay
import (
import (
"bytes"
"bytes"
"context"
"crypto/rand"
"crypto/rand"
"errors"
"errors"
"fmt"
"fmt"
@ -19,11 +20,18 @@ import (
"time"
"time"
"go4.org/mem"
"go4.org/mem"
"tailscale.com/client/local"
"tailscale.com/disco"
"tailscale.com/disco"
"tailscale.com/net/netcheck"
"tailscale.com/net/netmon"
"tailscale.com/net/packet"
"tailscale.com/net/packet"
"tailscale.com/net/stun"
"tailscale.com/net/udprelay/endpoint"
"tailscale.com/net/udprelay/endpoint"
"tailscale.com/tstime"
"tailscale.com/tstime"
"tailscale.com/types/key"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
"tailscale.com/util/set"
)
)
const (
const (
@ -42,25 +50,22 @@ const (
// Server implements an experimental UDP relay server.
// Server implements an experimental UDP relay server.
type Server struct {
type Server struct {
// disco keypair used as part of 3-way bind handshake
// The following fields are initialized once and never mutated.
disco key . DiscoPrivate
logf logger . Logf
disco Public key . DiscoPublic
disco key . DiscoPrivate
discoPublic key . DiscoPublic
bindLifetime time . Duration
bindLifetime time . Duration
steadyStateLifetime time . Duration
steadyStateLifetime time . Duration
bus * eventbus . Bus
// addrPorts contains the ip:port pairs returned as candidate server
uc * net . UDPConn
// endpoints in response to an allocation request.
closeOnce sync . Once
addrPorts [ ] netip . AddrPort
wg sync . WaitGroup
closeCh chan struct { }
uc * net . UDPConn
netChecker * netcheck . Client
closeOnce sync . Once
mu sync . Mutex // guards the following fields
wg sync . WaitGroup
addrPorts [ ] netip . AddrPort // the ip:port pairs returned as candidate endpoints
closeCh chan struct { }
closed bool
closed bool
mu sync . Mutex // guards the following fields
lamportID uint64
lamportID uint64
vniPool [ ] uint32 // the pool of available VNIs
vniPool [ ] uint32 // the pool of available VNIs
byVNI map [ uint32 ] * serverEndpoint
byVNI map [ uint32 ] * serverEndpoint
@ -270,14 +275,13 @@ func (e *serverEndpoint) isBound() bool {
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
// supported. Port may be 0, and what ultimately gets bound is returned as
// supported. Port may be 0, and what ultimately gets bound is returned as
// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as
// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic
// [endpoint.ServerEndpoint.AddrPorts] in response to Server.AllocateEndpoint()
// discovery, which is useful to override in tests.
// requests.
//
//
// TODO: IPv6 support
// TODO: IPv6 support
// TODO: dynamic addrs:port discovery
func NewServer ( logf logger . Logf , port int , overrideAddrs [ ] netip . Addr ) ( s * Server , boundPort uint16 , err error ) {
func NewServer ( port int , addrs [ ] netip . Addr ) ( s * Server , boundPort int , err error ) {
s = & Server {
s = & Server {
logf : logger . WithPrefix ( logf , "relayserver" ) ,
disco : key . NewDisco ( ) ,
disco : key . NewDisco ( ) ,
bindLifetime : defaultBindLifetime ,
bindLifetime : defaultBindLifetime ,
steadyStateLifetime : defaultSteadyStateLifetime ,
steadyStateLifetime : defaultSteadyStateLifetime ,
@ -292,26 +296,120 @@ func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err erro
for i := 1 ; i < 1 << 24 ; i ++ {
for i := 1 ; i < 1 << 24 ; i ++ {
s . vniPool = append ( s . vniPool , uint32 ( i ) )
s . vniPool = append ( s . vniPool , uint32 ( i ) )
}
}
boundPort , err = s . listenOn ( port )
bus := eventbus . New ( )
s . bus = bus
netMon , err := netmon . New ( s . bus , logf )
if err != nil {
if err != nil {
return nil , 0 , err
return nil , 0 , err
}
}
addrPorts := make ( [ ] netip . AddrPort , 0 , len ( addrs ) )
s . netChecker = & netcheck . Client {
for _ , addr := range addrs {
NetMon : netMon ,
addrPort , err := netip . ParseAddrPort ( net . JoinHostPort ( addr . String ( ) , strconv . Itoa ( boundPort ) ) )
Logf : logger . WithPrefix ( logf , "relayserver: netcheck:" ) ,
if err != nil {
SendPacket : func ( b [ ] byte , addrPort netip . AddrPort ) ( int , error ) {
return nil , 0 , err
return s . uc . WriteToUDPAddrPort ( b , addrPort )
}
} ,
addrPorts = append ( addrPorts , addrPort )
}
boundPort , err = s . listenOn ( port )
if err != nil {
return nil , 0 , err
}
}
s . addrPorts = addrPorts
s . wg . Add ( 2 )
s . wg . Add ( 1 )
go s . packetReadLoop ( )
go s . packetReadLoop ( )
s . wg . Add ( 1 )
go s . endpointGCLoop ( )
go s . endpointGCLoop ( )
if len ( overrideAddrs ) > 0 {
var addrPorts set . Set [ netip . AddrPort ]
addrPorts . Make ( )
for _ , addr := range overrideAddrs {
if addr . IsValid ( ) {
addrPorts . Add ( netip . AddrPortFrom ( addr , boundPort ) )
}
}
s . addrPorts = addrPorts . Slice ( )
} else {
s . wg . Add ( 1 )
go s . addrDiscoveryLoop ( )
}
return s , boundPort , nil
return s , boundPort , nil
}
}
func ( s * Server ) listenOn ( port int ) ( int , error ) {
func ( s * Server ) addrDiscoveryLoop ( ) {
defer s . wg . Done ( )
timer := time . NewTimer ( 0 ) // fire immediately
defer timer . Stop ( )
getAddrPorts := func ( ) ( [ ] netip . AddrPort , error ) {
var addrPorts set . Set [ netip . AddrPort ]
addrPorts . Make ( )
// get local addresses
localPort := s . uc . LocalAddr ( ) . ( * net . UDPAddr ) . Port
ips , _ , err := netmon . LocalAddresses ( )
if err != nil {
return nil , err
}
for _ , ip := range ips {
if ip . IsValid ( ) {
addrPorts . Add ( netip . AddrPortFrom ( ip , uint16 ( localPort ) ) )
}
}
// fetch DERPMap to feed to netcheck
derpMapCtx , derpMapCancel := context . WithTimeout ( context . Background ( ) , time . Second )
defer derpMapCancel ( )
localClient := & local . Client { }
// TODO(jwhited): We are in-process so use eventbus or similar.
// local.Client gets us going.
dm , err := localClient . CurrentDERPMap ( derpMapCtx )
if err != nil {
return nil , err
}
// get addrPorts as visible from DERP
netCheckerCtx , netCheckerCancel := context . WithTimeout ( context . Background ( ) , netcheck . ReportTimeout )
defer netCheckerCancel ( )
rep , err := s . netChecker . GetReport ( netCheckerCtx , dm , & netcheck . GetReportOpts {
OnlySTUN : true ,
} )
if err != nil {
return nil , err
}
if rep . GlobalV4 . IsValid ( ) {
addrPorts . Add ( rep . GlobalV4 )
}
if rep . GlobalV6 . IsValid ( ) {
addrPorts . Add ( rep . GlobalV6 )
}
// TODO(jwhited): consider logging if rep.MappingVariesByDestIP as
// that's a hint we are not well-positioned to operate as a UDP relay.
return addrPorts . Slice ( ) , nil
}
for {
select {
case <- timer . C :
// Mirror magicsock behavior for duration between STUN. We consider
// 30s a min bound for NAT timeout.
timer . Reset ( tstime . RandomDurationBetween ( 20 * time . Second , 26 * time . Second ) )
addrPorts , err := getAddrPorts ( )
if err != nil {
s . logf ( "error discovering IP:port candidates: %v" , err )
}
s . mu . Lock ( )
s . addrPorts = addrPorts
s . mu . Unlock ( )
case <- s . closeCh :
return
}
}
}
func ( s * Server ) listenOn ( port int ) ( uint16 , error ) {
uc , err := net . ListenUDP ( "udp4" , & net . UDPAddr { Port : port } )
uc , err := net . ListenUDP ( "udp4" , & net . UDPAddr { Port : port } )
if err != nil {
if err != nil {
return 0 , err
return 0 , err
@ -322,13 +420,13 @@ func (s *Server) listenOn(port int) (int, error) {
s . uc . Close ( )
s . uc . Close ( )
return 0 , err
return 0 , err
}
}
boundPort , err := strconv . Atoi( boundPortStr )
boundPort , err := strconv . ParseUint( boundPortStr , 10 , 16 )
if err != nil {
if err != nil {
s . uc . Close ( )
s . uc . Close ( )
return 0 , err
return 0 , err
}
}
s . uc = uc
s . uc = uc
return boundPort , nil
return uint16 ( boundPort ) , nil
}
}
// Close closes the server.
// Close closes the server.
@ -343,6 +441,7 @@ func (s *Server) Close() error {
clear ( s . byDisco )
clear ( s . byDisco )
s . vniPool = nil
s . vniPool = nil
s . closed = true
s . closed = true
s . bus . Close ( )
} )
} )
return nil
return nil
}
}
@ -378,6 +477,13 @@ func (s *Server) endpointGCLoop() {
}
}
func ( s * Server ) handlePacket ( from netip . AddrPort , b [ ] byte , uw udpWriter ) {
func ( s * Server ) handlePacket ( from netip . AddrPort , b [ ] byte , uw udpWriter ) {
if stun . Is ( b ) && b [ 1 ] == 0x01 {
// A b[1] value of 0x01 (STUN method binding) is sufficiently
// non-overlapping with the Geneve header where the LSB is always 0
// (part of 6 "reserved" bits).
s . netChecker . ReceiveSTUNPacket ( b , from )
return
}
gh := packet . GeneveHeader { }
gh := packet . GeneveHeader { }
err := gh . Decode ( b )
err := gh . Decode ( b )
if err != nil {
if err != nil {
@ -426,6 +532,10 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
return endpoint . ServerEndpoint { } , ErrServerClosed
return endpoint . ServerEndpoint { } , ErrServerClosed
}
}
if len ( s . addrPorts ) == 0 {
return endpoint . ServerEndpoint { } , errors . New ( "server addrPorts are not yet known" )
}
if discoA . Compare ( s . discoPublic ) == 0 || discoB . Compare ( s . discoPublic ) == 0 {
if discoA . Compare ( s . discoPublic ) == 0 || discoB . Compare ( s . discoPublic ) == 0 {
return endpoint . ServerEndpoint { } , fmt . Errorf ( "client disco equals server disco: %s" , s . discoPublic . ShortString ( ) )
return endpoint . ServerEndpoint { } , fmt . Errorf ( "client disco equals server disco: %s" , s . discoPublic . ShortString ( ) )
}
}
@ -439,8 +549,13 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
// to give the client a more accurate picture of the bind window.
// to give the client a more accurate picture of the bind window.
return endpoint . ServerEndpoint {
return endpoint . ServerEndpoint {
ServerDisco : s . discoPublic ,
ServerDisco : s . discoPublic ,
AddrPorts : s . addrPorts ,
// Returning the "latest" addrPorts for an existing allocation is
// the simple choice. It may not be the best depending on client
// behaviors and endpoint state (bound or not). We might want to
// consider storing them (maybe interning) in the [*serverEndpoint]
// at allocation time.
AddrPorts : slices . Clone ( s . addrPorts ) ,
VNI : e . vni ,
VNI : e . vni ,
LamportID : e . lamportID ,
LamportID : e . lamportID ,
BindLifetime : tstime . GoDuration { Duration : s . bindLifetime } ,
BindLifetime : tstime . GoDuration { Duration : s . bindLifetime } ,
@ -469,7 +584,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
return endpoint . ServerEndpoint {
return endpoint . ServerEndpoint {
ServerDisco : s . discoPublic ,
ServerDisco : s . discoPublic ,
AddrPorts : s . addrPorts ,
AddrPorts : s lices. Clone ( s . addrPorts ) ,
VNI : e . vni ,
VNI : e . vni ,
LamportID : e . lamportID ,
LamportID : e . lamportID ,
BindLifetime : tstime . GoDuration { Duration : s . bindLifetime } ,
BindLifetime : tstime . GoDuration { Duration : s . bindLifetime } ,