diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 141bdb707..6450bff9c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -158,6 +158,18 @@ func (m *peerMap) forEachDiscoEndpoint(f func(ep *endpoint)) { } } +// forEachEndpointWithDiscoKey invokes f on every endpoint in m +// that has the provided DiscoKey. +func (m *peerMap) forEachEndpointWithDiscoKey(dk tailcfg.DiscoKey, f func(ep *endpoint)) { + // TODO(bradfitz): once byDiscoKey is a set of endpoints, then range + // over that instead. + for _, pi := range m.byNodeKey { + if pi.ep != nil && pi.ep.discoKey == dk { + f(pi.ep) + } + } +} + // upsertDiscoEndpoint stores endpoint in the peerInfo for // ep.publicKey, and updates indexes. m must already have a // tailcfg.Node for ep.publicKey. @@ -1689,6 +1701,12 @@ const ( discoVerboseLog ) +// sendDiscoMessage sends discovery message m to dstDisco at dst. +// +// If dst is a DERP IP:port, then dstKey must be non-zero. +// +// The dstKey should only be non-zero if the dstDisco key +// unambiguously maps to exactly one peer. func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstDisco tailcfg.DiscoKey, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { c.mu.Lock() if c.closed { @@ -1710,7 +1728,11 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD sent, err = c.sendAddr(dst, key.Public(dstKey), pkt) if sent { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) { - c.logf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v", c.discoShort, dstDisco.ShortString(), dstKey.ShortString(), derpStr(dst.String()), disco.MessageSummary(m)) + node := "?" + if !dstKey.IsZero() { + node = dstKey.ShortString() + } + c.logf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m)) } } else if err == nil { // Can't send. (e.g. no IPv6 locally) @@ -1836,7 +1858,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ta switch dm := dm.(type) { case *disco.Ping: - c.handlePingLocked(dm, ep, src, di, derpNodeSrc) + c.handlePingLocked(dm, src, di, derpNodeSrc) case *disco.Pong: ep.handlePongConnLocked(dm, src) case *disco.CallMeMaybe: @@ -1864,21 +1886,57 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ta // di is the discoInfo of the source of the ping. // derpNodeSrc is non-zero if the ping arrived via DERP. -func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, di *discoInfo, derpNodeSrc tailcfg.NodeKey) { +func (c *Conn) handlePingLocked(dm *disco.Ping, src netaddr.IPPort, di *discoInfo, derpNodeSrc tailcfg.NodeKey) { likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second di.lastPingFrom = src di.lastPingTime = time.Now() - if !likelyHeartBeat || debugDisco { - c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, di.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) - } + + // If we got a ping over DERP, then derpNodeSrc is non-zero and we reply + // over DERP (in which case ipDst is also a DERP address). + // But if the ping was over UDP (ipDst is not a DERP address), then dstKey + // will be zero here, but that's fine: sendDiscoMessage only requires + // a dstKey if the dst ip:port is DERP. + dstKey := derpNodeSrc // Remember this route if not present. c.setAddrToDiscoLocked(src, di.discoKey) - de.addCandidateEndpoint(src) + var numNodes int + if !derpNodeSrc.IsZero() { + if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok { + ep.addCandidateEndpoint(src) + numNodes = 1 + } + } else { + c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) { + ep.addCandidateEndpoint(src) + numNodes++ + if numNodes == 1 && dstKey.IsZero() { + dstKey = ep.publicKey + } + }) + if numNodes > 1 { + // Zero it out if it's ambiguous, so sendDiscoMessage logging + // isn't confusing. + dstKey = tailcfg.NodeKey{} + } + } + + if numNodes == 0 { + c.logf("[unexpected] got disco ping from %v/%v for node not in peers", src, derpNodeSrc) + return + } + + if !likelyHeartBeat || debugDisco { + pingNodeSrcStr := dstKey.ShortString() + if numNodes > 1 { + pingNodeSrcStr = "[one-of-multi]" + } + c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, di.discoShort, pingNodeSrcStr, src, dm.TxID[:6]) + } ipDst := src discoDest := di.discoKey - go c.sendDiscoMessage(ipDst, de.publicKey, discoDest, &disco.Pong{ + go c.sendDiscoMessage(ipDst, dstKey, discoDest, &disco.Pong{ TxID: dm.TxID, Src: src, }, discoVerboseLog)