net/udprelay: replace map+sync.Mutex with sync.Map for VNI lookup

This commit also introduces a sync.Mutex for guarding mutatable fields
on serverEndpoint, now that it is no longer guarded by the sync.Mutex
in Server.

These changes reduce lock contention and by effect increase aggregate
throughput under high flow count load. A benchmark on Linux with AWS
c8gn instances showed a ~30% increase in aggregate throughput (37Gb/s
vs 28Gb/s) for 12 tailscaled flows.

Updates tailscale/corp#35264

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/17143/merge
Jordan Whited 6 days ago committed by Jordan Whited
parent 951d711054
commit a663639bea

@ -77,8 +77,8 @@ type Server struct {
closeCh chan struct{}
netChecker *netcheck.Client
mu sync.Mutex // guards the following fields
macSecrets [][blake2s.Size]byte // [0] is most recent, max 2 elements
mu sync.Mutex // guards the following fields
macSecrets views.Slice[[blake2s.Size]byte] // [0] is most recent, max 2 elements
macSecretRotatedAt mono.Time
derpMap *tailcfg.DERPMap
onlyStaticAddrPorts bool // no dynamic addr port discovery when set
@ -87,8 +87,11 @@ type Server struct {
closed bool
lamportID uint64
nextVNI uint32
byVNI map[uint32]*serverEndpoint
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
// serverEndpointByVNI is consistent with serverEndpointByDisco while mu is
// held, i.e. mu must be held around write ops. Read ops in performance
// sensitive paths, e.g. packet forwarding, do not need to acquire mu.
serverEndpointByVNI sync.Map // key is uint32 (Geneve VNI), value is [*serverEndpoint]
serverEndpointByDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
}
const macSecretRotationInterval = time.Minute * 2
@ -100,23 +103,23 @@ const (
)
// serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state.
// serverEndpoint methods are not thread-safe.
type serverEndpoint struct {
// discoPubKeys contains the key.DiscoPublic of the served clients. The
// indexing of this array aligns with the following fields, e.g.
// discoSharedSecrets[0] is the shared secret to use when sealing
// Disco protocol messages for transmission towards discoPubKeys[0].
discoPubKeys key.SortedPairOfDiscoPublic
discoSharedSecrets [2]key.DiscoShared
discoPubKeys key.SortedPairOfDiscoPublic
discoSharedSecrets [2]key.DiscoShared
lamportID uint64
vni uint32
allocatedAt mono.Time
mu sync.Mutex // guards the following fields
inProgressGeneration [2]uint32 // or zero if a handshake has never started, or has just completed
boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg
lastSeen [2]mono.Time
packetsRx [2]uint64 // num packets received from/sent by each client after they are bound
bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound
lamportID uint64
vni uint32
allocatedAt mono.Time
}
func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg disco.BindUDPRelayEndpointCommon) ([blake2s.Size]byte, error) {
@ -141,7 +144,10 @@ func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg di
return out, nil
}
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) {
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) {
e.mu.Lock()
defer e.mu.Unlock()
if senderIndex != 0 && senderIndex != 1 {
return nil, netip.AddrPort{}
}
@ -186,7 +192,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
}
reply = append(reply, disco.Magic...)
reply = serverDisco.AppendTo(reply)
mac, err := blakeMACFromBindMsg(macSecrets[0], from, m.BindUDPRelayEndpointCommon)
mac, err := blakeMACFromBindMsg(macSecrets.At(0), from, m.BindUDPRelayEndpointCommon)
if err != nil {
return nil, netip.AddrPort{}
}
@ -206,7 +212,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
// silently drop
return nil, netip.AddrPort{}
}
for _, macSecret := range macSecrets {
for _, macSecret := range macSecrets.All() {
mac, err := blakeMACFromBindMsg(macSecret, from, discoMsg.BindUDPRelayEndpointCommon)
if err != nil {
// silently drop
@ -230,7 +236,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
}
}
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) {
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) {
senderRaw, isDiscoMsg := disco.Source(b)
if !isDiscoMsg {
// Not a Disco message
@ -265,7 +271,9 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by
}
func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mono.Time) (write []byte, to netip.AddrPort) {
if !e.isBound() {
e.mu.Lock()
defer e.mu.Unlock()
if !e.isBoundLocked() {
// not a control packet, but serverEndpoint isn't bound
return nil, netip.AddrPort{}
}
@ -287,7 +295,9 @@ func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mon
}
func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
if !e.isBound() {
e.mu.Lock()
defer e.mu.Unlock()
if !e.isBoundLocked() {
if now.Sub(e.allocatedAt) > bindLifetime {
return true
}
@ -299,9 +309,9 @@ func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifet
return false
}
// isBound returns true if both clients have completed a 3-way handshake,
// isBoundLocked returns true if both clients have completed a 3-way handshake,
// otherwise false.
func (e *serverEndpoint) isBound() bool {
func (e *serverEndpoint) isBoundLocked() bool {
return e.boundAddrPorts[0].IsValid() &&
e.boundAddrPorts[1].IsValid()
}
@ -313,15 +323,14 @@ func (e *serverEndpoint) isBound() bool {
// used.
func NewServer(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (s *Server, err error) {
s = &Server{
logf: logf,
disco: key.NewDisco(),
bindLifetime: defaultBindLifetime,
steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}),
onlyStaticAddrPorts: onlyStaticAddrPorts,
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
logf: logf,
disco: key.NewDisco(),
bindLifetime: defaultBindLifetime,
steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}),
onlyStaticAddrPorts: onlyStaticAddrPorts,
serverEndpointByDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
nextVNI: minVNI,
}
s.discoPublic = s.disco.Public()
@ -640,8 +649,8 @@ func (s *Server) Close() error {
// acquire s.mu.
s.mu.Lock()
defer s.mu.Unlock()
clear(s.byVNI)
clear(s.byDisco)
s.serverEndpointByVNI.Clear()
clear(s.serverEndpointByDisco)
s.closed = true
s.bus.Close()
})
@ -659,10 +668,10 @@ func (s *Server) endpointGCLoop() {
// holding s.mu for the duration. Keep it simple (and slow) for now.
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.byDisco {
for k, v := range s.serverEndpointByDisco {
if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
delete(s.byDisco, k)
delete(s.byVNI, v.vni)
delete(s.serverEndpointByDisco, k)
s.serverEndpointByVNI.Delete(v.vni)
}
}
}
@ -690,12 +699,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n
if err != nil {
return nil, netip.AddrPort{}
}
// TODO: consider performance implications of holding s.mu for the remainder
// of this method, which does a bunch of disco/crypto work depending. Keep
// it simple (and slow) for now.
s.mu.Lock()
defer s.mu.Unlock()
e, ok := s.byVNI[gh.VNI.Get()]
e, ok := s.serverEndpointByVNI.Load(gh.VNI.Get())
if !ok {
// unknown VNI
return nil, netip.AddrPort{}
@ -708,27 +712,36 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n
return nil, netip.AddrPort{}
}
msg := b[packet.GeneveFixedHeaderLength:]
s.maybeRotateMACSecretLocked(now)
return e.handleSealedDiscoControlMsg(from, msg, s.discoPublic, s.macSecrets, now)
secrets := s.getMACSecrets(now)
return e.(*serverEndpoint).handleSealedDiscoControlMsg(from, msg, s.discoPublic, secrets, now)
}
return e.handleDataPacket(from, b, now)
return e.(*serverEndpoint).handleDataPacket(from, b, now)
}
func (s *Server) getMACSecrets(now mono.Time) views.Slice[[blake2s.Size]byte] {
s.mu.Lock()
defer s.mu.Unlock()
s.maybeRotateMACSecretLocked(now)
return s.macSecrets
}
func (s *Server) maybeRotateMACSecretLocked(now mono.Time) {
if !s.macSecretRotatedAt.IsZero() && now.Sub(s.macSecretRotatedAt) < macSecretRotationInterval {
return
}
switch len(s.macSecrets) {
secrets := s.macSecrets.AsSlice()
switch len(secrets) {
case 0:
s.macSecrets = make([][blake2s.Size]byte, 1, 2)
secrets = make([][blake2s.Size]byte, 1, 2)
case 1:
s.macSecrets = append(s.macSecrets, [blake2s.Size]byte{})
secrets = append(secrets, [blake2s.Size]byte{})
fallthrough
case 2:
s.macSecrets[1] = s.macSecrets[0]
secrets[1] = secrets[0]
}
rand.Read(s.macSecrets[0][:])
rand.Read(secrets[0][:])
s.macSecretRotatedAt = now
s.macSecrets = views.SliceOf(secrets)
return
}
@ -838,7 +851,7 @@ func (s *Server) getNextVNILocked() (uint32, error) {
} else {
s.nextVNI++
}
_, ok := s.byVNI[vni]
_, ok := s.serverEndpointByVNI.Load(vni)
if !ok {
return vni, nil
}
@ -877,7 +890,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
}
pair := key.NewSortedPairOfDiscoPublic(discoA, discoB)
e, ok := s.byDisco[pair]
e, ok := s.serverEndpointByDisco[pair]
if ok {
// Return the existing allocation. Clients can resolve duplicate
// [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID].
@ -915,8 +928,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0])
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1])
s.byDisco[pair] = e
s.byVNI[e.vni] = e
s.serverEndpointByDisco[pair] = e
s.serverEndpointByVNI.Store(e.vni, e)
s.logf("allocated endpoint vni=%d lamportID=%d disco[0]=%v disco[1]=%v", e.vni, e.lamportID, pair.Get()[0].ShortString(), pair.Get()[1].ShortString())
return endpoint.ServerEndpoint{
@ -930,19 +943,19 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
}, nil
}
// extractClientInfo constructs a [status.ClientInfo] for one of the two peer
// relay clients involved in this session.
func extractClientInfo(idx int, ep *serverEndpoint) status.ClientInfo {
if idx != 0 && idx != 1 {
panic(fmt.Sprintf("idx passed to extractClientInfo() must be 0 or 1; got %d", idx))
}
return status.ClientInfo{
Endpoint: ep.boundAddrPorts[idx],
ShortDisco: ep.discoPubKeys.Get()[idx].ShortString(),
PacketsTx: ep.packetsRx[idx],
BytesTx: ep.bytesRx[idx],
// extractClientInfo constructs a [status.ClientInfo] for both relay clients
// involved in this session.
func (e *serverEndpoint) extractClientInfo() [2]status.ClientInfo {
e.mu.Lock()
defer e.mu.Unlock()
ret := [2]status.ClientInfo{}
for i := range e.boundAddrPorts {
ret[i].Endpoint = e.boundAddrPorts[i]
ret[i].ShortDisco = e.discoPubKeys.Get()[i].ShortString()
ret[i].PacketsTx = e.packetsRx[i]
ret[i].BytesTx = e.bytesRx[i]
}
return ret
}
// GetSessions returns a slice of peer relay session statuses, with each
@ -955,14 +968,13 @@ func (s *Server) GetSessions() []status.ServerSession {
if s.closed {
return nil
}
var sessions = make([]status.ServerSession, 0, len(s.byDisco))
for _, se := range s.byDisco {
c1 := extractClientInfo(0, se)
c2 := extractClientInfo(1, se)
var sessions = make([]status.ServerSession, 0, len(s.serverEndpointByDisco))
for _, se := range s.serverEndpointByDisco {
clientInfos := se.extractClientInfo()
sessions = append(sessions, status.ServerSession{
VNI: se.vni,
Client1: c1,
Client2: c2,
Client1: clientInfos[0],
Client2: clientInfos[1],
})
}
return sessions

@ -339,19 +339,18 @@ func TestServer_getNextVNILocked(t *testing.T) {
c := qt.New(t)
s := &Server{
nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
}
for i := uint64(0); i < uint64(totalPossibleVNI); i++ {
vni, err := s.getNextVNILocked()
if err != nil { // using quicktest here triples test time
t.Fatal(err)
}
s.byVNI[vni] = nil
s.serverEndpointByVNI.Store(vni, nil)
}
c.Assert(s.nextVNI, qt.Equals, minVNI)
_, err := s.getNextVNILocked()
c.Assert(err, qt.IsNotNil)
delete(s.byVNI, minVNI)
s.serverEndpointByVNI.Delete(minVNI)
_, err = s.getNextVNILocked()
c.Assert(err, qt.IsNil)
}
@ -455,17 +454,17 @@ func TestServer_maybeRotateMACSecretLocked(t *testing.T) {
s := &Server{}
start := mono.Now()
s.maybeRotateMACSecretLocked(start)
qt.Assert(t, len(s.macSecrets), qt.Equals, 1)
macSecret := s.macSecrets[0]
qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1)
macSecret := s.macSecrets.At(0)
s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval - time.Nanosecond))
qt.Assert(t, len(s.macSecrets), qt.Equals, 1)
qt.Assert(t, s.macSecrets[0], qt.Equals, macSecret)
qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1)
qt.Assert(t, s.macSecrets.At(0), qt.Equals, macSecret)
s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval))
qt.Assert(t, len(s.macSecrets), qt.Equals, 2)
qt.Assert(t, s.macSecrets[1], qt.Equals, macSecret)
qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1])
qt.Assert(t, s.macSecrets.Len(), qt.Equals, 2)
qt.Assert(t, s.macSecrets.At(1), qt.Equals, macSecret)
qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1))
s.maybeRotateMACSecretLocked(s.macSecretRotatedAt.Add(macSecretRotationInterval))
qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[0])
qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[1])
qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1])
qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(0))
qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(1))
qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1))
}

Loading…
Cancel
Save