diff --git a/net/udprelay/server.go b/net/udprelay/server.go index de1376b64..69e0de095 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -77,11 +77,17 @@ type Server struct { addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints closed bool lamportID uint64 - vniPool []uint32 // the pool of available VNIs + nextVNI uint32 byVNI map[uint32]*serverEndpoint byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint } +const ( + minVNI = uint32(1) + maxVNI = uint32(1<<24 - 1) + totalPossibleVNI = maxVNI - minVNI + 1 +) + // serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state. // serverEndpoint methods are not thread-safe. type serverEndpoint struct { @@ -281,15 +287,10 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve steadyStateLifetime: defaultSteadyStateLifetime, closeCh: make(chan struct{}), byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), + nextVNI: minVNI, byVNI: make(map[uint32]*serverEndpoint), } s.discoPublic = s.disco.Public() - // TODO: instead of allocating 10s of MBs for the full pool, allocate - // smaller chunks and increase as needed - s.vniPool = make([]uint32, 0, 1<<24-1) - for i := 1; i < 1<<24; i++ { - s.vniPool = append(s.vniPool, uint32(i)) - } // TODO(creachadair): Find a way to plumb this in during initialization. // As-written, messages published here will not be seen by other components @@ -572,7 +573,6 @@ func (s *Server) Close() error { defer s.mu.Unlock() clear(s.byVNI) clear(s.byDisco) - s.vniPool = nil s.closed = true s.bus.Close() }) @@ -594,7 +594,6 @@ func (s *Server) endpointGCLoop() { if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) { delete(s.byDisco, k) delete(s.byVNI, v.vni) - s.vniPool = append(s.vniPool, v.vni) } } } @@ -729,6 +728,27 @@ func (e ErrServerNotReady) Error() string { return fmt.Sprintf("server not ready, retry after %v", e.RetryAfter) } +// getNextVNILocked returns the next available VNI. It implements the +// "Traditional BSD Port Selection Algorithm" from RFC6056. This algorithm does +// not attempt to obfuscate the selection, i.e. the selection is predictable. +// For now, we favor simplicity and reducing VNI re-use over more complex +// ephemeral port (VNI) selection algorithms. +func (s *Server) getNextVNILocked() (uint32, error) { + for i := uint32(0); i < totalPossibleVNI; i++ { + vni := s.nextVNI + if vni == maxVNI { + s.nextVNI = minVNI + } else { + s.nextVNI++ + } + _, ok := s.byVNI[vni] + if !ok { + return vni, nil + } + } + return 0, errors.New("VNI pool exhausted") +} + // AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair // of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB // it is returned without modification/reallocation. AllocateEndpoint returns @@ -777,8 +797,9 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv }, nil } - if len(s.vniPool) == 0 { - return endpoint.ServerEndpoint{}, errors.New("VNI pool exhausted") + vni, err := s.getNextVNILocked() + if err != nil { + return endpoint.ServerEndpoint{}, err } s.lamportID++ @@ -786,10 +807,10 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv discoPubKeys: pair, lamportID: s.lamportID, allocatedAt: time.Now(), + vni: vni, } e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0]) e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1]) - e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:] s.byDisco[pair] = e s.byVNI[e.vni] = e diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index 8fc4a4f78..bf7f0a9b5 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + qt "github.com/frankban/quicktest" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "go4.org/mem" @@ -319,3 +320,25 @@ func TestServer(t *testing.T) { }) } } + +func TestServer_getNextVNILocked(t *testing.T) { + t.Parallel() + c := qt.New(t) + s := &Server{ + nextVNI: minVNI, + byVNI: make(map[uint32]*serverEndpoint), + } + for i := uint64(0); i < uint64(totalPossibleVNI); i++ { + vni, err := s.getNextVNILocked() + if err != nil { // using quicktest here triples test time + t.Fatal(err) + } + s.byVNI[vni] = nil + } + c.Assert(s.nextVNI, qt.Equals, minVNI) + _, err := s.getNextVNILocked() + c.Assert(err, qt.IsNotNil) + delete(s.byVNI, minVNI) + _, err = s.getNextVNILocked() + c.Assert(err, qt.IsNil) +}