net/tstun: add inital support for NAT v4

This adds support in tstun to utitilize the SelfNodeV4MasqAddrForThisPeer and
perform the necessary modifications to the packet as it passes through tstun.

Currently this only handles ICMP, UDP and TCP traffic.
Subnet routers and Exit Nodes are also unsupported.

Updates tailscale/corp#8020

Co-authored-by: Melanie Warrick <warrick@tailscale.com>
Signed-off-by: Maisem Ali <maisem@tailscale.com>
maisem/k8s-cache
Maisem Ali 2 years ago committed by Maisem Ali
parent 535fad16f8
commit bb31fd7d1c

@ -440,6 +440,40 @@ func (q *Parsed) IsEchoResponse() bool {
} }
} }
// UpdateSrcAddr updates the source address in the packet buffer (e.g. during
// SNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP
// over IPv4 is supported. It panics if called with IPv6 addr.
func (q *Parsed) UpdateSrcAddr(src netip.Addr) {
if q.IPVersion != 4 || src.Is6() {
panic("UpdateSrcAddr: only IPv4 is supported")
}
old := q.Src.Addr()
q.Src = netip.AddrPortFrom(src, q.Src.Port())
b := q.Buffer()
v4 := src.As4()
copy(b[12:16], v4[:])
updateV4PacketChecksums(q, old, src)
}
// UpdateDstAddr updates the source address in the packet buffer (e.g. during
// DNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP
// over IPv4 is supported. It panics if called with IPv6 addr.
func (q *Parsed) UpdateDstAddr(dst netip.Addr) {
if q.IPVersion != 4 || dst.Is6() {
panic("UpdateDstAddr: only IPv4 is supported")
}
old := q.Dst.Addr()
q.Dst = netip.AddrPortFrom(dst, q.Dst.Port())
b := q.Buffer()
v4 := dst.As4()
copy(b[16:20], v4[:])
updateV4PacketChecksums(q, old, dst)
}
// EchoIDSeq extracts the identifier/sequence bytes from an ICMP Echo response, // EchoIDSeq extracts the identifier/sequence bytes from an ICMP Echo response,
// and returns them as a uint32, used to lookup internally routed ICMP echo // and returns them as a uint32, used to lookup internally routed ICMP echo
// responses. This function is intentionally lightweight as it is called on // responses. This function is intentionally lightweight as it is called on
@ -502,3 +536,69 @@ func withIP(ap netip.AddrPort, ip netip.Addr) netip.AddrPort {
func withPort(ap netip.AddrPort, port uint16) netip.AddrPort { func withPort(ap netip.AddrPort, port uint16) netip.AddrPort {
return netip.AddrPortFrom(ap.Addr(), port) return netip.AddrPortFrom(ap.Addr(), port)
} }
// updateV4PacketChecksums updates the checksums in the packet buffer.
// Currently (2023-03-01) only TCP/UDP/ICMP over IPv4 is supported.
// p is modified in place.
// If p.IPProto is unknown, only the IP header checksum is updated.
// TODO(maisem): more protocols (sctp, gre, dccp)
func updateV4PacketChecksums(p *Parsed, old, new netip.Addr) {
o4, n4 := old.As4(), new.As4()
updateV4Checksum(p.Buffer()[10:12], o4[:], n4[:]) // header
switch p.IPProto {
case ipproto.UDP:
updateV4Checksum(p.Transport()[6:8], o4[:], n4[:])
case ipproto.TCP:
updateV4Checksum(p.Transport()[16:18], o4[:], n4[:])
case ipproto.ICMPv4:
// Nothing to do.
}
// TODO(maisem): more protocols (sctp, gre, dccp)
}
// updateV4Checksum calculates and updates the checksum in the packet buffer
// for a change between old and new. The checksum is updated in place.
func updateV4Checksum(oldSum, old, new []byte) {
if len(old) != len(new) {
panic("old and new must be the same length")
}
if len(old)%2 != 0 {
panic("old and new must be even length")
}
/*
RFC 1624
Given the following notation:
HC - old checksum in header
C - one's complement sum of old header
HC' - new checksum in header
C' - one's complement sum of new header
m - old value of a 16-bit field
m' - new value of a 16-bit field
HC' = ~(C + (-m) + m') -- [Eqn. 3]
HC' = ~(~HC + ~m + m')
This can be simplified to:
HC' = ~(C + ~m + m') -- [Eqn. 3]
HC' = ~C'
C' = C + ~m + m'
*/
c := uint32(^binary.BigEndian.Uint16(oldSum))
cPrime := c
for len(new) > 0 {
mNot := uint32(^binary.BigEndian.Uint16(old[:2]))
mPrime := uint32(binary.BigEndian.Uint16(new[:2]))
cPrime += mPrime + mNot
new, old = new[2:], old[2:]
}
// Account for overflows by adding the carry bits back into the sum.
for (cPrime >> 16) > 0 {
cPrime = cPrime&0xFFFF + cPrime>>16
}
hcPrime := ^uint16(cPrime)
binary.BigEndian.PutUint16(oldSum, hcPrime)
}

@ -29,7 +29,10 @@ import (
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/views"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/wgengine/capture" "tailscale.com/wgengine/capture"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
) )
@ -88,6 +91,9 @@ type Wrapper struct {
destMACAtomic syncs.AtomicValue[[6]byte] destMACAtomic syncs.AtomicValue[[6]byte]
discoKey syncs.AtomicValue[key.DiscoPublic] discoKey syncs.AtomicValue[key.DiscoPublic]
// natV4Config stores the current NAT configuration.
natV4Config atomic.Pointer[natV4Config]
// vectorBuffer stores the oldest unconsumed packet vector from tdev. It is // vectorBuffer stores the oldest unconsumed packet vector from tdev. It is
// allocated in wrap() and the underlying arrays should never grow. // allocated in wrap() and the underlying arrays should never grow.
vectorBuffer [][]byte vectorBuffer [][]byte
@ -459,6 +465,139 @@ func (t *Wrapper) sendVectorOutbound(r tunVectorReadResult) {
t.vectorOutbound <- r t.vectorOutbound <- r
} }
// snatV4 does SNAT on p if it's an IPv4 packet and the destination
// address requires a different source address.
func (t *Wrapper) snatV4(p *packet.Parsed) {
if p.IPVersion != 4 {
return
}
nc := t.natV4Config.Load()
oldSrc := p.Src.Addr()
newSrc := nc.selectSrcIP(oldSrc, p.Dst.Addr())
if oldSrc != newSrc {
p.UpdateSrcAddr(newSrc)
}
}
// dnatV4 does destination NAT on p if it's an IPv4 packet.
func (t *Wrapper) dnatV4(p *packet.Parsed) {
if p.IPVersion != 4 {
return
}
nc := t.natV4Config.Load()
oldDst := p.Dst.Addr()
newDst := nc.mapDstIP(oldDst)
if newDst != oldDst {
p.UpdateDstAddr(newDst)
}
}
// findV4 returns the first Tailscale IPv4 address in addrs.
func findV4(addrs []netip.Prefix) netip.Addr {
for _, ap := range addrs {
a := ap.Addr()
if a.Is4() && tsaddr.IsTailscaleIP(a) {
return a
}
}
return netip.Addr{}
}
// natV4Config is the configuration for IPv4 NAT.
// It should be treated as immutable.
//
// The nil value is a valid configuration.
type natV4Config struct {
// nativeAddr is the IPv4 Tailscale Address of the current node.
nativeAddr netip.Addr
// listenAddrs is the set of IPv4 addresses that should be
// mapped to the native address. These are the addresses that
// peers will use to connect to this node.
listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{}
// dstMasqAddrs is map of dst addresses to their respective MasqueradeAsIP
// addresses. The MasqueradeAsIP address is the address that should be used
// as the source address for packets to dst.
dstMasqAddrs views.Map[netip.Addr, netip.Addr] // dst -> masqAddr
// TODO(maisem/nyghtowl): add support for subnets and exit nodes and test them.
// Determine IP routing table algorithm to use - e.g. ART?
}
// mapDstIP returns the destination IP to use for a packet to dst.
// If dst is not one of the listen addresses, it is returned as-is,
// otherwise the native address is returned.
func (c *natV4Config) mapDstIP(oldDst netip.Addr) netip.Addr {
if c == nil {
return oldDst
}
if _, ok := c.listenAddrs.GetOk(oldDst); ok {
return c.nativeAddr
}
return oldDst
}
// selectSrcIP returns the source IP to use for a packet to dst.
// If the packet is not from the native address, it is returned as-is.
func (c *natV4Config) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
if c == nil {
return oldSrc
}
if oldSrc != c.nativeAddr {
return oldSrc
}
if eip, ok := c.dstMasqAddrs.GetOk(dst); ok {
return eip
}
return oldSrc
}
// natConfigFromNetMap generates a natV4Config from nm.
// If v4 NAT is not required, it returns nil.
func natConfigFromNetMap(nm *netmap.NetworkMap) *natV4Config {
if nm == nil || nm.SelfNode == nil {
return nil
}
nativeAddr := findV4(nm.SelfNode.Addresses)
if !nativeAddr.IsValid() {
return nil
}
var (
dstMasqAddrs map[netip.Addr]netip.Addr
listenAddrs map[netip.Addr]struct{}
)
for _, p := range nm.Peers {
if !p.SelfNodeV4MasqAddrForThisPeer.IsValid() {
continue
}
peerV4 := findV4(p.Addresses)
if !peerV4.IsValid() {
continue
}
mak.Set(&dstMasqAddrs, peerV4, p.SelfNodeV4MasqAddrForThisPeer)
mak.Set(&listenAddrs, p.SelfNodeV4MasqAddrForThisPeer, struct{}{})
}
if len(listenAddrs) == 0 || len(dstMasqAddrs) == 0 {
return nil
}
return &natV4Config{
nativeAddr: nativeAddr,
listenAddrs: views.MapOf(listenAddrs),
dstMasqAddrs: views.MapOf(dstMasqAddrs),
}
}
// SetNetMap is called when a new NetworkMap is received.
// It currently (2023-03-01) only updates the IPv4 NAT configuration.
func (t *Wrapper) SetNetMap(nm *netmap.NetworkMap) {
cfg := natConfigFromNetMap(nm)
t.natV4Config.Store(cfg)
t.logf("nat config: %+v", cfg)
}
var ( var (
magicDNSIPPort = netip.AddrPortFrom(tsaddr.TailscaleServiceIP(), 0) // 100.100.100.100:0 magicDNSIPPort = netip.AddrPortFrom(tsaddr.TailscaleServiceIP(), 0) // 100.100.100.100:0
magicDNSIPPortv6 = netip.AddrPortFrom(tsaddr.TailscaleServiceIPv6(), 0) magicDNSIPPortv6 = netip.AddrPortFrom(tsaddr.TailscaleServiceIPv6(), 0)
@ -541,8 +680,8 @@ func (t *Wrapper) IdleDuration() time.Duration {
} }
func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
// packet from OS read and sent to WG
res, ok := <-t.vectorOutbound res, ok := <-t.vectorOutbound
if !ok { if !ok {
return 0, io.EOF return 0, io.EOF
} }
@ -566,6 +705,8 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
for _, data := range res.data { for _, data := range res.data {
p.Decode(data[res.dataOffset:]) p.Decode(data[res.dataOffset:])
t.snatV4(p)
if m := t.destIPActivity.Load(); m != nil { if m := t.destIPActivity.Load(); m != nil {
if fn := m[p.Dst.Addr()]; fn != nil { if fn := m[p.Dst.Addr()]; fn != nil {
fn() fn()
@ -622,6 +763,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
p.Decode(buf[offset : offset+n]) p.Decode(buf[offset : offset+n])
t.snatV4(p)
if m := t.destIPActivity.Load(); m != nil { if m := t.destIPActivity.Load(); m != nil {
if fn := m[p.Dst.Addr()]; fn != nil { if fn := m[p.Dst.Addr()]; fn != nil {
@ -738,11 +880,12 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed) filter.Resp
func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
metricPacketIn.Add(int64(len(buffs))) metricPacketIn.Add(int64(len(buffs)))
i := 0 i := 0
if !t.disableFilter {
p := parsedPacketPool.Get().(*packet.Parsed) p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p) defer parsedPacketPool.Put(p)
for _, buff := range buffs { for _, buff := range buffs {
p.Decode(buff[offset:]) p.Decode(buff[offset:])
t.dnatV4(p)
if !t.disableFilter {
if t.filterPacketInboundFromWireGuard(p) != filter.Accept { if t.filterPacketInboundFromWireGuard(p) != filter.Accept {
metricPacketInDrop.Add(1) metricPacketInDrop.Add(1)
} else { } else {
@ -750,7 +893,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
i++ i++
} }
} }
} else { }
if t.disableFilter {
i = len(buffs) i = len(buffs)
} }
buffs = buffs[:i] buffs = buffs[:i]
@ -801,6 +945,10 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt stack.PacketBufferPtr) error {
if capt := t.captureHook.Load(); capt != nil { if capt := t.captureHook.Load(); capt != nil {
capt(capture.SynthesizedToLocal, time.Now(), buf[PacketStartOffset:]) capt(capture.SynthesizedToLocal, time.Now(), buf[PacketStartOffset:])
} }
p := parsedPacketPool.Get().(*packet.Parsed)
defer parsedPacketPool.Put(p)
p.Decode(buf[PacketStartOffset:])
t.dnatV4(p)
return t.InjectInboundDirect(buf, PacketStartOffset) return t.InjectInboundDirect(buf, PacketStartOffset)
} }

@ -25,12 +25,14 @@ import (
"tailscale.com/net/connstats" "tailscale.com/net/connstats"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/tstest" "tailscale.com/tstest"
"tailscale.com/tstime/mono" "tailscale.com/tstime/mono"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netlogtype" "tailscale.com/types/netlogtype"
"tailscale.com/types/netmap"
"tailscale.com/util/must" "tailscale.com/util/must"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
) )
@ -593,3 +595,129 @@ func TestFilterDiscoLoop(t *testing.T) {
t.Errorf("log output mismatch\n got: %q\nwant: %q\n", got, want) t.Errorf("log output mismatch\n got: %q\nwant: %q\n", got, want)
} }
} }
func TestNATCfg(t *testing.T) {
node := func(ip, eip netip.Addr) *tailcfg.Node {
return &tailcfg.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(ip, ip.BitLen()),
},
SelfNodeV4MasqAddrForThisPeer: eip,
}
}
var (
noIP netip.Addr
selfNativeIP = netip.MustParseAddr("100.64.0.1")
selfEIP1 = netip.MustParseAddr("100.64.1.1")
selfEIP2 = netip.MustParseAddr("100.64.1.2")
peer1IP = netip.MustParseAddr("100.64.0.2")
peer2IP = netip.MustParseAddr("100.64.0.3")
// subnets should not be impacted.
// TODO(maisem/nyghtowl): add support for subnets and exit nodes and test them.
subnet = netip.MustParseAddr("192.168.0.1")
)
tests := []struct {
name string
nm *netmap.NetworkMap
snatMap map[netip.Addr]netip.Addr // dst -> src
dnatMap map[netip.Addr]netip.Addr
}{
{
name: "no-netmap",
nm: nil,
snatMap: map[netip.Addr]netip.Addr{
peer1IP: selfNativeIP,
peer2IP: selfNativeIP,
subnet: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfEIP1,
selfEIP2: selfEIP2,
},
},
{
name: "single-peer-requires-nat",
nm: &netmap.NetworkMap{
SelfNode: node(selfNativeIP, noIP),
Peers: []*tailcfg.Node{
node(peer1IP, noIP),
node(peer2IP, selfEIP1),
},
},
snatMap: map[netip.Addr]netip.Addr{
peer1IP: selfNativeIP,
peer2IP: selfEIP1,
subnet: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfNativeIP,
selfEIP2: selfEIP2,
subnet: subnet,
},
},
{
name: "multiple-peers-require-nat",
nm: &netmap.NetworkMap{
SelfNode: node(selfNativeIP, noIP),
Peers: []*tailcfg.Node{
node(peer1IP, selfEIP1),
node(peer2IP, selfEIP2),
},
},
snatMap: map[netip.Addr]netip.Addr{
peer1IP: selfEIP1,
peer2IP: selfEIP2,
subnet: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfNativeIP,
selfEIP2: selfNativeIP,
subnet: subnet,
},
},
{
name: "no-nat",
nm: &netmap.NetworkMap{
SelfNode: node(selfNativeIP, noIP),
Peers: []*tailcfg.Node{
node(peer1IP, noIP),
node(peer2IP, noIP),
},
},
snatMap: map[netip.Addr]netip.Addr{
peer1IP: selfNativeIP,
peer2IP: selfNativeIP,
subnet: selfNativeIP,
},
dnatMap: map[netip.Addr]netip.Addr{
selfNativeIP: selfNativeIP,
selfEIP1: selfEIP1,
selfEIP2: selfEIP2,
subnet: subnet,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ncfg := natConfigFromNetMap(tc.nm)
for peer, want := range tc.snatMap {
if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want {
t.Errorf("selectSrcIP: got %v; want %v", got, want)
}
}
for dstIP, want := range tc.dnatMap {
if got := ncfg.mapDstIP(dstIP); got != want {
t.Errorf("mapDstIP: got %v; want %v", got, want)
}
}
})
}
}

@ -1205,6 +1205,7 @@ func (e *userspaceEngine) SetNetworkMap(nm *netmap.NetworkMap) {
e.magicConn.SetNetworkMap(nm) e.magicConn.SetNetworkMap(nm)
e.mu.Lock() e.mu.Lock()
e.netMap = nm e.netMap = nm
e.tundev.SetNetMap(nm)
callbacks := make([]NetworkMapCallback, 0, 4) callbacks := make([]NetworkMapCallback, 0, 4)
for _, fn := range e.networkMapCallbacks { for _, fn := range e.networkMapCallbacks {
callbacks = append(callbacks, fn) callbacks = append(callbacks, fn)

Loading…
Cancel
Save