net/tstun: refactor peerConfig to allow storing more details

This refactors the peerConfig struct to allow storing more
details about a peer and not just the masq addresses. To be
used in a follow up change.

As a side effect, this also makes the DNAT logic on the inbound
packet stricter. Previously it would only match against the packets
dst IP, not it also takes the src IP into consideration. The beahvior
is at parity with the SNAT case.

Updates tailscale/corp#19623

Co-authored-by: Andrew Dunham <andrew@du.nham.ca>
Signed-off-by: Maisem Ali <maisem@tailscale.com>
Change-Id: I5f40802bebbf0f055436eb8824e4511d0052772d
pull/12024/head
Maisem Ali 7 months ago committed by Maisem Ali
parent f3d2fd22ef
commit 5ef178fdca

@ -33,10 +33,7 @@ import (
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/views"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/wgengine/capture" "tailscale.com/wgengine/capture"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
@ -107,7 +104,7 @@ type Wrapper struct {
timeNow func() time.Time timeNow func() time.Time
// peerConfig stores the current NAT configuration. // peerConfig stores the current NAT configuration.
peerConfig atomic.Pointer[peerConfig] peerConfig atomic.Pointer[peerConfigTable]
// vectorBuffer stores the oldest unconsumed packet vector from tdev. It is // vectorBuffer stores the oldest unconsumed packet vector from tdev. It is
// allocated in wrap() and the underlying arrays should never grow. // allocated in wrap() and the underlying arrays should never grow.
@ -504,8 +501,7 @@ func (t *Wrapper) sendVectorOutbound(r tunVectorReadResult) {
} }
// snat does SNAT on p if the destination address requires a different source address. // snat does SNAT on p if the destination address requires a different source address.
func (t *Wrapper) snat(p *packet.Parsed) { func (pc *peerConfigTable) snat(p *packet.Parsed) {
pc := t.peerConfig.Load()
oldSrc := p.Src.Addr() oldSrc := p.Src.Addr()
newSrc := pc.selectSrcIP(oldSrc, p.Dst.Addr()) newSrc := pc.selectSrcIP(oldSrc, p.Dst.Addr())
if oldSrc != newSrc { if oldSrc != newSrc {
@ -514,10 +510,9 @@ func (t *Wrapper) snat(p *packet.Parsed) {
} }
// dnat does destination NAT on p. // dnat does destination NAT on p.
func (t *Wrapper) dnat(p *packet.Parsed) { func (pc *peerConfigTable) dnat(p *packet.Parsed) {
pc := t.peerConfig.Load()
oldDst := p.Dst.Addr() oldDst := p.Dst.Addr()
newDst := pc.mapDstIP(oldDst) newDst := pc.mapDstIP(p.Src.Addr(), oldDst)
if newDst != oldDst { if newDst != oldDst {
checksum.UpdateDstAddr(p, newDst) checksum.UpdateDstAddr(p, newDst)
} }
@ -545,11 +540,12 @@ func findV6(addrs []netip.Prefix) netip.Addr {
return netip.Addr{} return netip.Addr{}
} }
// peerConfig is the configuration for different peers. // peerConfigTable contains configuration for individual peers and related
// It should be treated as immutable. // information necessary to perform peer-specific operations. It should be
// treated as immutable.
// //
// The nil value is a valid configuration. // The nil value is a valid configuration.
type peerConfig struct { type peerConfigTable struct {
// nativeAddr4 and nativeAddr6 are the IPv4/IPv6 Tailscale Addresses of // nativeAddr4 and nativeAddr6 are the IPv4/IPv6 Tailscale Addresses of
// the current node. // the current node.
// //
@ -559,50 +555,49 @@ type peerConfig struct {
// inbound packet is IPv6. // inbound packet is IPv6.
nativeAddr4, nativeAddr6 netip.Addr nativeAddr4, nativeAddr6 netip.Addr
// listenAddrs is the set of addresses that should be // byIP contains configuration for each peer, indexed by a peer's IP
// mapped to the native address. These are the addresses that // address(es).
// peers will use to connect to this node. byIP bart.Table[*peerConfig]
listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{}
// dstMasqAddrs is the routing table used to map a given dst IP to the
// respective MasqueradeAsIP address. The MasqueradeAsIP address is the
// address that should be used as the source address for packets to dst.
dstMasqAddrs *bart.Table[netip.Addr]
// masqAddrCounts is a count of peers by MasqueradeAsIP. // masqAddrCounts is a count of peers by MasqueradeAsIP.
// TODO? for logging
masqAddrCounts map[netip.Addr]int masqAddrCounts map[netip.Addr]int
} }
func (c *peerConfig) String() string { // peerConfig is the configuration for a single peer.
type peerConfig struct {
// dstMasqAddr{4,6} are the addresses that should be used as the
// source address when masquerading packets to this peer (i.e.
// SNAT). If an address is not valid, the packet should not be
// masqueraded for that address family.
dstMasqAddr4 netip.Addr
dstMasqAddr6 netip.Addr
}
func (c *peerConfigTable) String() string {
if c == nil { if c == nil {
return "peerConfig(nil)" return "peerConfigTable(nil)"
} }
var b strings.Builder var b strings.Builder
b.WriteString("peerConfig{") b.WriteString("peerConfigTable{")
fmt.Fprintf(&b, "nativeAddr4: %v, ", c.nativeAddr4) fmt.Fprintf(&b, "nativeAddr4: %v, ", c.nativeAddr4)
fmt.Fprintf(&b, "nativeAddr6: %v, ", c.nativeAddr6) fmt.Fprintf(&b, "nativeAddr6: %v, ", c.nativeAddr6)
fmt.Fprint(&b, "listenAddrs: [")
i := 0 // TODO: figure out how to iterate/debug/print c.byIP
c.listenAddrs.Range(func(k netip.Addr, _ struct{}) bool {
if i > 0 {
b.WriteString(", ")
}
b.WriteString(k.String())
i++
return true
})
i = 0 b.WriteString("}")
b.WriteString("], dstMasqAddrs: [")
for k, v := range c.masqAddrCounts { return b.String()
if i > 0 { }
b.WriteString(", ")
} func (c *peerConfig) String() string {
fmt.Fprintf(&b, "%v: %v peers", k, v) if c == nil {
i++ return "peerConfig(nil)"
} }
b.WriteString("]}") var b strings.Builder
b.WriteString("peerConfig{")
fmt.Fprintf(&b, "dstMasqAddr4: %v, ", c.dstMasqAddr4)
fmt.Fprintf(&b, "dstMasqAddr6: %v}", c.dstMasqAddr6)
return b.String() return b.String()
} }
@ -610,43 +605,70 @@ func (c *peerConfig) String() string {
// mapDstIP returns the destination IP to use for a packet to dst. // mapDstIP returns the destination IP to use for a packet to dst.
// If dst is not one of the listen addresses, it is returned as-is, // If dst is not one of the listen addresses, it is returned as-is,
// otherwise the native address is returned. // otherwise the native address is returned.
func (c *peerConfig) mapDstIP(oldDst netip.Addr) netip.Addr { func (pc *peerConfigTable) mapDstIP(src, oldDst netip.Addr) netip.Addr {
if c == nil { if pc == nil {
return oldDst return oldDst
} }
if _, ok := c.listenAddrs.GetOk(oldDst); ok {
if oldDst.Is4() && c.nativeAddr4.IsValid() { // The packet we're processing is inbound from WireGuard, received from
return c.nativeAddr4 // a peer. The 'src' of the packet is the remote peer's IP address,
} // possibly the masqueraded address (if the peer is shared/etc.).
if oldDst.Is6() && c.nativeAddr6.IsValid() { //
return c.nativeAddr6 // The 'dst' of the packet is the address for this local node. It could
} // be a masquerade address that we told other nodes to use, or one of
// our local node's Addresses.
c, ok := pc.byIP.Get(src)
if !ok {
return oldDst
}
if oldDst.Is4() && pc.nativeAddr4.IsValid() && c.dstMasqAddr4 == oldDst {
return pc.nativeAddr4
}
if oldDst.Is6() && pc.nativeAddr6.IsValid() && c.dstMasqAddr6 == oldDst {
return pc.nativeAddr6
} }
return oldDst return oldDst
} }
// selectSrcIP returns the source IP to use for a packet to dst. // selectSrcIP returns the source IP to use for a packet to dst.
// If the packet is not from the native address, it is returned as-is. // If the packet is not from the native address, it is returned as-is.
func (c *peerConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { func (pc *peerConfigTable) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if c == nil { if pc == nil {
return oldSrc return oldSrc
} }
if oldSrc.Is4() && oldSrc != c.nativeAddr4 {
// If this packet doesn't originate from this Tailscale node, don't
// SNAT it (e.g. if we're a subnet router).
if oldSrc.Is4() && oldSrc != pc.nativeAddr4 {
return oldSrc return oldSrc
} }
if oldSrc.Is6() && oldSrc != c.nativeAddr6 { if oldSrc.Is6() && oldSrc != pc.nativeAddr6 {
return oldSrc return oldSrc
} }
eip, ok := c.dstMasqAddrs.Get(dst)
// Look up the configuration for the destination
c, ok := pc.byIP.Get(dst)
if !ok { if !ok {
return oldSrc return oldSrc
} }
return eip
// Perform SNAT based on the address family and whether we have a valid
// addr.
if oldSrc.Is4() && c.dstMasqAddr4.IsValid() {
return c.dstMasqAddr4
}
if oldSrc.Is6() && c.dstMasqAddr6.IsValid() {
return c.dstMasqAddr6
}
// No SNAT; use old src
return oldSrc
} }
// peerConfigFromWGConfig generates a peerConfig from nm. If NAT is not required, // peerConfigTableFromWGConfig generates a peerConfigTable from nm. If NAT is
// and no additional configuration is present, it returns nil. // not required, and no additional configuration is present, it returns nil.
func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig { func peerConfigTableFromWGConfig(wcfg *wgcfg.Config) *peerConfigTable {
if wcfg == nil { if wcfg == nil {
return nil return nil
} }
@ -657,11 +679,11 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
return nil return nil
} }
var ( ret := &peerConfigTable{
rt bart.Table[netip.Addr] nativeAddr4: nativeAddr4,
masqAddrCounts = map[netip.Addr]int{} nativeAddr6: nativeAddr6,
listenAddrs set.Set[netip.Addr] masqAddrCounts: make(map[netip.Addr]int),
) }
// When using an exit node that requires masquerading, we need to // When using an exit node that requires masquerading, we need to
// fill out the routing table with all peers not just the ones that // fill out the routing table with all peers not just the ones that
@ -679,6 +701,8 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
break break
} }
} }
byIPSize := 0
for i := range wcfg.Peers { for i := range wcfg.Peers {
p := &wcfg.Peers[i] p := &wcfg.Peers[i]
@ -688,13 +712,11 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
var addrToUse4, addrToUse6 netip.Addr var addrToUse4, addrToUse6 netip.Addr
if p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() { if p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() {
addrToUse4 = *p.V4MasqAddr addrToUse4 = *p.V4MasqAddr
mak.Set(&listenAddrs, addrToUse4, struct{}{}) ret.masqAddrCounts[addrToUse4]++
masqAddrCounts[addrToUse4]++
} }
if p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() { if p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() {
addrToUse6 = *p.V6MasqAddr addrToUse6 = *p.V6MasqAddr
mak.Set(&listenAddrs, addrToUse6, struct{}{}) ret.masqAddrCounts[addrToUse6]++
masqAddrCounts[addrToUse6]++
} }
// If the exit node requires masquerading, set the masquerade // If the exit node requires masquerading, set the masquerade
@ -713,33 +735,27 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
continue continue
} }
// Build the SNAT table that maps each AllowedIP to the // Use the same peer configuration for each address of the peer.
// masquerade address. pc := &peerConfig{
dstMasqAddr4: addrToUse4,
dstMasqAddr6: addrToUse6,
}
// Insert an entry into our routing table for each allowed IP.
for _, ip := range p.AllowedIPs { for _, ip := range p.AllowedIPs {
is4 := ip.Addr().Is4() ret.byIP.Insert(ip, pc)
if is4 && addrToUse4.IsValid() { byIPSize++
rt.Insert(ip, addrToUse4)
}
if !is4 && addrToUse6.IsValid() {
rt.Insert(ip, addrToUse6)
}
} }
} }
if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 { if byIPSize == 0 && len(ret.masqAddrCounts) == 0 {
return nil return nil
} }
return &peerConfig{ return ret
nativeAddr4: nativeAddr4,
nativeAddr6: nativeAddr6,
listenAddrs: views.MapOf(listenAddrs),
dstMasqAddrs: &rt,
masqAddrCounts: masqAddrCounts,
}
} }
// SetNetMap is called when a new NetworkMap is received. // SetNetMap is called when a new NetworkMap is received.
func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) {
cfg := peerConfigFromWGConfig(wcfg) cfg := peerConfigTableFromWGConfig(wcfg)
old := t.peerConfig.Swap(cfg) old := t.peerConfig.Swap(cfg)
if !reflect.DeepEqual(old, cfg) { if !reflect.DeepEqual(old, cfg) {
@ -752,7 +768,7 @@ var (
magicDNSIPPortv6 = netip.AddrPortFrom(tsaddr.TailscaleServiceIPv6(), 0) magicDNSIPPortv6 = netip.AddrPortFrom(tsaddr.TailscaleServiceIPv6(), 0)
) )
func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed) filter.Response { func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConfigTable) filter.Response {
// Fake ICMP echo responses to MagicDNS (100.100.100.100). // Fake ICMP echo responses to MagicDNS (100.100.100.100).
if p.IsEchoRequest() { if p.IsEchoRequest() {
switch p.Dst { switch p.Dst {
@ -856,10 +872,12 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
captHook := t.captureHook.Load() captHook := t.captureHook.Load()
pc := t.peerConfig.Load()
for _, data := range res.data { for _, data := range res.data {
p.Decode(data[res.dataOffset:]) p.Decode(data[res.dataOffset:])
t.snat(p) pc.snat(p)
if m := t.destIPActivity.Load(); m != nil { if m := t.destIPActivity.Load(); m != nil {
if fn := m[p.Dst.Addr()]; fn != nil { if fn := m[p.Dst.Addr()]; fn != nil {
fn() fn()
@ -869,7 +887,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta)
} }
if !t.disableFilter { if !t.disableFilter {
response := t.filterPacketOutboundToWireGuard(p) response := t.filterPacketOutboundToWireGuard(p, pc)
if response != filter.Accept { if response != filter.Accept {
metricPacketOutDrop.Add(1) metricPacketOutDrop.Add(1)
continue continue
@ -913,10 +931,12 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int
n = copy(buf[offset:], res.data) n = copy(buf[offset:], res.data)
} }
pc := t.peerConfig.Load()
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
p.Decode(buf[offset : offset+n]) p.Decode(buf[offset : offset+n])
t.snat(p) pc.snat(p)
if m := t.destIPActivity.Load(); m != nil { if m := t.destIPActivity.Load(); m != nil {
if fn := m[p.Dst.Addr()]; fn != nil { if fn := m[p.Dst.Addr()]; fn != nil {
@ -931,7 +951,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int
return n, nil return n, nil
} }
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback) filter.Response { func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable) filter.Response {
if captHook != nil { if captHook != nil {
captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
} }
@ -977,7 +997,6 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
if filt == nil { if filt == nil {
return filter.Drop return filter.Drop
} }
outcome := filt.RunIn(p, t.filterFlags) outcome := filt.RunIn(p, t.filterFlags)
// Let peerapi through the filter; its ACLs are handled at L7, // Let peerapi through the filter; its ACLs are handled at L7,
@ -1036,11 +1055,12 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
captHook := t.captureHook.Load() captHook := t.captureHook.Load()
pc := t.peerConfig.Load()
for _, buff := range buffs { for _, buff := range buffs {
p.Decode(buff[offset:]) p.Decode(buff[offset:])
t.dnat(p) pc.dnat(p)
if !t.disableFilter { if !t.disableFilter {
if t.filterPacketInboundFromWireGuard(p, captHook) != filter.Accept { if t.filterPacketInboundFromWireGuard(p, captHook, pc) != filter.Accept {
metricPacketInDrop.Add(1) metricPacketInDrop.Add(1)
} else { } else {
buffs[i] = buff buffs[i] = buff
@ -1096,6 +1116,8 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer) error {
} }
pkt.DecRef() pkt.DecRef()
pc := t.peerConfig.Load()
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
p.Decode(buf[PacketStartOffset:]) p.Decode(buf[PacketStartOffset:])
@ -1103,7 +1125,8 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer) error {
if captHook != nil { if captHook != nil {
captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta) captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta)
} }
t.dnat(p)
pc.dnat(p)
return t.InjectInboundDirect(buf, PacketStartOffset) return t.InjectInboundDirect(buf, PacketStartOffset)
} }

@ -551,7 +551,7 @@ func TestPeerAPIBypass(t *testing.T) {
tt.w.SetFilter(tt.filter) tt.w.SetFilter(tt.filter)
tt.w.disableTSMPRejected = true tt.w.disableTSMPRejected = true
tt.w.logf = t.Logf tt.w.logf = t.Logf
if got := tt.w.filterPacketInboundFromWireGuard(p, nil); got != tt.want { if got := tt.w.filterPacketInboundFromWireGuard(p, nil, nil); got != tt.want {
t.Errorf("got = %v; want %v", got, tt.want) t.Errorf("got = %v; want %v", got, tt.want)
} }
}) })
@ -581,7 +581,7 @@ func TestFilterDiscoLoop(t *testing.T) {
p := new(packet.Parsed) p := new(packet.Parsed)
p.Decode(pkt) p.Decode(pkt)
got := tw.filterPacketInboundFromWireGuard(p, nil) got := tw.filterPacketInboundFromWireGuard(p, nil, nil)
if got != filter.DropSilently { if got != filter.DropSilently {
t.Errorf("got %v; want DropSilently", got) t.Errorf("got %v; want DropSilently", got)
} }
@ -592,7 +592,7 @@ func TestFilterDiscoLoop(t *testing.T) {
memLog.Reset() memLog.Reset()
pp := new(packet.Parsed) pp := new(packet.Parsed)
pp.Decode(pkt) pp.Decode(pkt)
got = tw.filterPacketOutboundToWireGuard(pp) got = tw.filterPacketOutboundToWireGuard(pp, nil)
if got != filter.DropSilently { if got != filter.DropSilently {
t.Errorf("got %v; want DropSilently", got) t.Errorf("got %v; want DropSilently", got)
} }
@ -653,11 +653,17 @@ func TestPeerCfg_NAT(t *testing.T) {
publicIP = netip.MustParseAddr("2001:4860:4860::8888") publicIP = netip.MustParseAddr("2001:4860:4860::8888")
} }
type dnatTest struct {
src netip.Addr
dst netip.Addr
want netip.Addr // new destination after DNAT
}
tests := []struct { tests := []struct {
name string name string
wcfg *wgcfg.Config wcfg *wgcfg.Config
snatMap map[netip.Addr]netip.Addr // dst -> src snatMap map[netip.Addr]netip.Addr // dst -> src
dnatMap map[netip.Addr]netip.Addr dnat []dnatTest
}{ }{
{ {
name: "no-cfg", name: "no-cfg",
@ -667,10 +673,10 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfNativeIP, peer2IP: selfNativeIP,
subnetIP: selfNativeIP, subnetIP: selfNativeIP,
}, },
dnatMap: map[netip.Addr]netip.Addr{ dnat: []dnatTest{
selfNativeIP: selfNativeIP, {selfNativeIP, selfNativeIP, selfNativeIP},
selfEIP1: selfEIP1, {peer1IP, selfEIP1, selfEIP1},
selfEIP2: selfEIP2, {peer2IP, selfEIP2, selfEIP2},
}, },
}, },
{ {
@ -679,19 +685,19 @@ func TestPeerCfg_NAT(t *testing.T) {
Addresses: selfAddrs, Addresses: selfAddrs,
Peers: []wgcfg.Peer{ Peers: []wgcfg.Peer{
node(peer1IP, noIP), node(peer1IP, noIP),
node(peer2IP, selfEIP1), node(peer2IP, selfEIP2),
}, },
}, },
snatMap: map[netip.Addr]netip.Addr{ snatMap: map[netip.Addr]netip.Addr{
peer1IP: selfNativeIP, peer1IP: selfNativeIP,
peer2IP: selfEIP1, peer2IP: selfEIP2,
subnetIP: selfNativeIP, subnetIP: selfNativeIP,
}, },
dnatMap: map[netip.Addr]netip.Addr{ dnat: []dnatTest{
selfNativeIP: selfNativeIP, {selfNativeIP, selfNativeIP, selfNativeIP},
selfEIP1: selfNativeIP, {peer1IP, selfEIP1, selfEIP1},
selfEIP2: selfEIP2, {peer2IP, selfEIP2, selfNativeIP}, // NATed
subnetIP: subnetIP, {peer2IP, subnetIP, subnetIP},
}, },
}, },
{ {
@ -708,11 +714,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2, peer2IP: selfEIP2,
subnetIP: selfNativeIP, subnetIP: selfNativeIP,
}, },
dnatMap: map[netip.Addr]netip.Addr{ dnat: []dnatTest{
selfNativeIP: selfNativeIP, {selfNativeIP, selfNativeIP, selfNativeIP},
selfEIP1: selfNativeIP, {peer1IP, selfEIP1, selfNativeIP},
selfEIP2: selfNativeIP, {peer2IP, selfEIP2, selfNativeIP},
subnetIP: subnetIP, {peer2IP, subnetIP, subnetIP},
}, },
}, },
{ {
@ -729,11 +735,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2, peer2IP: selfEIP2,
subnetIP: selfEIP2, subnetIP: selfEIP2,
}, },
dnatMap: map[netip.Addr]netip.Addr{ dnat: []dnatTest{
selfNativeIP: selfNativeIP, {selfNativeIP, selfNativeIP, selfNativeIP},
selfEIP1: selfNativeIP, {peer1IP, selfEIP1, selfNativeIP},
selfEIP2: selfNativeIP, {peer2IP, selfEIP2, selfNativeIP},
subnetIP: subnetIP, {peer2IP, subnetIP, subnetIP},
}, },
}, },
{ {
@ -750,11 +756,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2, peer2IP: selfEIP2,
publicIP: selfEIP2, publicIP: selfEIP2,
}, },
dnatMap: map[netip.Addr]netip.Addr{ dnat: []dnatTest{
selfNativeIP: selfNativeIP, {selfNativeIP, selfNativeIP, selfNativeIP},
selfEIP1: selfNativeIP, {peer1IP, selfEIP1, selfNativeIP},
selfEIP2: selfNativeIP, {peer2IP, selfEIP2, selfNativeIP},
subnetIP: subnetIP, {peer2IP, subnetIP, subnetIP},
}, },
}, },
{ {
@ -771,11 +777,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfNativeIP, peer2IP: selfNativeIP,
subnetIP: selfNativeIP, subnetIP: selfNativeIP,
}, },
dnatMap: map[netip.Addr]netip.Addr{ dnat: []dnatTest{
selfNativeIP: selfNativeIP, {selfNativeIP, selfNativeIP, selfNativeIP},
selfEIP1: selfEIP1, {peer1IP, selfEIP1, selfEIP1},
selfEIP2: selfEIP2, {peer2IP, selfEIP2, selfEIP2},
subnetIP: subnetIP, {peer2IP, subnetIP, subnetIP},
}, },
}, },
{ {
@ -792,25 +798,25 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2, peer2IP: selfEIP2,
publicIP: selfEIP2, publicIP: selfEIP2,
}, },
dnatMap: map[netip.Addr]netip.Addr{ dnat: []dnatTest{
selfNativeIP: selfNativeIP, {selfNativeIP, selfNativeIP, selfNativeIP},
selfEIP2: selfNativeIP, {peer2IP, selfEIP2, selfNativeIP},
subnetIP: subnetIP, {peer2IP, subnetIP, subnetIP},
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) { t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) {
pcfg := peerConfigFromWGConfig(tc.wcfg) pcfg := peerConfigTableFromWGConfig(tc.wcfg)
for peer, want := range tc.snatMap { for peer, want := range tc.snatMap {
if got := pcfg.selectSrcIP(selfNativeIP, peer); got != want { if got := pcfg.selectSrcIP(selfNativeIP, peer); got != want {
t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want) t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want)
} }
} }
for dstIP, want := range tc.dnatMap { for i, dt := range tc.dnat {
if got := pcfg.mapDstIP(dstIP); got != want { if got := pcfg.mapDstIP(dt.src, dt.dst); got != dt.want {
t.Errorf("mapDstIP[%v]: got %v; want %v", dstIP, got, want) t.Errorf("dnat[%d]: mapDstIP[%v, %v]: got %v; want %v", i, dt.src, dt.dst, got, dt.want)
} }
} }
if t.Failed() { if t.Failed() {

Loading…
Cancel
Save