// Copyright (c) 2019 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package magicsock import ( "bytes" "crypto/hmac" "crypto/subtle" "encoding/binary" "errors" "fmt" "hash" "net" "strings" "sync" "time" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/tai64n" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" "inet.af/netaddr" "tailscale.com/ipn/ipnstate" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/wgkey" "tailscale.com/wgengine/wgcfg" ) var ( errNoDestinations = errors.New("magicsock: no destinations") errDisabled = errors.New("magicsock: legacy networking disabled") ) func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.Endpoint, error) { if c.disableLegacy { return nil, errDisabled } a := &addrSet{ Logf: c.logf, publicKey: pk, curAddr: -1, } if addrs != "" { for _, ep := range strings.Split(addrs, ",") { ipp, err := netaddr.ParseIPPort(ep) if err != nil { return nil, fmt.Errorf("bogus address %q", ep) } a.ipPorts = append(a.ipPorts, ipp) } } // If this endpoint is being updated, remember its old set of // endpoints so we can remove any (from c.addrsByUDP) that are // not in the new set. var oldIPP []netaddr.IPPort if preva, ok := c.addrsByKey[pk]; ok { oldIPP = preva.ipPorts } c.addrsByKey[pk] = a // Add entries to c.addrsByUDP. for _, ipp := range a.ipPorts { if ipp.IP == derpMagicIPAddr { continue } c.addrsByUDP[ipp] = a } // Remove previous c.addrsByUDP entries that are no longer in the new set. for _, ipp := range oldIPP { if ipp.IP != derpMagicIPAddr && c.addrsByUDP[ipp] != a { delete(c.addrsByUDP, ipp) } } return a, nil } func (c *Conn) findLegacyEndpointLocked(ipp netaddr.IPPort, packet []byte) conn.Endpoint { if c.disableLegacy { return nil } // Pre-disco: look up their addrSet. if as, ok := c.addrsByUDP[ipp]; ok { as.updateDst(ipp) return as } // We don't know who this peer is. It's possible that it's one of // our legitimate peers and they've roamed to an address we don't // know. If this is a handshake packet, we can try to identify the // peer in question. if as := c.peerFromPacketLocked(packet); as != nil { as.updateDst(ipp) return as } // We have no idea who this is, drop the packet. // // In the past, when this magicsock implementation was the main // one, we tried harder to find a match here: we would pass the // packet into wireguard-go with a "singleEndpoint" implementation // that wrapped the UDPAddr. Then, a patch we added to // wireguard-go would call UpdateDst on that singleEndpoint after // decrypting the packet and identifying the peer (if any), // allowing us to update the relevant addrSet. // // This was a significant out of tree patch to wireguard-go, so we // got rid of it, and instead switched to this logic you're // reading now, which makes a best effort to identify sources for // handshake packets (because they're relatively easy to turn into // a peer public key statelessly), but otherwise drops packets // that come from "roaming" addresses that aren't known to // magicsock. // // The practical consequence of this is that some complex NAT // traversal cases will now fail between a very old Tailscale // client (0.96 and earlier) and a very new Tailscale // client. However, those scenarios were likely also failing on // all-old clients, because the probabilistic NAT opening didn't // work reliably. So, in practice, this simplification means // connectivity looks like this: // // - old+old client: unchanged // - old+new client (easy network topology): unchanged // - old+new client (hard network topology): was bad, now a bit worse // - new+new client: unchanged // // This degradation is acceptable in that it continues to support // the incremental upgrade of old clients that currently work // well, which is our primary goal for the <100 clients still left // on the oldest pre-DERP versions (as of 2021-01-12). return nil } func (c *Conn) resetAddrSetStatesLocked() { for _, as := range c.addrsByKey { as.curAddr = -1 as.stopSpray = as.timeNow().Add(sprayPeriod) } } func (c *Conn) sendAddrSet(b []byte, as *addrSet) error { if c.disableLegacy { return errDisabled } var addrBuf [8]netaddr.IPPort dsts, roamAddr := as.appendDests(addrBuf[:0], b) if len(dsts) == 0 { return errNoDestinations } var success bool var ret error for _, addr := range dsts { sent, err := c.sendAddr(addr, as.publicKey, b) if sent { success = true } else if ret == nil { ret = err } if err != nil && addr != roamAddr && c.sendLogLimit.Allow() { if c.connCtx.Err() == nil { // don't log if we're closed c.logf("magicsock: Conn.Send(%v): %v", addr, err) } } } if success { return nil } return ret } // peerFromPacketLocked extracts returns the addrSet for the peer who sent // packet, if derivable. // // The derived addrSet is a hint, not a cryptographically strong // assertion. The returned value MUST NOT be used for any security // critical function. Callers MUST assume that the addrset can be // picked by a remote attacker. func (c *Conn) peerFromPacketLocked(packet []byte) *addrSet { if len(packet) < 4 { return nil } msgType := binary.LittleEndian.Uint32(packet[:4]) if msgType != messageInitiationType { // Can't get peer out of a non-handshake packet. return nil } var msg messageInitiation reader := bytes.NewReader(packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { return nil } // Process just enough of the handshake to extract the long-term // peer public key. We don't verify the handshake all the way, so // this may be a spoofed packet. The extracted peer MUST NOT be // used for any security critical function. In our case, we use it // as a hint for roaming addresses. var ( pub = c.privateKey.Public() hash [blake2s.Size]byte chainKey [blake2s.Size]byte peerPK key.Public boxKey [chacha20poly1305.KeySize]byte ) mixHash(&hash, &initialHash, pub[:]) mixHash(&hash, &hash, msg.Ephemeral[:]) mixKey(&chainKey, &initialChainKey, msg.Ephemeral[:]) ss := c.privateKey.SharedSecret(key.Public(msg.Ephemeral)) if isZero(ss[:]) { return nil } kdf2(&chainKey, &boxKey, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(boxKey[:]) _, err = aead.Open(peerPK[:0], zeroNonce[:], msg.Static[:], hash[:]) if err != nil { return nil } return c.addrsByKey[peerPK] } func shouldSprayPacket(b []byte) bool { if len(b) < 4 { return false } msgType := binary.LittleEndian.Uint32(b[:4]) switch msgType { case messageInitiationType, messageResponseType, messageCookieReplyType: // TODO: necessary? return true } return false } const sprayPeriod = 3 * time.Second // appendDests appends to dsts the destinations that b should be // written to in order to reach as. Some of the returned IPPorts may // be fake addrs representing DERP servers. // // It also returns as's current roamAddr, if any. func (as *addrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPPort, roamAddr netaddr.IPPort) { spray := shouldSprayPacket(b) // true for handshakes now := as.timeNow() as.mu.Lock() defer as.mu.Unlock() as.lastSend = now // Spray logic. // // After exchanging a handshake with a peer, we send some outbound // packets to every endpoint of that peer. These packets are spaced out // over several seconds to make sure that our peer has an opportunity to // send its own spray packet to us before we are done spraying. // // Multiple packets are necessary because we have to both establish the // NAT mappings between two peers *and use* the mappings to switch away // from DERP to a higher-priority UDP endpoint. const sprayFreq = 250 * time.Millisecond if spray { as.lastSpray = now as.stopSpray = now.Add(sprayPeriod) // Reset our favorite route on new handshakes so we // can downgrade to a worse path if our better path // goes away. (https://github.com/tailscale/tailscale/issues/92) as.curAddr = -1 } else if now.Before(as.stopSpray) { // We are in the spray window. If it has been sprayFreq since we // last sprayed a packet, spray this packet. if now.Sub(as.lastSpray) >= sprayFreq { spray = true as.lastSpray = now } } // Pick our destination address(es). switch { case spray: // This packet is being sprayed to all addresses. for i := range as.ipPorts { dsts = append(dsts, as.ipPorts[i]) } if as.roamAddr != nil { dsts = append(dsts, *as.roamAddr) } case as.roamAddr != nil: // We have a roaming address, prefer it over other addrs. // TODO(danderson): this is not correct, there's no reason // roamAddr should be special like this. dsts = append(dsts, *as.roamAddr) case as.curAddr != -1: if as.curAddr >= len(as.ipPorts) { as.Logf("[unexpected] magicsock bug: as.curAddr >= len(as.ipPorts): %d >= %d", as.curAddr, len(as.ipPorts)) break } // No roaming addr, but we've seen packets from a known peer // addr, so keep using that one. dsts = append(dsts, as.ipPorts[as.curAddr]) default: // We know nothing about how to reach this peer, and we're not // spraying. Use the first address in the array, which will // usually be a DERP address that guarantees connectivity. if len(as.ipPorts) > 0 { dsts = append(dsts, as.ipPorts[0]) } } if logPacketDests { as.Logf("spray=%v; roam=%v; dests=%v", spray, as.roamAddr, dsts) } if as.roamAddr != nil { roamAddr = *as.roamAddr } return dsts, roamAddr } // addrSet is a set of UDP addresses that implements wireguard/conn.Endpoint. // // This is the legacy endpoint for peers that don't support discovery; // it predates discoEndpoint. type addrSet struct { publicKey key.Public // peer public key used for DERP communication // ipPorts is an ordered priority list provided by wgengine, // sorted from expensive+slow+reliable at the begnining to // fast+cheap at the end. More concretely, it's typically: // // [DERP fakeip:node, Global IP:port, LAN ip:port] // // But there could be multiple or none of each. ipPorts []netaddr.IPPort // clock, if non-nil, is used in tests instead of time.Now. clock func() time.Time Logf logger.Logf // must not be nil mu sync.Mutex // guards following fields lastSend time.Time // roamAddr is non-nil if/when we receive a correctly signed // WireGuard packet from an unexpected address. If so, we // remember it and send responses there in the future, but // this should hopefully never be used (or at least used // rarely) in the case that all the components of Tailscale // are correctly learning/sharing the network map details. roamAddr *netaddr.IPPort // curAddr is an index into addrs of the highest-priority // address a valid packet has been received from so far. // If no valid packet from addrs has been received, curAddr is -1. curAddr int // stopSpray is the time after which we stop spraying packets. stopSpray time.Time // lastSpray is the last time we sprayed a packet. lastSpray time.Time // loggedLogPriMask is a bit field of that tracks whether // we've already logged about receiving a packet from a low // priority ("low-pri") address when we already have curAddr // set to a better one. This is only to suppress some // redundant logs. loggedLogPriMask uint32 } // derpID returns this addrSet's home DERP node, or 0 if none is found. func (as *addrSet) derpID() int { for _, ua := range as.ipPorts { if ua.IP == derpMagicIPAddr { return int(ua.Port) } } return 0 } func (as *addrSet) timeNow() time.Time { if as.clock != nil { return as.clock() } return time.Now() } var noAddr, _ = netaddr.FromStdAddr(net.ParseIP("127.127.127.127"), 127, "") func (a *addrSet) dst() netaddr.IPPort { a.mu.Lock() defer a.mu.Unlock() if a.roamAddr != nil { return *a.roamAddr } if len(a.ipPorts) == 0 { return noAddr } i := a.curAddr if i == -1 { i = 0 } return a.ipPorts[i] } func (a *addrSet) DstToBytes() []byte { return packIPPort(a.dst()) } func (a *addrSet) DstToString() string { var addrs []string for _, addr := range a.ipPorts { addrs = append(addrs, addr.String()) } a.mu.Lock() defer a.mu.Unlock() if a.roamAddr != nil { addrs = append(addrs, a.roamAddr.String()) } return strings.Join(addrs, ",") } func (a *addrSet) DstIP() net.IP { return a.dst().IP.IPAddr().IP // TODO: add netaddr accessor to cut an alloc here? } func (a *addrSet) SrcIP() net.IP { return nil } func (a *addrSet) SrcToString() string { return "" } func (a *addrSet) ClearSrc() {} // updateDst records receipt of a packet from new. This is used to // potentially update the transmit address used for this addrSet. func (a *addrSet) updateDst(new netaddr.IPPort) error { if new.IP == derpMagicIPAddr { // Never consider DERP addresses as a viable candidate for // either curAddr or roamAddr. It's only ever a last resort // choice, never a preferred choice. // This is a hot path for established connections. return nil } a.mu.Lock() defer a.mu.Unlock() if a.roamAddr != nil && new == *a.roamAddr { // Packet from the current roaming address, no logging. // This is a hot path for established connections. return nil } if a.roamAddr == nil && a.curAddr >= 0 && new == a.ipPorts[a.curAddr] { // Packet from current-priority address, no logging. // This is a hot path for established connections. return nil } index := -1 for i := range a.ipPorts { if new == a.ipPorts[i] { index = i break } } publicKey := wgkey.Key(a.publicKey) pk := publicKey.ShortString() old := "" if a.curAddr >= 0 { old = a.ipPorts[a.curAddr].String() } switch { case index == -1: if a.roamAddr == nil { a.Logf("magicsock: rx %s from roaming address %s, set as new priority", pk, new) } else { a.Logf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr) } a.roamAddr = &new case a.roamAddr != nil: a.Logf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr) a.roamAddr = nil a.curAddr = index a.loggedLogPriMask = 0 case a.curAddr == -1: a.Logf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.ipPorts)) a.curAddr = index a.loggedLogPriMask = 0 case index < a.curAddr: if 1 <= index && index <= 32 && (a.loggedLogPriMask&1<<(index-1)) == 0 { a.Logf("magicsock: rx %s from low-pri %s (%d), keeping current %s (%d)", pk, new, index, old, a.curAddr) a.loggedLogPriMask |= 1 << (index - 1) } default: // index > a.curAddr a.Logf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.ipPorts), old) a.curAddr = index a.loggedLogPriMask = 0 } return nil } func (a *addrSet) String() string { a.mu.Lock() defer a.mu.Unlock() buf := new(strings.Builder) buf.WriteByte('[') if a.roamAddr != nil { buf.WriteString("roam:") sbPrintAddr(buf, *a.roamAddr) } for i, addr := range a.ipPorts { if i > 0 || a.roamAddr != nil { buf.WriteString(", ") } sbPrintAddr(buf, addr) if a.curAddr == i { buf.WriteByte('*') } } buf.WriteByte(']') return buf.String() } func (as *addrSet) populatePeerStatus(ps *ipnstate.PeerStatus) { as.mu.Lock() defer as.mu.Unlock() ps.LastWrite = as.lastSend for i, ua := range as.ipPorts { if ua.IP == derpMagicIPAddr { continue } uaStr := ua.String() ps.Addrs = append(ps.Addrs, uaStr) if as.curAddr == i { ps.CurAddr = uaStr } } if as.roamAddr != nil { ps.CurAddr = ippDebugString(*as.roamAddr) } } // Message types copied from wireguard-go/device/noise-protocol.go const ( messageInitiationType = 1 messageResponseType = 2 messageCookieReplyType = 3 ) // Cryptographic constants copied from wireguard-go/device/noise-protocol.go var ( noiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" wgIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" initialChainKey [blake2s.Size]byte initialHash [blake2s.Size]byte zeroNonce [chacha20poly1305.NonceSize]byte ) func init() { initialChainKey = blake2s.Sum256([]byte(noiseConstruction)) mixHash(&initialHash, &initialChainKey, []byte(wgIdentifier)) } // messageInitiation is the same as wireguard-go's MessageInitiation, // from wireguard-go/device/noise-protocol.go. type messageInitiation struct { Type uint32 Sender uint32 Ephemeral wgcfg.Key Static [wgcfg.KeySize + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte } func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { kdf1(dst, c[:], data) } func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { hash, _ := blake2s.New256(nil) hash.Write(h[:]) hash.Write(data) hash.Sum(dst[:0]) hash.Reset() } func hmac1(sum *[blake2s.Size]byte, key, in0 []byte) { mac := hmac.New(func() hash.Hash { h, _ := blake2s.New256(nil) return h }, key) mac.Write(in0) mac.Sum(sum[:0]) } func hmac2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { mac := hmac.New(func() hash.Hash { h, _ := blake2s.New256(nil) return h }, key) mac.Write(in0) mac.Write(in1) mac.Sum(sum[:0]) } func kdf1(t0 *[blake2s.Size]byte, key, input []byte) { hmac1(t0, key, input) hmac1(t0, t0[:], []byte{0x1}) } func kdf2(t0, t1 *[blake2s.Size]byte, key, input []byte) { var prk [blake2s.Size]byte hmac1(&prk, key, input) hmac1(t0, prk[:], []byte{0x1}) hmac2(t1, prk[:], t0[:], []byte{0x2}) for i := range prk[:] { prk[i] = 0 } } func isZero(val []byte) bool { acc := 1 for _, b := range val { acc &= subtle.ConstantTimeByteEq(b, 0) } return acc == 1 }