diff --git a/net/packet/packet.go b/net/packet/packet.go index aa71295c5..8dee86ea7 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -440,6 +440,40 @@ func (q *Parsed) IsEchoResponse() bool { } } +// UpdateSrcAddr updates the source address in the packet buffer (e.g. during +// SNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP +// over IPv4 is supported. It panics if called with IPv6 addr. +func (q *Parsed) UpdateSrcAddr(src netip.Addr) { + if q.IPVersion != 4 || src.Is6() { + panic("UpdateSrcAddr: only IPv4 is supported") + } + + old := q.Src.Addr() + q.Src = netip.AddrPortFrom(src, q.Src.Port()) + + b := q.Buffer() + v4 := src.As4() + copy(b[12:16], v4[:]) + updateV4PacketChecksums(q, old, src) +} + +// UpdateDstAddr updates the source address in the packet buffer (e.g. during +// DNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP +// over IPv4 is supported. It panics if called with IPv6 addr. +func (q *Parsed) UpdateDstAddr(dst netip.Addr) { + if q.IPVersion != 4 || dst.Is6() { + panic("UpdateDstAddr: only IPv4 is supported") + } + + old := q.Dst.Addr() + q.Dst = netip.AddrPortFrom(dst, q.Dst.Port()) + + b := q.Buffer() + v4 := dst.As4() + copy(b[16:20], v4[:]) + updateV4PacketChecksums(q, old, dst) +} + // EchoIDSeq extracts the identifier/sequence bytes from an ICMP Echo response, // and returns them as a uint32, used to lookup internally routed ICMP echo // responses. This function is intentionally lightweight as it is called on @@ -502,3 +536,69 @@ func withIP(ap netip.AddrPort, ip netip.Addr) netip.AddrPort { func withPort(ap netip.AddrPort, port uint16) netip.AddrPort { return netip.AddrPortFrom(ap.Addr(), port) } + +// updateV4PacketChecksums updates the checksums in the packet buffer. +// Currently (2023-03-01) only TCP/UDP/ICMP over IPv4 is supported. +// p is modified in place. +// If p.IPProto is unknown, only the IP header checksum is updated. +// TODO(maisem): more protocols (sctp, gre, dccp) +func updateV4PacketChecksums(p *Parsed, old, new netip.Addr) { + o4, n4 := old.As4(), new.As4() + updateV4Checksum(p.Buffer()[10:12], o4[:], n4[:]) // header + switch p.IPProto { + case ipproto.UDP: + updateV4Checksum(p.Transport()[6:8], o4[:], n4[:]) + case ipproto.TCP: + updateV4Checksum(p.Transport()[16:18], o4[:], n4[:]) + case ipproto.ICMPv4: + // Nothing to do. + } + // TODO(maisem): more protocols (sctp, gre, dccp) +} + +// updateV4Checksum calculates and updates the checksum in the packet buffer +// for a change between old and new. The checksum is updated in place. +func updateV4Checksum(oldSum, old, new []byte) { + if len(old) != len(new) { + panic("old and new must be the same length") + } + if len(old)%2 != 0 { + panic("old and new must be even length") + } + /* + RFC 1624 + Given the following notation: + + HC - old checksum in header + C - one's complement sum of old header + HC' - new checksum in header + C' - one's complement sum of new header + m - old value of a 16-bit field + m' - new value of a 16-bit field + + HC' = ~(C + (-m) + m') -- [Eqn. 3] + HC' = ~(~HC + ~m + m') + + This can be simplified to: + HC' = ~(C + ~m + m') -- [Eqn. 3] + HC' = ~C' + C' = C + ~m + m' + */ + + c := uint32(^binary.BigEndian.Uint16(oldSum)) + + cPrime := c + for len(new) > 0 { + mNot := uint32(^binary.BigEndian.Uint16(old[:2])) + mPrime := uint32(binary.BigEndian.Uint16(new[:2])) + cPrime += mPrime + mNot + new, old = new[2:], old[2:] + } + + // Account for overflows by adding the carry bits back into the sum. + for (cPrime >> 16) > 0 { + cPrime = cPrime&0xFFFF + cPrime>>16 + } + hcPrime := ^uint16(cPrime) + binary.BigEndian.PutUint16(oldSum, hcPrime) +} diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 24decbe29..ed8233169 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -29,7 +29,10 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/views" "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" ) @@ -88,6 +91,9 @@ type Wrapper struct { destMACAtomic syncs.AtomicValue[[6]byte] discoKey syncs.AtomicValue[key.DiscoPublic] + // natV4Config stores the current NAT configuration. + natV4Config atomic.Pointer[natV4Config] + // vectorBuffer stores the oldest unconsumed packet vector from tdev. It is // allocated in wrap() and the underlying arrays should never grow. vectorBuffer [][]byte @@ -459,6 +465,139 @@ func (t *Wrapper) sendVectorOutbound(r tunVectorReadResult) { t.vectorOutbound <- r } +// snatV4 does SNAT on p if it's an IPv4 packet and the destination +// address requires a different source address. +func (t *Wrapper) snatV4(p *packet.Parsed) { + if p.IPVersion != 4 { + return + } + + nc := t.natV4Config.Load() + oldSrc := p.Src.Addr() + newSrc := nc.selectSrcIP(oldSrc, p.Dst.Addr()) + if oldSrc != newSrc { + p.UpdateSrcAddr(newSrc) + } +} + +// dnatV4 does destination NAT on p if it's an IPv4 packet. +func (t *Wrapper) dnatV4(p *packet.Parsed) { + if p.IPVersion != 4 { + return + } + + nc := t.natV4Config.Load() + oldDst := p.Dst.Addr() + newDst := nc.mapDstIP(oldDst) + if newDst != oldDst { + p.UpdateDstAddr(newDst) + } +} + +// findV4 returns the first Tailscale IPv4 address in addrs. +func findV4(addrs []netip.Prefix) netip.Addr { + for _, ap := range addrs { + a := ap.Addr() + if a.Is4() && tsaddr.IsTailscaleIP(a) { + return a + } + } + return netip.Addr{} +} + +// natV4Config is the configuration for IPv4 NAT. +// It should be treated as immutable. +// +// The nil value is a valid configuration. +type natV4Config struct { + // nativeAddr is the IPv4 Tailscale Address of the current node. + nativeAddr netip.Addr + + // listenAddrs is the set of IPv4 addresses that should be + // mapped to the native address. These are the addresses that + // peers will use to connect to this node. + listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{} + + // dstMasqAddrs is map of dst addresses to their respective MasqueradeAsIP + // addresses. The MasqueradeAsIP address is the address that should be used + // as the source address for packets to dst. + dstMasqAddrs views.Map[netip.Addr, netip.Addr] // dst -> masqAddr + + // TODO(maisem/nyghtowl): add support for subnets and exit nodes and test them. + // Determine IP routing table algorithm to use - e.g. ART? +} + +// mapDstIP returns the destination IP to use for a packet to dst. +// If dst is not one of the listen addresses, it is returned as-is, +// otherwise the native address is returned. +func (c *natV4Config) mapDstIP(oldDst netip.Addr) netip.Addr { + if c == nil { + return oldDst + } + if _, ok := c.listenAddrs.GetOk(oldDst); ok { + return c.nativeAddr + } + return oldDst +} + +// selectSrcIP returns the source IP to use for a packet to dst. +// If the packet is not from the native address, it is returned as-is. +func (c *natV4Config) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { + if c == nil { + return oldSrc + } + if oldSrc != c.nativeAddr { + return oldSrc + } + if eip, ok := c.dstMasqAddrs.GetOk(dst); ok { + return eip + } + return oldSrc +} + +// natConfigFromNetMap generates a natV4Config from nm. +// If v4 NAT is not required, it returns nil. +func natConfigFromNetMap(nm *netmap.NetworkMap) *natV4Config { + if nm == nil || nm.SelfNode == nil { + return nil + } + nativeAddr := findV4(nm.SelfNode.Addresses) + if !nativeAddr.IsValid() { + return nil + } + var ( + dstMasqAddrs map[netip.Addr]netip.Addr + listenAddrs map[netip.Addr]struct{} + ) + for _, p := range nm.Peers { + if !p.SelfNodeV4MasqAddrForThisPeer.IsValid() { + continue + } + peerV4 := findV4(p.Addresses) + if !peerV4.IsValid() { + continue + } + mak.Set(&dstMasqAddrs, peerV4, p.SelfNodeV4MasqAddrForThisPeer) + mak.Set(&listenAddrs, p.SelfNodeV4MasqAddrForThisPeer, struct{}{}) + } + if len(listenAddrs) == 0 || len(dstMasqAddrs) == 0 { + return nil + } + return &natV4Config{ + nativeAddr: nativeAddr, + listenAddrs: views.MapOf(listenAddrs), + dstMasqAddrs: views.MapOf(dstMasqAddrs), + } +} + +// SetNetMap is called when a new NetworkMap is received. +// It currently (2023-03-01) only updates the IPv4 NAT configuration. +func (t *Wrapper) SetNetMap(nm *netmap.NetworkMap) { + cfg := natConfigFromNetMap(nm) + t.natV4Config.Store(cfg) + t.logf("nat config: %+v", cfg) +} + var ( magicDNSIPPort = netip.AddrPortFrom(tsaddr.TailscaleServiceIP(), 0) // 100.100.100.100:0 magicDNSIPPortv6 = netip.AddrPortFrom(tsaddr.TailscaleServiceIPv6(), 0) @@ -541,8 +680,8 @@ func (t *Wrapper) IdleDuration() time.Duration { } func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + // packet from OS read and sent to WG res, ok := <-t.vectorOutbound - if !ok { return 0, io.EOF } @@ -566,6 +705,8 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { defer parsedPacketPool.Put(p) for _, data := range res.data { p.Decode(data[res.dataOffset:]) + + t.snatV4(p) if m := t.destIPActivity.Load(); m != nil { if fn := m[p.Dst.Addr()]; fn != nil { fn() @@ -622,6 +763,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int p := parsedPacketPool.Get().(*packet.Parsed) defer parsedPacketPool.Put(p) p.Decode(buf[offset : offset+n]) + t.snatV4(p) if m := t.destIPActivity.Load(); m != nil { if fn := m[p.Dst.Addr()]; fn != nil { @@ -738,11 +880,12 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed) filter.Resp func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { metricPacketIn.Add(int64(len(buffs))) i := 0 - if !t.disableFilter { - p := parsedPacketPool.Get().(*packet.Parsed) - defer parsedPacketPool.Put(p) - for _, buff := range buffs { - p.Decode(buff[offset:]) + p := parsedPacketPool.Get().(*packet.Parsed) + defer parsedPacketPool.Put(p) + for _, buff := range buffs { + p.Decode(buff[offset:]) + t.dnatV4(p) + if !t.disableFilter { if t.filterPacketInboundFromWireGuard(p) != filter.Accept { metricPacketInDrop.Add(1) } else { @@ -750,7 +893,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { i++ } } - } else { + } + if t.disableFilter { i = len(buffs) } buffs = buffs[:i] @@ -801,6 +945,10 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt stack.PacketBufferPtr) error { if capt := t.captureHook.Load(); capt != nil { capt(capture.SynthesizedToLocal, time.Now(), buf[PacketStartOffset:]) } + p := parsedPacketPool.Get().(*packet.Parsed) + defer parsedPacketPool.Put(p) + p.Decode(buf[PacketStartOffset:]) + t.dnatV4(p) return t.InjectInboundDirect(buf, PacketStartOffset) } diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index c9c621ddf..a9cc2998c 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -25,12 +25,14 @@ import ( "tailscale.com/net/connstats" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstime/mono" "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netlogtype" + "tailscale.com/types/netmap" "tailscale.com/util/must" "tailscale.com/wgengine/filter" ) @@ -593,3 +595,129 @@ func TestFilterDiscoLoop(t *testing.T) { t.Errorf("log output mismatch\n got: %q\nwant: %q\n", got, want) } } + +func TestNATCfg(t *testing.T) { + node := func(ip, eip netip.Addr) *tailcfg.Node { + return &tailcfg.Node{ + Addresses: []netip.Prefix{ + netip.PrefixFrom(ip, ip.BitLen()), + }, + SelfNodeV4MasqAddrForThisPeer: eip, + } + } + var ( + noIP netip.Addr + + selfNativeIP = netip.MustParseAddr("100.64.0.1") + selfEIP1 = netip.MustParseAddr("100.64.1.1") + selfEIP2 = netip.MustParseAddr("100.64.1.2") + + peer1IP = netip.MustParseAddr("100.64.0.2") + peer2IP = netip.MustParseAddr("100.64.0.3") + + // subnets should not be impacted. + // TODO(maisem/nyghtowl): add support for subnets and exit nodes and test them. + subnet = netip.MustParseAddr("192.168.0.1") + ) + + tests := []struct { + name string + nm *netmap.NetworkMap + snatMap map[netip.Addr]netip.Addr // dst -> src + dnatMap map[netip.Addr]netip.Addr + }{ + { + name: "no-netmap", + nm: nil, + snatMap: map[netip.Addr]netip.Addr{ + peer1IP: selfNativeIP, + peer2IP: selfNativeIP, + subnet: selfNativeIP, + }, + dnatMap: map[netip.Addr]netip.Addr{ + selfNativeIP: selfNativeIP, + selfEIP1: selfEIP1, + selfEIP2: selfEIP2, + }, + }, + { + name: "single-peer-requires-nat", + nm: &netmap.NetworkMap{ + SelfNode: node(selfNativeIP, noIP), + Peers: []*tailcfg.Node{ + node(peer1IP, noIP), + node(peer2IP, selfEIP1), + }, + }, + snatMap: map[netip.Addr]netip.Addr{ + peer1IP: selfNativeIP, + peer2IP: selfEIP1, + subnet: selfNativeIP, + }, + dnatMap: map[netip.Addr]netip.Addr{ + selfNativeIP: selfNativeIP, + selfEIP1: selfNativeIP, + selfEIP2: selfEIP2, + subnet: subnet, + }, + }, + { + name: "multiple-peers-require-nat", + nm: &netmap.NetworkMap{ + SelfNode: node(selfNativeIP, noIP), + Peers: []*tailcfg.Node{ + node(peer1IP, selfEIP1), + node(peer2IP, selfEIP2), + }, + }, + snatMap: map[netip.Addr]netip.Addr{ + peer1IP: selfEIP1, + peer2IP: selfEIP2, + subnet: selfNativeIP, + }, + dnatMap: map[netip.Addr]netip.Addr{ + selfNativeIP: selfNativeIP, + selfEIP1: selfNativeIP, + selfEIP2: selfNativeIP, + subnet: subnet, + }, + }, + { + name: "no-nat", + nm: &netmap.NetworkMap{ + SelfNode: node(selfNativeIP, noIP), + Peers: []*tailcfg.Node{ + node(peer1IP, noIP), + node(peer2IP, noIP), + }, + }, + snatMap: map[netip.Addr]netip.Addr{ + peer1IP: selfNativeIP, + peer2IP: selfNativeIP, + subnet: selfNativeIP, + }, + dnatMap: map[netip.Addr]netip.Addr{ + selfNativeIP: selfNativeIP, + selfEIP1: selfEIP1, + selfEIP2: selfEIP2, + subnet: subnet, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ncfg := natConfigFromNetMap(tc.nm) + for peer, want := range tc.snatMap { + if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want { + t.Errorf("selectSrcIP: got %v; want %v", got, want) + } + } + for dstIP, want := range tc.dnatMap { + if got := ncfg.mapDstIP(dstIP); got != want { + t.Errorf("mapDstIP: got %v; want %v", got, want) + } + } + }) + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 8f0c2d226..9f23c7050 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -1205,6 +1205,7 @@ func (e *userspaceEngine) SetNetworkMap(nm *netmap.NetworkMap) { e.magicConn.SetNetworkMap(nm) e.mu.Lock() e.netMap = nm + e.tundev.SetNetMap(nm) callbacks := make([]NetworkMapCallback, 0, 4) for _, fn := range e.networkMapCallbacks { callbacks = append(callbacks, fn)