net/tstun: refactor peerConfig to allow storing more details

This refactors the peerConfig struct to allow storing more
details about a peer and not just the masq addresses. To be
used in a follow up change.

As a side effect, this also makes the DNAT logic on the inbound
packet stricter. Previously it would only match against the packets
dst IP, not it also takes the src IP into consideration. The beahvior
is at parity with the SNAT case.

Updates tailscale/corp#19623

Co-authored-by: Andrew Dunham <andrew@du.nham.ca>
Signed-off-by: Maisem Ali <maisem@tailscale.com>
Change-Id: I5f40802bebbf0f055436eb8824e4511d0052772d
pull/12024/head
Maisem Ali 6 months ago committed by Maisem Ali
parent f3d2fd22ef
commit 5ef178fdca

@ -33,10 +33,7 @@ import (
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/views"
"tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/wgengine/capture"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/wgcfg"
@ -107,7 +104,7 @@ type Wrapper struct {
timeNow func() time.Time
// peerConfig stores the current NAT configuration.
peerConfig atomic.Pointer[peerConfig]
peerConfig atomic.Pointer[peerConfigTable]
// vectorBuffer stores the oldest unconsumed packet vector from tdev. It is
// allocated in wrap() and the underlying arrays should never grow.
@ -504,8 +501,7 @@ func (t *Wrapper) sendVectorOutbound(r tunVectorReadResult) {
}
// snat does SNAT on p if the destination address requires a different source address.
func (t *Wrapper) snat(p *packet.Parsed) {
pc := t.peerConfig.Load()
func (pc *peerConfigTable) snat(p *packet.Parsed) {
oldSrc := p.Src.Addr()
newSrc := pc.selectSrcIP(oldSrc, p.Dst.Addr())
if oldSrc != newSrc {
@ -514,10 +510,9 @@ func (t *Wrapper) snat(p *packet.Parsed) {
}
// dnat does destination NAT on p.
func (t *Wrapper) dnat(p *packet.Parsed) {
pc := t.peerConfig.Load()
func (pc *peerConfigTable) dnat(p *packet.Parsed) {
oldDst := p.Dst.Addr()
newDst := pc.mapDstIP(oldDst)
newDst := pc.mapDstIP(p.Src.Addr(), oldDst)
if newDst != oldDst {
checksum.UpdateDstAddr(p, newDst)
}
@ -545,11 +540,12 @@ func findV6(addrs []netip.Prefix) netip.Addr {
return netip.Addr{}
}
// peerConfig is the configuration for different peers.
// It should be treated as immutable.
// peerConfigTable contains configuration for individual peers and related
// information necessary to perform peer-specific operations. It should be
// treated as immutable.
//
// The nil value is a valid configuration.
type peerConfig struct {
type peerConfigTable struct {
// nativeAddr4 and nativeAddr6 are the IPv4/IPv6 Tailscale Addresses of
// the current node.
//
@ -559,50 +555,49 @@ type peerConfig struct {
// inbound packet is IPv6.
nativeAddr4, nativeAddr6 netip.Addr
// listenAddrs is the set of addresses that should be
// mapped to the native address. These are the addresses that
// peers will use to connect to this node.
listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{}
// dstMasqAddrs is the routing table used to map a given dst IP to the
// respective MasqueradeAsIP address. The MasqueradeAsIP address is the
// address that should be used as the source address for packets to dst.
dstMasqAddrs *bart.Table[netip.Addr]
// byIP contains configuration for each peer, indexed by a peer's IP
// address(es).
byIP bart.Table[*peerConfig]
// masqAddrCounts is a count of peers by MasqueradeAsIP.
// TODO? for logging
masqAddrCounts map[netip.Addr]int
}
func (c *peerConfig) String() string {
// peerConfig is the configuration for a single peer.
type peerConfig struct {
// dstMasqAddr{4,6} are the addresses that should be used as the
// source address when masquerading packets to this peer (i.e.
// SNAT). If an address is not valid, the packet should not be
// masqueraded for that address family.
dstMasqAddr4 netip.Addr
dstMasqAddr6 netip.Addr
}
func (c *peerConfigTable) String() string {
if c == nil {
return "peerConfig(nil)"
return "peerConfigTable(nil)"
}
var b strings.Builder
b.WriteString("peerConfig{")
b.WriteString("peerConfigTable{")
fmt.Fprintf(&b, "nativeAddr4: %v, ", c.nativeAddr4)
fmt.Fprintf(&b, "nativeAddr6: %v, ", c.nativeAddr6)
fmt.Fprint(&b, "listenAddrs: [")
i := 0
c.listenAddrs.Range(func(k netip.Addr, _ struct{}) bool {
if i > 0 {
b.WriteString(", ")
}
b.WriteString(k.String())
i++
return true
})
// TODO: figure out how to iterate/debug/print c.byIP
i = 0
b.WriteString("], dstMasqAddrs: [")
for k, v := range c.masqAddrCounts {
if i > 0 {
b.WriteString(", ")
}
fmt.Fprintf(&b, "%v: %v peers", k, v)
i++
b.WriteString("}")
return b.String()
}
func (c *peerConfig) String() string {
if c == nil {
return "peerConfig(nil)"
}
b.WriteString("]}")
var b strings.Builder
b.WriteString("peerConfig{")
fmt.Fprintf(&b, "dstMasqAddr4: %v, ", c.dstMasqAddr4)
fmt.Fprintf(&b, "dstMasqAddr6: %v}", c.dstMasqAddr6)
return b.String()
}
@ -610,43 +605,70 @@ func (c *peerConfig) String() string {
// mapDstIP returns the destination IP to use for a packet to dst.
// If dst is not one of the listen addresses, it is returned as-is,
// otherwise the native address is returned.
func (c *peerConfig) mapDstIP(oldDst netip.Addr) netip.Addr {
if c == nil {
func (pc *peerConfigTable) mapDstIP(src, oldDst netip.Addr) netip.Addr {
if pc == nil {
return oldDst
}
if _, ok := c.listenAddrs.GetOk(oldDst); ok {
if oldDst.Is4() && c.nativeAddr4.IsValid() {
return c.nativeAddr4
}
if oldDst.Is6() && c.nativeAddr6.IsValid() {
return c.nativeAddr6
}
// The packet we're processing is inbound from WireGuard, received from
// a peer. The 'src' of the packet is the remote peer's IP address,
// possibly the masqueraded address (if the peer is shared/etc.).
//
// The 'dst' of the packet is the address for this local node. It could
// be a masquerade address that we told other nodes to use, or one of
// our local node's Addresses.
c, ok := pc.byIP.Get(src)
if !ok {
return oldDst
}
if oldDst.Is4() && pc.nativeAddr4.IsValid() && c.dstMasqAddr4 == oldDst {
return pc.nativeAddr4
}
if oldDst.Is6() && pc.nativeAddr6.IsValid() && c.dstMasqAddr6 == oldDst {
return pc.nativeAddr6
}
return oldDst
}
// selectSrcIP returns the source IP to use for a packet to dst.
// If the packet is not from the native address, it is returned as-is.
func (c *peerConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if c == nil {
func (pc *peerConfigTable) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if pc == nil {
return oldSrc
}
if oldSrc.Is4() && oldSrc != c.nativeAddr4 {
// If this packet doesn't originate from this Tailscale node, don't
// SNAT it (e.g. if we're a subnet router).
if oldSrc.Is4() && oldSrc != pc.nativeAddr4 {
return oldSrc
}
if oldSrc.Is6() && oldSrc != c.nativeAddr6 {
if oldSrc.Is6() && oldSrc != pc.nativeAddr6 {
return oldSrc
}
eip, ok := c.dstMasqAddrs.Get(dst)
// Look up the configuration for the destination
c, ok := pc.byIP.Get(dst)
if !ok {
return oldSrc
}
return eip
// Perform SNAT based on the address family and whether we have a valid
// addr.
if oldSrc.Is4() && c.dstMasqAddr4.IsValid() {
return c.dstMasqAddr4
}
if oldSrc.Is6() && c.dstMasqAddr6.IsValid() {
return c.dstMasqAddr6
}
// No SNAT; use old src
return oldSrc
}
// peerConfigFromWGConfig generates a peerConfig from nm. If NAT is not required,
// and no additional configuration is present, it returns nil.
func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
// peerConfigTableFromWGConfig generates a peerConfigTable from nm. If NAT is
// not required, and no additional configuration is present, it returns nil.
func peerConfigTableFromWGConfig(wcfg *wgcfg.Config) *peerConfigTable {
if wcfg == nil {
return nil
}
@ -657,11 +679,11 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
return nil
}
var (
rt bart.Table[netip.Addr]
masqAddrCounts = map[netip.Addr]int{}
listenAddrs set.Set[netip.Addr]
)
ret := &peerConfigTable{
nativeAddr4: nativeAddr4,
nativeAddr6: nativeAddr6,
masqAddrCounts: make(map[netip.Addr]int),
}
// When using an exit node that requires masquerading, we need to
// fill out the routing table with all peers not just the ones that
@ -679,6 +701,8 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
break
}
}
byIPSize := 0
for i := range wcfg.Peers {
p := &wcfg.Peers[i]
@ -688,13 +712,11 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
var addrToUse4, addrToUse6 netip.Addr
if p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() {
addrToUse4 = *p.V4MasqAddr
mak.Set(&listenAddrs, addrToUse4, struct{}{})
masqAddrCounts[addrToUse4]++
ret.masqAddrCounts[addrToUse4]++
}
if p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() {
addrToUse6 = *p.V6MasqAddr
mak.Set(&listenAddrs, addrToUse6, struct{}{})
masqAddrCounts[addrToUse6]++
ret.masqAddrCounts[addrToUse6]++
}
// If the exit node requires masquerading, set the masquerade
@ -713,33 +735,27 @@ func peerConfigFromWGConfig(wcfg *wgcfg.Config) *peerConfig {
continue
}
// Build the SNAT table that maps each AllowedIP to the
// masquerade address.
// Use the same peer configuration for each address of the peer.
pc := &peerConfig{
dstMasqAddr4: addrToUse4,
dstMasqAddr6: addrToUse6,
}
// Insert an entry into our routing table for each allowed IP.
for _, ip := range p.AllowedIPs {
is4 := ip.Addr().Is4()
if is4 && addrToUse4.IsValid() {
rt.Insert(ip, addrToUse4)
}
if !is4 && addrToUse6.IsValid() {
rt.Insert(ip, addrToUse6)
}
ret.byIP.Insert(ip, pc)
byIPSize++
}
}
if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 {
if byIPSize == 0 && len(ret.masqAddrCounts) == 0 {
return nil
}
return &peerConfig{
nativeAddr4: nativeAddr4,
nativeAddr6: nativeAddr6,
listenAddrs: views.MapOf(listenAddrs),
dstMasqAddrs: &rt,
masqAddrCounts: masqAddrCounts,
}
return ret
}
// SetNetMap is called when a new NetworkMap is received.
func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) {
cfg := peerConfigFromWGConfig(wcfg)
cfg := peerConfigTableFromWGConfig(wcfg)
old := t.peerConfig.Swap(cfg)
if !reflect.DeepEqual(old, cfg) {
@ -752,7 +768,7 @@ var (
magicDNSIPPortv6 = netip.AddrPortFrom(tsaddr.TailscaleServiceIPv6(), 0)
)
func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed) filter.Response {
func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConfigTable) filter.Response {
// Fake ICMP echo responses to MagicDNS (100.100.100.100).
if p.IsEchoRequest() {
switch p.Dst {
@ -856,10 +872,12 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p)
captHook := t.captureHook.Load()
pc := t.peerConfig.Load()
for _, data := range res.data {
p.Decode(data[res.dataOffset:])
t.snat(p)
pc.snat(p)
if m := t.destIPActivity.Load(); m != nil {
if fn := m[p.Dst.Addr()]; fn != nil {
fn()
@ -869,7 +887,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta)
}
if !t.disableFilter {
response := t.filterPacketOutboundToWireGuard(p)
response := t.filterPacketOutboundToWireGuard(p, pc)
if response != filter.Accept {
metricPacketOutDrop.Add(1)
continue
@ -913,10 +931,12 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int
n = copy(buf[offset:], res.data)
}
pc := t.peerConfig.Load()
p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p)
p.Decode(buf[offset : offset+n])
t.snat(p)
pc.snat(p)
if m := t.destIPActivity.Load(); m != nil {
if fn := m[p.Dst.Addr()]; fn != nil {
@ -931,7 +951,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int
return n, nil
}
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback) filter.Response {
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable) filter.Response {
if captHook != nil {
captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
}
@ -977,7 +997,6 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca
if filt == nil {
return filter.Drop
}
outcome := filt.RunIn(p, t.filterFlags)
// Let peerapi through the filter; its ACLs are handled at L7,
@ -1036,11 +1055,12 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p)
captHook := t.captureHook.Load()
pc := t.peerConfig.Load()
for _, buff := range buffs {
p.Decode(buff[offset:])
t.dnat(p)
pc.dnat(p)
if !t.disableFilter {
if t.filterPacketInboundFromWireGuard(p, captHook) != filter.Accept {
if t.filterPacketInboundFromWireGuard(p, captHook, pc) != filter.Accept {
metricPacketInDrop.Add(1)
} else {
buffs[i] = buff
@ -1096,6 +1116,8 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer) error {
}
pkt.DecRef()
pc := t.peerConfig.Load()
p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p)
p.Decode(buf[PacketStartOffset:])
@ -1103,7 +1125,8 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer) error {
if captHook != nil {
captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta)
}
t.dnat(p)
pc.dnat(p)
return t.InjectInboundDirect(buf, PacketStartOffset)
}

@ -551,7 +551,7 @@ func TestPeerAPIBypass(t *testing.T) {
tt.w.SetFilter(tt.filter)
tt.w.disableTSMPRejected = true
tt.w.logf = t.Logf
if got := tt.w.filterPacketInboundFromWireGuard(p, nil); got != tt.want {
if got := tt.w.filterPacketInboundFromWireGuard(p, nil, nil); got != tt.want {
t.Errorf("got = %v; want %v", got, tt.want)
}
})
@ -581,7 +581,7 @@ func TestFilterDiscoLoop(t *testing.T) {
p := new(packet.Parsed)
p.Decode(pkt)
got := tw.filterPacketInboundFromWireGuard(p, nil)
got := tw.filterPacketInboundFromWireGuard(p, nil, nil)
if got != filter.DropSilently {
t.Errorf("got %v; want DropSilently", got)
}
@ -592,7 +592,7 @@ func TestFilterDiscoLoop(t *testing.T) {
memLog.Reset()
pp := new(packet.Parsed)
pp.Decode(pkt)
got = tw.filterPacketOutboundToWireGuard(pp)
got = tw.filterPacketOutboundToWireGuard(pp, nil)
if got != filter.DropSilently {
t.Errorf("got %v; want DropSilently", got)
}
@ -653,11 +653,17 @@ func TestPeerCfg_NAT(t *testing.T) {
publicIP = netip.MustParseAddr("2001:4860:4860::8888")
}
type dnatTest struct {
src netip.Addr
dst netip.Addr
want netip.Addr // new destination after DNAT
}
tests := []struct {
name string
wcfg *wgcfg.Config
snatMap map[netip.Addr]netip.Addr // dst -> src
dnatMap map[netip.Addr]netip.Addr
dnat []dnatTest
}{
{
name: "no-cfg",
@ -667,10 +673,10 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfNativeIP,
subnetIP: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfEIP1,
selfEIP2: selfEIP2,
dnat: []dnatTest{
{selfNativeIP, selfNativeIP, selfNativeIP},
{peer1IP, selfEIP1, selfEIP1},
{peer2IP, selfEIP2, selfEIP2},
},
},
{
@ -679,19 +685,19 @@ func TestPeerCfg_NAT(t *testing.T) {
Addresses: selfAddrs,
Peers: []wgcfg.Peer{
node(peer1IP, noIP),
node(peer2IP, selfEIP1),
node(peer2IP, selfEIP2),
},
},
snatMap: map[netip.Addr]netip.Addr{
peer1IP: selfNativeIP,
peer2IP: selfEIP1,
peer2IP: selfEIP2,
subnetIP: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfNativeIP,
selfEIP2: selfEIP2,
subnetIP: subnetIP,
dnat: []dnatTest{
{selfNativeIP, selfNativeIP, selfNativeIP},
{peer1IP, selfEIP1, selfEIP1},
{peer2IP, selfEIP2, selfNativeIP}, // NATed
{peer2IP, subnetIP, subnetIP},
},
},
{
@ -708,11 +714,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2,
subnetIP: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfNativeIP,
selfEIP2: selfNativeIP,
subnetIP: subnetIP,
dnat: []dnatTest{
{selfNativeIP, selfNativeIP, selfNativeIP},
{peer1IP, selfEIP1, selfNativeIP},
{peer2IP, selfEIP2, selfNativeIP},
{peer2IP, subnetIP, subnetIP},
},
},
{
@ -729,11 +735,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2,
subnetIP: selfEIP2,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfNativeIP,
selfEIP2: selfNativeIP,
subnetIP: subnetIP,
dnat: []dnatTest{
{selfNativeIP, selfNativeIP, selfNativeIP},
{peer1IP, selfEIP1, selfNativeIP},
{peer2IP, selfEIP2, selfNativeIP},
{peer2IP, subnetIP, subnetIP},
},
},
{
@ -750,11 +756,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2,
publicIP: selfEIP2,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfNativeIP,
selfEIP2: selfNativeIP,
subnetIP: subnetIP,
dnat: []dnatTest{
{selfNativeIP, selfNativeIP, selfNativeIP},
{peer1IP, selfEIP1, selfNativeIP},
{peer2IP, selfEIP2, selfNativeIP},
{peer2IP, subnetIP, subnetIP},
},
},
{
@ -771,11 +777,11 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfNativeIP,
subnetIP: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfEIP1,
selfEIP2: selfEIP2,
subnetIP: subnetIP,
dnat: []dnatTest{
{selfNativeIP, selfNativeIP, selfNativeIP},
{peer1IP, selfEIP1, selfEIP1},
{peer2IP, selfEIP2, selfEIP2},
{peer2IP, subnetIP, subnetIP},
},
},
{
@ -792,25 +798,25 @@ func TestPeerCfg_NAT(t *testing.T) {
peer2IP: selfEIP2,
publicIP: selfEIP2,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP2: selfNativeIP,
subnetIP: subnetIP,
dnat: []dnatTest{
{selfNativeIP, selfNativeIP, selfNativeIP},
{peer2IP, selfEIP2, selfNativeIP},
{peer2IP, subnetIP, subnetIP},
},
},
}
for _, tc := range tests {
t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) {
pcfg := peerConfigFromWGConfig(tc.wcfg)
pcfg := peerConfigTableFromWGConfig(tc.wcfg)
for peer, want := range tc.snatMap {
if got := pcfg.selectSrcIP(selfNativeIP, peer); got != want {
t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want)
}
}
for dstIP, want := range tc.dnatMap {
if got := pcfg.mapDstIP(dstIP); got != want {
t.Errorf("mapDstIP[%v]: got %v; want %v", dstIP, got, want)
for i, dt := range tc.dnat {
if got := pcfg.mapDstIP(dt.src, dt.dst); got != dt.want {
t.Errorf("dnat[%d]: mapDstIP[%v, %v]: got %v; want %v", i, dt.src, dt.dst, got, dt.want)
}
}
if t.Failed() {

Loading…
Cancel
Save