diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index d7a9fc8f9..ade2e6be7 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -33,10 +33,7 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" - "tailscale.com/types/views" "tailscale.com/util/clientmetric" - "tailscale.com/util/mak" - "tailscale.com/util/set" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" @@ -107,7 +104,7 @@ type Wrapper struct { timeNow func() time.Time // peerConfig stores the current NAT configuration. - peerConfig atomic.Pointer[peerConfig] + peerConfig atomic.Pointer[peerConfigTable] // vectorBuffer stores the oldest unconsumed packet vector from tdev. It is // allocated in wrap() and the underlying arrays should never grow. @@ -504,8 +501,7 @@ func (t *Wrapper) sendVectorOutbound(r tunVectorReadResult) { } // snat does SNAT on p if the destination address requires a different source address. -func (t *Wrapper) snat(p *packet.Parsed) { - pc := t.peerConfig.Load() +func (pc *peerConfigTable) snat(p *packet.Parsed) { oldSrc := p.Src.Addr() newSrc := pc.selectSrcIP(oldSrc, p.Dst.Addr()) if oldSrc != newSrc { @@ -514,10 +510,9 @@ func (t *Wrapper) snat(p *packet.Parsed) { } // dnat does destination NAT on p. -func (t *Wrapper) dnat(p *packet.Parsed) { - pc := t.peerConfig.Load() +func (pc *peerConfigTable) dnat(p *packet.Parsed) { oldDst := p.Dst.Addr() - newDst := pc.mapDstIP(oldDst) + newDst := pc.mapDstIP(p.Src.Addr(), oldDst) if newDst != oldDst { checksum.UpdateDstAddr(p, newDst) } @@ -545,11 +540,12 @@ func findV6(addrs []netip.Prefix) netip.Addr { return netip.Addr{} } -// peerConfig is the configuration for different peers. -// It should be treated as immutable. +// peerConfigTable contains configuration for individual peers and related +// information necessary to perform peer-specific operations. It should be +// treated as immutable. // // The nil value is a valid configuration. -type peerConfig struct { +type peerConfigTable struct { // nativeAddr4 and nativeAddr6 are the IPv4/IPv6 Tailscale Addresses of // the current node. // @@ -559,50 +555,49 @@ type peerConfig struct { // inbound packet is IPv6. nativeAddr4, nativeAddr6 netip.Addr - // listenAddrs is the set of addresses that should be - // mapped to the native address. These are the addresses that - // peers will use to connect to this node. - listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{} - - // dstMasqAddrs is the routing table used to map a given dst IP to the - // respective MasqueradeAsIP address. The MasqueradeAsIP address is the - // address that should be used as the source address for packets to dst. - dstMasqAddrs *bart.Table[netip.Addr] + // byIP contains configuration for each peer, indexed by a peer's IP + // address(es). + byIP bart.Table[*peerConfig] // masqAddrCounts is a count of peers by MasqueradeAsIP. + // TODO? for logging masqAddrCounts map[netip.Addr]int } -func (c *peerConfig) String() string { +// peerConfig is the configuration for a single peer. +type peerConfig struct { + // dstMasqAddr{4,6} are the addresses that should be used as the + // source address when masquerading packets to this peer (i.e. + // SNAT). If an address is not valid, the packet should not be + // masqueraded for that address family. + dstMasqAddr4 netip.Addr + dstMasqAddr6 netip.Addr +} + +func (c *peerConfigTable) String() string { if c == nil { - return "peerConfig(nil)" + return "peerConfigTable(nil)" } var b strings.Builder - b.WriteString("peerConfig{") + b.WriteString("peerConfigTable{") fmt.Fprintf(&b, "nativeAddr4: %v, ", c.nativeAddr4) fmt.Fprintf(&b, "nativeAddr6: %v, ", c.nativeAddr6) - fmt.Fprint(&b, "listenAddrs: [") - i := 0 - c.listenAddrs.Range(func(k netip.Addr, _ struct{}) bool { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(k.String()) - i++ - return true - }) + // TODO: figure out how to iterate/debug/print c.byIP - i = 0 - b.WriteString("], dstMasqAddrs: [") - for k, v := range c.masqAddrCounts { - if i > 0 { - b.WriteString(", ") - } - fmt.Fprintf(&b, "%v: %v peers", k, v) - i++ + b.WriteString("}") + + return b.String() +} + +func (c *peerConfig) String() string { + if c == nil { + return "peerConfig(nil)" } - b.WriteString("]}") + var b strings.Builder + b.WriteString("peerConfig{") + fmt.Fprintf(&b, "dstMasqAddr4: %v, ", c.dstMasqAddr4) + fmt.Fprintf(&b, "dstMasqAddr6: %v}", c.dstMasqAddr6) return b.String() } @@ -610,43 +605,70 @@ func (c *peerConfig) String() string { // mapDstIP returns the destination IP to use for a packet to dst. // If dst is not one of the listen addresses, it is returned as-is, // otherwise the native address is returned. -func (c *peerConfig) mapDstIP(oldDst netip.Addr) netip.Addr { - if c == nil { +func (pc *peerConfigTable) mapDstIP(src, oldDst netip.Addr) netip.Addr { + if pc == nil { return oldDst } - if _, ok := c.listenAddrs.GetOk(oldDst); ok { - if oldDst.Is4() && c.nativeAddr4.IsValid() { - return c.nativeAddr4 - } - if oldDst.Is6() && c.nativeAddr6.IsValid() { - return c.nativeAddr6 - } + + // The packet we're processing is inbound from WireGuard, received from + // a peer. The 'src' of the packet is the remote peer's IP address, + // possibly the masqueraded address (if the peer is shared/etc.). + // + // The 'dst' of the packet is the address for this local node. It could + // be a masquerade address that we told other nodes to use, or one of + // our local node's Addresses. + c, ok := pc.byIP.Get(src) + if !ok { + return oldDst + } + + if oldDst.Is4() && pc.nativeAddr4.IsValid() && c.dstMasqAddr4 == oldDst { + return pc.nativeAddr4 + } + if oldDst.Is6() && pc.nativeAddr6.IsValid() && c.dstMasqAddr6 == oldDst { + return pc.nativeAddr6 } return oldDst } // selectSrcIP returns the source IP to use for a packet to dst. // If the packet is not from the native address, it is returned as-is. -func (c *peerConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { - if c == nil { +func (pc *peerConfigTable) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { + if pc == nil { return oldSrc } - if oldSrc.Is4() && oldSrc != c.nativeAddr4 { + + // If this packet doesn't originate from this Tailscale node, don't + // SNAT it (e.g. if we're a subnet router). + if oldSrc.Is4() && oldSrc != pc.nativeAddr4 { return oldSrc } - if oldSrc.Is6() && oldSrc != c.nativeAddr6 { + if oldSrc.Is6() && oldSrc != pc.nativeAddr6 { return oldSrc } - eip, ok := c.dstMasqAddrs.Get(dst) + + // Look up the configuration for the destination + c, ok := pc.byIP.Get(dst) if !ok { return oldSrc } - return eip + + // Perform SNAT based on the address family and whether we have a valid + // addr. + if oldSrc.Is4() && c.dstMasqAddr4.IsValid() { + return c.dstMasqAddr4 + } + if oldSrc.Is6() && c.dstMasqAddr6.IsValid() { + return c.dstMasqAddr6 + } + + // No SNAT; use old src + return oldSrc } -// peerConfigFromWGConfig generates a peerConfig from nm. If NAT is not required, -// and no additional configuration is present, it returns nil. -func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig { +// peerConfigTableFromWGConfig generates a peerConfigTable from nm. If NAT is +// not required, and no additional configuration is present, it returns nil. +func peerConfigTableFromWGConfig(wcfg *wgcfg.Config) *peerConfigTable { if wcfg == nil { return nil } @@ -657,11 +679,11 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig { return nil } - var ( - rt bart.Table[netip.Addr] - masqAddrCounts = map[netip.Addr]int{} - listenAddrs set.Set[netip.Addr] - ) + ret := &peerConfigTable{ + nativeAddr4: nativeAddr4, + nativeAddr6: nativeAddr6, + masqAddrCounts: make(map[netip.Addr]int), + } // When using an exit node that requires masquerading, we need to // fill out the routing table with all peers not just the ones that @@ -679,6 +701,8 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig { break } } + + byIPSize := 0 for i := range wcfg.Peers { p := &wcfg.Peers[i] @@ -688,13 +712,11 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig { var addrToUse4, addrToUse6 netip.Addr if p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() { addrToUse4 = *p.V4MasqAddr - mak.Set(&listenAddrs, addrToUse4, struct{}{}) - masqAddrCounts[addrToUse4]++ + ret.masqAddrCounts[addrToUse4]++ } if p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() { addrToUse6 = *p.V6MasqAddr - mak.Set(&listenAddrs, addrToUse6, struct{}{}) - masqAddrCounts[addrToUse6]++ + ret.masqAddrCounts[addrToUse6]++ } // If the exit node requires masquerading, set the masquerade @@ -713,33 +735,27 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig { continue } - // Build the SNAT table that maps each AllowedIP to the - // masquerade address. + // Use the same peer configuration for each address of the peer. + pc := &peerConfig{ + dstMasqAddr4: addrToUse4, + dstMasqAddr6: addrToUse6, + } + + // Insert an entry into our routing table for each allowed IP. for _, ip := range p.AllowedIPs { - is4 := ip.Addr().Is4() - if is4 && addrToUse4.IsValid() { - rt.Insert(ip, addrToUse4) - } - if !is4 && addrToUse6.IsValid() { - rt.Insert(ip, addrToUse6) - } + ret.byIP.Insert(ip, pc) + byIPSize++ } } - if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 { + if byIPSize == 0 && len(ret.masqAddrCounts) == 0 { return nil } - return &peerConfig{ - nativeAddr4: nativeAddr4, - nativeAddr6: nativeAddr6, - listenAddrs: views.MapOf(listenAddrs), - dstMasqAddrs: &rt, - masqAddrCounts: masqAddrCounts, - } + return ret } // SetNetMap is called when a new NetworkMap is received. func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { - cfg := peerConfigFromWGConfig(wcfg) + cfg := peerConfigTableFromWGConfig(wcfg) old := t.peerConfig.Swap(cfg) if !reflect.DeepEqual(old, cfg) { @@ -752,7 +768,7 @@ var ( magicDNSIPPortv6 = netip.AddrPortFrom(tsaddr.TailscaleServiceIPv6(), 0) ) -func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed) filter.Response { +func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConfigTable) filter.Response { // Fake ICMP echo responses to MagicDNS (100.100.100.100). if p.IsEchoRequest() { switch p.Dst { @@ -856,10 +872,12 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { p := parsedPacketPool.Get().(*packet.Parsed) defer parsedPacketPool.Put(p) captHook := t.captureHook.Load() + pc := t.peerConfig.Load() for _, data := range res.data { p.Decode(data[res.dataOffset:]) - t.snat(p) + pc.snat(p) + if m := t.destIPActivity.Load(); m != nil { if fn := m[p.Dst.Addr()]; fn != nil { fn() @@ -869,7 +887,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) } if !t.disableFilter { - response := t.filterPacketOutboundToWireGuard(p) + response := t.filterPacketOutboundToWireGuard(p, pc) if response != filter.Accept { metricPacketOutDrop.Add(1) continue @@ -913,10 +931,12 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int n = copy(buf[offset:], res.data) } + pc := t.peerConfig.Load() + p := parsedPacketPool.Get().(*packet.Parsed) defer parsedPacketPool.Put(p) p.Decode(buf[offset : offset+n]) - t.snat(p) + pc.snat(p) if m := t.destIPActivity.Load(); m != nil { if fn := m[p.Dst.Addr()]; fn != nil { @@ -931,7 +951,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int return n, nil } -func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback) filter.Response { +func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable) filter.Response { if captHook != nil { captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) } @@ -977,7 +997,6 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca if filt == nil { return filter.Drop } - outcome := filt.RunIn(p, t.filterFlags) // Let peerapi through the filter; its ACLs are handled at L7, @@ -1036,11 +1055,12 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { p := parsedPacketPool.Get().(*packet.Parsed) defer parsedPacketPool.Put(p) captHook := t.captureHook.Load() + pc := t.peerConfig.Load() for _, buff := range buffs { p.Decode(buff[offset:]) - t.dnat(p) + pc.dnat(p) if !t.disableFilter { - if t.filterPacketInboundFromWireGuard(p, captHook) != filter.Accept { + if t.filterPacketInboundFromWireGuard(p, captHook, pc) != filter.Accept { metricPacketInDrop.Add(1) } else { buffs[i] = buff @@ -1096,6 +1116,8 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer) error { } pkt.DecRef() + pc := t.peerConfig.Load() + p := parsedPacketPool.Get().(*packet.Parsed) defer parsedPacketPool.Put(p) p.Decode(buf[PacketStartOffset:]) @@ -1103,7 +1125,8 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer) error { if captHook != nil { captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta) } - t.dnat(p) + + pc.dnat(p) return t.InjectInboundDirect(buf, PacketStartOffset) } diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 417082e10..5e3685c62 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -551,7 +551,7 @@ func TestPeerAPIBypass(t *testing.T) { tt.w.SetFilter(tt.filter) tt.w.disableTSMPRejected = true tt.w.logf = t.Logf - if got := tt.w.filterPacketInboundFromWireGuard(p, nil); got != tt.want { + if got := tt.w.filterPacketInboundFromWireGuard(p, nil, nil); got != tt.want { t.Errorf("got = %v; want %v", got, tt.want) } }) @@ -581,7 +581,7 @@ func TestFilterDiscoLoop(t *testing.T) { p := new(packet.Parsed) p.Decode(pkt) - got := tw.filterPacketInboundFromWireGuard(p, nil) + got := tw.filterPacketInboundFromWireGuard(p, nil, nil) if got != filter.DropSilently { t.Errorf("got %v; want DropSilently", got) } @@ -592,7 +592,7 @@ func TestFilterDiscoLoop(t *testing.T) { memLog.Reset() pp := new(packet.Parsed) pp.Decode(pkt) - got = tw.filterPacketOutboundToWireGuard(pp) + got = tw.filterPacketOutboundToWireGuard(pp, nil) if got != filter.DropSilently { t.Errorf("got %v; want DropSilently", got) } @@ -653,11 +653,17 @@ func TestPeerCfg_NAT(t *testing.T) { publicIP = netip.MustParseAddr("2001:4860:4860::8888") } + type dnatTest struct { + src netip.Addr + dst netip.Addr + want netip.Addr // new destination after DNAT + } + tests := []struct { name string wcfg *wgcfg.Config snatMap map[netip.Addr]netip.Addr // dst -> src - dnatMap map[netip.Addr]netip.Addr + dnat []dnatTest }{ { name: "no-cfg", @@ -667,10 +673,10 @@ func TestPeerCfg_NAT(t *testing.T) { peer2IP: selfNativeIP, subnetIP: selfNativeIP, }, - dnatMap: map[netip.Addr]netip.Addr{ - selfNativeIP: selfNativeIP, - selfEIP1: selfEIP1, - selfEIP2: selfEIP2, + dnat: []dnatTest{ + {selfNativeIP, selfNativeIP, selfNativeIP}, + {peer1IP, selfEIP1, selfEIP1}, + {peer2IP, selfEIP2, selfEIP2}, }, }, { @@ -679,19 +685,19 @@ func TestPeerCfg_NAT(t *testing.T) { Addresses: selfAddrs, Peers: []wgcfg.Peer{ node(peer1IP, noIP), - node(peer2IP, selfEIP1), + node(peer2IP, selfEIP2), }, }, snatMap: map[netip.Addr]netip.Addr{ peer1IP: selfNativeIP, - peer2IP: selfEIP1, + peer2IP: selfEIP2, subnetIP: selfNativeIP, }, - dnatMap: map[netip.Addr]netip.Addr{ - selfNativeIP: selfNativeIP, - selfEIP1: selfNativeIP, - selfEIP2: selfEIP2, - subnetIP: subnetIP, + dnat: []dnatTest{ + {selfNativeIP, selfNativeIP, selfNativeIP}, + {peer1IP, selfEIP1, selfEIP1}, + {peer2IP, selfEIP2, selfNativeIP}, // NATed + {peer2IP, subnetIP, subnetIP}, }, }, { @@ -708,11 +714,11 @@ func TestPeerCfg_NAT(t *testing.T) { peer2IP: selfEIP2, subnetIP: selfNativeIP, }, - dnatMap: map[netip.Addr]netip.Addr{ - selfNativeIP: selfNativeIP, - selfEIP1: selfNativeIP, - selfEIP2: selfNativeIP, - subnetIP: subnetIP, + dnat: []dnatTest{ + {selfNativeIP, selfNativeIP, selfNativeIP}, + {peer1IP, selfEIP1, selfNativeIP}, + {peer2IP, selfEIP2, selfNativeIP}, + {peer2IP, subnetIP, subnetIP}, }, }, { @@ -729,11 +735,11 @@ func TestPeerCfg_NAT(t *testing.T) { peer2IP: selfEIP2, subnetIP: selfEIP2, }, - dnatMap: map[netip.Addr]netip.Addr{ - selfNativeIP: selfNativeIP, - selfEIP1: selfNativeIP, - selfEIP2: selfNativeIP, - subnetIP: subnetIP, + dnat: []dnatTest{ + {selfNativeIP, selfNativeIP, selfNativeIP}, + {peer1IP, selfEIP1, selfNativeIP}, + {peer2IP, selfEIP2, selfNativeIP}, + {peer2IP, subnetIP, subnetIP}, }, }, { @@ -750,11 +756,11 @@ func TestPeerCfg_NAT(t *testing.T) { peer2IP: selfEIP2, publicIP: selfEIP2, }, - dnatMap: map[netip.Addr]netip.Addr{ - selfNativeIP: selfNativeIP, - selfEIP1: selfNativeIP, - selfEIP2: selfNativeIP, - subnetIP: subnetIP, + dnat: []dnatTest{ + {selfNativeIP, selfNativeIP, selfNativeIP}, + {peer1IP, selfEIP1, selfNativeIP}, + {peer2IP, selfEIP2, selfNativeIP}, + {peer2IP, subnetIP, subnetIP}, }, }, { @@ -771,11 +777,11 @@ func TestPeerCfg_NAT(t *testing.T) { peer2IP: selfNativeIP, subnetIP: selfNativeIP, }, - dnatMap: map[netip.Addr]netip.Addr{ - selfNativeIP: selfNativeIP, - selfEIP1: selfEIP1, - selfEIP2: selfEIP2, - subnetIP: subnetIP, + dnat: []dnatTest{ + {selfNativeIP, selfNativeIP, selfNativeIP}, + {peer1IP, selfEIP1, selfEIP1}, + {peer2IP, selfEIP2, selfEIP2}, + {peer2IP, subnetIP, subnetIP}, }, }, { @@ -792,25 +798,25 @@ func TestPeerCfg_NAT(t *testing.T) { peer2IP: selfEIP2, publicIP: selfEIP2, }, - dnatMap: map[netip.Addr]netip.Addr{ - selfNativeIP: selfNativeIP, - selfEIP2: selfNativeIP, - subnetIP: subnetIP, + dnat: []dnatTest{ + {selfNativeIP, selfNativeIP, selfNativeIP}, + {peer2IP, selfEIP2, selfNativeIP}, + {peer2IP, subnetIP, subnetIP}, }, }, } for _, tc := range tests { t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) { - pcfg := peerConfigFromWGConfig(tc.wcfg) + pcfg := peerConfigTableFromWGConfig(tc.wcfg) for peer, want := range tc.snatMap { if got := pcfg.selectSrcIP(selfNativeIP, peer); got != want { t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want) } } - for dstIP, want := range tc.dnatMap { - if got := pcfg.mapDstIP(dstIP); got != want { - t.Errorf("mapDstIP[%v]: got %v; want %v", dstIP, got, want) + for i, dt := range tc.dnat { + if got := pcfg.mapDstIP(dt.src, dt.dst); got != dt.want { + t.Errorf("dnat[%d]: mapDstIP[%v, %v]: got %v; want %v", i, dt.src, dt.dst, got, dt.want) } } if t.Failed() {