util/linuxfw: initial implementation of package

This package is an initial implementation of something that can read
netfilter and iptables rules from the Linux kernel without needing to
shell out to an external utility; it speaks directly to the kernel using
syscalls and parses the data returned.

Currently this is read-only since it only knows how to parse a subset of
the available data.

Signed-off-by: Andrew Dunham <andrew@tailscale.com>
Change-Id: Iccadf5dcc081b73268d8ccf8884c24eb6a6f1ff5
pull/7232/head
Andrew Dunham 2 years ago committed by Andrew Dunham
parent 3c107ff301
commit ba48ec5e39

@ -28,8 +28,9 @@ require (
github.com/go-ole/go-ole v1.2.6
github.com/godbus/dbus/v5 v5.0.6
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
github.com/google/go-cmp v0.5.8
github.com/google/go-cmp v0.5.9
github.com/google/go-containerregistry v0.9.0
github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c
github.com/google/uuid v1.3.0
github.com/goreleaser/nfpm v1.10.3
github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3
@ -44,7 +45,7 @@ require (
github.com/mattn/go-colorable v0.1.12
github.com/mattn/go-isatty v0.0.14
github.com/mdlayher/genetlink v1.2.0
github.com/mdlayher/netlink v1.6.0
github.com/mdlayher/netlink v1.7.1
github.com/mdlayher/sdnotify v1.0.0
github.com/miekg/dns v1.1.43
github.com/mitchellh/go-ps v1.0.0
@ -229,7 +230,7 @@ require (
github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/mbilski/exhaustivestruct v1.2.0 // indirect
github.com/mdlayher/socket v0.2.3 // indirect
github.com/mdlayher/socket v0.4.0 // indirect
github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517 // indirect
github.com/mgechev/revive v1.1.2 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect

@ -521,8 +521,9 @@ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-containerregistry v0.9.0 h1:5Ths7RjxyFV0huKChQTgY6fLzvHhZMpLTFNja8U0/0w=
github.com/google/go-containerregistry v0.9.0/go.mod h1:9eq4BnSufyT1kHNffX+vSXVonaJ7yaIOulrKZejMxnQ=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@ -534,6 +535,8 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk=
github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c h1:06RMfw+TMMHtRuUOroMeatRCCgSMWXCJQeABvHU69YQ=
github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c/go.mod h1:BVIYo3cdnT4qSylnYqcd5YtmXhr51cJPGtnLBe/uLBU=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
@ -839,15 +842,16 @@ github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE
github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY=
github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o=
github.com/mdlayher/netlink v1.6.0 h1:rOHX5yl7qnlpiVkFWoqccueppMtXzeziFjWAjLg6sz0=
github.com/mdlayher/netlink v1.6.0/go.mod h1:0o3PlBmGst1xve7wQ7j/hwpNaFaH4qCRyWCdcZk8/vA=
github.com/mdlayher/netlink v1.7.1 h1:FdUaT/e33HjEXagwELR8R3/KL1Fq5x3G5jgHLp/BTmg=
github.com/mdlayher/netlink v1.7.1/go.mod h1:nKO5CSjE/DJjVhk/TNp6vCE1ktVxEA8VEh8drhZzxsQ=
github.com/mdlayher/raw v0.0.0-20190606142536-fef19f00fc18/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg=
github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg=
github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c=
github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE=
github.com/mdlayher/socket v0.1.1/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5Awbj+qDs=
github.com/mdlayher/socket v0.2.3 h1:XZA2X2TjdOwNoNPVPclRCURoX/hokBY8nkTmRZFEheM=
github.com/mdlayher/socket v0.2.3/go.mod h1:bz12/FozYNH/VbvC3q7TRIK/Y6dH1kCKsXaUeXi/FmY=
github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw=
github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc=
github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517 h1:zpIH83+oKzcpryru8ceC6BxnoG8TBrhgAvRg8obzup0=
github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517/go.mod h1:KQ7+USdGKfpPjXk4Ga+5XxQM4Lm4e3gAogrreFAYpOg=
github.com/mgechev/revive v1.1.2 h1:MiYA/o9M7REjvOF20QN43U8OtXDDHQFKLCtJnxLGLog=

@ -0,0 +1,35 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package linuxfw
import (
"encoding/hex"
"fmt"
"strings"
"unicode"
)
func formatMaybePrintable(b []byte) string {
// Remove a single trailing null, if any
if len(b) > 0 && b[len(b)-1] == 0 {
b = b[:len(b)-1]
}
nonprintable := strings.IndexFunc(string(b), func(r rune) bool {
return r > unicode.MaxASCII || !unicode.IsPrint(r)
})
if nonprintable >= 0 {
return "<hex>" + hex.EncodeToString(b)
}
return string(b)
}
func formatPortRange(r [2]uint16) string {
if r == [2]uint16{0, 65535} {
return fmt.Sprintf(`any`)
} else if r[0] == r[1] {
return fmt.Sprintf(`%d`, r[0])
}
return fmt.Sprintf(`%d-%d`, r[0], r[1])
}

@ -0,0 +1,825 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !(386 || loong64)
package linuxfw
import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"unsafe"
"github.com/josharian/native"
"golang.org/x/sys/unix"
linuxabi "gvisor.dev/gvisor/pkg/abi/linux"
"tailscale.com/net/netaddr"
"tailscale.com/types/logger"
)
type sockLen uint32
var (
iptablesChainNames = map[int]string{
linuxabi.NF_INET_PRE_ROUTING: "PREROUTING",
linuxabi.NF_INET_LOCAL_IN: "INPUT",
linuxabi.NF_INET_FORWARD: "FORWARD",
linuxabi.NF_INET_LOCAL_OUT: "OUTPUT",
linuxabi.NF_INET_POST_ROUTING: "POSTROUTING",
}
iptablesStandardChains = (func() map[string]bool {
ret := make(map[string]bool)
for _, v := range iptablesChainNames {
ret[v] = true
}
return ret
})()
)
// DebugNetfilter prints debug information about iptables rules to the
// provided log function.
func DebugIptables(logf logger.Logf) error {
for _, table := range []string{"filter", "nat", "raw"} {
type chainAndEntry struct {
chain string
entry *entry
}
// Collect all entries first so we can resolve jumps
var (
lastChain string
ces []chainAndEntry
chainOffsets = make(map[int]string)
)
err := enumerateIptablesTable(logf, table, func(chain string, entry *entry) error {
if chain != lastChain {
chainOffsets[entry.Offset] = chain
lastChain = chain
}
ces = append(ces, chainAndEntry{
chain: lastChain,
entry: entry,
})
return nil
})
if err != nil {
return err
}
lastChain = ""
for _, ce := range ces {
if ce.chain != lastChain {
logf("iptables: table=%s chain=%s", table, ce.chain)
lastChain = ce.chain
}
// Fixup jump
if std, ok := ce.entry.Target.Data.(standardTarget); ok {
if strings.HasPrefix(std.Verdict, "JUMP(") {
var off int
if _, err := fmt.Sscanf(std.Verdict, "JUMP(%d)", &off); err == nil {
if jt, ok := chainOffsets[off]; ok {
std.Verdict = "JUMP(" + jt + ")"
ce.entry.Target.Data = std
}
}
}
}
logf("iptables: entry=%+v", ce.entry)
}
}
return nil
}
// DetectIptables returns the number of iptables rules that are present in the
// system, ignoring the default "ACCEPT" rule present in the standard iptables
// chains.
//
// It only returns an error when the kernel returns an error (i.e. when a
// syscall fails); when there are no iptables rules, it is valid for this
// function to return 0, nil.
func DetectIptables() (int, error) {
dummyLog := func(string, ...any) {}
var (
validRules int
firstErr error
)
for _, table := range []string{"filter", "nat", "raw"} {
err := enumerateIptablesTable(dummyLog, table, func(chain string, entry *entry) error {
// If we have any rules other than basic 'ACCEPT' entries in a
// standard chain, then we consider this a valid rule.
switch {
case !iptablesStandardChains[chain]:
validRules++
case entry.Target.Name != "standard":
validRules++
case entry.Target.Name == "standard" && entry.Target.Data.(standardTarget).Verdict != "ACCEPT":
validRules++
}
return nil
})
if err != nil && firstErr == nil {
firstErr = err
}
}
return validRules, firstErr
}
func enumerateIptablesTable(logf logger.Logf, table string, cb func(string, *entry) error) error {
ln, err := net.Listen("tcp4", ":0")
if err != nil {
return err
}
defer ln.Close()
tcpLn := ln.(*net.TCPListener)
conn, err := tcpLn.SyscallConn()
if err != nil {
return err
}
var tableName linuxabi.TableName
copy(tableName[:], []byte(table))
tbl := linuxabi.IPTGetinfo{
Name: tableName,
}
slt := sockLen(linuxabi.SizeOfIPTGetinfo)
var ctrlErr error
err = conn.Control(func(fd uintptr) {
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
uintptr(unix.SOL_IP),
linuxabi.IPT_SO_GET_INFO,
uintptr(unsafe.Pointer(&tbl)),
uintptr(unsafe.Pointer(&slt)),
0,
)
if errno != 0 {
ctrlErr = errno
return
}
})
if err != nil {
return err
}
if ctrlErr != nil {
return ctrlErr
}
if tbl.Size < 1 {
return nil
}
// Allocate enough space to be able to get all iptables information.
entsBuf := make([]byte, linuxabi.SizeOfIPTGetEntries+tbl.Size)
entsHdr := (*linuxabi.IPTGetEntries)(unsafe.Pointer(&entsBuf[0]))
entsHdr.Name = tableName
entsHdr.Size = tbl.Size
slt = sockLen(len(entsBuf))
err = conn.Control(func(fd uintptr) {
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
uintptr(unix.SOL_IP),
linuxabi.IPT_SO_GET_ENTRIES,
uintptr(unsafe.Pointer(&entsBuf[0])),
uintptr(unsafe.Pointer(&slt)),
0,
)
if errno != 0 {
ctrlErr = errno
return
}
})
if err != nil {
return err
}
if ctrlErr != nil {
return ctrlErr
}
// Skip header
entsBuf = entsBuf[linuxabi.SizeOfIPTGetEntries:]
var (
totalOffset int
currentChain string
)
for len(entsBuf) > 0 {
parser := entryParser{
buf: entsBuf,
logf: logf,
checkExtraBytes: true,
}
entry, err := parser.parseEntry(entsBuf)
if err != nil {
logf("iptables: err=%v", err)
break
}
entry.Offset += totalOffset
// Don't pass 'ERROR' nodes to our caller
if entry.Target.Name == "ERROR" {
if parser.offset == len(entsBuf) {
// all done
break
}
// New user-defined chain
currentChain = entry.Target.Data.(errorTarget).ErrorName
} else {
// Detect if we're at a new chain based on the hook
// offsets we fetched earlier.
for i, he := range tbl.HookEntry {
if int(he) == totalOffset {
currentChain = iptablesChainNames[i]
}
}
// Now that we have everything, call our callback.
if err := cb(currentChain, &entry); err != nil {
return err
}
}
entsBuf = entsBuf[parser.offset:]
totalOffset += parser.offset
}
return nil
}
// TODO(andrew): convert to use cstruct
type entryParser struct {
buf []byte
offset int
logf logger.Logf
// Set to 'true' to print debug messages about unused bytes returned
// from the kernel
checkExtraBytes bool
}
func (p *entryParser) haveLen(ln int) bool {
if len(p.buf)-p.offset < ln {
return false
}
return true
}
func (p *entryParser) assertLen(ln int) error {
if !p.haveLen(ln) {
return fmt.Errorf("need %d bytes: %w", ln, errBufferTooSmall)
}
return nil
}
func (p *entryParser) getBytes(amt int) []byte {
ret := p.buf[p.offset : p.offset+amt]
p.offset += amt
return ret
}
func (p *entryParser) getByte() byte {
ret := p.buf[p.offset]
p.offset += 1
return ret
}
func (p *entryParser) get4() (ret [4]byte) {
ret[0] = p.buf[p.offset+0]
ret[1] = p.buf[p.offset+1]
ret[2] = p.buf[p.offset+2]
ret[3] = p.buf[p.offset+3]
p.offset += 4
return
}
func (p *entryParser) setOffset(off, max int) error {
// We can't go back
if off < p.offset {
return fmt.Errorf("invalid target offset (%d < %d): %w", off, p.offset, errMalformed)
}
// Ensure we don't go beyond our maximum, if given
if max >= 0 && off >= max {
return fmt.Errorf("invalid target offset (%d >= %d): %w", off, max, errMalformed)
}
// If we aren't already at this offset, move forward
if p.offset < off {
if p.checkExtraBytes {
extraData := p.buf[p.offset:off]
diff := off - p.offset
p.logf("%d bytes (%d, %d) are unused: %s", diff, p.offset, off, hex.EncodeToString(extraData))
}
p.offset = off
}
return nil
}
var (
errBufferTooSmall = errors.New("buffer too small")
errMalformed = errors.New("data malformed")
)
type entry struct {
Offset int
IP iptip
NFCache uint32
PacketCount uint64
ByteCount uint64
Matches []match
Target target
}
func (e entry) String() string {
var sb strings.Builder
sb.WriteString("{")
fmt.Fprintf(&sb, "Offset:%d IP:%v PacketCount:%d ByteCount:%d", e.Offset, e.IP, e.PacketCount, e.ByteCount)
if len(e.Matches) > 0 {
fmt.Fprintf(&sb, " Matches:%v", e.Matches)
}
fmt.Fprintf(&sb, " Target:%v", e.Target)
sb.WriteString("}")
return sb.String()
}
func (p *entryParser) parseEntry(b []byte) (entry, error) {
startOff := p.offset
iptip, err := p.parseIPTIP()
if err != nil {
return entry{}, fmt.Errorf("parsing IPTIP: %w", err)
}
ret := entry{
Offset: startOff,
IP: iptip,
}
// Must have space for the rest of the members
if err := p.assertLen(28); err != nil {
return entry{}, err
}
ret.NFCache = native.Endian.Uint32(p.getBytes(4))
targetOffset := int(native.Endian.Uint16(p.getBytes(2)))
nextOffset := int(native.Endian.Uint16(p.getBytes(2)))
/* unused field: Comeback = */ p.getBytes(4)
ret.PacketCount = native.Endian.Uint64(p.getBytes(8))
ret.ByteCount = native.Endian.Uint64(p.getBytes(8))
// Must have at least enough space in our buffer to get to the target;
// doing this here means we can avoid bounds checks in parseMatches
if err := p.assertLen(targetOffset - p.offset); err != nil {
return entry{}, err
}
// Matches are stored between the end of the entry structure and the
// start of the 'targets' structure.
ret.Matches, err = p.parseMatches(targetOffset)
if err != nil {
return entry{}, err
}
if targetOffset > 0 {
if err := p.setOffset(targetOffset, nextOffset); err != nil {
return entry{}, err
}
ret.Target, err = p.parseTarget(nextOffset)
if err != nil {
return entry{}, fmt.Errorf("parsing target: %w", err)
}
}
if err := p.setOffset(nextOffset, -1); err != nil {
return entry{}, err
}
return ret, nil
}
type iptip struct {
Src netip.Addr
Dst netip.Addr
SrcMask netip.Addr
DstMask netip.Addr
InputInterface string
OutputInterface string
InputInterfaceMask []byte
OutputInterfaceMask []byte
Protocol uint16
Flags uint8
InverseFlags uint8
}
var protocolNames = map[uint16]string{
unix.IPPROTO_ESP: "esp",
unix.IPPROTO_GRE: "gre",
unix.IPPROTO_ICMP: "icmp",
unix.IPPROTO_ICMPV6: "icmpv6",
unix.IPPROTO_IGMP: "igmp",
unix.IPPROTO_IP: "ip",
unix.IPPROTO_IPIP: "ipip",
unix.IPPROTO_IPV6: "ip6",
unix.IPPROTO_RAW: "raw",
unix.IPPROTO_TCP: "tcp",
unix.IPPROTO_UDP: "udp",
}
func (ip iptip) String() string {
var sb strings.Builder
sb.WriteString("{")
formatAddrMask := func(addr, mask netip.Addr) string {
if pref, ok := netaddr.FromStdIPNet(&net.IPNet{
IP: addr.AsSlice(),
Mask: mask.AsSlice(),
}); ok {
return fmt.Sprint(pref)
}
return fmt.Sprintf("%s/%s", addr, mask)
}
fmt.Fprintf(&sb, "Src:%s", formatAddrMask(ip.Src, ip.SrcMask))
fmt.Fprintf(&sb, ", Dst:%s", formatAddrMask(ip.Dst, ip.DstMask))
translateMask := func(mask []byte) string {
var ret []byte
for _, b := range mask {
if b != 0 {
ret = append(ret, 'X')
} else {
ret = append(ret, '.')
}
}
return string(ret)
}
if ip.InputInterface != "" {
fmt.Fprintf(&sb, ", InputInterface:%s/%s", ip.InputInterface, translateMask(ip.InputInterfaceMask))
}
if ip.OutputInterface != "" {
fmt.Fprintf(&sb, ", OutputInterface:%s/%s", ip.OutputInterface, translateMask(ip.OutputInterfaceMask))
}
if nm, ok := protocolNames[ip.Protocol]; ok {
fmt.Fprintf(&sb, ", Protocol:%s", nm)
} else {
fmt.Fprintf(&sb, ", Protocol:%d", ip.Protocol)
}
if ip.Flags != 0 {
fmt.Fprintf(&sb, ", Flags:%d", ip.Flags)
}
if ip.InverseFlags != 0 {
fmt.Fprintf(&sb, ", InverseFlags:%d", ip.InverseFlags)
}
sb.WriteString("}")
return sb.String()
}
func (p *entryParser) parseIPTIP() (iptip, error) {
if err := p.assertLen(84); err != nil {
return iptip{}, err
}
var ret iptip
ret.Src = netip.AddrFrom4(p.get4())
ret.Dst = netip.AddrFrom4(p.get4())
ret.SrcMask = netip.AddrFrom4(p.get4())
ret.DstMask = netip.AddrFrom4(p.get4())
const IFNAMSIZ = 16
ret.InputInterface = unix.ByteSliceToString(p.getBytes(IFNAMSIZ))
ret.OutputInterface = unix.ByteSliceToString(p.getBytes(IFNAMSIZ))
ret.InputInterfaceMask = p.getBytes(IFNAMSIZ)
ret.OutputInterfaceMask = p.getBytes(IFNAMSIZ)
ret.Protocol = native.Endian.Uint16(p.getBytes(2))
ret.Flags = p.getByte()
ret.InverseFlags = p.getByte()
return ret, nil
}
type match struct {
Name string
Revision int
Data any
RawData []byte
}
func (m match) String() string {
return fmt.Sprintf("{Name:%s, Data:%v}", m.Name, m.Data)
}
type matchTCP struct {
SourcePortRange [2]uint16
DestPortRange [2]uint16
Option byte
FlagMask byte
FlagCompare byte
InverseFlags byte
}
func (m matchTCP) String() string {
var sb strings.Builder
sb.WriteString("{")
fmt.Fprintf(&sb, "SrcPort:%s, DstPort:%s",
formatPortRange(m.SourcePortRange),
formatPortRange(m.DestPortRange))
// TODO(andrew): format semantically
if m.Option != 0 {
fmt.Fprintf(&sb, ", Option:%d", m.Option)
}
if m.FlagMask != 0 {
fmt.Fprintf(&sb, ", FlagMask:%d", m.FlagMask)
}
if m.FlagCompare != 0 {
fmt.Fprintf(&sb, ", FlagCompare:%d", m.FlagCompare)
}
if m.InverseFlags != 0 {
fmt.Fprintf(&sb, ", InverseFlags:%d", m.InverseFlags)
}
sb.WriteString("}")
return sb.String()
}
func (p *entryParser) parseMatches(maxOffset int) ([]match, error) {
const XT_EXTENSION_MAXNAMELEN = 29
const structSize = 2 + XT_EXTENSION_MAXNAMELEN + 1
var ret []match
for {
// If we don't have space for a single match structure, we're done
if p.offset+structSize > maxOffset {
break
}
var curr match
matchSize := int(native.Endian.Uint16(p.getBytes(2)))
curr.Name = unix.ByteSliceToString(p.getBytes(XT_EXTENSION_MAXNAMELEN))
curr.Revision = int(p.getByte())
// The data size is the total match size minus what we've already consumed.
dataLen := matchSize - structSize
dataEnd := p.offset + dataLen
// If we don't have space for the match data, then there's something wrong
if dataEnd > maxOffset {
return nil, fmt.Errorf("out of space for match (%d > max %d): %w", dataEnd, maxOffset, errMalformed)
} else if dataEnd > len(p.buf) {
return nil, fmt.Errorf("out of space for match (%d > buf %d): %w", dataEnd, len(p.buf), errMalformed)
}
curr.RawData = p.getBytes(dataLen)
// TODO(andrew): more here; UDP, etc.
switch curr.Name {
case "tcp":
/*
struct xt_tcp {
__u16 spts[2]; // Source port range.
__u16 dpts[2]; // Destination port range.
__u8 option; // TCP Option iff non-zero
__u8 flg_mask; // TCP flags mask byte
__u8 flg_cmp; // TCP flags compare byte
__u8 invflags; // Inverse flags
};
*/
if len(curr.RawData) >= 12 {
curr.Data = matchTCP{
SourcePortRange: [...]uint16{
native.Endian.Uint16(curr.RawData[0:2]),
native.Endian.Uint16(curr.RawData[2:4]),
},
DestPortRange: [...]uint16{
native.Endian.Uint16(curr.RawData[4:6]),
native.Endian.Uint16(curr.RawData[6:8]),
},
Option: curr.RawData[8],
FlagMask: curr.RawData[9],
FlagCompare: curr.RawData[10],
InverseFlags: curr.RawData[11],
}
}
}
ret = append(ret, curr)
}
return ret, nil
}
type target struct {
Name string
Revision int
Data any
RawData []byte
}
func (t target) String() string {
return fmt.Sprintf("{Name:%s, Data:%v}", t.Name, t.Data)
}
func (p *entryParser) parseTarget(nextOffset int) (target, error) {
const XT_EXTENSION_MAXNAMELEN = 29
const structSize = 2 + XT_EXTENSION_MAXNAMELEN + 1
if err := p.assertLen(structSize); err != nil {
return target{}, err
}
var ret target
targetSize := int(native.Endian.Uint16(p.getBytes(2)))
ret.Name = unix.ByteSliceToString(p.getBytes(XT_EXTENSION_MAXNAMELEN))
ret.Revision = int(p.getByte())
if targetSize > structSize {
dataLen := targetSize - structSize
if err := p.assertLen(dataLen); err != nil {
return target{}, err
}
ret.RawData = p.getBytes(dataLen)
}
// Special case; matches what iptables does
if ret.Name == "" {
ret.Name = "standard"
}
switch ret.Name {
case "standard":
if len(ret.RawData) >= 4 {
verdict := int32(native.Endian.Uint32(ret.RawData))
var info string
switch verdict {
case -1:
info = "DROP"
case -2:
info = "ACCEPT"
case -4:
info = "QUEUE"
case -5:
info = "RETURN"
case int32(nextOffset):
info = "FALLTHROUGH"
default:
info = fmt.Sprintf("JUMP(%d)", verdict)
}
ret.Data = standardTarget{Verdict: info}
}
case "ERROR":
ret.Data = errorTarget{
ErrorName: unix.ByteSliceToString(ret.RawData),
}
case "REJECT":
if len(ret.RawData) >= 4 {
ret.Data = rejectTarget{
With: rejectWith(native.Endian.Uint32(ret.RawData)),
}
}
case "MARK":
if len(ret.RawData) >= 8 {
mark := native.Endian.Uint32(ret.RawData[0:4])
mask := native.Endian.Uint32(ret.RawData[4:8])
var mode markMode
switch {
case mark == 0:
mode = markModeAnd
mark = ^mask
case mark == mask:
mode = markModeOr
case mask == 0:
mode = markModeXor
case mask == 0xffffffff:
mode = markModeSet
default:
// TODO(andrew): handle xset?
}
ret.Data = markTarget{
Mark: mark,
Mode: mode,
}
}
}
return ret, nil
}
// Various types for things in iptables-land follow.
type standardTarget struct {
Verdict string
}
type errorTarget struct {
ErrorName string
}
type rejectWith int
const (
rwIPT_ICMP_NET_UNREACHABLE rejectWith = iota
rwIPT_ICMP_HOST_UNREACHABLE
rwIPT_ICMP_PROT_UNREACHABLE
rwIPT_ICMP_PORT_UNREACHABLE
rwIPT_ICMP_ECHOREPLY
rwIPT_ICMP_NET_PROHIBITED
rwIPT_ICMP_HOST_PROHIBITED
rwIPT_TCP_RESET
rwIPT_ICMP_ADMIN_PROHIBITED
)
func (rw rejectWith) String() string {
switch rw {
case rwIPT_ICMP_NET_UNREACHABLE:
return "icmp-net-unreachable"
case rwIPT_ICMP_HOST_UNREACHABLE:
return "icmp-host-unreachable"
case rwIPT_ICMP_PROT_UNREACHABLE:
return "icmp-prot-unreachable"
case rwIPT_ICMP_PORT_UNREACHABLE:
return "icmp-port-unreachable"
case rwIPT_ICMP_ECHOREPLY:
return "icmp-echo-reply"
case rwIPT_ICMP_NET_PROHIBITED:
return "icmp-net-prohibited"
case rwIPT_ICMP_HOST_PROHIBITED:
return "icmp-host-prohibited"
case rwIPT_TCP_RESET:
return "tcp-reset"
case rwIPT_ICMP_ADMIN_PROHIBITED:
return "icmp-admin-prohibited"
default:
return "UNKNOWN"
}
}
type rejectTarget struct {
With rejectWith
}
type markMode byte
const (
markModeSet markMode = iota
markModeAnd
markModeOr
markModeXor
)
func (mm markMode) String() string {
switch mm {
case markModeSet:
return "set"
case markModeAnd:
return "and"
case markModeOr:
return "or"
case markModeXor:
return "xor"
default:
return "UNKNOWN"
}
}
type markTarget struct {
Mode markMode
Mark uint32
}

@ -0,0 +1,11 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package linuxfw returns the kind of firewall being used by the kernel.
package linuxfw
import "errors"
// ErrUnsupported is the error returned from all functions on non-Linux
// platforms.
var ErrUnsupported = errors.New("unsupported")

@ -0,0 +1,19 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !(386 || loong64)
package linuxfw
import (
"testing"
"unsafe"
"tailscale.com/util/linuxfw/linuxfwtest"
)
func TestSizes(t *testing.T) {
linuxfwtest.TestSizes(t, &linuxfwtest.SizeInfo{
SizeofSocklen: unsafe.Sizeof(sockLen(0)),
})
}

@ -0,0 +1,33 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// NOTE: linux_386 and linux_loong64 are currently unsupported due to missing
// support in upstream dependencies.
//go:build !linux || (linux && (386 || loong64))
package linuxfw
import (
"tailscale.com/types/logger"
)
// DebugNetfilter is not supported on non-Linux platforms.
func DebugNetfilter(logf logger.Logf) error {
return ErrUnsupported
}
// DetectNetfilter is not supported on non-Linux platforms.
func DetectNetfilter() (int, error) {
return 0, ErrUnsupported
}
// DebugIptables is not supported on non-Linux platforms.
func DebugIptables(logf logger.Logf) error {
return ErrUnsupported
}
// DetectIptables is not supported on non-Linux platforms.
func DetectIptables() (int, error) {
return 0, ErrUnsupported
}

@ -0,0 +1,31 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build cgo && linux
// Package linuxfwtest contains tests for the linuxfw package. Go does not
// support cgo in tests, and we don't want the main package to have a cgo
// dependency, so we put all the tests here and call them from the main package
// in tests intead.
package linuxfwtest
import (
"testing"
"unsafe"
)
/*
#include <sys/socket.h> // socket()
*/
import "C"
type SizeInfo struct {
SizeofSocklen uintptr
}
func TestSizes(t *testing.T, si *SizeInfo) {
want := unsafe.Sizeof(C.socklen_t(0))
if want != si.SizeofSocklen {
t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen)
}
}

@ -0,0 +1,18 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !cgo || !linux
package linuxfwtest
import (
"testing"
)
type SizeInfo struct {
SizeofSocklen uintptr
}
func TestSizes(t *testing.T, si *SizeInfo) {
t.Skip("not supported without cgo")
}

@ -0,0 +1,269 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !(386 || loong64)
package linuxfw
import (
"fmt"
"sort"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/josharian/native"
"golang.org/x/sys/unix"
"tailscale.com/types/logger"
)
// DebugNetfilter prints debug information about netfilter rules to the
// provided log function.
func DebugNetfilter(logf logger.Logf) error {
conn, err := nftables.New()
if err != nil {
return err
}
chains, err := conn.ListChains()
if err != nil {
return fmt.Errorf("cannot list chains: %w", err)
}
if len(chains) == 0 {
logf("netfilter: no chains")
return nil
}
for _, chain := range chains {
logf("netfilter: table=%s chain=%s", chain.Table.Name, chain.Name)
rules, err := conn.GetRules(chain.Table, chain)
if err != nil {
continue
}
sort.Slice(rules, func(i, j int) bool {
return rules[i].Position < rules[j].Position
})
for i, rule := range rules {
logf("netfilter: rule[%d]: pos=%d flags=%d", i, rule.Position, rule.Flags)
for _, ex := range rule.Exprs {
switch v := ex.(type) {
case *expr.Meta:
key := metaKeyNames[v.Key]
if key == "" {
key = "UNKNOWN"
}
logf("netfilter: Meta: key=%s source_register=%v register=%d", key, v.SourceRegister, v.Register)
case *expr.Cmp:
op := cmpOpNames[v.Op]
if op == "" {
op = "UNKNOWN"
}
logf("netfilter: Cmp: op=%s register=%d data=%s", op, v.Register, formatMaybePrintable(v.Data))
case *expr.Counter:
// don't print
case *expr.Verdict:
kind := verdictNames[v.Kind]
if kind == "" {
kind = "UNKNOWN"
}
logf("netfilter: Verdict: kind=%s data=%s", kind, v.Chain)
case *expr.Target:
logf("netfilter: Target: name=%s info=%s", v.Name, printTargetInfo(v.Name, v.Info))
case *expr.Match:
logf("netfilter: Match: name=%s info=%+v", v.Name, printMatchInfo(v.Name, v.Info))
case *expr.Payload:
logf("netfilter: Payload: op=%s src=%d dst=%d base=%s offset=%d len=%d",
payloadOperationTypeNames[v.OperationType],
v.SourceRegister, v.DestRegister,
payloadBaseNames[v.Base],
v.Offset, v.Len)
// TODO(andrew): csum
case *expr.Bitwise:
var xor string
for _, b := range v.Xor {
if b != 0 {
xor = fmt.Sprintf(" xor=%v", v.Xor)
break
}
}
logf("netfilter: Bitwise: src=%d dst=%d len=%d mask=%v%s",
v.SourceRegister, v.DestRegister, v.Len, v.Mask, xor)
default:
logf("netfilter: unknown %T: %+v", v, v)
}
}
}
}
return nil
}
// DetectNetfilter returns the number of nftables rules present in the system.
func DetectNetfilter() (int, error) {
conn, err := nftables.New()
if err != nil {
return 0, err
}
chains, err := conn.ListChains()
if err != nil {
return 0, fmt.Errorf("cannot list chains: %w", err)
}
var validRules int
for _, chain := range chains {
rules, err := conn.GetRules(chain.Table, chain)
if err != nil {
continue
}
validRules += len(rules)
}
return validRules, nil
}
func printMatchInfo(name string, info xt.InfoAny) string {
var sb strings.Builder
sb.WriteString(`{`)
var handled bool = true
switch v := info.(type) {
// TODO(andrew): we should support these common types
//case *xt.ConntrackMtinfo3:
//case *xt.ConntrackMtinfo2:
case *xt.Tcp:
fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts))
if v.Option != 0 {
fmt.Fprintf(&sb, " Option:%d", v.Option)
}
if v.FlagsMask != 0 {
fmt.Fprintf(&sb, " FlagsMask:%d", v.FlagsMask)
}
if v.FlagsCmp != 0 {
fmt.Fprintf(&sb, " FlagsCmp:%d", v.FlagsCmp)
}
if v.InvFlags != 0 {
fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags)
}
case *xt.Udp:
fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts))
if v.InvFlags != 0 {
fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags)
}
case *xt.AddrType:
var sprefix, dprefix string
if v.InvertSource {
sprefix = "!"
}
if v.InvertDest {
dprefix = "!"
}
// TODO(andrew): translate source/dest
fmt.Fprintf(&sb, "Source:%s%d Dest:%s%d", sprefix, v.Source, dprefix, v.Dest)
case *xt.AddrTypeV1:
// TODO(andrew): translate source/dest
fmt.Fprintf(&sb, "Source:%d Dest:%d", v.Source, v.Dest)
var flags []string
for flag, name := range addrTypeFlagNames {
if v.Flags&flag != 0 {
flags = append(flags, name)
}
}
if len(flags) > 0 {
sort.Strings(flags)
fmt.Fprintf(&sb, "Flags:%s", strings.Join(flags, ","))
}
default:
handled = false
}
if handled {
sb.WriteString(`}`)
return sb.String()
}
unknown, ok := info.(*xt.Unknown)
if !ok {
return fmt.Sprintf("(%T)%+v", info, info)
}
data := []byte(*unknown)
// Things where upstream has no type
handled = true
switch name {
case "pkttype":
if len(data) != 8 {
handled = false
break
}
pkttype := int(native.Endian.Uint32(data[0:4]))
invert := int(native.Endian.Uint32(data[4:8]))
var invertPrefix string
if invert != 0 {
invertPrefix = "!"
}
pkttypeName := packetTypeNames[pkttype]
if pkttypeName != "" {
fmt.Fprintf(&sb, "PktType:%s%s", invertPrefix, pkttypeName)
} else {
fmt.Fprintf(&sb, "PktType:%s%d", invertPrefix, pkttype)
}
default:
handled = true
}
if !handled {
return fmt.Sprintf("(%T)%+v", info, info)
}
sb.WriteString(`}`)
return sb.String()
}
func printTargetInfo(name string, info xt.InfoAny) string {
var sb strings.Builder
sb.WriteString(`{`)
unknown, ok := info.(*xt.Unknown)
if !ok {
return fmt.Sprintf("(%T)%+v", info, info)
}
data := []byte(*unknown)
// Things where upstream has no type
switch name {
case "LOG":
if len(data) != 32 {
fmt.Fprintf(&sb, `Error:"bad size; want 32, got %d"`, len(data))
break
}
level := data[0]
logflags := data[1]
prefix := unix.ByteSliceToString(data[2:])
fmt.Fprintf(&sb, "Level:%d LogFlags:%d Prefix:%q", level, logflags, prefix)
default:
return fmt.Sprintf("(%T)%+v", info, info)
}
sb.WriteString(`}`)
return sb.String()
}

@ -0,0 +1,94 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux && !(386 || loong64)
package linuxfw
import (
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
)
var metaKeyNames = map[expr.MetaKey]string{
expr.MetaKeyLEN: "LEN",
expr.MetaKeyPROTOCOL: "PROTOCOL",
expr.MetaKeyPRIORITY: "PRIORITY",
expr.MetaKeyMARK: "MARK",
expr.MetaKeyIIF: "IIF",
expr.MetaKeyOIF: "OIF",
expr.MetaKeyIIFNAME: "IIFNAME",
expr.MetaKeyOIFNAME: "OIFNAME",
expr.MetaKeyIIFTYPE: "IIFTYPE",
expr.MetaKeyOIFTYPE: "OIFTYPE",
expr.MetaKeySKUID: "SKUID",
expr.MetaKeySKGID: "SKGID",
expr.MetaKeyNFTRACE: "NFTRACE",
expr.MetaKeyRTCLASSID: "RTCLASSID",
expr.MetaKeySECMARK: "SECMARK",
expr.MetaKeyNFPROTO: "NFPROTO",
expr.MetaKeyL4PROTO: "L4PROTO",
expr.MetaKeyBRIIIFNAME: "BRIIIFNAME",
expr.MetaKeyBRIOIFNAME: "BRIOIFNAME",
expr.MetaKeyPKTTYPE: "PKTTYPE",
expr.MetaKeyCPU: "CPU",
expr.MetaKeyIIFGROUP: "IIFGROUP",
expr.MetaKeyOIFGROUP: "OIFGROUP",
expr.MetaKeyCGROUP: "CGROUP",
expr.MetaKeyPRANDOM: "PRANDOM",
}
var cmpOpNames = map[expr.CmpOp]string{
expr.CmpOpEq: "EQ",
expr.CmpOpNeq: "NEQ",
expr.CmpOpLt: "LT",
expr.CmpOpLte: "LTE",
expr.CmpOpGt: "GT",
expr.CmpOpGte: "GTE",
}
var verdictNames = map[expr.VerdictKind]string{
expr.VerdictReturn: "RETURN",
expr.VerdictGoto: "GOTO",
expr.VerdictJump: "JUMP",
expr.VerdictBreak: "BREAK",
expr.VerdictContinue: "CONTINUE",
expr.VerdictDrop: "DROP",
expr.VerdictAccept: "ACCEPT",
expr.VerdictStolen: "STOLEN",
expr.VerdictQueue: "QUEUE",
expr.VerdictRepeat: "REPEAT",
expr.VerdictStop: "STOP",
}
var payloadOperationTypeNames = map[expr.PayloadOperationType]string{
expr.PayloadLoad: "LOAD",
expr.PayloadWrite: "WRITE",
}
var payloadBaseNames = map[expr.PayloadBase]string{
expr.PayloadBaseLLHeader: "ll-header",
expr.PayloadBaseNetworkHeader: "network-header",
expr.PayloadBaseTransportHeader: "transport-header",
}
var packetTypeNames = map[int]string{
0 /* PACKET_HOST */ : "unicast",
1 /* PACKET_BROADCAST */ : "broadcast",
2 /* PACKET_MULTICAST */ : "multicast",
}
var addrTypeFlagNames = map[xt.AddrTypeFlags]string{
xt.AddrTypeUnspec: "unspec",
xt.AddrTypeUnicast: "unicast",
xt.AddrTypeLocal: "local",
xt.AddrTypeBroadcast: "broadcast",
xt.AddrTypeAnycast: "anycast",
xt.AddrTypeMulticast: "multicast",
xt.AddrTypeBlackhole: "blackhole",
xt.AddrTypeUnreachable: "unreachable",
xt.AddrTypeProhibit: "prohibit",
xt.AddrTypeThrow: "throw",
xt.AddrTypeNat: "nat",
xt.AddrTypeXresolve: "xresolve",
}
Loading…
Cancel
Save