wgengine/magicsock: add a little LRU cache for netaddr.IPPort lookups

And while plumbing, a bit of discovery work I'll need: the
endpointOfAddr map to map from validated paths to the discoEndpoint.
Not being populated yet.

Updates #483
pull/514/head^2
Brad Fitzpatrick 4 years ago
parent 2d6e84e19e
commit 92252b0988

@ -25,6 +25,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/golang/groupcache/lru"
"github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/wgcfg" "github.com/tailscale/wireguard-go/wgcfg"
@ -66,7 +67,7 @@ type Conn struct {
// bufferedIPv4From and bufferedIPv4Packet are owned by // bufferedIPv4From and bufferedIPv4Packet are owned by
// ReceiveIPv4, and used when both a DERP and IPv4 packet arrive // ReceiveIPv4, and used when both a DERP and IPv4 packet arrive
// at the same time. It stores the IPv4 packet for use in the next call. // at the same time. It stores the IPv4 packet for use in the next call.
bufferedIPv4From *net.UDPAddr // if non-nil, then bufferedIPv4Packet is valid bufferedIPv4From netaddr.IPPort // if non-zero, then bufferedIPv4Packet is valid
bufferedIPv4Packet []byte // the received packet (reused, owned by ReceiveIPv4) bufferedIPv4Packet []byte // the received packet (reused, owned by ReceiveIPv4)
connCtx context.Context // closed on Conn.Close connCtx context.Context // closed on Conn.Close
@ -95,6 +96,7 @@ type Conn struct {
nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node
discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey
endpointOfAddr map[netaddr.IPPort]*discoEndpoint // validated non-DERP paths only
endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint
sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key
@ -253,6 +255,7 @@ func newConn() *Conn {
peerLastDerp: make(map[key.Public]int), peerLastDerp: make(map[key.Public]int),
endpointOfDisco: make(map[tailcfg.DiscoKey]*discoEndpoint), endpointOfDisco: make(map[tailcfg.DiscoKey]*discoEndpoint),
sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte), sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte),
endpointOfAddr: make(map[netaddr.IPPort]*discoEndpoint),
} }
c.endpointsUpdateWaiter = sync.NewCond(&c.mu) c.endpointsUpdateWaiter = sync.NewCond(&c.mu)
return c return c
@ -1150,28 +1153,27 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan
// findEndpoint maps from a UDP address to a WireGuard endpoint, for // findEndpoint maps from a UDP address to a WireGuard endpoint, for
// ReceiveIPv4/ReceiveIPv6. // ReceiveIPv4/ReceiveIPv6.
func (c *Conn) findEndpoint(addr *net.UDPAddr) conn.Endpoint { // The provided addr and ipp must match.
if as := c.findAddrSet(addr); as != nil { func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr) conn.Endpoint {
return as c.mu.Lock()
} defer c.mu.Unlock()
// The peer that sent this packet has roamed beyond the
// knowledge provided by the control server.
// If the packet is valid wireguard will call UpdateDst
// on the original endpoint using this addr.
return (*singleEndpoint)(addr)
}
func (c *Conn) findAddrSet(addr *net.UDPAddr) *AddrSet { // See if they have a discoEndpoint, for a set of peers
ip, ok := netaddr.FromStdIP(addr.IP) // both supporting active discovery.
if !ok { if de, ok := c.endpointOfAddr[ipp]; ok {
return nil return de
} }
ipp := netaddr.IPPort{ip, uint16(addr.Port)}
c.mu.Lock() // Pre-disco: look up their AddrSet.
defer c.mu.Unlock() if as, ok := c.addrsByUDP[ipp]; ok {
return as
}
return c.addrsByUDP[ipp] // Pre-disco: the peer that sent this packet has roamed beyond
// the knowledge provided by the control server. If the
// packet is valid wireguard will call UpdateDst on the
// original endpoint using this addr.
return (*singleEndpoint)(addr)
} }
type udpReadResult struct { type udpReadResult struct {
@ -1179,6 +1181,7 @@ type udpReadResult struct {
n int n int
err error err error
addr *net.UDPAddr addr *net.UDPAddr
ipp netaddr.IPPort
} }
// aLongTimeAgo is a non-zero time, far in the past, used for // aLongTimeAgo is a non-zero time, far in the past, used for
@ -1198,7 +1201,7 @@ func (c *Conn) awaitUDP4(b []byte) {
return return
} }
addr := pAddr.(*net.UDPAddr) addr := pAddr.(*net.UDPAddr)
ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone) ipp, ok := c.pconn4.ippCache.IPPort(addr)
if !ok { if !ok {
continue continue
} }
@ -1211,7 +1214,7 @@ func (c *Conn) awaitUDP4(b []byte) {
} }
select { select {
case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: case c.udpRecvCh <- udpReadResult{n: n, addr: addr, ipp: ipp}:
case <-c.donec(): case <-c.donec():
} }
return return
@ -1232,9 +1235,9 @@ func wgRecvAddr(e conn.Endpoint, addr *net.UDPAddr) *net.UDPAddr {
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) { func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) {
Top: Top:
// First, process any buffered packet from earlier. // First, process any buffered packet from earlier.
if addr := c.bufferedIPv4From; addr != nil { if from := c.bufferedIPv4From; from != (netaddr.IPPort{}) {
c.bufferedIPv4From = nil c.bufferedIPv4From = netaddr.IPPort{}
ep := c.findEndpoint(addr) ep := c.findEndpoint(from, from.UDPAddr())
return copy(b, c.bufferedIPv4Packet), ep, wgRecvAddr(ep, addr), nil return copy(b, c.bufferedIPv4Packet), ep, wgRecvAddr(ep, addr), nil
} }
@ -1246,6 +1249,7 @@ Top:
var addrSet *AddrSet var addrSet *AddrSet
var discoEp *discoEndpoint var discoEp *discoEndpoint
var ipp netaddr.IPPort
select { select {
case dm := <-c.derpRecvCh: case dm := <-c.derpRecvCh:
@ -1263,7 +1267,7 @@ Top:
// but DERP sent first. So now we have both ready. // but DERP sent first. So now we have both ready.
// Save the UDP packet away for use by the next // Save the UDP packet away for use by the next
// ReceiveIPv4 call. // ReceiveIPv4 call.
c.bufferedIPv4From = um.addr c.bufferedIPv4From = um.ipp
c.bufferedIPv4Packet = append(c.bufferedIPv4Packet[:0], b[:um.n]...) c.bufferedIPv4Packet = append(c.bufferedIPv4Packet[:0], b[:um.n]...)
} }
c.pconn4.SetReadDeadline(time.Time{}) c.pconn4.SetReadDeadline(time.Time{})
@ -1279,8 +1283,8 @@ Top:
return 0, nil, nil, err return 0, nil, nil, err
} }
addr := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)} ipp = netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)}
if c.handleDiscoMessage(b[:n], addr) { if c.handleDiscoMessage(b[:n], ipp) {
goto Top goto Top
} }
@ -1302,7 +1306,7 @@ Top:
if um.err != nil { if um.err != nil {
return 0, nil, nil, err return 0, nil, nil, err
} }
n, addr = um.n, um.addr n, addr, ipp = um.n, um.addr, um.ipp
case <-c.donec(): case <-c.donec():
// Socket has been shut down. All the producers of packets // Socket has been shut down. All the producers of packets
@ -1323,7 +1327,7 @@ Top:
} else if discoEp != nil { } else if discoEp != nil {
ep = discoEp ep = discoEp
} else { } else {
ep = c.findEndpoint(addr) ep = c.findEndpoint(ipp, addr)
} }
return n, ep, wgRecvAddr(ep, addr), nil return n, ep, wgRecvAddr(ep, addr), nil
} }
@ -1338,7 +1342,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) {
return 0, nil, nil, err return 0, nil, nil, err
} }
addr := pAddr.(*net.UDPAddr) addr := pAddr.(*net.UDPAddr)
ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone) ipp, ok := c.pconn6.ippCache.IPPort(addr)
if !ok { if !ok {
continue continue
} }
@ -1350,7 +1354,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) {
continue continue
} }
ep := c.findEndpoint(addr) ep := c.findEndpoint(ipp, addr)
return n, ep, wgRecvAddr(ep, addr), nil return n, ep, wgRecvAddr(ep, addr), nil
} }
} }
@ -2306,6 +2310,10 @@ func (e *singleEndpoint) Addrs() []wgcfg.Endpoint {
// RebindingUDPConn is a UDP socket that can be re-bound. // RebindingUDPConn is a UDP socket that can be re-bound.
// Unix has no notion of re-binding a socket, so we swap it out for a new one. // Unix has no notion of re-binding a socket, so we swap it out for a new one.
type RebindingUDPConn struct { type RebindingUDPConn struct {
// ippCache is a cache from UDPAddr => netaddr.IPPort. It's not safe for concurrent use.
// This is used by ReceiveIPv6 and awaitUDP4 (called from ReceiveIPv4).
ippCache ippCache
mu sync.Mutex mu sync.Mutex
pconn *net.UDPConn pconn *net.UDPConn
} }
@ -2567,3 +2575,43 @@ func (de *discoEndpoint) cleanup() {
// TODO: real work later, when there's stuff to do // TODO: real work later, when there's stuff to do
de.c.logf("magicsock: doing cleanup for discovery key %x", de.discoKey[:]) de.c.logf("magicsock: doing cleanup for discovery key %x", de.discoKey[:])
} }
// ippCache is a cache of *net.UDPAddr => netaddr.IPPort mappings.
//
// It's not safe for concurrent use.
type ippCache struct {
c *lru.Cache
}
// IPPort is a caching wrapper around netaddr.FromStdAddr.
//
// It is not safe for concurrent use.
func (ic *ippCache) IPPort(u *net.UDPAddr) (netaddr.IPPort, bool) {
if u == nil || len(u.IP) > 16 {
return netaddr.IPPort{}, false
}
if ic.c == nil {
ic.c = lru.New(64) // arbitrary
}
key := ippCacheKey{ipLen: uint8(len(u.IP)), port: uint16(u.Port), zone: u.Zone}
copy(key.ip[:], u.IP[:])
if v, ok := ic.c.Get(key); ok {
return v.(netaddr.IPPort), true
}
ipp, ok := netaddr.FromStdAddr(u.IP, u.Port, u.Zone)
if ok {
ic.c.Add(key, ipp)
}
return ipp, ok
}
// ippCacheKey is the cache key type used by ippCache.IPPort.
// It must be comparable, being used as a map key in the lru package.
type ippCacheKey struct {
ip [16]byte
port uint16
ipLen uint8 // bytes in ip that are valid; rest are zero
zone string
}

Loading…
Cancel
Save