diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 05756fb92..141bdb707 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -343,9 +343,9 @@ type Conn struct { // nodeOfDisco tracks the networkmap Node entity for each peer // discovery key. peerMap peerMap - // sharedDiscoKey is the precomputed nacl/box key for - // communication with the peer that has the given DiscoKey. - sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte + + // discoInfo is the state for an active DiscoKey. + discoInfo map[tailcfg.DiscoKey]*discoInfo // netInfoFunc is a callback that provides a tailcfg.NetInfo when // discovered network conditions change. @@ -506,11 +506,11 @@ func (o *Options) derpActiveFunc() func() { // of NewConn. Mostly for tests. func newConn() *Conn { c := &Conn{ - derpRecvCh: make(chan derpReadResult), - derpStarted: make(chan struct{}), - peerLastDerp: make(map[key.Public]int), - peerMap: newPeerMap(), - sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte), + derpRecvCh: make(chan derpReadResult), + derpStarted: make(chan struct{}), + peerLastDerp: make(map[key.Public]int), + peerMap: newPeerMap(), + discoInfo: make(map[tailcfg.DiscoKey]*discoInfo), } c.bind = &connBind{Conn: c, closed: true} c.muCond = sync.NewCond(&c.mu) @@ -1596,7 +1596,7 @@ func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache) c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b, ipp) return nil, false } - if c.handleDiscoMessage(b, ipp, key.Public{}) { + if c.handleDiscoMessage(b, ipp, tailcfg.NodeKey{}) { return nil, false } if !c.havePrivateKey.Get() { @@ -1659,7 +1659,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en } ipp := netaddr.IPPortFrom(derpMagicIPAddr, uint16(regionID)) - if c.handleDiscoMessage(b[:n], ipp, dm.src) { + if c.handleDiscoMessage(b[:n], ipp, tailcfg.NodeKey(dm.src)) { return 0, nil } @@ -1703,10 +1703,10 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD pkt = append(pkt, disco.Magic...) pkt = append(pkt, c.discoPublic[:]...) pkt = append(pkt, nonce[:]...) - sharedKey := c.sharedDiscoKeyLocked(dstDisco) + di := c.discoInfoLocked(dstDisco) c.mu.Unlock() - pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, sharedKey) + pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, di.sharedKey) sent, err = c.sendAddr(dst, key.Public(dstKey), pkt) if sent { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) { @@ -1736,7 +1736,7 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD // src.Port() being the region ID) and the derpNodeSrc will be the node key // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. -func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc key.Public) (isDiscoMsg bool) { +func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc tailcfg.NodeKey) (isDiscoMsg bool) { const headerLen = len(disco.Magic) + len(tailcfg.DiscoKey{}) + disco.NonceLen if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { return false @@ -1784,10 +1784,12 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke // // From here on, peerNode and de are non-nil. + di := c.discoInfoLocked(sender) + var nonce [disco.NonceLen]byte copy(nonce[:], msg[len(disco.Magic)+len(key.Public{}):]) sealedBox := msg[headerLen:] - payload, ok := box.OpenAfterPrecomputation(nil, sealedBox, &nonce, c.sharedDiscoKeyLocked(sender)) + payload, ok := box.OpenAfterPrecomputation(nil, sealedBox, &nonce, di.sharedKey) if !ok { // This might be have been intended for a previous // disco key. When we restart we get a new disco key @@ -1834,7 +1836,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke switch dm := dm.(type) { case *disco.Ping: - c.handlePingLocked(dm, ep, src, sender) + c.handlePingLocked(dm, ep, src, di, derpNodeSrc) case *disco.Pong: ep.handlePongConnLocked(dm, src) case *disco.CallMeMaybe: @@ -1860,20 +1862,22 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke return } -func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, sender tailcfg.DiscoKey) { - likelyHeartBeat := src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second - de.lastPingFrom = src - de.lastPingTime = time.Now() +// di is the discoInfo of the source of the ping. +// derpNodeSrc is non-zero if the ping arrived via DERP. +func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, di *discoInfo, derpNodeSrc tailcfg.NodeKey) { + likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second + di.lastPingFrom = src + di.lastPingTime = time.Now() if !likelyHeartBeat || debugDisco { - c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) + c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, di.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) } // Remember this route if not present. - c.setAddrToDiscoLocked(src, sender) + c.setAddrToDiscoLocked(src, di.discoKey) de.addCandidateEndpoint(src) ipDst := src - discoDest := sender + discoDest := di.discoKey go c.sendDiscoMessage(ipDst, de.publicKey, discoDest, &disco.Pong{ TxID: dm.TxID, Src: src, @@ -1935,14 +1939,21 @@ func (c *Conn) setAddrToDiscoLocked(src netaddr.IPPort, newk tailcfg.DiscoKey) { c.peerMap.setDiscoKeyForIPPort(src, newk) } -func (c *Conn) sharedDiscoKeyLocked(k tailcfg.DiscoKey) *[32]byte { - if v, ok := c.sharedDiscoKey[k]; ok { - return v +// discoInfoLocked returns the previous or new discoInfo for k. +// +// c.mu must be held. +func (c *Conn) discoInfoLocked(k tailcfg.DiscoKey) *discoInfo { + di, ok := c.discoInfo[k] + if !ok { + di = &discoInfo{ + discoKey: k, + discoShort: k.ShortString(), + sharedKey: new([32]byte), + } + box.Precompute(di.sharedKey, key.Public(k).B32(), c.discoPrivate.B32()) + c.discoInfo[k] = di } - shared := new([32]byte) - box.Precompute(shared, key.Public(k).B32(), c.discoPrivate.B32()) - c.sharedDiscoKey[k] = shared - return shared + return di } func (c *Conn) SetNetworkUp(up bool) { @@ -2191,10 +2202,10 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { }) } - // discokeys might have changed in the above. Discard unused cached keys. - for discoKey := range c.sharedDiscoKey { - if !c.peerMap.anyEndpointForDiscoKey(discoKey) { - delete(c.sharedDiscoKey, discoKey) + // discokeys might have changed in the above. Discard unused info. + for dk := range c.discoInfo { + if !c.peerMap.anyEndpointForDiscoKey(dk) { + delete(c.discoInfo, dk) } } } @@ -2999,10 +3010,6 @@ type endpoint struct { fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using wgEndpoint string // string from ParseEndpoint, holds a JSON-serialized wgcfg.Endpoints - // Owned by Conn.mu: - lastPingFrom netaddr.IPPort - lastPingTime time.Time - // mu protects all following fields. mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu @@ -3788,3 +3795,35 @@ type ippEndpointCache struct { gen int64 de *endpoint } + +// discoInfo is the info and state for the DiscoKey +// in the Conn.discoInfo map key. +// +// Note that a DiscoKey does not necessarily map to exactly one +// node. In the case of shared nodes and users switching accounts, two +// nodes in the NetMap may legitimately have the same DiscoKey. As +// such, no fields in here should be considered node-specific. +type discoInfo struct { + // discoKey is the same as the Conn.discoInfo map key, + // just so you can pass around a *discoInfo alone. + // Not modifed once initiazed. + discoKey tailcfg.DiscoKey + + // discoShort is discoKey.ShortString(). + // Not modifed once initiazed; + discoShort string + + // sharedKey is the precomputed nacl/box key for + // communication with the peer that has the DiscoKey + // used to look up this *discoInfo in Conn.discoInfo. + // Not modifed once initialized. + sharedKey *[32]byte + + // Mutable fields follow, owned by Conn.mu: + + // lastPingFrom is the src of a ping for discoKey. + lastPingFrom netaddr.IPPort + + // lastPingTime is the last time of a ping for discoKey. + lastPingTime time.Time +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index eb99c90ac..8b9df78b1 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1158,7 +1158,7 @@ func TestDiscoMessage(t *testing.T) { pkt = append(pkt, nonce[:]...) pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32()) - got := c.handleDiscoMessage(pkt, netaddr.IPPort{}, key.Public{}) + got := c.handleDiscoMessage(pkt, netaddr.IPPort{}, tailcfg.NodeKey{}) if !got { t.Error("failed to open it") }