wgengine/magicsock: keep discoOfAddr populated, use it for findEndpoint

Update the mapping from ip:port to discokey, so when we retrieve a
packet from the network, we can find the same conn.Endpoint that we
gave to wireguard-go previously, without making it think we've
roamed. (We did, but we're not using its roaming.)

Updates #483
pull/514/head^2
Brad Fitzpatrick 4 years ago
parent 77e89c4a72
commit 275a20f817

@ -97,8 +97,7 @@ type Conn struct {
discoPublic tailcfg.DiscoKey // public of discoPrivate discoPublic tailcfg.DiscoKey // public of discoPrivate
nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node
discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey
discoOfAddr map[netaddr.IPPort]tailcfg.DiscoKey // validated non-DERP paths only
endpointOfAddr map[netaddr.IPPort]*discoEndpoint // validated non-DERP paths only
endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint
sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key
@ -257,7 +256,7 @@ func newConn() *Conn {
peerLastDerp: make(map[key.Public]int), peerLastDerp: make(map[key.Public]int),
endpointOfDisco: make(map[tailcfg.DiscoKey]*discoEndpoint), endpointOfDisco: make(map[tailcfg.DiscoKey]*discoEndpoint),
sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte), sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte),
endpointOfAddr: make(map[netaddr.IPPort]*discoEndpoint), discoOfAddr: make(map[netaddr.IPPort]tailcfg.DiscoKey),
} }
c.endpointsUpdateWaiter = sync.NewCond(&c.mu) c.endpointsUpdateWaiter = sync.NewCond(&c.mu)
return c return c
@ -1171,14 +1170,20 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan
// findEndpoint maps from a UDP address to a WireGuard endpoint, for // findEndpoint maps from a UDP address to a WireGuard endpoint, for
// ReceiveIPv4/ReceiveIPv6. // ReceiveIPv4/ReceiveIPv6.
// The provided addr and ipp must match. // The provided addr and ipp must match.
//
// TODO(bradfitz): add a fast path that returns nil here for normal
// wireguard-go transport packets; IIRC wireguard-go only uses this
// Endpoint for the relatively rare non-data packets.
func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr) conn.Endpoint { func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr) conn.Endpoint {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
// See if they have a discoEndpoint, for a set of peers // See if they have a discoEndpoint, for a set of peers
// both supporting active discovery. // both supporting active discovery.
if de, ok := c.endpointOfAddr[ipp]; ok { if dk, ok := c.discoOfAddr[ipp]; ok {
return de if ep, ok := c.endpointOfDisco[dk]; ok {
return ep
}
} }
// Pre-disco: look up their AddrSet. // Pre-disco: look up their AddrSet.
@ -1496,7 +1501,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
Src: src, Src: src,
}) })
case *disco.Pong: case *disco.Pong:
go de.handlePong(dm) de.handlePongConnLocked(dm, src)
case disco.CallMeMaybe: case disco.CallMeMaybe:
if src.IP != derpMagicIPAddr { if src.IP != derpMagicIPAddr {
// CallMeMaybe messages should only come via DERP. // CallMeMaybe messages should only come via DERP.
@ -1509,6 +1514,51 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
return true return true
} }
// cleanDiscoOfAddrLocked lazily checks a few entries in c.discoOfAddr
// and deletes them if they're stale. It has no pointers in it so we
// don't go through the effort of keeping it aggressively
// pruned. Instead, we lazily clean it whenever it grows.
//
// c.mu must be held.
//
// If the caller already has a discoEndpoint mutex held as well, it
// can be passed in as alreadyLocked so it won't be re-acquired.
func (c *Conn) cleanDiscoOfAddrLocked(alreadyLocked *discoEndpoint) {
// If it's small enough, don't worry about it.
if len(c.discoOfAddr) < 16 {
return
}
const checkEntries = 5 // per one unit of growth
// Take advantage of Go's random map iteration to check & clean
// a few entries.
n := 0
for ipp, dk := range c.discoOfAddr {
n++
if n > checkEntries {
return
}
de, ok := c.endpointOfDisco[dk]
if !ok {
// This discokey isn't even known anymore. Clean.
delete(c.discoOfAddr, ipp)
continue
}
if de != alreadyLocked {
de.mu.Lock()
}
if _, ok := de.endpointState[ipp]; !ok {
// The discoEndpoint no longer knows about that endpoint.
// It must've changed. Clean.
delete(c.discoOfAddr, ipp)
}
if de != alreadyLocked {
de.mu.Unlock()
}
}
}
func (c *Conn) sharedDiscoKeyLocked(k tailcfg.DiscoKey) *[32]byte { func (c *Conn) sharedDiscoKeyLocked(k tailcfg.DiscoKey) *[32]byte {
if v, ok := c.sharedDiscoKey[k]; ok { if v, ok := c.sharedDiscoKey[k]; ok {
return v return v
@ -2512,6 +2562,7 @@ func udpAddrDebugString(ua net.UDPAddr) string {
// discoEndpoint is a wireguard/conn.Endpoint for new-style peers that // discoEndpoint is a wireguard/conn.Endpoint for new-style peers that
// advertise a DiscoKey and participate in active discovery. // advertise a DiscoKey and participate in active discovery.
type discoEndpoint struct { type discoEndpoint struct {
// These fields are initialized once and never modified.
c *Conn c *Conn
publicKey key.Public // peer public key (for WireGuard + DERP) publicKey key.Public // peer public key (for WireGuard + DERP)
discoKey tailcfg.DiscoKey // for discovery mesages discoKey tailcfg.DiscoKey // for discovery mesages
@ -2754,10 +2805,21 @@ func (de *discoEndpoint) noteConnectivityChange() {
de.trustBestAddrUntil = time.Time{} de.trustBestAddrUntil = time.Time{}
} }
func (de *discoEndpoint) handlePong(m *disco.Pong) { // handlePongConnLocked handles a Pong message (a reply to an earlier ping).
// It should be called with the Conn.mu held.
func (de *discoEndpoint) handlePongConnLocked(m *disco.Pong, src netaddr.IPPort) {
de.mu.Lock() de.mu.Lock()
defer de.mu.Unlock() defer de.mu.Unlock()
if src.IP == derpMagicIPAddr {
// We might support pinging a node via DERP in the
// future to see if it's still there, but we don't
// yet. We shouldn't ever get here, but bail out early
// in case we do in the future. (In which case, hi!,
// you'll be modifying this code.)
return
}
sp, ok := de.sentPing[m.TxID] sp, ok := de.sentPing[m.TxID]
if !ok { if !ok {
// This is not a pong for a ping we sent. Ignore. // This is not a pong for a ping we sent. Ignore.
@ -2765,6 +2827,13 @@ func (de *discoEndpoint) handlePong(m *disco.Pong) {
} }
delete(de.sentPing, m.TxID) delete(de.sentPing, m.TxID)
if v, ok := de.c.discoOfAddr[src]; !ok || v != de.discoKey {
de.c.discoOfAddr[src] = de.discoKey
if !ok {
de.c.cleanDiscoOfAddrLocked(de)
}
}
now := time.Now() now := time.Now()
delay := now.Sub(sp.at) delay := now.Sub(sp.at)
de.c.logf("magicsock: disco: got pong reply after %v", delay) de.c.logf("magicsock: disco: got pong reply after %v", delay)

Loading…
Cancel
Save