net/tstun: refactor natConfig to not be per-family

This was a holdover from the older, pre-BART days and is no longer
necessary.

Updates #cleanup

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I71b892bab1898077767b9ff51cef33d59c08faf8
pull/11958/head
Andrew Dunham 7 months ago
parent 13e1355546
commit 10497acc95

@ -550,61 +550,14 @@ func findV6(addrs []netip.Prefix) netip.Addr {
// //
// The nil value is a valid configuration. // The nil value is a valid configuration.
type natConfig struct { type natConfig struct {
v4, v6 *natFamilyConfig // nativeAddr4 and nativeAddr6 are the IPv4/IPv6 Tailscale Addresses of
} // the current node.
//
func (c *natConfig) String() string { // These are implicitly used as the address to rewrite to in the DNAT
if c == nil { // path (as configured by listenAddrs, below). The IPv4 address will be
return "<nil>" // used if the inbound packet is IPv4, and the IPv6 address if the
} // inbound packet is IPv6.
nativeAddr4, nativeAddr6 netip.Addr
var b strings.Builder
b.WriteString("natConfig{")
fmt.Fprintf(&b, "v4: %v, ", c.v4)
fmt.Fprintf(&b, "v6: %v", c.v6)
b.WriteString("}")
return b.String()
}
// mapDstIP returns the destination IP to use for a packet to dst.
// If dst is not one of the listen addresses, it is returned as-is,
// otherwise the native address is returned.
func (c *natConfig) mapDstIP(oldDst netip.Addr) netip.Addr {
if c == nil {
return oldDst
}
if oldDst.Is4() {
return c.v4.mapDstIP(oldDst)
}
if oldDst.Is6() {
return c.v6.mapDstIP(oldDst)
}
return oldDst
}
// selectSrcIP returns the source IP to use for a packet to dst.
// If the packet is not from the native address, it is returned as-is.
func (c *natConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if c == nil {
return oldSrc
}
if oldSrc.Is4() {
return c.v4.selectSrcIP(oldSrc, dst)
}
if oldSrc.Is6() {
return c.v6.selectSrcIP(oldSrc, dst)
}
return oldSrc
}
// natFamilyConfig is the NAT configuration for a particular
// address family.
// It should be treated as immutable.
//
// The nil value is a valid configuration.
type natFamilyConfig struct {
// nativeAddr is the Tailscale Address of the current node.
nativeAddr netip.Addr
// listenAddrs is the set of addresses that should be // listenAddrs is the set of addresses that should be
// mapped to the native address. These are the addresses that // mapped to the native address. These are the addresses that
@ -620,13 +573,14 @@ type natFamilyConfig struct {
masqAddrCounts map[netip.Addr]int masqAddrCounts map[netip.Addr]int
} }
func (c *natFamilyConfig) String() string { func (c *natConfig) String() string {
if c == nil { if c == nil {
return "natFamilyConfig(nil)" return "natConfig(nil)"
} }
var b strings.Builder var b strings.Builder
b.WriteString("natFamilyConfig{") b.WriteString("natConfig{")
fmt.Fprintf(&b, "nativeAddr: %v, ", c.nativeAddr) fmt.Fprintf(&b, "nativeAddr4: %v, ", c.nativeAddr4)
fmt.Fprintf(&b, "nativeAddr6: %v, ", c.nativeAddr6)
fmt.Fprint(&b, "listenAddrs: [") fmt.Fprint(&b, "listenAddrs: [")
i := 0 i := 0
@ -656,23 +610,31 @@ func (c *natFamilyConfig) String() string {
// mapDstIP returns the destination IP to use for a packet to dst. // mapDstIP returns the destination IP to use for a packet to dst.
// If dst is not one of the listen addresses, it is returned as-is, // If dst is not one of the listen addresses, it is returned as-is,
// otherwise the native address is returned. // otherwise the native address is returned.
func (c *natFamilyConfig) mapDstIP(oldDst netip.Addr) netip.Addr { func (c *natConfig) mapDstIP(oldDst netip.Addr) netip.Addr {
if c == nil { if c == nil {
return oldDst return oldDst
} }
if _, ok := c.listenAddrs.GetOk(oldDst); ok { if _, ok := c.listenAddrs.GetOk(oldDst); ok {
return c.nativeAddr if oldDst.Is4() && c.nativeAddr4.IsValid() {
return c.nativeAddr4
}
if oldDst.Is6() && c.nativeAddr6.IsValid() {
return c.nativeAddr6
}
} }
return oldDst return oldDst
} }
// selectSrcIP returns the source IP to use for a packet to dst. // selectSrcIP returns the source IP to use for a packet to dst.
// If the packet is not from the native address, it is returned as-is. // If the packet is not from the native address, it is returned as-is.
func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { func (c *natConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if c == nil { if c == nil {
return oldSrc return oldSrc
} }
if oldSrc != c.nativeAddr { if oldSrc.Is4() && oldSrc != c.nativeAddr4 {
return oldSrc
}
if oldSrc.Is6() && oldSrc != c.nativeAddr6 {
return oldSrc return oldSrc
} }
eip, ok := c.dstMasqAddrs.Get(dst) eip, ok := c.dstMasqAddrs.Get(dst)
@ -682,22 +644,16 @@ func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
return eip return eip
} }
// natConfigFromWGConfig generates a natFamilyConfig from nm, // natConfigFromWGConfig generates a natConfig from nm. If NAT is not required,
// for the indicated address family. // it returns nil.
// If NAT is not required for that address family, it returns nil. func natConfigFromWGConfig(wcfg *wgcfg.Config) *natConfig {
func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFamilyConfig {
if wcfg == nil { if wcfg == nil {
return nil return nil
} }
var nativeAddr netip.Addr nativeAddr4 := findV4(wcfg.Addresses)
switch addrFam { nativeAddr6 := findV6(wcfg.Addresses)
case ipproto.Version4: if !nativeAddr4.IsValid() && !nativeAddr6.IsValid() {
nativeAddr = findV4(wcfg.Addresses)
case ipproto.Version6:
nativeAddr = findV6(wcfg.Addresses)
}
if !nativeAddr.IsValid() {
return nil return nil
} }
@ -714,10 +670,10 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
for _, p := range wcfg.Peers { for _, p := range wcfg.Peers {
isExitNode := slices.Contains(p.AllowedIPs, tsaddr.AllIPv4()) || slices.Contains(p.AllowedIPs, tsaddr.AllIPv6()) isExitNode := slices.Contains(p.AllowedIPs, tsaddr.AllIPv4()) || slices.Contains(p.AllowedIPs, tsaddr.AllIPv6())
if isExitNode { if isExitNode {
hasMasqAddrsForFamily := false || hasMasqAddr := false ||
(addrFam == ipproto.Version4 && p.V4MasqAddr != nil && p.V4MasqAddr.IsValid()) || (p.V4MasqAddr != nil && p.V4MasqAddr.IsValid()) ||
(addrFam == ipproto.Version6 && p.V6MasqAddr != nil && p.V6MasqAddr.IsValid()) (p.V6MasqAddr != nil && p.V6MasqAddr.IsValid())
if hasMasqAddrsForFamily { if hasMasqAddr {
exitNodeRequiresMasq = true exitNodeRequiresMasq = true
} }
break break
@ -725,29 +681,56 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
} }
for i := range wcfg.Peers { for i := range wcfg.Peers {
p := &wcfg.Peers[i] p := &wcfg.Peers[i]
var addrToUse netip.Addr
if addrFam == ipproto.Version4 && p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() { // Build a routing table that configures DNAT (i.e. changing
addrToUse = *p.V4MasqAddr // the V4MasqAddr/V6MasqAddr for a given peer to the current
mak.Set(&listenAddrs, addrToUse, struct{}{}) // peer's v4/v6 IP).
} else if addrFam == ipproto.Version6 && p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() { var addrToUse4, addrToUse6 netip.Addr
addrToUse = *p.V6MasqAddr if p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() {
mak.Set(&listenAddrs, addrToUse, struct{}{}) addrToUse4 = *p.V4MasqAddr
} else if exitNodeRequiresMasq { mak.Set(&listenAddrs, addrToUse4, struct{}{})
addrToUse = nativeAddr masqAddrCounts[addrToUse4]++
} else { }
if p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() {
addrToUse6 = *p.V6MasqAddr
mak.Set(&listenAddrs, addrToUse6, struct{}{})
masqAddrCounts[addrToUse6]++
}
// If the exit node requires masquerading, set the masquerade
// addresses to our native addresses.
if exitNodeRequiresMasq {
if !addrToUse4.IsValid() && nativeAddr4.IsValid() {
addrToUse4 = nativeAddr4
}
if !addrToUse6.IsValid() && nativeAddr6.IsValid() {
addrToUse6 = nativeAddr6
}
}
if !addrToUse4.IsValid() && !addrToUse6.IsValid() {
// NAT not required for this peer.
continue continue
} }
masqAddrCounts[addrToUse]++ // Build the SNAT table that maps each AllowedIP to the
// masquerade address.
for _, ip := range p.AllowedIPs { for _, ip := range p.AllowedIPs {
rt.Insert(ip, addrToUse) is4 := ip.Addr().Is4()
if is4 && addrToUse4.IsValid() {
rt.Insert(ip, addrToUse4)
}
if !is4 && addrToUse6.IsValid() {
rt.Insert(ip, addrToUse6)
}
} }
} }
if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 { if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 {
return nil return nil
} }
return &natFamilyConfig{ return &natConfig{
nativeAddr: nativeAddr, nativeAddr4: nativeAddr4,
nativeAddr6: nativeAddr6,
listenAddrs: views.MapOf(listenAddrs), listenAddrs: views.MapOf(listenAddrs),
dstMasqAddrs: &rt, dstMasqAddrs: &rt,
masqAddrCounts: masqAddrCounts, masqAddrCounts: masqAddrCounts,
@ -756,11 +739,7 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
// SetNetMap is called when a new NetworkMap is received. // SetNetMap is called when a new NetworkMap is received.
func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) {
v4, v6 := natConfigFromWGConfig(wcfg, ipproto.Version4), natConfigFromWGConfig(wcfg, ipproto.Version6) cfg := natConfigFromWGConfig(wcfg)
var cfg *natConfig
if v4 != nil || v6 != nil {
cfg = &natConfig{v4: v4, v6: v6}
}
old := t.natConfig.Swap(cfg) old := t.natConfig.Swap(cfg)
if !reflect.DeepEqual(old, cfg) { if !reflect.DeepEqual(old, cfg) {

@ -601,6 +601,8 @@ func TestFilterDiscoLoop(t *testing.T) {
} }
} }
// TODO(andrew-d): refactor this test to no longer use addrFam, after #11945
// removed it in natConfigFromWGConfig
func TestNATCfg(t *testing.T) { func TestNATCfg(t *testing.T) {
node := func(ip, masqIP netip.Addr, otherAllowedIPs ...netip.Prefix) wgcfg.Peer { node := func(ip, masqIP netip.Addr, otherAllowedIPs ...netip.Prefix) wgcfg.Peer {
p := wgcfg.Peer{ p := wgcfg.Peer{
@ -800,7 +802,7 @@ func TestNATCfg(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) { t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) {
ncfg := natConfigFromWGConfig(tc.wcfg, addrFam) ncfg := natConfigFromWGConfig(tc.wcfg)
for peer, want := range tc.snatMap { for peer, want := range tc.snatMap {
if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want { if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want {
t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want) t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want)

Loading…
Cancel
Save