net/udprelay: replace VNI pool with selection algorithm (#17868)

This reduces memory usage when tailscaled is acting as a peer relay.

Updates #17801

Signed-off-by: Jordan Whited <jordan@tailscale.com>
(cherry picked from commit f4f9dd7f8c)
pull/17969/head
Jordan Whited 3 weeks ago committed by Jordan Whited
parent 771a9d29ff
commit eb03b354f6

@ -77,11 +77,17 @@ type Server struct {
addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
closed bool
lamportID uint64
vniPool []uint32 // the pool of available VNIs
nextVNI uint32
byVNI map[uint32]*serverEndpoint
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
}
const (
minVNI = uint32(1)
maxVNI = uint32(1<<24 - 1)
totalPossibleVNI = maxVNI - minVNI + 1
)
// serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state.
// serverEndpoint methods are not thread-safe.
type serverEndpoint struct {
@ -281,15 +287,10 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}),
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
}
s.discoPublic = s.disco.Public()
// TODO: instead of allocating 10s of MBs for the full pool, allocate
// smaller chunks and increase as needed
s.vniPool = make([]uint32, 0, 1<<24-1)
for i := 1; i < 1<<24; i++ {
s.vniPool = append(s.vniPool, uint32(i))
}
// TODO(creachadair): Find a way to plumb this in during initialization.
// As-written, messages published here will not be seen by other components
@ -557,7 +558,6 @@ func (s *Server) Close() error {
defer s.mu.Unlock()
clear(s.byVNI)
clear(s.byDisco)
s.vniPool = nil
s.closed = true
s.bus.Close()
})
@ -579,7 +579,6 @@ func (s *Server) endpointGCLoop() {
if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
delete(s.byDisco, k)
delete(s.byVNI, v.vni)
s.vniPool = append(s.vniPool, v.vni)
}
}
}
@ -714,6 +713,27 @@ func (e ErrServerNotReady) Error() string {
return fmt.Sprintf("server not ready, retry after %v", e.RetryAfter)
}
// getNextVNILocked returns the next available VNI. It implements the
// "Traditional BSD Port Selection Algorithm" from RFC6056. This algorithm does
// not attempt to obfuscate the selection, i.e. the selection is predictable.
// For now, we favor simplicity and reducing VNI re-use over more complex
// ephemeral port (VNI) selection algorithms.
func (s *Server) getNextVNILocked() (uint32, error) {
for i := uint32(0); i < totalPossibleVNI; i++ {
vni := s.nextVNI
if vni == maxVNI {
s.nextVNI = minVNI
} else {
s.nextVNI++
}
_, ok := s.byVNI[vni]
if !ok {
return vni, nil
}
}
return 0, errors.New("VNI pool exhausted")
}
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
// it is returned without modification/reallocation. AllocateEndpoint returns
@ -762,8 +782,9 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
}, nil
}
if len(s.vniPool) == 0 {
return endpoint.ServerEndpoint{}, errors.New("VNI pool exhausted")
vni, err := s.getNextVNILocked()
if err != nil {
return endpoint.ServerEndpoint{}, err
}
s.lamportID++
@ -771,10 +792,10 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
discoPubKeys: pair,
lamportID: s.lamportID,
allocatedAt: time.Now(),
vni: vni,
}
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0])
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1])
e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
s.byDisco[pair] = e
s.byVNI[e.vni] = e

@ -10,6 +10,7 @@ import (
"testing"
"time"
qt "github.com/frankban/quicktest"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go4.org/mem"
@ -319,3 +320,25 @@ func TestServer(t *testing.T) {
})
}
}
func TestServer_getNextVNILocked(t *testing.T) {
t.Parallel()
c := qt.New(t)
s := &Server{
nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
}
for i := uint64(0); i < uint64(totalPossibleVNI); i++ {
vni, err := s.getNextVNILocked()
if err != nil { // using quicktest here triples test time
t.Fatal(err)
}
s.byVNI[vni] = nil
}
c.Assert(s.nextVNI, qt.Equals, minVNI)
_, err := s.getNextVNILocked()
c.Assert(err, qt.IsNotNil)
delete(s.byVNI, minVNI)
_, err = s.getNextVNILocked()
c.Assert(err, qt.IsNil)
}

Loading…
Cancel
Save