diff --git a/disco/disco.go b/disco/disco.go index d91db6a0d..b0476bba1 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -26,6 +26,7 @@ import ( "net" "inet.af/netaddr" + "tailscale.com/tailcfg" ) // Magic is the 6 byte header of all discovery messages. @@ -106,12 +107,28 @@ func appendMsgHeader(b []byte, t MessageType, ver uint8, dataLen int) (all, data } type Ping struct { + // TxID is a random client-generated per-ping transaction ID. TxID [12]byte + + // NodeKey is the ping sender's wireguard public key. Old + // clients (~1.16.0 and earlier) don't send this field. It + // shouldn't be trusted by itself. But if present and the + // netmap's peer for this NodeKey's DiscoKey matches the + // sender of this disco key, they it can be. + NodeKey tailcfg.NodeKey } func (m *Ping) AppendMarshal(b []byte) []byte { - ret, d := appendMsgHeader(b, TypePing, v0, 12) - copy(d, m.TxID[:]) + dataLen := 12 + hasKey := !m.NodeKey.IsZero() + if hasKey { + dataLen += len(m.NodeKey) + } + ret, d := appendMsgHeader(b, TypePing, v0, dataLen) + n := copy(d, m.TxID[:]) + if hasKey { + copy(d[n:], m.NodeKey[:]) + } return ret } @@ -120,7 +137,10 @@ func parsePing(ver uint8, p []byte) (m *Ping, err error) { return nil, errShort } m = new(Ping) - copy(m.TxID[:], p) + p = p[copy(m.TxID[:], p):] + if len(p) >= len(m.NodeKey) { + copy(m.NodeKey[:], p) + } return m, nil } diff --git a/disco/disco_test.go b/disco/disco_test.go index 9b16e62ba..a02622d79 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -11,6 +11,7 @@ import ( "testing" "inet.af/netaddr" + "tailscale.com/tailcfg" ) func TestMarshalAndParse(t *testing.T) { @@ -26,6 +27,19 @@ func TestMarshalAndParse(t *testing.T) { }, want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", }, + { + name: "ping_with_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: tailcfg.NodeKey{ + 1: 1, + 2: 2, + 30: 30, + 31: 31, + }, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", + }, { name: "pong", m: &Pong{ diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 9d16c6e8b..db9cfe00a 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -281,7 +281,8 @@ type Conn struct { networkUp syncs.AtomicBool // havePrivateKey is whether privateKey is non-zero. - havePrivateKey syncs.AtomicBool + havePrivateKey syncs.AtomicBool + publicKeyAtomic atomic.Value // of tailcfg.NodeKey (or NodeKey zero value if !havePrivateKey) // port is the preferred port from opts.Port; 0 means auto. port syncs.AtomicUint32 @@ -2053,6 +2054,12 @@ func (c *Conn) SetPrivateKey(privateKey wgkey.Private) error { c.privateKey = newKey c.havePrivateKey.Set(!newKey.IsZero()) + if newKey.IsZero() { + c.publicKeyAtomic.Store(tailcfg.NodeKey{}) + } else { + c.publicKeyAtomic.Store(tailcfg.NodeKey(newKey.Public())) + } + if oldKey.IsZero() { c.everHadKey = true c.logf("magicsock: SetPrivateKey called (init)") @@ -3401,7 +3408,11 @@ func (de *endpoint) removeSentPingLocked(txid stun.TxID, sp sentPing) { // The caller (startPingLocked) should've already been recorded the ping in // sentPing and set up the timer. func (de *endpoint) sendDiscoPing(ep netaddr.IPPort, txid stun.TxID, logLevel discoLogLevel) { - sent, _ := de.sendDiscoMessage(ep, &disco.Ping{TxID: [12]byte(txid)}, logLevel) + selfPubKey, _ := de.c.publicKeyAtomic.Load().(tailcfg.NodeKey) + sent, _ := de.sendDiscoMessage(ep, &disco.Ping{ + TxID: [12]byte(txid), + NodeKey: selfPubKey, + }, logLevel) if !sent { de.forgetPing(txid) }