diff --git a/go.mod b/go.mod index 2f0bd690c..c0cd84264 100644 --- a/go.mod +++ b/go.mod @@ -28,8 +28,9 @@ require ( github.com/go-ole/go-ole v1.2.6 github.com/godbus/dbus/v5 v5.0.6 github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da - github.com/google/go-cmp v0.5.8 + github.com/google/go-cmp v0.5.9 github.com/google/go-containerregistry v0.9.0 + github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c github.com/google/uuid v1.3.0 github.com/goreleaser/nfpm v1.10.3 github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 @@ -44,7 +45,7 @@ require ( github.com/mattn/go-colorable v0.1.12 github.com/mattn/go-isatty v0.0.14 github.com/mdlayher/genetlink v1.2.0 - github.com/mdlayher/netlink v1.6.0 + github.com/mdlayher/netlink v1.7.1 github.com/mdlayher/sdnotify v1.0.0 github.com/miekg/dns v1.1.43 github.com/mitchellh/go-ps v1.0.0 @@ -229,7 +230,7 @@ require ( github.com/mattn/go-runewidth v0.0.13 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/mbilski/exhaustivestruct v1.2.0 // indirect - github.com/mdlayher/socket v0.2.3 // indirect + github.com/mdlayher/socket v0.4.0 // indirect github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517 // indirect github.com/mgechev/revive v1.1.2 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect diff --git a/go.sum b/go.sum index b386a7f34..3f325c181 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,9 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-containerregistry v0.9.0 h1:5Ths7RjxyFV0huKChQTgY6fLzvHhZMpLTFNja8U0/0w= github.com/google/go-containerregistry v0.9.0/go.mod h1:9eq4BnSufyT1kHNffX+vSXVonaJ7yaIOulrKZejMxnQ= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -534,6 +535,8 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= +github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c h1:06RMfw+TMMHtRuUOroMeatRCCgSMWXCJQeABvHU69YQ= +github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c/go.mod h1:BVIYo3cdnT4qSylnYqcd5YtmXhr51cJPGtnLBe/uLBU= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -839,15 +842,16 @@ github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY= github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o= -github.com/mdlayher/netlink v1.6.0 h1:rOHX5yl7qnlpiVkFWoqccueppMtXzeziFjWAjLg6sz0= github.com/mdlayher/netlink v1.6.0/go.mod h1:0o3PlBmGst1xve7wQ7j/hwpNaFaH4qCRyWCdcZk8/vA= +github.com/mdlayher/netlink v1.7.1 h1:FdUaT/e33HjEXagwELR8R3/KL1Fq5x3G5jgHLp/BTmg= +github.com/mdlayher/netlink v1.7.1/go.mod h1:nKO5CSjE/DJjVhk/TNp6vCE1ktVxEA8VEh8drhZzxsQ= github.com/mdlayher/raw v0.0.0-20190606142536-fef19f00fc18/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg= github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.1.1/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5Awbj+qDs= -github.com/mdlayher/socket v0.2.3 h1:XZA2X2TjdOwNoNPVPclRCURoX/hokBY8nkTmRZFEheM= -github.com/mdlayher/socket v0.2.3/go.mod h1:bz12/FozYNH/VbvC3q7TRIK/Y6dH1kCKsXaUeXi/FmY= +github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw= +github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc= github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517 h1:zpIH83+oKzcpryru8ceC6BxnoG8TBrhgAvRg8obzup0= github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517/go.mod h1:KQ7+USdGKfpPjXk4Ga+5XxQM4Lm4e3gAogrreFAYpOg= github.com/mgechev/revive v1.1.2 h1:MiYA/o9M7REjvOF20QN43U8OtXDDHQFKLCtJnxLGLog= diff --git a/util/linuxfw/helpers.go b/util/linuxfw/helpers.go new file mode 100644 index 000000000..7526d68ed --- /dev/null +++ b/util/linuxfw/helpers.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package linuxfw + +import ( + "encoding/hex" + "fmt" + "strings" + "unicode" +) + +func formatMaybePrintable(b []byte) string { + // Remove a single trailing null, if any + if len(b) > 0 && b[len(b)-1] == 0 { + b = b[:len(b)-1] + } + + nonprintable := strings.IndexFunc(string(b), func(r rune) bool { + return r > unicode.MaxASCII || !unicode.IsPrint(r) + }) + if nonprintable >= 0 { + return "" + hex.EncodeToString(b) + } + return string(b) +} + +func formatPortRange(r [2]uint16) string { + if r == [2]uint16{0, 65535} { + return fmt.Sprintf(`any`) + } else if r[0] == r[1] { + return fmt.Sprintf(`%d`, r[0]) + } + return fmt.Sprintf(`%d-%d`, r[0], r[1]) +} diff --git a/util/linuxfw/iptables.go b/util/linuxfw/iptables.go new file mode 100644 index 000000000..6e7633215 --- /dev/null +++ b/util/linuxfw/iptables.go @@ -0,0 +1,825 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !(386 || loong64) + +package linuxfw + +import ( + "encoding/hex" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "unsafe" + + "github.com/josharian/native" + "golang.org/x/sys/unix" + linuxabi "gvisor.dev/gvisor/pkg/abi/linux" + "tailscale.com/net/netaddr" + "tailscale.com/types/logger" +) + +type sockLen uint32 + +var ( + iptablesChainNames = map[int]string{ + linuxabi.NF_INET_PRE_ROUTING: "PREROUTING", + linuxabi.NF_INET_LOCAL_IN: "INPUT", + linuxabi.NF_INET_FORWARD: "FORWARD", + linuxabi.NF_INET_LOCAL_OUT: "OUTPUT", + linuxabi.NF_INET_POST_ROUTING: "POSTROUTING", + } + iptablesStandardChains = (func() map[string]bool { + ret := make(map[string]bool) + for _, v := range iptablesChainNames { + ret[v] = true + } + return ret + })() +) + +// DebugNetfilter prints debug information about iptables rules to the +// provided log function. +func DebugIptables(logf logger.Logf) error { + for _, table := range []string{"filter", "nat", "raw"} { + type chainAndEntry struct { + chain string + entry *entry + } + + // Collect all entries first so we can resolve jumps + var ( + lastChain string + ces []chainAndEntry + chainOffsets = make(map[int]string) + ) + err := enumerateIptablesTable(logf, table, func(chain string, entry *entry) error { + if chain != lastChain { + chainOffsets[entry.Offset] = chain + lastChain = chain + } + + ces = append(ces, chainAndEntry{ + chain: lastChain, + entry: entry, + }) + return nil + }) + if err != nil { + return err + } + + lastChain = "" + for _, ce := range ces { + if ce.chain != lastChain { + logf("iptables: table=%s chain=%s", table, ce.chain) + lastChain = ce.chain + } + + // Fixup jump + if std, ok := ce.entry.Target.Data.(standardTarget); ok { + if strings.HasPrefix(std.Verdict, "JUMP(") { + var off int + if _, err := fmt.Sscanf(std.Verdict, "JUMP(%d)", &off); err == nil { + if jt, ok := chainOffsets[off]; ok { + std.Verdict = "JUMP(" + jt + ")" + ce.entry.Target.Data = std + } + } + } + } + + logf("iptables: entry=%+v", ce.entry) + } + } + return nil +} + +// DetectIptables returns the number of iptables rules that are present in the +// system, ignoring the default "ACCEPT" rule present in the standard iptables +// chains. +// +// It only returns an error when the kernel returns an error (i.e. when a +// syscall fails); when there are no iptables rules, it is valid for this +// function to return 0, nil. +func DetectIptables() (int, error) { + dummyLog := func(string, ...any) {} + + var ( + validRules int + firstErr error + ) + for _, table := range []string{"filter", "nat", "raw"} { + err := enumerateIptablesTable(dummyLog, table, func(chain string, entry *entry) error { + // If we have any rules other than basic 'ACCEPT' entries in a + // standard chain, then we consider this a valid rule. + switch { + case !iptablesStandardChains[chain]: + validRules++ + case entry.Target.Name != "standard": + validRules++ + case entry.Target.Name == "standard" && entry.Target.Data.(standardTarget).Verdict != "ACCEPT": + validRules++ + } + return nil + }) + if err != nil && firstErr == nil { + firstErr = err + } + } + + return validRules, firstErr +} + +func enumerateIptablesTable(logf logger.Logf, table string, cb func(string, *entry) error) error { + ln, err := net.Listen("tcp4", ":0") + if err != nil { + return err + } + defer ln.Close() + + tcpLn := ln.(*net.TCPListener) + conn, err := tcpLn.SyscallConn() + if err != nil { + return err + } + + var tableName linuxabi.TableName + copy(tableName[:], []byte(table)) + + tbl := linuxabi.IPTGetinfo{ + Name: tableName, + } + slt := sockLen(linuxabi.SizeOfIPTGetinfo) + + var ctrlErr error + err = conn.Control(func(fd uintptr) { + _, _, errno := unix.Syscall6( + unix.SYS_GETSOCKOPT, + fd, + uintptr(unix.SOL_IP), + linuxabi.IPT_SO_GET_INFO, + uintptr(unsafe.Pointer(&tbl)), + uintptr(unsafe.Pointer(&slt)), + 0, + ) + if errno != 0 { + ctrlErr = errno + return + } + }) + if err != nil { + return err + } + if ctrlErr != nil { + return ctrlErr + } + + if tbl.Size < 1 { + return nil + } + + // Allocate enough space to be able to get all iptables information. + entsBuf := make([]byte, linuxabi.SizeOfIPTGetEntries+tbl.Size) + entsHdr := (*linuxabi.IPTGetEntries)(unsafe.Pointer(&entsBuf[0])) + entsHdr.Name = tableName + entsHdr.Size = tbl.Size + + slt = sockLen(len(entsBuf)) + + err = conn.Control(func(fd uintptr) { + _, _, errno := unix.Syscall6( + unix.SYS_GETSOCKOPT, + fd, + uintptr(unix.SOL_IP), + linuxabi.IPT_SO_GET_ENTRIES, + uintptr(unsafe.Pointer(&entsBuf[0])), + uintptr(unsafe.Pointer(&slt)), + 0, + ) + if errno != 0 { + ctrlErr = errno + return + } + }) + if err != nil { + return err + } + if ctrlErr != nil { + return ctrlErr + } + + // Skip header + entsBuf = entsBuf[linuxabi.SizeOfIPTGetEntries:] + + var ( + totalOffset int + currentChain string + ) + for len(entsBuf) > 0 { + parser := entryParser{ + buf: entsBuf, + logf: logf, + checkExtraBytes: true, + } + entry, err := parser.parseEntry(entsBuf) + if err != nil { + logf("iptables: err=%v", err) + break + } + entry.Offset += totalOffset + + // Don't pass 'ERROR' nodes to our caller + if entry.Target.Name == "ERROR" { + if parser.offset == len(entsBuf) { + // all done + break + } + + // New user-defined chain + currentChain = entry.Target.Data.(errorTarget).ErrorName + } else { + // Detect if we're at a new chain based on the hook + // offsets we fetched earlier. + for i, he := range tbl.HookEntry { + if int(he) == totalOffset { + currentChain = iptablesChainNames[i] + } + } + + // Now that we have everything, call our callback. + if err := cb(currentChain, &entry); err != nil { + return err + } + } + + entsBuf = entsBuf[parser.offset:] + totalOffset += parser.offset + } + return nil +} + +// TODO(andrew): convert to use cstruct +type entryParser struct { + buf []byte + offset int + + logf logger.Logf + + // Set to 'true' to print debug messages about unused bytes returned + // from the kernel + checkExtraBytes bool +} + +func (p *entryParser) haveLen(ln int) bool { + if len(p.buf)-p.offset < ln { + return false + } + return true +} + +func (p *entryParser) assertLen(ln int) error { + if !p.haveLen(ln) { + return fmt.Errorf("need %d bytes: %w", ln, errBufferTooSmall) + } + return nil +} + +func (p *entryParser) getBytes(amt int) []byte { + ret := p.buf[p.offset : p.offset+amt] + p.offset += amt + return ret +} + +func (p *entryParser) getByte() byte { + ret := p.buf[p.offset] + p.offset += 1 + return ret +} + +func (p *entryParser) get4() (ret [4]byte) { + ret[0] = p.buf[p.offset+0] + ret[1] = p.buf[p.offset+1] + ret[2] = p.buf[p.offset+2] + ret[3] = p.buf[p.offset+3] + p.offset += 4 + return +} + +func (p *entryParser) setOffset(off, max int) error { + // We can't go back + if off < p.offset { + return fmt.Errorf("invalid target offset (%d < %d): %w", off, p.offset, errMalformed) + } + + // Ensure we don't go beyond our maximum, if given + if max >= 0 && off >= max { + return fmt.Errorf("invalid target offset (%d >= %d): %w", off, max, errMalformed) + } + + // If we aren't already at this offset, move forward + if p.offset < off { + if p.checkExtraBytes { + extraData := p.buf[p.offset:off] + diff := off - p.offset + p.logf("%d bytes (%d, %d) are unused: %s", diff, p.offset, off, hex.EncodeToString(extraData)) + } + + p.offset = off + } + return nil +} + +var ( + errBufferTooSmall = errors.New("buffer too small") + errMalformed = errors.New("data malformed") +) + +type entry struct { + Offset int + IP iptip + NFCache uint32 + PacketCount uint64 + ByteCount uint64 + Matches []match + Target target +} + +func (e entry) String() string { + var sb strings.Builder + sb.WriteString("{") + + fmt.Fprintf(&sb, "Offset:%d IP:%v PacketCount:%d ByteCount:%d", e.Offset, e.IP, e.PacketCount, e.ByteCount) + if len(e.Matches) > 0 { + fmt.Fprintf(&sb, " Matches:%v", e.Matches) + } + fmt.Fprintf(&sb, " Target:%v", e.Target) + + sb.WriteString("}") + return sb.String() +} + +func (p *entryParser) parseEntry(b []byte) (entry, error) { + startOff := p.offset + + iptip, err := p.parseIPTIP() + if err != nil { + return entry{}, fmt.Errorf("parsing IPTIP: %w", err) + } + + ret := entry{ + Offset: startOff, + IP: iptip, + } + + // Must have space for the rest of the members + if err := p.assertLen(28); err != nil { + return entry{}, err + } + + ret.NFCache = native.Endian.Uint32(p.getBytes(4)) + targetOffset := int(native.Endian.Uint16(p.getBytes(2))) + nextOffset := int(native.Endian.Uint16(p.getBytes(2))) + /* unused field: Comeback = */ p.getBytes(4) + ret.PacketCount = native.Endian.Uint64(p.getBytes(8)) + ret.ByteCount = native.Endian.Uint64(p.getBytes(8)) + + // Must have at least enough space in our buffer to get to the target; + // doing this here means we can avoid bounds checks in parseMatches + if err := p.assertLen(targetOffset - p.offset); err != nil { + return entry{}, err + } + + // Matches are stored between the end of the entry structure and the + // start of the 'targets' structure. + ret.Matches, err = p.parseMatches(targetOffset) + if err != nil { + return entry{}, err + } + + if targetOffset > 0 { + if err := p.setOffset(targetOffset, nextOffset); err != nil { + return entry{}, err + } + + ret.Target, err = p.parseTarget(nextOffset) + if err != nil { + return entry{}, fmt.Errorf("parsing target: %w", err) + } + } + + if err := p.setOffset(nextOffset, -1); err != nil { + return entry{}, err + } + + return ret, nil +} + +type iptip struct { + Src netip.Addr + Dst netip.Addr + SrcMask netip.Addr + DstMask netip.Addr + InputInterface string + OutputInterface string + InputInterfaceMask []byte + OutputInterfaceMask []byte + Protocol uint16 + Flags uint8 + InverseFlags uint8 +} + +var protocolNames = map[uint16]string{ + unix.IPPROTO_ESP: "esp", + unix.IPPROTO_GRE: "gre", + unix.IPPROTO_ICMP: "icmp", + unix.IPPROTO_ICMPV6: "icmpv6", + unix.IPPROTO_IGMP: "igmp", + unix.IPPROTO_IP: "ip", + unix.IPPROTO_IPIP: "ipip", + unix.IPPROTO_IPV6: "ip6", + unix.IPPROTO_RAW: "raw", + unix.IPPROTO_TCP: "tcp", + unix.IPPROTO_UDP: "udp", +} + +func (ip iptip) String() string { + var sb strings.Builder + sb.WriteString("{") + + formatAddrMask := func(addr, mask netip.Addr) string { + if pref, ok := netaddr.FromStdIPNet(&net.IPNet{ + IP: addr.AsSlice(), + Mask: mask.AsSlice(), + }); ok { + return fmt.Sprint(pref) + } + return fmt.Sprintf("%s/%s", addr, mask) + } + + fmt.Fprintf(&sb, "Src:%s", formatAddrMask(ip.Src, ip.SrcMask)) + fmt.Fprintf(&sb, ", Dst:%s", formatAddrMask(ip.Dst, ip.DstMask)) + + translateMask := func(mask []byte) string { + var ret []byte + for _, b := range mask { + if b != 0 { + ret = append(ret, 'X') + } else { + ret = append(ret, '.') + } + } + return string(ret) + } + + if ip.InputInterface != "" { + fmt.Fprintf(&sb, ", InputInterface:%s/%s", ip.InputInterface, translateMask(ip.InputInterfaceMask)) + } + if ip.OutputInterface != "" { + fmt.Fprintf(&sb, ", OutputInterface:%s/%s", ip.OutputInterface, translateMask(ip.OutputInterfaceMask)) + } + if nm, ok := protocolNames[ip.Protocol]; ok { + fmt.Fprintf(&sb, ", Protocol:%s", nm) + } else { + fmt.Fprintf(&sb, ", Protocol:%d", ip.Protocol) + } + + if ip.Flags != 0 { + fmt.Fprintf(&sb, ", Flags:%d", ip.Flags) + } + if ip.InverseFlags != 0 { + fmt.Fprintf(&sb, ", InverseFlags:%d", ip.InverseFlags) + } + + sb.WriteString("}") + return sb.String() +} + +func (p *entryParser) parseIPTIP() (iptip, error) { + if err := p.assertLen(84); err != nil { + return iptip{}, err + } + + var ret iptip + + ret.Src = netip.AddrFrom4(p.get4()) + ret.Dst = netip.AddrFrom4(p.get4()) + ret.SrcMask = netip.AddrFrom4(p.get4()) + ret.DstMask = netip.AddrFrom4(p.get4()) + + const IFNAMSIZ = 16 + ret.InputInterface = unix.ByteSliceToString(p.getBytes(IFNAMSIZ)) + ret.OutputInterface = unix.ByteSliceToString(p.getBytes(IFNAMSIZ)) + + ret.InputInterfaceMask = p.getBytes(IFNAMSIZ) + ret.OutputInterfaceMask = p.getBytes(IFNAMSIZ) + + ret.Protocol = native.Endian.Uint16(p.getBytes(2)) + ret.Flags = p.getByte() + ret.InverseFlags = p.getByte() + return ret, nil +} + +type match struct { + Name string + Revision int + Data any + RawData []byte +} + +func (m match) String() string { + return fmt.Sprintf("{Name:%s, Data:%v}", m.Name, m.Data) +} + +type matchTCP struct { + SourcePortRange [2]uint16 + DestPortRange [2]uint16 + Option byte + FlagMask byte + FlagCompare byte + InverseFlags byte +} + +func (m matchTCP) String() string { + var sb strings.Builder + sb.WriteString("{") + + fmt.Fprintf(&sb, "SrcPort:%s, DstPort:%s", + formatPortRange(m.SourcePortRange), + formatPortRange(m.DestPortRange)) + + // TODO(andrew): format semantically + if m.Option != 0 { + fmt.Fprintf(&sb, ", Option:%d", m.Option) + } + if m.FlagMask != 0 { + fmt.Fprintf(&sb, ", FlagMask:%d", m.FlagMask) + } + if m.FlagCompare != 0 { + fmt.Fprintf(&sb, ", FlagCompare:%d", m.FlagCompare) + } + if m.InverseFlags != 0 { + fmt.Fprintf(&sb, ", InverseFlags:%d", m.InverseFlags) + } + + sb.WriteString("}") + return sb.String() +} + +func (p *entryParser) parseMatches(maxOffset int) ([]match, error) { + const XT_EXTENSION_MAXNAMELEN = 29 + const structSize = 2 + XT_EXTENSION_MAXNAMELEN + 1 + + var ret []match + for { + // If we don't have space for a single match structure, we're done + if p.offset+structSize > maxOffset { + break + } + + var curr match + + matchSize := int(native.Endian.Uint16(p.getBytes(2))) + curr.Name = unix.ByteSliceToString(p.getBytes(XT_EXTENSION_MAXNAMELEN)) + curr.Revision = int(p.getByte()) + + // The data size is the total match size minus what we've already consumed. + dataLen := matchSize - structSize + dataEnd := p.offset + dataLen + + // If we don't have space for the match data, then there's something wrong + if dataEnd > maxOffset { + return nil, fmt.Errorf("out of space for match (%d > max %d): %w", dataEnd, maxOffset, errMalformed) + } else if dataEnd > len(p.buf) { + return nil, fmt.Errorf("out of space for match (%d > buf %d): %w", dataEnd, len(p.buf), errMalformed) + } + + curr.RawData = p.getBytes(dataLen) + + // TODO(andrew): more here; UDP, etc. + switch curr.Name { + case "tcp": + /* + struct xt_tcp { + __u16 spts[2]; // Source port range. + __u16 dpts[2]; // Destination port range. + __u8 option; // TCP Option iff non-zero + __u8 flg_mask; // TCP flags mask byte + __u8 flg_cmp; // TCP flags compare byte + __u8 invflags; // Inverse flags + }; + */ + if len(curr.RawData) >= 12 { + curr.Data = matchTCP{ + SourcePortRange: [...]uint16{ + native.Endian.Uint16(curr.RawData[0:2]), + native.Endian.Uint16(curr.RawData[2:4]), + }, + DestPortRange: [...]uint16{ + native.Endian.Uint16(curr.RawData[4:6]), + native.Endian.Uint16(curr.RawData[6:8]), + }, + Option: curr.RawData[8], + FlagMask: curr.RawData[9], + FlagCompare: curr.RawData[10], + InverseFlags: curr.RawData[11], + } + } + } + + ret = append(ret, curr) + } + return ret, nil +} + +type target struct { + Name string + Revision int + Data any + RawData []byte +} + +func (t target) String() string { + return fmt.Sprintf("{Name:%s, Data:%v}", t.Name, t.Data) +} + +func (p *entryParser) parseTarget(nextOffset int) (target, error) { + const XT_EXTENSION_MAXNAMELEN = 29 + const structSize = 2 + XT_EXTENSION_MAXNAMELEN + 1 + + if err := p.assertLen(structSize); err != nil { + return target{}, err + } + + var ret target + + targetSize := int(native.Endian.Uint16(p.getBytes(2))) + ret.Name = unix.ByteSliceToString(p.getBytes(XT_EXTENSION_MAXNAMELEN)) + ret.Revision = int(p.getByte()) + + if targetSize > structSize { + dataLen := targetSize - structSize + if err := p.assertLen(dataLen); err != nil { + return target{}, err + } + + ret.RawData = p.getBytes(dataLen) + } + + // Special case; matches what iptables does + if ret.Name == "" { + ret.Name = "standard" + } + + switch ret.Name { + case "standard": + if len(ret.RawData) >= 4 { + verdict := int32(native.Endian.Uint32(ret.RawData)) + + var info string + switch verdict { + case -1: + info = "DROP" + case -2: + info = "ACCEPT" + case -4: + info = "QUEUE" + case -5: + info = "RETURN" + case int32(nextOffset): + info = "FALLTHROUGH" + default: + info = fmt.Sprintf("JUMP(%d)", verdict) + } + ret.Data = standardTarget{Verdict: info} + } + + case "ERROR": + ret.Data = errorTarget{ + ErrorName: unix.ByteSliceToString(ret.RawData), + } + + case "REJECT": + if len(ret.RawData) >= 4 { + ret.Data = rejectTarget{ + With: rejectWith(native.Endian.Uint32(ret.RawData)), + } + } + + case "MARK": + if len(ret.RawData) >= 8 { + mark := native.Endian.Uint32(ret.RawData[0:4]) + mask := native.Endian.Uint32(ret.RawData[4:8]) + + var mode markMode + switch { + case mark == 0: + mode = markModeAnd + mark = ^mask + + case mark == mask: + mode = markModeOr + + case mask == 0: + mode = markModeXor + + case mask == 0xffffffff: + mode = markModeSet + + default: + // TODO(andrew): handle xset? + } + + ret.Data = markTarget{ + Mark: mark, + Mode: mode, + } + } + } + + return ret, nil +} + +// Various types for things in iptables-land follow. + +type standardTarget struct { + Verdict string +} + +type errorTarget struct { + ErrorName string +} + +type rejectWith int + +const ( + rwIPT_ICMP_NET_UNREACHABLE rejectWith = iota + rwIPT_ICMP_HOST_UNREACHABLE + rwIPT_ICMP_PROT_UNREACHABLE + rwIPT_ICMP_PORT_UNREACHABLE + rwIPT_ICMP_ECHOREPLY + rwIPT_ICMP_NET_PROHIBITED + rwIPT_ICMP_HOST_PROHIBITED + rwIPT_TCP_RESET + rwIPT_ICMP_ADMIN_PROHIBITED +) + +func (rw rejectWith) String() string { + switch rw { + case rwIPT_ICMP_NET_UNREACHABLE: + return "icmp-net-unreachable" + case rwIPT_ICMP_HOST_UNREACHABLE: + return "icmp-host-unreachable" + case rwIPT_ICMP_PROT_UNREACHABLE: + return "icmp-prot-unreachable" + case rwIPT_ICMP_PORT_UNREACHABLE: + return "icmp-port-unreachable" + case rwIPT_ICMP_ECHOREPLY: + return "icmp-echo-reply" + case rwIPT_ICMP_NET_PROHIBITED: + return "icmp-net-prohibited" + case rwIPT_ICMP_HOST_PROHIBITED: + return "icmp-host-prohibited" + case rwIPT_TCP_RESET: + return "tcp-reset" + case rwIPT_ICMP_ADMIN_PROHIBITED: + return "icmp-admin-prohibited" + default: + return "UNKNOWN" + } +} + +type rejectTarget struct { + With rejectWith +} + +type markMode byte + +const ( + markModeSet markMode = iota + markModeAnd + markModeOr + markModeXor +) + +func (mm markMode) String() string { + switch mm { + case markModeSet: + return "set" + case markModeAnd: + return "and" + case markModeOr: + return "or" + case markModeXor: + return "xor" + default: + return "UNKNOWN" + } +} + +type markTarget struct { + Mode markMode + Mark uint32 +} diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go new file mode 100644 index 000000000..f3d7b0561 --- /dev/null +++ b/util/linuxfw/linuxfw.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package linuxfw returns the kind of firewall being used by the kernel. +package linuxfw + +import "errors" + +// ErrUnsupported is the error returned from all functions on non-Linux +// platforms. +var ErrUnsupported = errors.New("unsupported") diff --git a/util/linuxfw/linuxfw_struct_linux_test.go b/util/linuxfw/linuxfw_struct_linux_test.go new file mode 100644 index 000000000..ae9a2f7e5 --- /dev/null +++ b/util/linuxfw/linuxfw_struct_linux_test.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !(386 || loong64) + +package linuxfw + +import ( + "testing" + "unsafe" + + "tailscale.com/util/linuxfw/linuxfwtest" +) + +func TestSizes(t *testing.T) { + linuxfwtest.TestSizes(t, &linuxfwtest.SizeInfo{ + SizeofSocklen: unsafe.Sizeof(sockLen(0)), + }) +} diff --git a/util/linuxfw/linuxfw_unsupported.go b/util/linuxfw/linuxfw_unsupported.go new file mode 100644 index 000000000..819d39ca5 --- /dev/null +++ b/util/linuxfw/linuxfw_unsupported.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// NOTE: linux_386 and linux_loong64 are currently unsupported due to missing +// support in upstream dependencies. + +//go:build !linux || (linux && (386 || loong64)) + +package linuxfw + +import ( + "tailscale.com/types/logger" +) + +// DebugNetfilter is not supported on non-Linux platforms. +func DebugNetfilter(logf logger.Logf) error { + return ErrUnsupported +} + +// DetectNetfilter is not supported on non-Linux platforms. +func DetectNetfilter() (int, error) { + return 0, ErrUnsupported +} + +// DebugIptables is not supported on non-Linux platforms. +func DebugIptables(logf logger.Logf) error { + return ErrUnsupported +} + +// DetectIptables is not supported on non-Linux platforms. +func DetectIptables() (int, error) { + return 0, ErrUnsupported +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest.go b/util/linuxfw/linuxfwtest/linuxfwtest.go new file mode 100644 index 000000000..ee2cbd1b2 --- /dev/null +++ b/util/linuxfw/linuxfwtest/linuxfwtest.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo && linux + +// Package linuxfwtest contains tests for the linuxfw package. Go does not +// support cgo in tests, and we don't want the main package to have a cgo +// dependency, so we put all the tests here and call them from the main package +// in tests intead. +package linuxfwtest + +import ( + "testing" + "unsafe" +) + +/* +#include // socket() +*/ +import "C" + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + want := unsafe.Sizeof(C.socklen_t(0)) + if want != si.SizeofSocklen { + t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) + } +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go new file mode 100644 index 000000000..6e9569900 --- /dev/null +++ b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !cgo || !linux + +package linuxfwtest + +import ( + "testing" +) + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + t.Skip("not supported without cgo") +} diff --git a/util/linuxfw/nftables.go b/util/linuxfw/nftables.go new file mode 100644 index 000000000..35205975a --- /dev/null +++ b/util/linuxfw/nftables.go @@ -0,0 +1,269 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !(386 || loong64) + +package linuxfw + +import ( + "fmt" + "sort" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" + "github.com/josharian/native" + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +// DebugNetfilter prints debug information about netfilter rules to the +// provided log function. +func DebugNetfilter(logf logger.Logf) error { + conn, err := nftables.New() + if err != nil { + return err + } + + chains, err := conn.ListChains() + if err != nil { + return fmt.Errorf("cannot list chains: %w", err) + } + + if len(chains) == 0 { + logf("netfilter: no chains") + return nil + } + + for _, chain := range chains { + logf("netfilter: table=%s chain=%s", chain.Table.Name, chain.Name) + + rules, err := conn.GetRules(chain.Table, chain) + if err != nil { + continue + } + sort.Slice(rules, func(i, j int) bool { + return rules[i].Position < rules[j].Position + }) + + for i, rule := range rules { + logf("netfilter: rule[%d]: pos=%d flags=%d", i, rule.Position, rule.Flags) + for _, ex := range rule.Exprs { + switch v := ex.(type) { + case *expr.Meta: + key := metaKeyNames[v.Key] + if key == "" { + key = "UNKNOWN" + } + logf("netfilter: Meta: key=%s source_register=%v register=%d", key, v.SourceRegister, v.Register) + + case *expr.Cmp: + op := cmpOpNames[v.Op] + if op == "" { + op = "UNKNOWN" + } + logf("netfilter: Cmp: op=%s register=%d data=%s", op, v.Register, formatMaybePrintable(v.Data)) + + case *expr.Counter: + // don't print + + case *expr.Verdict: + kind := verdictNames[v.Kind] + if kind == "" { + kind = "UNKNOWN" + } + logf("netfilter: Verdict: kind=%s data=%s", kind, v.Chain) + + case *expr.Target: + logf("netfilter: Target: name=%s info=%s", v.Name, printTargetInfo(v.Name, v.Info)) + + case *expr.Match: + logf("netfilter: Match: name=%s info=%+v", v.Name, printMatchInfo(v.Name, v.Info)) + + case *expr.Payload: + logf("netfilter: Payload: op=%s src=%d dst=%d base=%s offset=%d len=%d", + payloadOperationTypeNames[v.OperationType], + v.SourceRegister, v.DestRegister, + payloadBaseNames[v.Base], + v.Offset, v.Len) + // TODO(andrew): csum + + case *expr.Bitwise: + var xor string + for _, b := range v.Xor { + if b != 0 { + xor = fmt.Sprintf(" xor=%v", v.Xor) + break + } + } + logf("netfilter: Bitwise: src=%d dst=%d len=%d mask=%v%s", + v.SourceRegister, v.DestRegister, v.Len, v.Mask, xor) + + default: + logf("netfilter: unknown %T: %+v", v, v) + } + } + } + } + + return nil +} + +// DetectNetfilter returns the number of nftables rules present in the system. +func DetectNetfilter() (int, error) { + conn, err := nftables.New() + if err != nil { + return 0, err + } + + chains, err := conn.ListChains() + if err != nil { + return 0, fmt.Errorf("cannot list chains: %w", err) + } + + var validRules int + for _, chain := range chains { + rules, err := conn.GetRules(chain.Table, chain) + if err != nil { + continue + } + validRules += len(rules) + } + return validRules, nil +} + +func printMatchInfo(name string, info xt.InfoAny) string { + var sb strings.Builder + sb.WriteString(`{`) + + var handled bool = true + switch v := info.(type) { + // TODO(andrew): we should support these common types + //case *xt.ConntrackMtinfo3: + //case *xt.ConntrackMtinfo2: + case *xt.Tcp: + fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts)) + if v.Option != 0 { + fmt.Fprintf(&sb, " Option:%d", v.Option) + } + if v.FlagsMask != 0 { + fmt.Fprintf(&sb, " FlagsMask:%d", v.FlagsMask) + } + if v.FlagsCmp != 0 { + fmt.Fprintf(&sb, " FlagsCmp:%d", v.FlagsCmp) + } + if v.InvFlags != 0 { + fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags) + } + + case *xt.Udp: + fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts)) + if v.InvFlags != 0 { + fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags) + } + + case *xt.AddrType: + var sprefix, dprefix string + if v.InvertSource { + sprefix = "!" + } + if v.InvertDest { + dprefix = "!" + } + // TODO(andrew): translate source/dest + fmt.Fprintf(&sb, "Source:%s%d Dest:%s%d", sprefix, v.Source, dprefix, v.Dest) + + case *xt.AddrTypeV1: + // TODO(andrew): translate source/dest + fmt.Fprintf(&sb, "Source:%d Dest:%d", v.Source, v.Dest) + + var flags []string + for flag, name := range addrTypeFlagNames { + if v.Flags&flag != 0 { + flags = append(flags, name) + } + } + if len(flags) > 0 { + sort.Strings(flags) + fmt.Fprintf(&sb, "Flags:%s", strings.Join(flags, ",")) + } + + default: + handled = false + } + if handled { + sb.WriteString(`}`) + return sb.String() + } + + unknown, ok := info.(*xt.Unknown) + if !ok { + return fmt.Sprintf("(%T)%+v", info, info) + } + data := []byte(*unknown) + + // Things where upstream has no type + handled = true + switch name { + case "pkttype": + if len(data) != 8 { + handled = false + break + } + + pkttype := int(native.Endian.Uint32(data[0:4])) + invert := int(native.Endian.Uint32(data[4:8])) + var invertPrefix string + if invert != 0 { + invertPrefix = "!" + } + + pkttypeName := packetTypeNames[pkttype] + if pkttypeName != "" { + fmt.Fprintf(&sb, "PktType:%s%s", invertPrefix, pkttypeName) + } else { + fmt.Fprintf(&sb, "PktType:%s%d", invertPrefix, pkttype) + } + + default: + handled = true + } + + if !handled { + return fmt.Sprintf("(%T)%+v", info, info) + } + + sb.WriteString(`}`) + return sb.String() +} + +func printTargetInfo(name string, info xt.InfoAny) string { + var sb strings.Builder + sb.WriteString(`{`) + + unknown, ok := info.(*xt.Unknown) + if !ok { + return fmt.Sprintf("(%T)%+v", info, info) + } + data := []byte(*unknown) + + // Things where upstream has no type + switch name { + case "LOG": + if len(data) != 32 { + fmt.Fprintf(&sb, `Error:"bad size; want 32, got %d"`, len(data)) + break + } + + level := data[0] + logflags := data[1] + prefix := unix.ByteSliceToString(data[2:]) + fmt.Fprintf(&sb, "Level:%d LogFlags:%d Prefix:%q", level, logflags, prefix) + default: + return fmt.Sprintf("(%T)%+v", info, info) + } + + sb.WriteString(`}`) + return sb.String() +} diff --git a/util/linuxfw/nftables_types.go b/util/linuxfw/nftables_types.go new file mode 100644 index 000000000..7af23baa1 --- /dev/null +++ b/util/linuxfw/nftables_types.go @@ -0,0 +1,94 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !(386 || loong64) + +package linuxfw + +import ( + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" +) + +var metaKeyNames = map[expr.MetaKey]string{ + expr.MetaKeyLEN: "LEN", + expr.MetaKeyPROTOCOL: "PROTOCOL", + expr.MetaKeyPRIORITY: "PRIORITY", + expr.MetaKeyMARK: "MARK", + expr.MetaKeyIIF: "IIF", + expr.MetaKeyOIF: "OIF", + expr.MetaKeyIIFNAME: "IIFNAME", + expr.MetaKeyOIFNAME: "OIFNAME", + expr.MetaKeyIIFTYPE: "IIFTYPE", + expr.MetaKeyOIFTYPE: "OIFTYPE", + expr.MetaKeySKUID: "SKUID", + expr.MetaKeySKGID: "SKGID", + expr.MetaKeyNFTRACE: "NFTRACE", + expr.MetaKeyRTCLASSID: "RTCLASSID", + expr.MetaKeySECMARK: "SECMARK", + expr.MetaKeyNFPROTO: "NFPROTO", + expr.MetaKeyL4PROTO: "L4PROTO", + expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", + expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", + expr.MetaKeyPKTTYPE: "PKTTYPE", + expr.MetaKeyCPU: "CPU", + expr.MetaKeyIIFGROUP: "IIFGROUP", + expr.MetaKeyOIFGROUP: "OIFGROUP", + expr.MetaKeyCGROUP: "CGROUP", + expr.MetaKeyPRANDOM: "PRANDOM", +} + +var cmpOpNames = map[expr.CmpOp]string{ + expr.CmpOpEq: "EQ", + expr.CmpOpNeq: "NEQ", + expr.CmpOpLt: "LT", + expr.CmpOpLte: "LTE", + expr.CmpOpGt: "GT", + expr.CmpOpGte: "GTE", +} + +var verdictNames = map[expr.VerdictKind]string{ + expr.VerdictReturn: "RETURN", + expr.VerdictGoto: "GOTO", + expr.VerdictJump: "JUMP", + expr.VerdictBreak: "BREAK", + expr.VerdictContinue: "CONTINUE", + expr.VerdictDrop: "DROP", + expr.VerdictAccept: "ACCEPT", + expr.VerdictStolen: "STOLEN", + expr.VerdictQueue: "QUEUE", + expr.VerdictRepeat: "REPEAT", + expr.VerdictStop: "STOP", +} + +var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ + expr.PayloadLoad: "LOAD", + expr.PayloadWrite: "WRITE", +} + +var payloadBaseNames = map[expr.PayloadBase]string{ + expr.PayloadBaseLLHeader: "ll-header", + expr.PayloadBaseNetworkHeader: "network-header", + expr.PayloadBaseTransportHeader: "transport-header", +} + +var packetTypeNames = map[int]string{ + 0 /* PACKET_HOST */ : "unicast", + 1 /* PACKET_BROADCAST */ : "broadcast", + 2 /* PACKET_MULTICAST */ : "multicast", +} + +var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ + xt.AddrTypeUnspec: "unspec", + xt.AddrTypeUnicast: "unicast", + xt.AddrTypeLocal: "local", + xt.AddrTypeBroadcast: "broadcast", + xt.AddrTypeAnycast: "anycast", + xt.AddrTypeMulticast: "multicast", + xt.AddrTypeBlackhole: "blackhole", + xt.AddrTypeUnreachable: "unreachable", + xt.AddrTypeProhibit: "prohibit", + xt.AddrTypeThrow: "throw", + xt.AddrTypeNat: "nat", + xt.AddrTypeXresolve: "xresolve", +}