diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index a8003b4ae..b84eb2ec6 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -25,6 +25,7 @@ import ( "syscall" "time" + "github.com/golang/groupcache/lru" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/wgcfg" @@ -66,8 +67,8 @@ type Conn struct { // bufferedIPv4From and bufferedIPv4Packet are owned by // ReceiveIPv4, and used when both a DERP and IPv4 packet arrive // at the same time. It stores the IPv4 packet for use in the next call. - bufferedIPv4From *net.UDPAddr // if non-nil, then bufferedIPv4Packet is valid - bufferedIPv4Packet []byte // the received packet (reused, owned by ReceiveIPv4) + bufferedIPv4From netaddr.IPPort // if non-zero, then bufferedIPv4Packet is valid + bufferedIPv4Packet []byte // the received packet (reused, owned by ReceiveIPv4) connCtx context.Context // closed on Conn.Close connCtxCancel func() // closes connCtx @@ -95,6 +96,7 @@ type Conn struct { nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey + endpointOfAddr map[netaddr.IPPort]*discoEndpoint // validated non-DERP paths only endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key @@ -253,6 +255,7 @@ func newConn() *Conn { peerLastDerp: make(map[key.Public]int), endpointOfDisco: make(map[tailcfg.DiscoKey]*discoEndpoint), sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte), + endpointOfAddr: make(map[netaddr.IPPort]*discoEndpoint), } c.endpointsUpdateWaiter = sync.NewCond(&c.mu) return c @@ -1150,28 +1153,27 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan // findEndpoint maps from a UDP address to a WireGuard endpoint, for // ReceiveIPv4/ReceiveIPv6. -func (c *Conn) findEndpoint(addr *net.UDPAddr) conn.Endpoint { - if as := c.findAddrSet(addr); as != nil { - return as - } - // The peer that sent this packet has roamed beyond the - // knowledge provided by the control server. - // If the packet is valid wireguard will call UpdateDst - // on the original endpoint using this addr. - return (*singleEndpoint)(addr) -} +// The provided addr and ipp must match. +func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr) conn.Endpoint { + c.mu.Lock() + defer c.mu.Unlock() -func (c *Conn) findAddrSet(addr *net.UDPAddr) *AddrSet { - ip, ok := netaddr.FromStdIP(addr.IP) - if !ok { - return nil + // See if they have a discoEndpoint, for a set of peers + // both supporting active discovery. + if de, ok := c.endpointOfAddr[ipp]; ok { + return de } - ipp := netaddr.IPPort{ip, uint16(addr.Port)} - c.mu.Lock() - defer c.mu.Unlock() + // Pre-disco: look up their AddrSet. + if as, ok := c.addrsByUDP[ipp]; ok { + return as + } - return c.addrsByUDP[ipp] + // Pre-disco: the peer that sent this packet has roamed beyond + // the knowledge provided by the control server. If the + // packet is valid wireguard will call UpdateDst on the + // original endpoint using this addr. + return (*singleEndpoint)(addr) } type udpReadResult struct { @@ -1179,6 +1181,7 @@ type udpReadResult struct { n int err error addr *net.UDPAddr + ipp netaddr.IPPort } // aLongTimeAgo is a non-zero time, far in the past, used for @@ -1198,7 +1201,7 @@ func (c *Conn) awaitUDP4(b []byte) { return } addr := pAddr.(*net.UDPAddr) - ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone) + ipp, ok := c.pconn4.ippCache.IPPort(addr) if !ok { continue } @@ -1211,7 +1214,7 @@ func (c *Conn) awaitUDP4(b []byte) { } select { - case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: + case c.udpRecvCh <- udpReadResult{n: n, addr: addr, ipp: ipp}: case <-c.donec(): } return @@ -1232,9 +1235,9 @@ func wgRecvAddr(e conn.Endpoint, addr *net.UDPAddr) *net.UDPAddr { func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) { Top: // First, process any buffered packet from earlier. - if addr := c.bufferedIPv4From; addr != nil { - c.bufferedIPv4From = nil - ep := c.findEndpoint(addr) + if from := c.bufferedIPv4From; from != (netaddr.IPPort{}) { + c.bufferedIPv4From = netaddr.IPPort{} + ep := c.findEndpoint(from, from.UDPAddr()) return copy(b, c.bufferedIPv4Packet), ep, wgRecvAddr(ep, addr), nil } @@ -1246,6 +1249,7 @@ Top: var addrSet *AddrSet var discoEp *discoEndpoint + var ipp netaddr.IPPort select { case dm := <-c.derpRecvCh: @@ -1263,7 +1267,7 @@ Top: // but DERP sent first. So now we have both ready. // Save the UDP packet away for use by the next // ReceiveIPv4 call. - c.bufferedIPv4From = um.addr + c.bufferedIPv4From = um.ipp c.bufferedIPv4Packet = append(c.bufferedIPv4Packet[:0], b[:um.n]...) } c.pconn4.SetReadDeadline(time.Time{}) @@ -1279,8 +1283,8 @@ Top: return 0, nil, nil, err } - addr := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)} - if c.handleDiscoMessage(b[:n], addr) { + ipp = netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)} + if c.handleDiscoMessage(b[:n], ipp) { goto Top } @@ -1302,7 +1306,7 @@ Top: if um.err != nil { return 0, nil, nil, err } - n, addr = um.n, um.addr + n, addr, ipp = um.n, um.addr, um.ipp case <-c.donec(): // Socket has been shut down. All the producers of packets @@ -1323,7 +1327,7 @@ Top: } else if discoEp != nil { ep = discoEp } else { - ep = c.findEndpoint(addr) + ep = c.findEndpoint(ipp, addr) } return n, ep, wgRecvAddr(ep, addr), nil } @@ -1338,7 +1342,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) { return 0, nil, nil, err } addr := pAddr.(*net.UDPAddr) - ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone) + ipp, ok := c.pconn6.ippCache.IPPort(addr) if !ok { continue } @@ -1350,7 +1354,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) { continue } - ep := c.findEndpoint(addr) + ep := c.findEndpoint(ipp, addr) return n, ep, wgRecvAddr(ep, addr), nil } } @@ -2306,6 +2310,10 @@ func (e *singleEndpoint) Addrs() []wgcfg.Endpoint { // RebindingUDPConn is a UDP socket that can be re-bound. // Unix has no notion of re-binding a socket, so we swap it out for a new one. type RebindingUDPConn struct { + // ippCache is a cache from UDPAddr => netaddr.IPPort. It's not safe for concurrent use. + // This is used by ReceiveIPv6 and awaitUDP4 (called from ReceiveIPv4). + ippCache ippCache + mu sync.Mutex pconn *net.UDPConn } @@ -2567,3 +2575,43 @@ func (de *discoEndpoint) cleanup() { // TODO: real work later, when there's stuff to do de.c.logf("magicsock: doing cleanup for discovery key %x", de.discoKey[:]) } + +// ippCache is a cache of *net.UDPAddr => netaddr.IPPort mappings. +// +// It's not safe for concurrent use. +type ippCache struct { + c *lru.Cache +} + +// IPPort is a caching wrapper around netaddr.FromStdAddr. +// +// It is not safe for concurrent use. +func (ic *ippCache) IPPort(u *net.UDPAddr) (netaddr.IPPort, bool) { + if u == nil || len(u.IP) > 16 { + return netaddr.IPPort{}, false + } + if ic.c == nil { + ic.c = lru.New(64) // arbitrary + } + + key := ippCacheKey{ipLen: uint8(len(u.IP)), port: uint16(u.Port), zone: u.Zone} + copy(key.ip[:], u.IP[:]) + + if v, ok := ic.c.Get(key); ok { + return v.(netaddr.IPPort), true + } + ipp, ok := netaddr.FromStdAddr(u.IP, u.Port, u.Zone) + if ok { + ic.c.Add(key, ipp) + } + return ipp, ok +} + +// ippCacheKey is the cache key type used by ippCache.IPPort. +// It must be comparable, being used as a map key in the lru package. +type ippCacheKey struct { + ip [16]byte + port uint16 + ipLen uint8 // bytes in ip that are valid; rest are zero + zone string +}