net/udprelay: replace VNI pool with selection algorithm (#17868)

This reduces memory usage when tailscaled is acting as a peer relay.

Updates #17801

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/17881/head
Jordan Whited 3 weeks ago committed by GitHub
parent 31fe75ad9e
commit f4f9dd7f8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -77,11 +77,17 @@ type Server struct {
addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
closed bool closed bool
lamportID uint64 lamportID uint64
vniPool []uint32 // the pool of available VNIs nextVNI uint32
byVNI map[uint32]*serverEndpoint byVNI map[uint32]*serverEndpoint
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
} }
const (
minVNI = uint32(1)
maxVNI = uint32(1<<24 - 1)
totalPossibleVNI = maxVNI - minVNI + 1
)
// serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state. // serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state.
// serverEndpoint methods are not thread-safe. // serverEndpoint methods are not thread-safe.
type serverEndpoint struct { type serverEndpoint struct {
@ -281,15 +287,10 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
steadyStateLifetime: defaultSteadyStateLifetime, steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint), byVNI: make(map[uint32]*serverEndpoint),
} }
s.discoPublic = s.disco.Public() s.discoPublic = s.disco.Public()
// TODO: instead of allocating 10s of MBs for the full pool, allocate
// smaller chunks and increase as needed
s.vniPool = make([]uint32, 0, 1<<24-1)
for i := 1; i < 1<<24; i++ {
s.vniPool = append(s.vniPool, uint32(i))
}
// TODO(creachadair): Find a way to plumb this in during initialization. // TODO(creachadair): Find a way to plumb this in during initialization.
// As-written, messages published here will not be seen by other components // As-written, messages published here will not be seen by other components
@ -572,7 +573,6 @@ func (s *Server) Close() error {
defer s.mu.Unlock() defer s.mu.Unlock()
clear(s.byVNI) clear(s.byVNI)
clear(s.byDisco) clear(s.byDisco)
s.vniPool = nil
s.closed = true s.closed = true
s.bus.Close() s.bus.Close()
}) })
@ -594,7 +594,6 @@ func (s *Server) endpointGCLoop() {
if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) { if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
delete(s.byDisco, k) delete(s.byDisco, k)
delete(s.byVNI, v.vni) delete(s.byVNI, v.vni)
s.vniPool = append(s.vniPool, v.vni)
} }
} }
} }
@ -729,6 +728,27 @@ func (e ErrServerNotReady) Error() string {
return fmt.Sprintf("server not ready, retry after %v", e.RetryAfter) return fmt.Sprintf("server not ready, retry after %v", e.RetryAfter)
} }
// getNextVNILocked returns the next available VNI. It implements the
// "Traditional BSD Port Selection Algorithm" from RFC6056. This algorithm does
// not attempt to obfuscate the selection, i.e. the selection is predictable.
// For now, we favor simplicity and reducing VNI re-use over more complex
// ephemeral port (VNI) selection algorithms.
func (s *Server) getNextVNILocked() (uint32, error) {
for i := uint32(0); i < totalPossibleVNI; i++ {
vni := s.nextVNI
if vni == maxVNI {
s.nextVNI = minVNI
} else {
s.nextVNI++
}
_, ok := s.byVNI[vni]
if !ok {
return vni, nil
}
}
return 0, errors.New("VNI pool exhausted")
}
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair // AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB // of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
// it is returned without modification/reallocation. AllocateEndpoint returns // it is returned without modification/reallocation. AllocateEndpoint returns
@ -777,8 +797,9 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
}, nil }, nil
} }
if len(s.vniPool) == 0 { vni, err := s.getNextVNILocked()
return endpoint.ServerEndpoint{}, errors.New("VNI pool exhausted") if err != nil {
return endpoint.ServerEndpoint{}, err
} }
s.lamportID++ s.lamportID++
@ -786,10 +807,10 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
discoPubKeys: pair, discoPubKeys: pair,
lamportID: s.lamportID, lamportID: s.lamportID,
allocatedAt: time.Now(), allocatedAt: time.Now(),
vni: vni,
} }
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0]) e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0])
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1]) e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1])
e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
s.byDisco[pair] = e s.byDisco[pair] = e
s.byVNI[e.vni] = e s.byVNI[e.vni] = e

@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
qt "github.com/frankban/quicktest"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"go4.org/mem" "go4.org/mem"
@ -319,3 +320,25 @@ func TestServer(t *testing.T) {
}) })
} }
} }
func TestServer_getNextVNILocked(t *testing.T) {
t.Parallel()
c := qt.New(t)
s := &Server{
nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
}
for i := uint64(0); i < uint64(totalPossibleVNI); i++ {
vni, err := s.getNextVNILocked()
if err != nil { // using quicktest here triples test time
t.Fatal(err)
}
s.byVNI[vni] = nil
}
c.Assert(s.nextVNI, qt.Equals, minVNI)
_, err := s.getNextVNILocked()
c.Assert(err, qt.IsNotNil)
delete(s.byVNI, minVNI)
_, err = s.getNextVNILocked()
c.Assert(err, qt.IsNil)
}

Loading…
Cancel
Save