From d86081f3535168921cd9b4ddf1dd127f3639bd00 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 15 Oct 2021 20:45:33 -0700 Subject: [PATCH] wgengine/magicsock: add new discoInfo type for DiscoKey state, move some fields As more prep for removing the false assumption that you're able to map from DiscoKey to a single peer, move the lastPingFrom and lastPingTime fields from the endpoint type to a new discoInfo type, effectively upgrading the old sharedDiscoKey map (which only held a *[32]byte nacl precomputed key as its value) to discoInfo which then includes that naclbox key. Then start plumbing it into handlePing in prep for removing the need for handlePing to take an endpoint parameter. Updates #3088 Signed-off-by: Brad Fitzpatrick --- wgengine/magicsock/magicsock.go | 113 ++++++++++++++++++--------- wgengine/magicsock/magicsock_test.go | 2 +- 2 files changed, 77 insertions(+), 38 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 05756fb92..141bdb707 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -343,9 +343,9 @@ type Conn struct { // nodeOfDisco tracks the networkmap Node entity for each peer // discovery key. peerMap peerMap - // sharedDiscoKey is the precomputed nacl/box key for - // communication with the peer that has the given DiscoKey. - sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte + + // discoInfo is the state for an active DiscoKey. + discoInfo map[tailcfg.DiscoKey]*discoInfo // netInfoFunc is a callback that provides a tailcfg.NetInfo when // discovered network conditions change. @@ -506,11 +506,11 @@ func (o *Options) derpActiveFunc() func() { // of NewConn. Mostly for tests. func newConn() *Conn { c := &Conn{ - derpRecvCh: make(chan derpReadResult), - derpStarted: make(chan struct{}), - peerLastDerp: make(map[key.Public]int), - peerMap: newPeerMap(), - sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte), + derpRecvCh: make(chan derpReadResult), + derpStarted: make(chan struct{}), + peerLastDerp: make(map[key.Public]int), + peerMap: newPeerMap(), + discoInfo: make(map[tailcfg.DiscoKey]*discoInfo), } c.bind = &connBind{Conn: c, closed: true} c.muCond = sync.NewCond(&c.mu) @@ -1596,7 +1596,7 @@ func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache) c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b, ipp) return nil, false } - if c.handleDiscoMessage(b, ipp, key.Public{}) { + if c.handleDiscoMessage(b, ipp, tailcfg.NodeKey{}) { return nil, false } if !c.havePrivateKey.Get() { @@ -1659,7 +1659,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en } ipp := netaddr.IPPortFrom(derpMagicIPAddr, uint16(regionID)) - if c.handleDiscoMessage(b[:n], ipp, dm.src) { + if c.handleDiscoMessage(b[:n], ipp, tailcfg.NodeKey(dm.src)) { return 0, nil } @@ -1703,10 +1703,10 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD pkt = append(pkt, disco.Magic...) pkt = append(pkt, c.discoPublic[:]...) pkt = append(pkt, nonce[:]...) - sharedKey := c.sharedDiscoKeyLocked(dstDisco) + di := c.discoInfoLocked(dstDisco) c.mu.Unlock() - pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, sharedKey) + pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, di.sharedKey) sent, err = c.sendAddr(dst, key.Public(dstKey), pkt) if sent { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) { @@ -1736,7 +1736,7 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD // src.Port() being the region ID) and the derpNodeSrc will be the node key // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. -func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc key.Public) (isDiscoMsg bool) { +func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc tailcfg.NodeKey) (isDiscoMsg bool) { const headerLen = len(disco.Magic) + len(tailcfg.DiscoKey{}) + disco.NonceLen if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { return false @@ -1784,10 +1784,12 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke // // From here on, peerNode and de are non-nil. + di := c.discoInfoLocked(sender) + var nonce [disco.NonceLen]byte copy(nonce[:], msg[len(disco.Magic)+len(key.Public{}):]) sealedBox := msg[headerLen:] - payload, ok := box.OpenAfterPrecomputation(nil, sealedBox, &nonce, c.sharedDiscoKeyLocked(sender)) + payload, ok := box.OpenAfterPrecomputation(nil, sealedBox, &nonce, di.sharedKey) if !ok { // This might be have been intended for a previous // disco key. When we restart we get a new disco key @@ -1834,7 +1836,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke switch dm := dm.(type) { case *disco.Ping: - c.handlePingLocked(dm, ep, src, sender) + c.handlePingLocked(dm, ep, src, di, derpNodeSrc) case *disco.Pong: ep.handlePongConnLocked(dm, src) case *disco.CallMeMaybe: @@ -1860,20 +1862,22 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke return } -func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, sender tailcfg.DiscoKey) { - likelyHeartBeat := src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second - de.lastPingFrom = src - de.lastPingTime = time.Now() +// di is the discoInfo of the source of the ping. +// derpNodeSrc is non-zero if the ping arrived via DERP. +func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, di *discoInfo, derpNodeSrc tailcfg.NodeKey) { + likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second + di.lastPingFrom = src + di.lastPingTime = time.Now() if !likelyHeartBeat || debugDisco { - c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) + c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, di.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) } // Remember this route if not present. - c.setAddrToDiscoLocked(src, sender) + c.setAddrToDiscoLocked(src, di.discoKey) de.addCandidateEndpoint(src) ipDst := src - discoDest := sender + discoDest := di.discoKey go c.sendDiscoMessage(ipDst, de.publicKey, discoDest, &disco.Pong{ TxID: dm.TxID, Src: src, @@ -1935,14 +1939,21 @@ func (c *Conn) setAddrToDiscoLocked(src netaddr.IPPort, newk tailcfg.DiscoKey) { c.peerMap.setDiscoKeyForIPPort(src, newk) } -func (c *Conn) sharedDiscoKeyLocked(k tailcfg.DiscoKey) *[32]byte { - if v, ok := c.sharedDiscoKey[k]; ok { - return v +// discoInfoLocked returns the previous or new discoInfo for k. +// +// c.mu must be held. +func (c *Conn) discoInfoLocked(k tailcfg.DiscoKey) *discoInfo { + di, ok := c.discoInfo[k] + if !ok { + di = &discoInfo{ + discoKey: k, + discoShort: k.ShortString(), + sharedKey: new([32]byte), + } + box.Precompute(di.sharedKey, key.Public(k).B32(), c.discoPrivate.B32()) + c.discoInfo[k] = di } - shared := new([32]byte) - box.Precompute(shared, key.Public(k).B32(), c.discoPrivate.B32()) - c.sharedDiscoKey[k] = shared - return shared + return di } func (c *Conn) SetNetworkUp(up bool) { @@ -2191,10 +2202,10 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { }) } - // discokeys might have changed in the above. Discard unused cached keys. - for discoKey := range c.sharedDiscoKey { - if !c.peerMap.anyEndpointForDiscoKey(discoKey) { - delete(c.sharedDiscoKey, discoKey) + // discokeys might have changed in the above. Discard unused info. + for dk := range c.discoInfo { + if !c.peerMap.anyEndpointForDiscoKey(dk) { + delete(c.discoInfo, dk) } } } @@ -2999,10 +3010,6 @@ type endpoint struct { fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using wgEndpoint string // string from ParseEndpoint, holds a JSON-serialized wgcfg.Endpoints - // Owned by Conn.mu: - lastPingFrom netaddr.IPPort - lastPingTime time.Time - // mu protects all following fields. mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu @@ -3788,3 +3795,35 @@ type ippEndpointCache struct { gen int64 de *endpoint } + +// discoInfo is the info and state for the DiscoKey +// in the Conn.discoInfo map key. +// +// Note that a DiscoKey does not necessarily map to exactly one +// node. In the case of shared nodes and users switching accounts, two +// nodes in the NetMap may legitimately have the same DiscoKey. As +// such, no fields in here should be considered node-specific. +type discoInfo struct { + // discoKey is the same as the Conn.discoInfo map key, + // just so you can pass around a *discoInfo alone. + // Not modifed once initiazed. + discoKey tailcfg.DiscoKey + + // discoShort is discoKey.ShortString(). + // Not modifed once initiazed; + discoShort string + + // sharedKey is the precomputed nacl/box key for + // communication with the peer that has the DiscoKey + // used to look up this *discoInfo in Conn.discoInfo. + // Not modifed once initialized. + sharedKey *[32]byte + + // Mutable fields follow, owned by Conn.mu: + + // lastPingFrom is the src of a ping for discoKey. + lastPingFrom netaddr.IPPort + + // lastPingTime is the last time of a ping for discoKey. + lastPingTime time.Time +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index eb99c90ac..8b9df78b1 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1158,7 +1158,7 @@ func TestDiscoMessage(t *testing.T) { pkt = append(pkt, nonce[:]...) pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32()) - got := c.handleDiscoMessage(pkt, netaddr.IPPort{}, key.Public{}) + got := c.handleDiscoMessage(pkt, netaddr.IPPort{}, tailcfg.NodeKey{}) if !got { t.Error("failed to open it") }