diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 0ef8d1baa..0805e6b44 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -550,61 +550,14 @@ func findV6(addrs []netip.Prefix) netip.Addr { // // The nil value is a valid configuration. type natConfig struct { - v4, v6 *natFamilyConfig -} - -func (c *natConfig) String() string { - if c == nil { - return "" - } - - var b strings.Builder - b.WriteString("natConfig{") - fmt.Fprintf(&b, "v4: %v, ", c.v4) - fmt.Fprintf(&b, "v6: %v", c.v6) - b.WriteString("}") - return b.String() -} - -// mapDstIP returns the destination IP to use for a packet to dst. -// If dst is not one of the listen addresses, it is returned as-is, -// otherwise the native address is returned. -func (c *natConfig) mapDstIP(oldDst netip.Addr) netip.Addr { - if c == nil { - return oldDst - } - if oldDst.Is4() { - return c.v4.mapDstIP(oldDst) - } - if oldDst.Is6() { - return c.v6.mapDstIP(oldDst) - } - return oldDst -} - -// selectSrcIP returns the source IP to use for a packet to dst. -// If the packet is not from the native address, it is returned as-is. -func (c *natConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { - if c == nil { - return oldSrc - } - if oldSrc.Is4() { - return c.v4.selectSrcIP(oldSrc, dst) - } - if oldSrc.Is6() { - return c.v6.selectSrcIP(oldSrc, dst) - } - return oldSrc -} - -// natFamilyConfig is the NAT configuration for a particular -// address family. -// It should be treated as immutable. -// -// The nil value is a valid configuration. -type natFamilyConfig struct { - // nativeAddr is the Tailscale Address of the current node. - nativeAddr netip.Addr + // nativeAddr4 and nativeAddr6 are the IPv4/IPv6 Tailscale Addresses of + // the current node. + // + // These are implicitly used as the address to rewrite to in the DNAT + // path (as configured by listenAddrs, below). The IPv4 address will be + // used if the inbound packet is IPv4, and the IPv6 address if the + // inbound packet is IPv6. + nativeAddr4, nativeAddr6 netip.Addr // listenAddrs is the set of addresses that should be // mapped to the native address. These are the addresses that @@ -620,13 +573,14 @@ type natFamilyConfig struct { masqAddrCounts map[netip.Addr]int } -func (c *natFamilyConfig) String() string { +func (c *natConfig) String() string { if c == nil { - return "natFamilyConfig(nil)" + return "natConfig(nil)" } var b strings.Builder - b.WriteString("natFamilyConfig{") - fmt.Fprintf(&b, "nativeAddr: %v, ", c.nativeAddr) + b.WriteString("natConfig{") + fmt.Fprintf(&b, "nativeAddr4: %v, ", c.nativeAddr4) + fmt.Fprintf(&b, "nativeAddr6: %v, ", c.nativeAddr6) fmt.Fprint(&b, "listenAddrs: [") i := 0 @@ -656,23 +610,31 @@ func (c *natFamilyConfig) String() string { // mapDstIP returns the destination IP to use for a packet to dst. // If dst is not one of the listen addresses, it is returned as-is, // otherwise the native address is returned. -func (c *natFamilyConfig) mapDstIP(oldDst netip.Addr) netip.Addr { +func (c *natConfig) mapDstIP(oldDst netip.Addr) netip.Addr { if c == nil { return oldDst } if _, ok := c.listenAddrs.GetOk(oldDst); ok { - return c.nativeAddr + if oldDst.Is4() && c.nativeAddr4.IsValid() { + return c.nativeAddr4 + } + if oldDst.Is6() && c.nativeAddr6.IsValid() { + return c.nativeAddr6 + } } return oldDst } // selectSrcIP returns the source IP to use for a packet to dst. // If the packet is not from the native address, it is returned as-is. -func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { +func (c *natConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { if c == nil { return oldSrc } - if oldSrc != c.nativeAddr { + if oldSrc.Is4() && oldSrc != c.nativeAddr4 { + return oldSrc + } + if oldSrc.Is6() && oldSrc != c.nativeAddr6 { return oldSrc } eip, ok := c.dstMasqAddrs.Get(dst) @@ -682,22 +644,16 @@ func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { return eip } -// natConfigFromWGConfig generates a natFamilyConfig from nm, -// for the indicated address family. -// If NAT is not required for that address family, it returns nil. -func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFamilyConfig { +// natConfigFromWGConfig generates a natConfig from nm. If NAT is not required, +// it returns nil. +func natConfigFromWGConfig(wcfg *wgcfg.Config) *natConfig { if wcfg == nil { return nil } - var nativeAddr netip.Addr - switch addrFam { - case ipproto.Version4: - nativeAddr = findV4(wcfg.Addresses) - case ipproto.Version6: - nativeAddr = findV6(wcfg.Addresses) - } - if !nativeAddr.IsValid() { + nativeAddr4 := findV4(wcfg.Addresses) + nativeAddr6 := findV6(wcfg.Addresses) + if !nativeAddr4.IsValid() && !nativeAddr6.IsValid() { return nil } @@ -714,10 +670,10 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami for _, p := range wcfg.Peers { isExitNode := slices.Contains(p.AllowedIPs, tsaddr.AllIPv4()) || slices.Contains(p.AllowedIPs, tsaddr.AllIPv6()) if isExitNode { - hasMasqAddrsForFamily := false || - (addrFam == ipproto.Version4 && p.V4MasqAddr != nil && p.V4MasqAddr.IsValid()) || - (addrFam == ipproto.Version6 && p.V6MasqAddr != nil && p.V6MasqAddr.IsValid()) - if hasMasqAddrsForFamily { + hasMasqAddr := false || + (p.V4MasqAddr != nil && p.V4MasqAddr.IsValid()) || + (p.V6MasqAddr != nil && p.V6MasqAddr.IsValid()) + if hasMasqAddr { exitNodeRequiresMasq = true } break @@ -725,29 +681,56 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami } for i := range wcfg.Peers { p := &wcfg.Peers[i] - var addrToUse netip.Addr - if addrFam == ipproto.Version4 && p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() { - addrToUse = *p.V4MasqAddr - mak.Set(&listenAddrs, addrToUse, struct{}{}) - } else if addrFam == ipproto.Version6 && p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() { - addrToUse = *p.V6MasqAddr - mak.Set(&listenAddrs, addrToUse, struct{}{}) - } else if exitNodeRequiresMasq { - addrToUse = nativeAddr - } else { + + // Build a routing table that configures DNAT (i.e. changing + // the V4MasqAddr/V6MasqAddr for a given peer to the current + // peer's v4/v6 IP). + var addrToUse4, addrToUse6 netip.Addr + if p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() { + addrToUse4 = *p.V4MasqAddr + mak.Set(&listenAddrs, addrToUse4, struct{}{}) + masqAddrCounts[addrToUse4]++ + } + if p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() { + addrToUse6 = *p.V6MasqAddr + mak.Set(&listenAddrs, addrToUse6, struct{}{}) + masqAddrCounts[addrToUse6]++ + } + + // If the exit node requires masquerading, set the masquerade + // addresses to our native addresses. + if exitNodeRequiresMasq { + if !addrToUse4.IsValid() && nativeAddr4.IsValid() { + addrToUse4 = nativeAddr4 + } + if !addrToUse6.IsValid() && nativeAddr6.IsValid() { + addrToUse6 = nativeAddr6 + } + } + + if !addrToUse4.IsValid() && !addrToUse6.IsValid() { + // NAT not required for this peer. continue } - masqAddrCounts[addrToUse]++ + // Build the SNAT table that maps each AllowedIP to the + // masquerade address. for _, ip := range p.AllowedIPs { - rt.Insert(ip, addrToUse) + is4 := ip.Addr().Is4() + if is4 && addrToUse4.IsValid() { + rt.Insert(ip, addrToUse4) + } + if !is4 && addrToUse6.IsValid() { + rt.Insert(ip, addrToUse6) + } } } if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 { return nil } - return &natFamilyConfig{ - nativeAddr: nativeAddr, + return &natConfig{ + nativeAddr4: nativeAddr4, + nativeAddr6: nativeAddr6, listenAddrs: views.MapOf(listenAddrs), dstMasqAddrs: &rt, masqAddrCounts: masqAddrCounts, @@ -756,11 +739,7 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami // SetNetMap is called when a new NetworkMap is received. func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { - v4, v6 := natConfigFromWGConfig(wcfg, ipproto.Version4), natConfigFromWGConfig(wcfg, ipproto.Version6) - var cfg *natConfig - if v4 != nil || v6 != nil { - cfg = &natConfig{v4: v4, v6: v6} - } + cfg := natConfigFromWGConfig(wcfg) old := t.natConfig.Swap(cfg) if !reflect.DeepEqual(old, cfg) { diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index c959e0dc8..879033be3 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -601,6 +601,8 @@ func TestFilterDiscoLoop(t *testing.T) { } } +// TODO(andrew-d): refactor this test to no longer use addrFam, after #11945 +// removed it in natConfigFromWGConfig func TestNATCfg(t *testing.T) { node := func(ip, masqIP netip.Addr, otherAllowedIPs ...netip.Prefix) wgcfg.Peer { p := wgcfg.Peer{ @@ -800,7 +802,7 @@ func TestNATCfg(t *testing.T) { for _, tc := range tests { t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) { - ncfg := natConfigFromWGConfig(tc.wcfg, addrFam) + ncfg := natConfigFromWGConfig(tc.wcfg) for peer, want := range tc.snatMap { if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want { t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want)