From 243ce6ccc1a3e032e3d4a014e7dd09a9379c3c42 Mon Sep 17 00:00:00 2001 From: KevinLiang10 Date: Fri, 16 Jun 2023 18:54:58 +0000 Subject: [PATCH] util/linuxfw: decoupling IPTables logic from linux router This change is introducing new netfilterRunner interface and moving iptables manipulation to a lower leveled iptables runner. For #391 Signed-off-by: KevinLiang10 --- cmd/derper/depaware.txt | 28 +- cmd/tailscale/depaware.txt | 26 ++ cmd/tailscaled/depaware.txt | 19 +- net/netns/netns_linux.go | 11 +- net/netns/netns_linux_test.go | 42 -- util/linuxfw/iptables_runner.go | 475 +++++++++++++++++++ util/linuxfw/iptables_runner_test.go | 420 +++++++++++++++++ util/linuxfw/linuxfw.go | 177 +++++++- util/linuxfw/linuxfw_unsupported.go | 6 + wgengine/router/router_linux.go | 657 +++++---------------------- wgengine/router/router_linux_test.go | 291 ++++++++---- 11 files changed, 1454 insertions(+), 698 deletions(-) create mode 100644 util/linuxfw/iptables_runner.go create mode 100644 util/linuxfw/iptables_runner_test.go diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 78d7b1b30..6b1aba2c7 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -12,9 +12,16 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus 💣 github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus + L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/golang/protobuf/proto from github.com/matttproud/golang_protobuf_extensions/pbutil+ + L github.com/google/nftables from tailscale.com/util/linuxfw + L 💣 github.com/google/nftables/alignedbuff from github.com/google/nftables/xt + L 💣 github.com/google/nftables/binaryutil from github.com/google/nftables+ + L github.com/google/nftables/expr from github.com/google/nftables+ + L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ + L github.com/google/nftables/xt from github.com/google/nftables/expr+ github.com/hdevalence/ed25519consensus from tailscale.com/tka L github.com/josharian/native from github.com/mdlayher/netlink+ L 💣 github.com/jsimonetti/rtnetlink from tailscale.com/net/interfaces+ @@ -23,6 +30,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa github.com/matttproud/golang_protobuf_extensions/pbutil from github.com/prometheus/common/expfmt L 💣 github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + L github.com/mdlayher/netlink/nltest from github.com/google/nftables L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket 💣 github.com/prometheus/client_golang/prometheus from tailscale.com/tsweb/promvarz @@ -34,6 +42,9 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs LD github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs + L 💣 github.com/tailscale/netlink from tailscale.com/util/linuxfw + L 💣 github.com/vishvananda/netlink/nl from github.com/tailscale/netlink + L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 💣 go4.org/mem from tailscale.com/client/tailscale+ go4.org/netipx from tailscale.com/wgengine/filter @@ -66,6 +77,20 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa google.golang.org/protobuf/runtime/protoimpl from github.com/golang/protobuf/proto+ google.golang.org/protobuf/types/descriptorpb from google.golang.org/protobuf/reflect/protodesc google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ + L gvisor.dev/gvisor/pkg/abi from gvisor.dev/gvisor/pkg/abi/linux + L 💣 gvisor.dev/gvisor/pkg/abi/linux from tailscale.com/util/linuxfw + L gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/abi/linux + L gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/abi/linux + L 💣 gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/abi/linux+ + L 💣 gvisor.dev/gvisor/pkg/hostarch from gvisor.dev/gvisor/pkg/abi/linux+ + L gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log + L gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context + L gvisor.dev/gvisor/pkg/marshal from gvisor.dev/gvisor/pkg/abi/linux+ + L 💣 gvisor.dev/gvisor/pkg/marshal/primitive from gvisor.dev/gvisor/pkg/abi/linux + L 💣 gvisor.dev/gvisor/pkg/state from gvisor.dev/gvisor/pkg/abi/linux+ + L gvisor.dev/gvisor/pkg/state/wire from gvisor.dev/gvisor/pkg/state + L 💣 gvisor.dev/gvisor/pkg/sync from gvisor.dev/gvisor/pkg/linewriter+ + L gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context nhooyr.io/websocket from tailscale.com/cmd/derper+ nhooyr.io/websocket/internal/errd from nhooyr.io/websocket nhooyr.io/websocket/internal/xsync from nhooyr.io/websocket @@ -130,8 +155,9 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/util/dnsname from tailscale.com/hostinfo+ tailscale.com/util/httpm from tailscale.com/client/tailscale tailscale.com/util/lineread from tailscale.com/hostinfo+ + L 💣 tailscale.com/util/linuxfw from tailscale.com/net/netns tailscale.com/util/mak from tailscale.com/syncs+ - tailscale.com/util/multierr from tailscale.com/health + tailscale.com/util/multierr from tailscale.com/health+ tailscale.com/util/set from tailscale.com/health+ tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/slicesx from tailscale.com/cmd/derper+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index bac8b185c..720e129a6 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -10,8 +10,15 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep W 💣 github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/negotiate+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy + L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/golang/groupcache/lru from tailscale.com/net/dnscache + L github.com/google/nftables from tailscale.com/util/linuxfw + L 💣 github.com/google/nftables/alignedbuff from github.com/google/nftables/xt + L 💣 github.com/google/nftables/binaryutil from github.com/google/nftables+ + L github.com/google/nftables/expr from github.com/google/nftables+ + L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ + L github.com/google/nftables/xt from github.com/google/nftables/expr+ github.com/google/uuid from tailscale.com/util/quarantine+ github.com/hdevalence/ed25519consensus from tailscale.com/tka L github.com/josharian/native from github.com/mdlayher/netlink+ @@ -23,6 +30,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep 💣 github.com/mattn/go-isatty from github.com/mattn/go-colorable+ L 💣 github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + L github.com/mdlayher/netlink/nltest from github.com/google/nftables L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink 💣 github.com/mitchellh/go-ps from tailscale.com/cmd/tailscale/cli+ github.com/peterbourgon/ff/v3 from github.com/peterbourgon/ff/v3/ffcli @@ -36,13 +44,30 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep github.com/tailscale/goupnp/scpd from github.com/tailscale/goupnp github.com/tailscale/goupnp/soap from github.com/tailscale/goupnp+ github.com/tailscale/goupnp/ssdp from github.com/tailscale/goupnp + L 💣 github.com/tailscale/netlink from tailscale.com/util/linuxfw github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck github.com/toqueteos/webbrowser from tailscale.com/cmd/tailscale/cli + L 💣 github.com/vishvananda/netlink/nl from github.com/tailscale/netlink + L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 💣 go4.org/mem from tailscale.com/derp+ go4.org/netipx from tailscale.com/wgengine/filter W 💣 golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/interfaces+ gopkg.in/yaml.v2 from sigs.k8s.io/yaml + L gvisor.dev/gvisor/pkg/abi from gvisor.dev/gvisor/pkg/abi/linux + L 💣 gvisor.dev/gvisor/pkg/abi/linux from tailscale.com/util/linuxfw + L gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/abi/linux + L gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/abi/linux + L 💣 gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/abi/linux+ + L 💣 gvisor.dev/gvisor/pkg/hostarch from gvisor.dev/gvisor/pkg/abi/linux+ + L gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log + L gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context + L gvisor.dev/gvisor/pkg/marshal from gvisor.dev/gvisor/pkg/abi/linux+ + L 💣 gvisor.dev/gvisor/pkg/marshal/primitive from gvisor.dev/gvisor/pkg/abi/linux + L 💣 gvisor.dev/gvisor/pkg/state from gvisor.dev/gvisor/pkg/abi/linux+ + L gvisor.dev/gvisor/pkg/state/wire from gvisor.dev/gvisor/pkg/state + L 💣 gvisor.dev/gvisor/pkg/sync from gvisor.dev/gvisor/pkg/linewriter+ + L gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context k8s.io/client-go/util/homedir from tailscale.com/cmd/tailscale/cli nhooyr.io/websocket from tailscale.com/derp/derphttp+ nhooyr.io/websocket/internal/errd from nhooyr.io/websocket @@ -120,6 +145,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/groupmember from tailscale.com/cmd/tailscale/cli tailscale.com/util/httpm from tailscale.com/client/tailscale tailscale.com/util/lineread from tailscale.com/net/interfaces+ + L 💣 tailscale.com/util/linuxfw from tailscale.com/net/netns tailscale.com/util/mak from tailscale.com/net/netcheck+ tailscale.com/util/multierr from tailscale.com/control/controlhttp+ tailscale.com/util/must from tailscale.com/cmd/tailscale/cli diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 4355a0297..8219faa21 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -75,7 +75,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm - L github.com/coreos/go-iptables/iptables from tailscale.com/wgengine/router + L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw LD 💣 github.com/creack/pty from tailscale.com/ssh/tailssh W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com W 💣 github.com/dblohm7/wingoes/com from tailscale.com/cmd/tailscaled @@ -86,6 +86,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L 💣 github.com/godbus/dbus/v5 from tailscale.com/net/dns+ github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ + L github.com/google/nftables from tailscale.com/util/linuxfw + L 💣 github.com/google/nftables/alignedbuff from github.com/google/nftables/xt + L 💣 github.com/google/nftables/binaryutil from github.com/google/nftables+ + L github.com/google/nftables/expr from github.com/google/nftables+ + L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ + L github.com/google/nftables/xt from github.com/google/nftables/expr+ github.com/hdevalence/ed25519consensus from tailscale.com/tka L 💣 github.com/illarion/gonotify from tailscale.com/net/dns L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/net/tstun @@ -109,6 +115,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/mdlayher/genetlink from tailscale.com/net/tstun L 💣 github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ L 💣 github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + L github.com/mdlayher/netlink/nltest from github.com/google/nftables L github.com/mdlayher/sdnotify from tailscale.com/util/systemd L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket @@ -153,13 +160,18 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de go4.org/netipx from tailscale.com/ipn/ipnlocal+ W 💣 golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun+ W 💣 golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/dns+ + L gvisor.dev/gvisor/pkg/abi from gvisor.dev/gvisor/pkg/abi/linux + L 💣 gvisor.dev/gvisor/pkg/abi/linux from tailscale.com/util/linuxfw gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/tcpip+ - gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/bufferv2 + gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/bufferv2+ 💣 gvisor.dev/gvisor/pkg/bufferv2 from gvisor.dev/gvisor/pkg/tcpip+ - gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/refs + gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/refs+ 💣 gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/state/wire+ + L 💣 gvisor.dev/gvisor/pkg/hostarch from gvisor.dev/gvisor/pkg/abi/linux+ gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context+ + L gvisor.dev/gvisor/pkg/marshal from gvisor.dev/gvisor/pkg/abi/linux+ + L 💣 gvisor.dev/gvisor/pkg/marshal/primitive from gvisor.dev/gvisor/pkg/abi/linux gvisor.dev/gvisor/pkg/rand from gvisor.dev/gvisor/pkg/tcpip/network/hash+ gvisor.dev/gvisor/pkg/refs from gvisor.dev/gvisor/pkg/bufferv2+ 💣 gvisor.dev/gvisor/pkg/sleep from gvisor.dev/gvisor/pkg/tcpip/transport/tcp @@ -317,6 +329,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale+ tailscale.com/util/lineread from tailscale.com/hostinfo+ + L 💣 tailscale.com/util/linuxfw from tailscale.com/net/netns+ tailscale.com/util/mak from tailscale.com/control/controlclient+ tailscale.com/util/multierr from tailscale.com/control/controlclient+ tailscale.com/util/must from tailscale.com/logpolicy diff --git a/net/netns/netns_linux.go b/net/netns/netns_linux.go index 5d09d7d19..bac14e9d7 100644 --- a/net/netns/netns_linux.go +++ b/net/netns/netns_linux.go @@ -17,16 +17,9 @@ import ( "tailscale.com/net/interfaces" "tailscale.com/net/netmon" "tailscale.com/types/logger" + "tailscale.com/util/linuxfw" ) -// tailscaleBypassMark is the mark indicating that packets originating -// from a socket should bypass Tailscale-managed routes during routing -// table lookups. -// -// Keep this in sync with tailscaleBypassMark in -// wgengine/router/router_linux.go. -const tailscaleBypassMark = 0x80000 - // socketMarkWorksOnce is the sync.Once & cached value for useSocketMark. var socketMarkWorksOnce struct { sync.Once @@ -119,7 +112,7 @@ func controlC(network, address string, c syscall.RawConn) error { } func setBypassMark(fd uintptr) error { - if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, tailscaleBypassMark); err != nil { + if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, linuxfw.TailscaleBypassMarkNum); err != nil { return fmt.Errorf("setting SO_MARK bypass: %w", err) } return nil diff --git a/net/netns/netns_linux_test.go b/net/netns/netns_linux_test.go index 5a6b1bbda..a5000f37f 100644 --- a/net/netns/netns_linux_test.go +++ b/net/netns/netns_linux_test.go @@ -4,51 +4,9 @@ package netns import ( - "fmt" - "go/ast" - "go/parser" - "go/token" "testing" ) -// verifies tailscaleBypassMark is in sync with wgengine. -func TestBypassMarkInSync(t *testing.T) { - want := fmt.Sprintf("%q", fmt.Sprintf("0x%x", tailscaleBypassMark)) - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, "../../wgengine/router/router_linux.go", nil, 0) - if err != nil { - t.Fatal(err) - } - for _, decl := range f.Decls { - gd, ok := decl.(*ast.GenDecl) - if !ok || gd.Tok != token.CONST { - continue - } - for _, spec := range gd.Specs { - vs, ok := spec.(*ast.ValueSpec) - if !ok { - continue - } - for i, ident := range vs.Names { - if ident.Name != "tailscaleBypassMark" { - continue - } - valExpr := vs.Values[i] - lit, ok := valExpr.(*ast.BasicLit) - if !ok { - t.Errorf("tailscaleBypassMark = %T, expected *ast.BasicLit", valExpr) - } - if lit.Value == want { - // Pass. - return - } - t.Fatalf("router_linux.go's tailscaleBypassMark = %s; not in sync with netns's %s", lit.Value, want) - } - } - } - t.Errorf("tailscaleBypassMark not found in router_linux.go") -} - func TestSocketMarkWorks(t *testing.T) { _ = socketMarkWorks() // we cannot actually assert whether the test runner has SO_MARK available diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go new file mode 100644 index 000000000..754a22b22 --- /dev/null +++ b/util/linuxfw/iptables_runner.go @@ -0,0 +1,475 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "fmt" + "net/netip" + "strings" + + "github.com/coreos/go-iptables/iptables" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +type iptablesInterface interface { + // Adding this interface for testing purposes so we can mock out + // the iptables library, in reality this is a wrapper to *iptables.IPTables. + Insert(table, chain string, pos int, args ...string) error + Append(table, chain string, args ...string) error + Exists(table, chain string, args ...string) (bool, error) + Delete(table, chain string, args ...string) error + ClearChain(table, chain string) error + NewChain(table, chain string) error + DeleteChain(table, chain string) error +} + +type iptablesRunner struct { + ipt4 iptablesInterface + ipt6 iptablesInterface + + v6Available bool + v6NATAvailable bool +} + +// NewIPTablesRunner constructs a NetfilterRunner that programs iptables rules. +// If the underlying iptables library fails to initialize, that error is +// returned. The runner probes for IPv6 support once at initialization time and +// if not found, no IPv6 rules will be modified for the lifetime of the runner. +func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { + ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + return nil, err + } + + supportsV6, supportsV6NAT := false, false + v6err := checkIPv6(logf) + if v6err != nil { + logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) + } else { + supportsV6 = true + supportsV6NAT = supportsV6 && checkSupportsV6NAT() + logf("v6nat = %v", supportsV6NAT) + } + + var ipt6 *iptables.IPTables + if supportsV6 { + ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return nil, err + } + } + return &iptablesRunner{ipt4, ipt6, supportsV6, supportsV6NAT}, nil +} + +// HasIPV6 returns true if the system supports IPv6. +func (i *iptablesRunner) HasIPV6() bool { + return i.v6Available +} + +// HasIPV6NAT returns true if the system supports IPv6 NAT. +func (i *iptablesRunner) HasIPV6NAT() bool { + return i.v6NATAvailable +} + +func isErrChainNotExist(err error) bool { + return errCode(err) == 1 +} + +// getIPTByAddr returns the iptablesInterface with correct IP family +// that we will be using for the given address. +func (i *iptablesRunner) getIPTByAddr(addr netip.Addr) iptablesInterface { + nf := i.ipt4 + if addr.Is6() { + nf = i.ipt6 + } + return nf +} + +// AddLoopbackRule adds an iptables rule to permit loopback traffic to +// a local Tailscale IP. +func (i *iptablesRunner) AddLoopbackRule(addr netip.Addr) error { + if err := i.getIPTByAddr(addr).Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { + return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err) + } + + return nil +} + +// tsChain returns the name of the tailscale sub-chain corresponding +// to the given "parent" chain (e.g. INPUT, FORWARD, ...). +func tsChain(chain string) string { + return "ts-" + strings.ToLower(chain) +} + +// DelLoopbackRule removes the iptables rule permitting loopback +// traffic to a Tailscale IP. +func (i *iptablesRunner) DelLoopbackRule(addr netip.Addr) error { + if err := i.getIPTByAddr(addr).Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { + return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err) + } + + return nil +} + +// getTables gets the available iptablesInterface in iptables runner. +func (i *iptablesRunner) getTables() []iptablesInterface { + if i.HasIPV6() { + return []iptablesInterface{i.ipt4, i.ipt6} + } + return []iptablesInterface{i.ipt4} +} + +// getNATTables gets the available iptablesInterface in iptables runner. +// If the system does not support IPv6 NAT, only the IPv4 iptablesInterface +// is returned. +func (i *iptablesRunner) getNATTables() []iptablesInterface { + if i.HasIPV6NAT() { + return i.getTables() + } + return []iptablesInterface{i.ipt4} +} + +// AddHooks inserts calls to tailscale's netfilter chains in +// the relevant main netfilter chains. The tailscale chains must +// already exist. If they do not, an error is returned. +func (i *iptablesRunner) AddHooks() error { + // divert inserts a jump to the tailscale chain in the given table/chain. + // If the jump already exists, it is a no-op. + divert := func(ipt iptablesInterface, table, chain string) error { + tsChain := tsChain(chain) + + args := []string{"-j", tsChain} + exists, err := ipt.Exists(table, chain, args...) + if err != nil { + return fmt.Errorf("checking for %v in %s/%s: %w", args, table, chain, err) + } + if exists { + return nil + } + if err := ipt.Insert(table, chain, 1, args...); err != nil { + return fmt.Errorf("adding %v in %s/%s: %w", args, table, chain, err) + } + return nil + } + + for _, ipt := range i.getTables() { + if err := divert(ipt, "filter", "INPUT"); err != nil { + return err + } + if err := divert(ipt, "filter", "FORWARD"); err != nil { + return err + } + } + + for _, ipt := range i.getNATTables() { + if err := divert(ipt, "nat", "POSTROUTING"); err != nil { + return err + } + } + return nil +} + +// AddChains creates custom Tailscale chains in netfilter via iptables +// if the ts-chain doesn't already exist. +func (i *iptablesRunner) AddChains() error { + // create creates a chain in the given table if it doesn't already exist. + // If the chain already exists, it is a no-op. + create := func(ipt iptablesInterface, table, chain string) error { + err := ipt.ClearChain(table, chain) + if isErrChainNotExist(err) { + // nonexistent chain. let's create it! + return ipt.NewChain(table, chain) + } + if err != nil { + return fmt.Errorf("setting up %s/%s: %w", table, chain, err) + } + return nil + } + + for _, ipt := range i.getTables() { + if err := create(ipt, "filter", "ts-input"); err != nil { + return err + } + if err := create(ipt, "filter", "ts-forward"); err != nil { + return err + } + } + + for _, ipt := range i.getNATTables() { + if err := create(ipt, "nat", "ts-postrouting"); err != nil { + return err + } + } + + return nil +} + +// AddBase adds some basic processing rules to be supplemented by +// later calls to other helpers. +func (i *iptablesRunner) AddBase(tunname string) error { + if err := i.addBase4(tunname); err != nil { + return err + } + if i.HasIPV6() { + if err := i.addBase6(tunname); err != nil { + return err + } + } + return nil +} + +// addBase4 adds some basic IPv6 processing rules to be +// supplemented by later calls to other helpers. +func (i *iptablesRunner) addBase4(tunname string) error { + // Only allow CGNAT range traffic to come from tailscale0. There + // is an exception carved out for ranges used by ChromeOS, for + // which we fall out of the Tailscale chain. + // + // Note, this will definitely break nodes that end up using the + // CGNAT range for other purposes :(. + args := []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"} + if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) + } + args = []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} + if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) + } + + // Forward all traffic from the Tailscale interface, and drop + // traffic to the tailscale interface by default. We use packet + // marks here so both filter/FORWARD and nat/POSTROUTING can match + // on these packets of interest. + // + // In particular, we only want to apply SNAT rules in + // nat/POSTROUTING to packets that originated from the Tailscale + // interface, but we can't match on the inbound interface in + // POSTROUTING. So instead, we match on the inbound interface in + // filter/FORWARD, and set a packet mark that nat/POSTROUTING can + // use to effectively run that same test again. + args = []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + args = []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + args = []string{"-o", tunname, "-j", "ACCEPT"} + if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) + } + + return nil +} + +// addBase6 adds some basic IPv4 processing rules to be +// supplemented by later calls to other helpers. +func (i *iptablesRunner) addBase6(tunname string) error { + // TODO: only allow traffic from Tailscale's ULA range to come + // from tailscale0. + + args := []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask} + if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) + } + args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"} + if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) + } + // TODO: drop forwarded traffic to tailscale0 from tailscale's ULA + // (see corresponding IPv4 CGNAT rule). + args = []string{"-o", tunname, "-j", "ACCEPT"} + if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { + return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) + } + + return nil +} + +// DelChains removes the custom Tailscale chains from netfilter via iptables. +func (i *iptablesRunner) DelChains() error { + for _, ipt := range i.getTables() { + if err := delChain(ipt, "filter", "ts-input"); err != nil { + return err + } + if err := delChain(ipt, "filter", "ts-forward"); err != nil { + return err + } + } + + for _, ipt := range i.getNATTables() { + if err := delChain(ipt, "nat", "ts-postrouting"); err != nil { + return err + } + } + + return nil +} + +// DelBase empties but does not remove custom Tailscale chains from +// netfilter via iptables. +func (i *iptablesRunner) DelBase() error { + del := func(ipt iptablesInterface, table, chain string) error { + if err := ipt.ClearChain(table, chain); err != nil { + if isErrChainNotExist(err) { + // nonexistent chain. That's fine, since it's + // the desired state anyway. + return nil + } + return fmt.Errorf("flushing %s/%s: %w", table, chain, err) + } + return nil + } + + for _, ipt := range i.getTables() { + if err := del(ipt, "filter", "ts-input"); err != nil { + return err + } + if err := del(ipt, "filter", "ts-forward"); err != nil { + return err + } + } + for _, ipt := range i.getNATTables() { + if err := del(ipt, "nat", "ts-postrouting"); err != nil { + return err + } + } + + return nil +} + +// DelHooks deletes the calls to tailscale's netfilter chains +// in the relevant main netfilter chains. +func (i *iptablesRunner) DelHooks(logf logger.Logf) error { + for _, ipt := range i.getTables() { + if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil { + return err + } + if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil { + return err + } + } + for _, ipt := range i.getNATTables() { + if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil { + return err + } + } + + return nil +} + +// AddSNATRule adds a netfilter rule to SNAT traffic destined for +// local subnets. +func (i *iptablesRunner) AddSNATRule() error { + args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"} + for _, ipt := range i.getNATTables() { + if err := ipt.Append("nat", "ts-postrouting", args...); err != nil { + return fmt.Errorf("adding %v in nat/ts-postrouting: %w", args, err) + } + } + return nil +} + +// DelSNATRule removes the netfilter rule to SNAT traffic destined for +// local subnets. An error is returned if the rule does not exist. +func (i *iptablesRunner) DelSNATRule() error { + args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"} + for _, ipt := range i.getNATTables() { + if err := ipt.Delete("nat", "ts-postrouting", args...); err != nil { + return fmt.Errorf("deleting %v in nat/ts-postrouting: %w", args, err) + } + } + return nil +} + +// IPTablesCleanup removes all Tailscale added iptables rules. +// Any errors that occur are logged to the provided logf. +func IPTablesCleanup(logf logger.Logf) { + err := clearRules(iptables.ProtocolIPv4, logf) + if err != nil { + logf("linuxfw: clear iptables: %v", err) + } + + err = clearRules(iptables.ProtocolIPv6, logf) + if err != nil { + logf("linuxfw: clear ip6tables: %v", err) + } +} + +// delTSHook deletes hook in a chain that jumps to a ts-chain. If the hook does not +// exist, it's a no-op since the desired state is already achieved but we log the +// error because error code from the iptables module resists unwrapping. +func delTSHook(ipt iptablesInterface, table, chain string, logf logger.Logf) error { + tsChain := tsChain(chain) + args := []string{"-j", tsChain} + if err := ipt.Delete(table, chain, args...); err != nil { + // TODO(apenwarr): check for errCode(1) here. + // Unfortunately the error code from the iptables + // module resists unwrapping, unlike with other + // calls. So we have to assume if Delete fails, + // it's because there is no such rule. + logf("deleting %v in %s/%s: %v", args, table, chain, err) + return nil + } + return nil +} + +// delChain flushs and deletes a chain. If the chain does not exist, it's a no-op +// since the desired state is already achieved. otherwise, it returns an error. +func delChain(ipt iptablesInterface, table, chain string) error { + if err := ipt.ClearChain(table, chain); err != nil { + if isErrChainNotExist(err) { + // nonexistent chain. nothing to do. + return nil + } + return fmt.Errorf("flushing %s/%s: %w", table, chain, err) + } + if err := ipt.DeleteChain(table, chain); err != nil { + return fmt.Errorf("deleting %s/%s: %w", table, chain, err) + } + return nil +} + +// clearRules clears all the iptables rules created by Tailscale +// for the given protocol. If error occurs, it's logged but not returned. +func clearRules(proto iptables.Protocol, logf logger.Logf) error { + ipt, err := iptables.NewWithProtocol(proto) + if err != nil { + return err + } + + var errs []error + + if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil { + errs = append(errs, err) + } + if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil { + errs = append(errs, err) + } + if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil { + errs = append(errs, err) + } + + if err := delChain(ipt, "filter", "ts-input"); err != nil { + errs = append(errs, err) + } + if err := delChain(ipt, "filter", "ts-forward"); err != nil { + errs = append(errs, err) + } + + if err := delChain(ipt, "nat", "ts-postrouting"); err != nil { + errs = append(errs, err) + } + + return multierr.New(errs...) +} diff --git a/util/linuxfw/iptables_runner_test.go b/util/linuxfw/iptables_runner_test.go new file mode 100644 index 000000000..e294f064b --- /dev/null +++ b/util/linuxfw/iptables_runner_test.go @@ -0,0 +1,420 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "errors" + "net/netip" + "strings" + "testing" + + "tailscale.com/net/tsaddr" +) + +var errExec = errors.New("execution failed") + +type fakeIPTables struct { + t *testing.T + n map[string][]string +} + +type fakeRule struct { + table, chain string + args []string +} + +func newIPTables(t *testing.T) *fakeIPTables { + return &fakeIPTables{ + t: t, + n: map[string][]string{ + "filter/INPUT": nil, + "filter/OUTPUT": nil, + "filter/FORWARD": nil, + "nat/PREROUTING": nil, + "nat/OUTPUT": nil, + "nat/POSTROUTING": nil, + }, + } +} + +func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + if pos > len(rules)+1 { + n.t.Errorf("bad position %d in %s", pos, k) + return errExec + } + rules = append(rules, "") + copy(rules[pos:], rules[pos-1:]) + rules[pos-1] = strings.Join(args, " ") + n.n[k] = rules + } else { + n.t.Errorf("unknown table/chain %s", k) + return errExec + } + return nil +} + +func (n *fakeIPTables) Append(table, chain string, args ...string) error { + k := table + "/" + chain + return n.Insert(table, chain, len(n.n[k])+1, args...) +} + +func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + for _, rule := range rules { + if rule == strings.Join(args, " ") { + return true, nil + } + } + return false, nil + } else { + n.t.Logf("unknown table/chain %s", k) + return false, errExec + } +} + +func hasChain(n *fakeIPTables, table, chain string) bool { + k := table + "/" + chain + if _, ok := n.n[k]; ok { + return true + } else { + return false + } +} + +func (n *fakeIPTables) Delete(table, chain string, args ...string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + for i, rule := range rules { + if rule == strings.Join(args, " ") { + rules = append(rules[:i], rules[i+1:]...) + n.n[k] = rules + return nil + } + } + n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k) + return errExec + } else { + n.t.Errorf("unknown table/chain %s", k) + return errExec + } +} + +func (n *fakeIPTables) ClearChain(table, chain string) error { + k := table + "/" + chain + if _, ok := n.n[k]; ok { + n.n[k] = nil + return nil + } else { + n.t.Logf("note: ClearChain: unknown table/chain %s", k) + return errors.New("exitcode:1") + } +} + +func (n *fakeIPTables) NewChain(table, chain string) error { + k := table + "/" + chain + if _, ok := n.n[k]; ok { + n.t.Errorf("table/chain %s already exists", k) + return errExec + } + n.n[k] = nil + return nil +} + +func (n *fakeIPTables) DeleteChain(table, chain string) error { + k := table + "/" + chain + if rules, ok := n.n[k]; ok { + if len(rules) != 0 { + n.t.Errorf("%s is not empty", k) + return errExec + } + delete(n.n, k) + return nil + } else { + n.t.Errorf("%s does not exist", k) + return errExec + } +} + +func newFakeIPTablesRunner(t *testing.T) *iptablesRunner { + ipt4 := newIPTables(t) + ipt6 := newIPTables(t) + + iptr := &iptablesRunner{ipt4, ipt6, true, true} + return iptr +} + +func TestAddAndDeleteChains(t *testing.T) { + iptr := newFakeIPTablesRunner(t) + err := iptr.AddChains() + if err != nil { + t.Fatal(err) + } + + // Check that the chains were created. + tsChains := []struct{ table, chain string }{ // table/chain + {"filter", "ts-input"}, + {"filter", "ts-forward"}, + {"nat", "ts-postrouting"}, + } + + for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} { + for _, tc := range tsChains { + // Exists returns error if the chain doesn't exist. + if _, err := proto.Exists(tc.table, tc.chain); err != nil { + t.Errorf("chain %s/%s doesn't exist", tc.table, tc.chain) + } + } + } + + err = iptr.DelChains() + if err != nil { + t.Fatal(err) + } + + // Check that the chains were deleted. + for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} { + for _, tc := range tsChains { + if _, err = proto.Exists(tc.table, tc.chain); err == nil { + t.Errorf("chain %s/%s still exists", tc.table, tc.chain) + } + } + } + +} + +func TestAddAndDeleteHooks(t *testing.T) { + iptr := newFakeIPTablesRunner(t) + // don't need to test what happens if the chains don't exist, because + // this is handled by fake iptables, in realife iptables would return error. + if err := iptr.AddChains(); err != nil { + t.Fatal(err) + } + defer iptr.DelChains() + + if err := iptr.AddHooks(); err != nil { + t.Fatal(err) + } + + // Check that the rules were created. + tsRules := []fakeRule{ // table/chain/rule + {"filter", "INPUT", []string{"-j", "ts-input"}}, + {"filter", "FORWARD", []string{"-j", "ts-forward"}}, + {"nat", "POSTROUTING", []string{"-j", "ts-postrouting"}}, + } + + for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} { + for _, tr := range tsRules { + if exists, err := proto.Exists(tr.table, tr.chain, tr.args...); err != nil { + t.Fatal(err) + } else if !exists { + t.Errorf("rule %s/%s/%s doesn't exist", tr.table, tr.chain, strings.Join(tr.args, " ")) + } + // check if the rule is at front of the chain + if proto.(*fakeIPTables).n[tr.table+"/"+tr.chain][0] != strings.Join(tr.args, " ") { + t.Errorf("v4 rule %s/%s/%s is not at the top", tr.table, tr.chain, strings.Join(tr.args, " ")) + } + } + } + + if err := iptr.DelHooks(t.Logf); err != nil { + t.Fatal(err) + } + + // Check that the rules were deleted. + for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} { + for _, tr := range tsRules { + if exists, err := proto.Exists(tr.table, tr.chain, tr.args...); err != nil { + t.Fatal(err) + } else if exists { + t.Errorf("rule %s/%s/%s still exists", tr.table, tr.chain, strings.Join(tr.args, " ")) + } + } + } + + if err := iptr.AddHooks(); err != nil { + t.Fatal(err) + } +} + +func TestAddAndDeleteBase(t *testing.T) { + iptr := newFakeIPTablesRunner(t) + tunname := "tun0" + if err := iptr.AddChains(); err != nil { + t.Fatal(err) + } + + if err := iptr.AddBase(tunname); err != nil { + t.Fatal(err) + } + + // Check that the rules were created. + tsRulesV4 := []fakeRule{ // table/chain/rule + {"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}}, + {"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}}, + {"filter", "ts-forward", []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}}, + } + + tsRulesCommon := []fakeRule{ // table/chain/rule + {"filter", "ts-forward", []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask}}, + {"filter", "ts-forward", []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"}}, + {"filter", "ts-forward", []string{"-o", tunname, "-j", "ACCEPT"}}, + } + + // check that the rules were created for ipt4 + for _, tr := range append(tsRulesV4, tsRulesCommon...) { + if exists, err := iptr.ipt4.Exists(tr.table, tr.chain, tr.args...); err != nil { + t.Fatal(err) + } else if !exists { + t.Errorf("rule %s/%s/%s doesn't exist", tr.table, tr.chain, strings.Join(tr.args, " ")) + } + } + + // check that the rules were created for ipt6 + for _, tr := range tsRulesCommon { + if exists, err := iptr.ipt6.Exists(tr.table, tr.chain, tr.args...); err != nil { + t.Fatal(err) + } else if !exists { + t.Errorf("rule %s/%s/%s doesn't exist", tr.table, tr.chain, strings.Join(tr.args, " ")) + } + } + + if err := iptr.DelBase(); err != nil { + t.Fatal(err) + } + + // Check that the rules were deleted. + for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} { + for _, tr := range append(tsRulesV4, tsRulesCommon...) { + if exists, err := proto.Exists(tr.table, tr.chain, tr.args...); err != nil { + t.Fatal(err) + } else if exists { + t.Errorf("rule %s/%s/%s still exists", tr.table, tr.chain, strings.Join(tr.args, " ")) + } + } + } + + if err := iptr.DelChains(); err != nil { + t.Fatal(err) + } +} + +func TestAddAndDelLoopbackRule(t *testing.T) { + iptr := newFakeIPTablesRunner(t) + // We don't need to test for malformed addresses, AddLoopbackRule + // takes in a netip.Addr, which is already valid. + fakeAddrV4 := netip.MustParseAddr("192.168.0.2") + fakeAddrV6 := netip.MustParseAddr("2001:db8::2") + + if err := iptr.AddChains(); err != nil { + t.Fatal(err) + } + if err := iptr.AddLoopbackRule(fakeAddrV4); err != nil { + t.Fatal(err) + } + if err := iptr.AddLoopbackRule(fakeAddrV6); err != nil { + t.Fatal(err) + } + + // Check that the rules were created. + tsRulesV4 := fakeRule{ // table/chain/rule + "filter", "ts-input", []string{"-i", "lo", "-s", fakeAddrV4.String(), "-j", "ACCEPT"}} + + tsRulesV6 := fakeRule{ // table/chain/rule + "filter", "ts-input", []string{"-i", "lo", "-s", fakeAddrV6.String(), "-j", "ACCEPT"}} + + // check that the rules were created for ipt4 and ipt6 + if exist, err := iptr.ipt4.Exists(tsRulesV4.table, tsRulesV4.chain, tsRulesV4.args...); err != nil { + t.Fatal(err) + } else if !exist { + t.Errorf("rule %s/%s/%s doesn't exist", tsRulesV4.table, tsRulesV4.chain, strings.Join(tsRulesV4.args, " ")) + } + if exist, err := iptr.ipt6.Exists(tsRulesV6.table, tsRulesV6.chain, tsRulesV6.args...); err != nil { + t.Fatal(err) + } else if !exist { + t.Errorf("rule %s/%s/%s doesn't exist", tsRulesV6.table, tsRulesV6.chain, strings.Join(tsRulesV6.args, " ")) + } + + // check that the rule is at the top + chain := "filter/ts-input" + if iptr.ipt4.(*fakeIPTables).n[chain][0] != strings.Join(tsRulesV4.args, " ") { + t.Errorf("v4 rule %s/%s/%s is not at the top", tsRulesV4.table, tsRulesV4.chain, strings.Join(tsRulesV4.args, " ")) + } + if iptr.ipt6.(*fakeIPTables).n[chain][0] != strings.Join(tsRulesV6.args, " ") { + t.Errorf("v6 rule %s/%s/%s is not at the top", tsRulesV6.table, tsRulesV6.chain, strings.Join(tsRulesV6.args, " ")) + } + + // delete the rules + if err := iptr.DelLoopbackRule(fakeAddrV4); err != nil { + t.Fatal(err) + } + if err := iptr.DelLoopbackRule(fakeAddrV6); err != nil { + t.Fatal(err) + } + + // Check that the rules were deleted. + if exist, err := iptr.ipt4.Exists(tsRulesV4.table, tsRulesV4.chain, tsRulesV4.args...); err != nil { + t.Fatal(err) + } else if exist { + t.Errorf("rule %s/%s/%s still exists", tsRulesV4.table, tsRulesV4.chain, strings.Join(tsRulesV4.args, " ")) + } + + if exist, err := iptr.ipt6.Exists(tsRulesV6.table, tsRulesV6.chain, tsRulesV6.args...); err != nil { + t.Fatal(err) + } else if exist { + t.Errorf("rule %s/%s/%s still exists", tsRulesV6.table, tsRulesV6.chain, strings.Join(tsRulesV6.args, " ")) + } + + if err := iptr.DelChains(); err != nil { + t.Fatal(err) + } +} + +func TestAddAndDelSNATRule(t *testing.T) { + iptr := newFakeIPTablesRunner(t) + + if err := iptr.AddChains(); err != nil { + t.Fatal(err) + } + + rule := fakeRule{ // table/chain/rule + "nat", "ts-postrouting", []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"}, + } + + // Add SNAT rule + if err := iptr.AddSNATRule(); err != nil { + t.Fatal(err) + } + + // Check that the rule was created for ipt4 and ipt6 + for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} { + if exist, err := proto.Exists(rule.table, rule.chain, rule.args...); err != nil { + t.Fatal(err) + } else if !exist { + t.Errorf("rule %s/%s/%s doesn't exist", rule.table, rule.chain, strings.Join(rule.args, " ")) + } + } + + // Delete SNAT rule + if err := iptr.DelSNATRule(); err != nil { + t.Fatal(err) + } + + // Check that the rule was deleted for ipt4 and ipt6 + for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} { + if exist, err := proto.Exists(rule.table, rule.chain, rule.args...); err != nil { + t.Fatal(err) + } else if exist { + t.Errorf("rule %s/%s/%s still exists", rule.table, rule.chain, strings.Join(rule.args, " ")) + } + } + + if err := iptr.DelChains(); err != nil { + t.Fatal(err) + } +} diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index f3d7b0561..dc50aa6cc 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -2,10 +2,179 @@ // SPDX-License-Identifier: BSD-3-Clause // Package linuxfw returns the kind of firewall being used by the kernel. + +//go:build linux + package linuxfw -import "errors" +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + + "github.com/tailscale/netlink" + "tailscale.com/types/logger" +) + +// The following bits are added to packet marks for Tailscale use. +// +// We tried to pick bits sufficiently out of the way that it's +// unlikely to collide with existing uses. We have 4 bytes of mark +// bits to play with. We leave the lower byte alone on the assumption +// that sysadmins would use those. Kubernetes uses a few bits in the +// second byte, so we steer clear of that too. +// +// Empirically, most of the documentation on packet marks on the +// internet gives the impression that the marks are 16 bits +// wide. Based on this, we theorize that the upper two bytes are +// relatively unused in the wild, and so we consume bits 16:23 (the +// third byte). +// +// The constants are in the iptables/iproute2 string format for +// matching and setting the bits, so they can be directly embedded in +// commands. +const ( + // The mask for reading/writing the 'firewall mask' bits on a packet. + // See the comment on the const block on why we only use the third byte. + // + // We claim bits 16:23 entirely. For now we only use the lower four + // bits, leaving the higher 4 bits for future use. + TailscaleFwmarkMask = "0xff0000" + TailscaleFwmarkMaskNeg = "0xff00ffff" + TailscaleFwmarkMaskNum = 0xff0000 + + // Packet is from Tailscale and to a subnet route destination, so + // is allowed to be routed through this machine. + TailscaleSubnetRouteMark = "0x40000" + TailscaleSubnetRouteMarkNum = 0x40000 + // This one is same value but padded to even number of digit, so + // hex decoding can work correctly. + TailscaleSubnetRouteMarkHexStr = "0x040000" + + // Packet was originated by tailscaled itself, and must not be + // routed over the Tailscale network. + TailscaleBypassMark = "0x80000" + TailscaleBypassMarkNum = 0x80000 +) + +// errCode extracts and returns the process exit code from err, or +// zero if err is nil. +func errCode(err error) int { + if err == nil { + return 0 + } + var e *exec.ExitError + if ok := errors.As(err, &e); ok { + return e.ExitCode() + } + s := err.Error() + if strings.HasPrefix(s, "exitcode:") { + code, err := strconv.Atoi(s[9:]) + if err == nil { + return code + } + } + return -42 +} + +// checkIPv6 checks whether the system appears to have a working IPv6 +// network stack. It returns an error explaining what looks wrong or +// missing. It does not check that IPv6 is currently functional or +// that there's a global address, just that the system would support +// IPv6 if it were on an IPv6 network. +func checkIPv6(logf logger.Logf) error { + _, err := os.Stat("/proc/sys/net/ipv6") + if os.IsNotExist(err) { + return err + } + bs, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6") + if err != nil { + // Be conservative if we can't find the IPv6 configuration knob. + return err + } + disabled, err := strconv.ParseBool(strings.TrimSpace(string(bs))) + if err != nil { + return errors.New("disable_ipv6 has invalid bool") + } + if disabled { + return errors.New("disable_ipv6 is set") + } + + // Older kernels don't support IPv6 policy routing. Some kernels + // support policy routing but don't have this knob, so absence of + // the knob is not fatal. + bs, err = os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_policy") + if err == nil { + disabled, err = strconv.ParseBool(strings.TrimSpace(string(bs))) + if err != nil { + return errors.New("disable_policy has invalid bool") + } + if disabled { + return errors.New("disable_policy is set") + } + } + + if err := CheckIPRuleSupportsV6(logf); err != nil { + return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err) + } + + // Some distros ship ip6tables separately from iptables. + if _, err := exec.LookPath("ip6tables"); err != nil { + return err + } + + return nil +} + +// checkSupportsV6NAT returns whether the system has a "nat" table in the +// IPv6 netfilter stack. +// +// The nat table was added after the initial release of ipv6 +// netfilter, so some older distros ship a kernel that can't NAT IPv6 +// traffic. +func checkSupportsV6NAT() bool { + bs, err := os.ReadFile("/proc/net/ip6_tables_names") + if err != nil { + // Can't read the file. Assume SNAT works. + return true + } + if bytes.Contains(bs, []byte("nat\n")) { + return true + } + // In nftables mode, that proc file will be empty. Try another thing: + if exec.Command("modprobe", "ip6table_nat").Run() == nil { + return true + } + return false +} + +func CheckIPRuleSupportsV6(logf logger.Logf) error { + // First try just a read-only operation to ideally avoid + // having to modify any state. + if rules, err := netlink.RuleList(netlink.FAMILY_V6); err != nil { + return fmt.Errorf("querying IPv6 policy routing rules: %w", err) + } else { + if len(rules) > 0 { + logf("[v1] kernel supports IPv6 policy routing (found %d rules)", len(rules)) + return nil + } + } -// ErrUnsupported is the error returned from all functions on non-Linux -// platforms. -var ErrUnsupported = errors.New("unsupported") + // Try to actually create & delete one as a test. + rule := netlink.NewRule() + rule.Priority = 1234 + rule.Mark = TailscaleBypassMarkNum + rule.Table = 52 + rule.Family = netlink.FAMILY_V6 + // First delete the rule unconditionally, and don't check for + // errors. This is just cleaning up anything that might be already + // there. + netlink.RuleDel(rule) + // And clean up on exit. + defer netlink.RuleDel(rule) + return netlink.RuleAdd(rule) +} diff --git a/util/linuxfw/linuxfw_unsupported.go b/util/linuxfw/linuxfw_unsupported.go index 246b61147..84ba2ecbb 100644 --- a/util/linuxfw/linuxfw_unsupported.go +++ b/util/linuxfw/linuxfw_unsupported.go @@ -9,9 +9,15 @@ package linuxfw import ( + "errors" + "tailscale.com/types/logger" ) +// ErrUnsupported is the error returned from all functions on non-Linux +// platforms. +var ErrUnsupported = errors.New("linuxfw:unsupported") + // DebugNetfilter is not supported on non-Linux platforms. func DebugNetfilter(logf logger.Logf) error { return ErrUnsupported diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index 34f31c6e8..ee39849e6 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -4,7 +4,6 @@ package router import ( - "bytes" "errors" "fmt" "net" @@ -17,7 +16,6 @@ import ( "syscall" "time" - "github.com/coreos/go-iptables/iptables" "github.com/tailscale/netlink" "github.com/tailscale/wireguard-go/tun" "go4.org/netipx" @@ -25,9 +23,9 @@ import ( "golang.org/x/time/rate" "tailscale.com/envknob" "tailscale.com/net/netmon" - "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/types/preftype" + "tailscale.com/util/linuxfw" "tailscale.com/util/multierr" "tailscale.com/version/distro" ) @@ -38,56 +36,34 @@ const ( netfilterOn = preftype.NetfilterOn ) -// The following bits are added to packet marks for Tailscale use. -// -// We tried to pick bits sufficiently out of the way that it's -// unlikely to collide with existing uses. We have 4 bytes of mark -// bits to play with. We leave the lower byte alone on the assumption -// that sysadmins would use those. Kubernetes uses a few bits in the -// second byte, so we steer clear of that too. -// -// Empirically, most of the documentation on packet marks on the -// internet gives the impression that the marks are 16 bits -// wide. Based on this, we theorize that the upper two bytes are -// relatively unused in the wild, and so we consume bits 16:23 (the -// third byte). -// -// The constants are in the iptables/iproute2 string format for -// matching and setting the bits, so they can be directly embedded in -// commands. -const ( - // The mask for reading/writing the 'firewall mask' bits on a packet. - // See the comment on the const block on why we only use the third byte. - // - // We claim bits 16:23 entirely. For now we only use the lower four - // bits, leaving the higher 4 bits for future use. - tailscaleFwmarkMask = "0xff0000" - tailscaleFwmarkMaskNum = 0xff0000 - - // Packet is from Tailscale and to a subnet route destination, so - // is allowed to be routed through this machine. - tailscaleSubnetRouteMark = "0x40000" - - // Packet was originated by tailscaled itself, and must not be - // routed over the Tailscale network. - // - // Keep this in sync with tailscaleBypassMark in - // net/netns/netns_linux.go. - tailscaleBypassMark = "0x80000" - tailscaleBypassMarkNum = 0x80000 -) - // netfilterRunner abstracts helpers to run netfilter commands. It // exists purely to swap out go-iptables for a fake implementation in // tests. type netfilterRunner interface { - Insert(table, chain string, pos int, args ...string) error - Append(table, chain string, args ...string) error - Exists(table, chain string, args ...string) (bool, error) - Delete(table, chain string, args ...string) error - ClearChain(table, chain string) error - NewChain(table, chain string) error - DeleteChain(table, chain string) error + AddLoopbackRule(addr netip.Addr) error + DelLoopbackRule(addr netip.Addr) error + AddHooks() error + DelHooks(logf logger.Logf) error + AddChains() error + DelChains() error + AddBase(tunname string) error + DelBase() error + AddSNATRule() error + DelSNATRule() error + + HasIPV6() bool + HasIPV6NAT() bool +} + +func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { + var nfr netfilterRunner + var err error + nfr, err = linuxfw.NewIPTablesRunner(logf) + if err != nil { + return nil, err + } + + return nfr, nil } type linuxRouter struct { @@ -109,16 +85,13 @@ type linuxRouter struct { // Various feature checks for the network stack. ipRuleAvailable bool // whether kernel was built with IP_MULTIPLE_TABLES - v6Available bool - v6NATAvailable bool fwmaskWorks bool // whether we can use 'ip rule...fwmark /' // ipPolicyPrefBase is the base priority at which ip rules are installed. ipPolicyPrefBase int - ipt4 netfilterRunner - ipt6 netfilterRunner - cmd commandRunner + nfr netfilterRunner + cmd commandRunner } func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor) (Router, error) { @@ -127,51 +100,27 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni return nil, err } - ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + nfr, err := newNetfilterRunner(logf) if err != nil { return nil, err } - v6err := checkIPv6(logf) - if v6err != nil { - logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) - } - supportsV6 := v6err == nil - supportsV6NAT := supportsV6 && supportsV6NAT() - if supportsV6 { - logf("v6nat = %v", supportsV6NAT) - } - - var ipt6 netfilterRunner - if supportsV6 { - // The iptables package probes for `ip6tables` and errors out - // if unavailable. We want that to be a non-fatal error. - ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - return nil, err - } - } - cmd := osCommandRunner{ ambientCapNetAdmin: useAmbientCaps(), } - return newUserspaceRouterAdvanced(logf, tunname, netMon, ipt4, ipt6, cmd, supportsV6, supportsV6NAT) + return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd) } -func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, netfilter4, netfilter6 netfilterRunner, cmd commandRunner, supportsV6, supportsV6NAT bool) (Router, error) { +func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr netfilterRunner, cmd commandRunner) (Router, error) { r := &linuxRouter{ logf: logf, tunname: tunname, netfilterMode: netfilterOff, netMon: netMon, - v6Available: supportsV6, - v6NATAvailable: supportsV6NAT, - - ipt4: netfilter4, - ipt6: netfilter6, - cmd: cmd, + nfr: nfr, + cmd: cmd, ipRuleFixLimiter: rate.NewLimiter(rate.Every(5*time.Second), 10), ipPolicyPrefBase: 5200, @@ -484,23 +433,23 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { case netfilterOff: switch r.netfilterMode { case netfilterNoDivert: - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.delNetfilterChains(); err != nil { + if err := r.nfr.DelChains(); err != nil { r.logf("note: %v", err) // harmless, continue. // This can happen if someone left a ref to // this table somewhere else. } case netfilterOn: - if err := r.delNetfilterHooks(); err != nil { + if err := r.nfr.DelHooks(r.logf); err != nil { return err } - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.delNetfilterChains(); err != nil { + if err := r.nfr.DelChains(); err != nil { r.logf("note: %v", err) // harmless, continue. // This can happen if someone left a ref to @@ -512,15 +461,15 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { switch r.netfilterMode { case netfilterOff: reprocess = true - if err := r.addNetfilterChains(); err != nil { + if err := r.nfr.AddChains(); err != nil { return err } - if err := r.addNetfilterBase(); err != nil { + if err := r.nfr.AddBase(r.tunname); err != nil { return err } r.snatSubnetRoutes = false case netfilterOn: - if err := r.delNetfilterHooks(); err != nil { + if err := r.nfr.DelHooks(r.logf); err != nil { return err } } @@ -529,33 +478,33 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { // we can't add a "-j ts-forward" rule to FORWARD // while ts-forward contains an "-m mark" rule. But // we can add the row *before* populating ts-forward. - // So we have to delNetFilterBase, then add the hooks, - // then re-addNetFilterBase, just in case. + // So we have to delBase, then add the hooks, + // then re-addBase, just in case. switch r.netfilterMode { case netfilterOff: reprocess = true - if err := r.addNetfilterChains(); err != nil { + if err := r.nfr.AddChains(); err != nil { return err } - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.addNetfilterHooks(); err != nil { + if err := r.nfr.AddHooks(); err != nil { return err } - if err := r.addNetfilterBase(); err != nil { + if err := r.nfr.AddBase(r.tunname); err != nil { return err } r.snatSubnetRoutes = false case netfilterNoDivert: reprocess = true - if err := r.delNetfilterBase(); err != nil { + if err := r.nfr.DelBase(); err != nil { return err } - if err := r.addNetfilterHooks(); err != nil { + if err := r.nfr.AddHooks(); err != nil { return err } - if err := r.addNetfilterBase(); err != nil { + if err := r.nfr.AddBase(r.tunname); err != nil { return err } r.snatSubnetRoutes = false @@ -579,11 +528,19 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { return nil } +func (r *linuxRouter) getV6Available() bool { + return r.nfr.HasIPV6() +} + +func (r *linuxRouter) getV6NATAvailable() bool { + return r.nfr.HasIPV6NAT() +} + // addAddress adds an IP/mask to the tunnel interface. Fails if the // address is already assigned to the interface, or if the addition // fails. func (r *linuxRouter) addAddress(addr netip.Prefix) error { - if !r.v6Available && addr.Addr().Is6() { + if !r.getV6Available() && addr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -609,7 +566,7 @@ func (r *linuxRouter) addAddress(addr netip.Prefix) error { // the address is not assigned to the interface, or if the removal // fails. func (r *linuxRouter) delAddress(addr netip.Prefix) error { - if !r.v6Available && addr.Addr().Is6() { + if !r.getV6Available() && addr.Addr().Is6() { return nil } if err := r.delLoopbackRule(addr.Addr()); err != nil { @@ -638,17 +595,8 @@ func (r *linuxRouter) addLoopbackRule(addr netip.Addr) error { return nil } - nf := r.ipt4 - if addr.Is6() { - if !r.v6Available { - // IPv6 not available, ignore. - return nil - } - nf = r.ipt6 - } - - if err := nf.Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { - return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err) + if err := r.nfr.AddLoopbackRule(addr); err != nil { + return err } return nil } @@ -660,17 +608,8 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error { return nil } - nf := r.ipt4 - if addr.Is6() { - if !r.v6Available { - // IPv6 not available, ignore. - return nil - } - nf = r.ipt6 - } - - if err := nf.Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { - return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err) + if err := r.nfr.DelLoopbackRule(addr); err != nil { + return err } return nil } @@ -679,7 +618,7 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error { // interface. Fails if the route already exists, or if adding the // route fails. func (r *linuxRouter) addRoute(cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -704,7 +643,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error { if !r.ipRuleAvailable { return nil } - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -712,7 +651,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error { } err := netlink.RouteReplace(&netlink.Route{ Dst: netipx.PrefixIPNet(cidr.Masked()), - Table: tailscaleRouteTable.num, + Table: tailscaleRouteTable.Num, Type: unix.RTN_THROW, }) if err != nil { @@ -722,7 +661,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error { } func (r *linuxRouter) addRouteDef(routeDef []string, cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } args := append([]string{"ip", "route", "add"}, routeDef...) @@ -756,7 +695,7 @@ var ( // interface. Fails if the route doesn't exist, or if removing the // route fails. func (r *linuxRouter) delRoute(cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -784,7 +723,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error { if !r.ipRuleAvailable { return nil } - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } if r.useIPCommand() { @@ -803,7 +742,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error { } func (r *linuxRouter) delRouteDef(routeDef []string, cidr netip.Prefix) error { - if !r.v6Available && cidr.Addr().Is6() { + if !r.getV6Available() && cidr.Addr().Is6() { return nil } args := append([]string{"ip", "route", "del"}, routeDef...) @@ -865,7 +804,7 @@ func (r *linuxRouter) linkIndex() (int, error) { // routeTable returns the route table to use. func (r *linuxRouter) routeTable() int { if r.ipRuleAvailable { - return tailscaleRouteTable.num + return tailscaleRouteTable.Num } return 0 } @@ -962,7 +901,7 @@ func (f addrFamily) netlinkInt() int { } func (r *linuxRouter) addrFamilies() []addrFamily { - if r.v6Available { + if r.getV6Available() { return []addrFamily{v4, v6} } return []addrFamily{v4} @@ -985,30 +924,34 @@ func (r *linuxRouter) addIPRules() error { return r.justAddIPRules() } -// routeTable is a Linux routing table: both its name and number. +// RouteTable is a Linux routing table: both its name and number. // See /etc/iproute2/rt_tables. -type routeTable struct { - name string - num int +type RouteTable struct { + Name string + Num int } -// ipCmdArg returns the string form of the table to pass to the "ip" command. -func (rt routeTable) ipCmdArg() string { - if rt.num >= 253 { - return rt.name +var routeTableByNumber = map[int]RouteTable{} + +// IpCmdArg returns the string form of the table to pass to the "ip" command. +func (rt RouteTable) ipCmdArg() string { + if rt.Num >= 253 { + return rt.Name } - return strconv.Itoa(rt.num) + return strconv.Itoa(rt.Num) } -var routeTableByNumber = map[int]routeTable{} - -func newRouteTable(name string, num int) routeTable { - rt := routeTable{name, num} +func newRouteTable(name string, num int) RouteTable { + rt := RouteTable{name, num} routeTableByNumber[num] = rt return rt } -func mustRouteTable(num int) routeTable { +// MustRouteTable returns the RouteTable with the given number key. +// It panics if the number is unknown because this result is a part +// of IP rule argument and we don't want to continue with an invalid +// argument with table no exist. +func mustRouteTable(num int) RouteTable { rt, ok := routeTableByNumber[num] if !ok { panic(fmt.Sprintf("unknown route table %v", num)) @@ -1059,22 +1002,22 @@ var ipRules = []netlink.Rule{ // main routing table. { Priority: 10, - Mark: tailscaleBypassMarkNum, - Table: mainRouteTable.num, + Mark: linuxfw.TailscaleBypassMarkNum, + Table: mainRouteTable.Num, }, // ...and then we try the 'default' table, for correctness, // even though it's been empty on every Linux system I've ever seen. { Priority: 30, - Mark: tailscaleBypassMarkNum, - Table: defaultRouteTable.num, + Mark: linuxfw.TailscaleBypassMarkNum, + Table: defaultRouteTable.Num, }, // If neither of those matched (no default route on this system?) // then packets from us should be aborted rather than falling through // to the tailscale routes, because that would create routing loops. { Priority: 50, - Mark: tailscaleBypassMarkNum, + Mark: linuxfw.TailscaleBypassMarkNum, Type: unix.RTN_UNREACHABLE, }, // If we get to this point, capture all packets and send them @@ -1084,7 +1027,7 @@ var ipRules = []netlink.Rule{ // beat non-VPN routes. { Priority: 70, - Table: tailscaleRouteTable.num, + Table: tailscaleRouteTable.Num, }, // If that didn't match, then non-fwmark packets fall through to the // usual rules (pref 32766 and 32767, ie. main and default). @@ -1105,7 +1048,7 @@ func (r *linuxRouter) justAddIPRules() error { // Note: r is a value type here; safe to mutate it. ru.Family = family.netlinkInt() if ru.Mark != 0 { - ru.Mask = tailscaleFwmarkMaskNum + ru.Mask = linuxfw.TailscaleFwmarkMaskNum } ru.Goto = -1 ru.SuppressIfgroup = -1 @@ -1138,7 +1081,7 @@ func (r *linuxRouter) addIPRulesWithIPCommand() error { } if rule.Mark != 0 { if r.fwmaskWorks { - args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, tailscaleFwmarkMask)) + args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, linuxfw.TailscaleFwmarkMask)) } else { args = append(args, "fwmark", fmt.Sprintf("0x%x", rule.Mark)) } @@ -1239,284 +1182,6 @@ func (r *linuxRouter) delIPRulesWithIPCommand() error { return rg.ErrAcc } -func (r *linuxRouter) netfilterFamilies() []netfilterRunner { - if r.v6Available { - return []netfilterRunner{r.ipt4, r.ipt6} - } - return []netfilterRunner{r.ipt4} -} - -// addNetfilterChains creates custom Tailscale chains in netfilter. -func (r *linuxRouter) addNetfilterChains() error { - create := func(ipt netfilterRunner, table, chain string) error { - err := ipt.ClearChain(table, chain) - if errCode(err) == 1 { - // nonexistent chain. let's create it! - return ipt.NewChain(table, chain) - } - if err != nil { - return fmt.Errorf("setting up %s/%s: %w", table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := create(ipt, "filter", "ts-input"); err != nil { - return err - } - if err := create(ipt, "filter", "ts-forward"); err != nil { - return err - } - } - if err := create(r.ipt4, "nat", "ts-postrouting"); err != nil { - return err - } - if r.v6NATAvailable { - if err := create(r.ipt6, "nat", "ts-postrouting"); err != nil { - return err - } - } - return nil -} - -// addNetfilterBase adds some basic processing rules to be -// supplemented by later calls to other helpers. -func (r *linuxRouter) addNetfilterBase() error { - if err := r.addNetfilterBase4(); err != nil { - return err - } - if r.v6Available { - if err := r.addNetfilterBase6(); err != nil { - return err - } - } - return nil -} - -// addNetfilterBase4 adds some basic IPv4 processing rules to be -// supplemented by later calls to other helpers. -func (r *linuxRouter) addNetfilterBase4() error { - // Only allow CGNAT range traffic to come from tailscale0. There - // is an exception carved out for ranges used by ChromeOS, for - // which we fall out of the Tailscale chain. - // - // Note, this will definitely break nodes that end up using the - // CGNAT range for other purposes :(. - args := []string{"!", "-i", r.tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"} - if err := r.ipt4.Append("filter", "ts-input", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) - } - args = []string{"!", "-i", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} - if err := r.ipt4.Append("filter", "ts-input", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) - } - - // Forward all traffic from the Tailscale interface, and drop - // traffic to the tailscale interface by default. We use packet - // marks here so both filter/FORWARD and nat/POSTROUTING can match - // on these packets of interest. - // - // In particular, we only want to apply SNAT rules in - // nat/POSTROUTING to packets that originated from the Tailscale - // interface, but we can't match on the inbound interface in - // POSTROUTING. So instead, we match on the inbound interface in - // filter/FORWARD, and set a packet mark that nat/POSTROUTING can - // use to effectively run that same test again. - args = []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - args = []string{"-o", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - args = []string{"-o", r.tunname, "-j", "ACCEPT"} - if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) - } - - return nil -} - -// addNetfilterBase4 adds some basic IPv6 processing rules to be -// supplemented by later calls to other helpers. -func (r *linuxRouter) addNetfilterBase6() error { - // TODO: only allow traffic from Tailscale's ULA range to come - // from tailscale0. - - args := []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask} - if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) - } - args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"} - if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) - } - // TODO: drop forwarded traffic to tailscale0 from tailscale's ULA - // (see corresponding IPv4 CGNAT rule). - args = []string{"-o", r.tunname, "-j", "ACCEPT"} - if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil { - return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) - } - - return nil -} - -// delNetfilterChains removes the custom Tailscale chains from netfilter. -func (r *linuxRouter) delNetfilterChains() error { - del := func(ipt netfilterRunner, table, chain string) error { - if err := ipt.ClearChain(table, chain); err != nil { - if errCode(err) == 1 { - // nonexistent chain. That's fine, since it's - // the desired state anyway. - return nil - } - return fmt.Errorf("flushing %s/%s: %w", table, chain, err) - } - if err := ipt.DeleteChain(table, chain); err != nil { - // this shouldn't fail, because if the chain didn't - // exist, we would have returned after ClearChain. - return fmt.Errorf("deleting %s/%s: %v", table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := del(ipt, "filter", "ts-input"); err != nil { - return err - } - if err := del(ipt, "filter", "ts-forward"); err != nil { - return err - } - } - if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil { - return err - } - if r.v6NATAvailable { - if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil { - return err - } - } - - return nil -} - -// delNetfilterBase empties but does not remove custom Tailscale chains from -// netfilter. -func (r *linuxRouter) delNetfilterBase() error { - del := func(ipt netfilterRunner, table, chain string) error { - if err := ipt.ClearChain(table, chain); err != nil { - if errCode(err) == 1 { - // nonexistent chain. That's fine, since it's - // the desired state anyway. - return nil - } - return fmt.Errorf("flushing %s/%s: %w", table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := del(ipt, "filter", "ts-input"); err != nil { - return err - } - if err := del(ipt, "filter", "ts-forward"); err != nil { - return err - } - } - if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil { - return err - } - if r.v6NATAvailable { - if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil { - return err - } - } - - return nil -} - -// addNetfilterHooks inserts calls to tailscale's netfilter chains in -// the relevant main netfilter chains. The tailscale chains must -// already exist. -func (r *linuxRouter) addNetfilterHooks() error { - divert := func(ipt netfilterRunner, table, chain string) error { - tsChain := tsChain(chain) - - args := []string{"-j", tsChain} - exists, err := ipt.Exists(table, chain, args...) - if err != nil { - return fmt.Errorf("checking for %v in %s/%s: %w", args, table, chain, err) - } - if exists { - return nil - } - if err := ipt.Insert(table, chain, 1, args...); err != nil { - return fmt.Errorf("adding %v in %s/%s: %w", args, table, chain, err) - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := divert(ipt, "filter", "INPUT"); err != nil { - return err - } - if err := divert(ipt, "filter", "FORWARD"); err != nil { - return err - } - } - if err := divert(r.ipt4, "nat", "POSTROUTING"); err != nil { - return err - } - if r.v6NATAvailable { - if err := divert(r.ipt6, "nat", "POSTROUTING"); err != nil { - return err - } - } - return nil -} - -// delNetfilterHooks deletes the calls to tailscale's netfilter chains -// in the relevant main netfilter chains. -func (r *linuxRouter) delNetfilterHooks() error { - del := func(ipt netfilterRunner, table, chain string) error { - tsChain := tsChain(chain) - args := []string{"-j", tsChain} - if err := ipt.Delete(table, chain, args...); err != nil { - // TODO(apenwarr): check for errCode(1) here. - // Unfortunately the error code from the iptables - // module resists unwrapping, unlike with other - // calls. So we have to assume if Delete fails, - // it's because there is no such rule. - r.logf("note: deleting %v in %s/%s: %w", args, table, chain, err) - return nil - } - return nil - } - - for _, ipt := range r.netfilterFamilies() { - if err := del(ipt, "filter", "INPUT"); err != nil { - return err - } - if err := del(ipt, "filter", "FORWARD"); err != nil { - return err - } - } - if err := del(r.ipt4, "nat", "POSTROUTING"); err != nil { - return err - } - if r.v6NATAvailable { - if err := del(r.ipt6, "nat", "POSTROUTING"); err != nil { - return err - } - } - return nil -} - // addSNATRule adds a netfilter rule to SNAT traffic destined for // local subnets. func (r *linuxRouter) addSNATRule() error { @@ -1524,14 +1189,8 @@ func (r *linuxRouter) addSNATRule() error { return nil } - args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"} - if err := r.ipt4.Append("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("adding %v in v4/nat/ts-postrouting: %w", args, err) - } - if r.v6NATAvailable { - if err := r.ipt6.Append("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("adding %v in v6/nat/ts-postrouting: %w", args, err) - } + if err := r.nfr.AddSNATRule(); err != nil { + return err } return nil } @@ -1543,14 +1202,8 @@ func (r *linuxRouter) delSNATRule() error { return nil } - args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"} - if err := r.ipt4.Delete("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("deleting %v in v4/nat/ts-postrouting: %w", args, err) - } - if r.v6NATAvailable { - if err := r.ipt6.Delete("nat", "ts-postrouting", args...); err != nil { - return fmt.Errorf("deleting %v in v6/nat/ts-postrouting: %w", args, err) - } + if err := r.nfr.DelSNATRule(); err != nil { + return err } return nil } @@ -1619,12 +1272,6 @@ func cidrDiff(kind string, old map[netip.Prefix]bool, new []netip.Prefix, add, d return ret, nil } -// tsChain returns the name of the tailscale sub-chain corresponding -// to the given "parent" chain (e.g. INPUT, FORWARD, ...). -func tsChain(chain string) string { - return "ts-" + strings.ToLower(chain) -} - // normalizeCIDR returns cidr as an ip/mask string, with the host bits // of the IP address zeroed out. func normalizeCIDR(cidr netip.Prefix) string { @@ -1632,105 +1279,9 @@ func normalizeCIDR(cidr netip.Prefix) string { } func cleanup(logf logger.Logf, interfaceName string) { - // TODO(dmytro): clean up iptables. -} - -// checkIPv6 checks whether the system appears to have a working IPv6 -// network stack. It returns an error explaining what looks wrong or -// missing. It does not check that IPv6 is currently functional or -// that there's a global address, just that the system would support -// IPv6 if it were on an IPv6 network. -func checkIPv6(logf logger.Logf) error { - _, err := os.Stat("/proc/sys/net/ipv6") - if os.IsNotExist(err) { - return err - } - bs, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6") - if err != nil { - // Be conservative if we can't find the IPv6 configuration knob. - return err + if interfaceName != "userspace-networking" { + linuxfw.IPTablesCleanup(logf) } - disabled, err := strconv.ParseBool(strings.TrimSpace(string(bs))) - if err != nil { - return errors.New("disable_ipv6 has invalid bool") - } - if disabled { - return errors.New("disable_ipv6 is set") - } - - // Older kernels don't support IPv6 policy routing. Some kernels - // support policy routing but don't have this knob, so absence of - // the knob is not fatal. - bs, err = os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_policy") - if err == nil { - disabled, err = strconv.ParseBool(strings.TrimSpace(string(bs))) - if err != nil { - return errors.New("disable_policy has invalid bool") - } - if disabled { - return errors.New("disable_policy is set") - } - } - - if err := checkIPRuleSupportsV6(logf); err != nil { - return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err) - } - - // Some distros ship ip6tables separately from iptables. - if _, err := exec.LookPath("ip6tables"); err != nil { - return err - } - - return nil -} - -// supportsV6NAT returns whether the system has a "nat" table in the -// IPv6 netfilter stack. -// -// The nat table was added after the initial release of ipv6 -// netfilter, so some older distros ship a kernel that can't NAT IPv6 -// traffic. -func supportsV6NAT() bool { - bs, err := os.ReadFile("/proc/net/ip6_tables_names") - if err != nil { - // Can't read the file. Assume SNAT works. - return true - } - if bytes.Contains(bs, []byte("nat\n")) { - return true - } - // In nftables mode, that proc file will be empty. Try another thing: - if exec.Command("modprobe", "ip6table_nat").Run() == nil { - return true - } - return false -} - -func checkIPRuleSupportsV6(logf logger.Logf) error { - // First try just a read-only operation to ideally avoid - // having to modify any state. - if rules, err := netlink.RuleList(netlink.FAMILY_V6); err != nil { - return fmt.Errorf("querying IPv6 policy routing rules: %w", err) - } else { - if len(rules) > 0 { - logf("[v1] kernel supports IPv6 policy routing (found %d rules)", len(rules)) - return nil - } - } - - // Try to actually create & delete one as a test. - rule := netlink.NewRule() - rule.Priority = 1234 - rule.Mark = tailscaleBypassMarkNum - rule.Table = tailscaleRouteTable.num - rule.Family = netlink.FAMILY_V6 - // First delete the rule unconditionally, and don't check for - // errors. This is just cleaning up anything that might be already - // there. - netlink.RuleDel(rule) - // And clean up on exit. - defer netlink.RuleDel(rule) - return netlink.RuleAdd(rule) } // Checks if the running openWRT system is using mwan3, based on the heuristic diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index acd2e6c10..d5b3219ec 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -22,8 +22,10 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/exp/slices" "tailscale.com/net/netmon" + "tailscale.com/net/tsaddr" "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/util/linuxfw" ) func TestRouterStates(t *testing.T) { @@ -328,7 +330,7 @@ ip route add throw 192.168.0.0/24 table 52` + basic, defer mon.Close() fake := NewFakeOS(t) - router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.netfilter4, fake.netfilter6, fake, true, true) + router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.nfr, fake) if err != nil { t.Fatalf("failed to create router: %v", err) } @@ -362,15 +364,17 @@ ip route add throw 192.168.0.0/24 table 52` + basic, } } -type fakeNetfilter struct { - t *testing.T - n map[string][]string +type fakeIPTablesRunner struct { + t *testing.T + ipt4 map[string][]string + ipt6 map[string][]string + //we always assume ipv6 and ipv6 nat are enabled when testing } -func newNetfilter(t *testing.T) *fakeNetfilter { - return &fakeNetfilter{ +func newIPTablesRunner(t *testing.T) netfilterRunner { + return &fakeIPTablesRunner{ t: t, - n: map[string][]string{ + ipt4: map[string][]string{ "filter/INPUT": nil, "filter/OUTPUT": nil, "filter/FORWARD": nil, @@ -378,118 +382,233 @@ func newNetfilter(t *testing.T) *fakeNetfilter { "nat/OUTPUT": nil, "nat/POSTROUTING": nil, }, + ipt6: map[string][]string{ + "filter/INPUT": nil, + "filter/OUTPUT": nil, + "filter/FORWARD": nil, + "nat/PREROUTING": nil, + "nat/OUTPUT": nil, + "nat/POSTROUTING": nil, + }, + } +} + +func insertRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error { + // Get current rules for filter/ts-input chain with according IP version + curTSInputRules, ok := curIPT[chain] + if !ok { + n.t.Fatalf("no %s chain exists", chain) + return fmt.Errorf("no %s chain exists", chain) + } + + // Add new rule to top of filter/ts-input + curTSInputRules = append(curTSInputRules, "") + copy(curTSInputRules[1:], curTSInputRules) + curTSInputRules[0] = newRule + curIPT[chain] = curTSInputRules + return nil +} + +func appendRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error { + // Get current rules for filter/ts-input chain with according IP version + curTSInputRules, ok := curIPT[chain] + if !ok { + n.t.Fatalf("no %s chain exists", chain) + return fmt.Errorf("no %s chain exists", chain) + } + + // Add new rule to end of filter/ts-input + curTSInputRules = append(curTSInputRules, newRule) + curIPT[chain] = curTSInputRules + return nil +} + +func deleteRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, delRule string) error { + // Get current rules for filter/ts-input chain with according IP version + curTSInputRules, ok := curIPT[chain] + if !ok { + n.t.Fatalf("no %s chain exists", chain) + return fmt.Errorf("no %s chain exists", chain) + } + + // Remove rule from filter/ts-input + for i, rule := range curTSInputRules { + if rule == delRule { + curTSInputRules = append(curTSInputRules[:i], curTSInputRules[i+1:]...) + break + } + } + curIPT[chain] = curTSInputRules + return nil +} + +func (n *fakeIPTablesRunner) AddLoopbackRule(addr netip.Addr) error { + curIPT := n.ipt4 + if addr.Is6() { + curIPT = n.ipt6 } + newRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String()) + + return insertRule(n, curIPT, "filter/ts-input", newRule) } -func (n *fakeNetfilter) Insert(table, chain string, pos int, args ...string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - if pos > len(rules)+1 { - n.t.Errorf("bad position %d in %s", pos, k) - return errExec +func (n *fakeIPTablesRunner) AddBase(tunname string) error { + if err := n.AddBase4(tunname); err != nil { + return err + } + if n.HasIPV6() { + if err := n.AddBase6(tunname); err != nil { + return err } - rules = append(rules, "") - copy(rules[pos:], rules[pos-1:]) - rules[pos-1] = strings.Join(args, " ") - n.n[k] = rules - } else { - n.t.Errorf("unknown table/chain %s", k) - return errExec } return nil } -func (n *fakeNetfilter) Append(table, chain string, args ...string) error { - k := table + "/" + chain - return n.Insert(table, chain, len(n.n[k])+1, args...) +func (n *fakeIPTablesRunner) AddBase4(tunname string) error { + curIPT := n.ipt4 + newRules := []struct{ chain, rule string }{ + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())}, + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, + {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-o %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, + {"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)}, + } + for _, rule := range newRules { + if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil { + return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err) + } + } + return nil } -func (n *fakeNetfilter) Exists(table, chain string, args ...string) (bool, error) { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - for _, rule := range rules { - if rule == strings.Join(args, " ") { - return true, nil +func (n *fakeIPTablesRunner) AddBase6(tunname string) error { + curIPT := n.ipt6 + newRules := []struct{ chain, rule string }{ + {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)}, + } + for _, rule := range newRules { + if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil { + return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err) + } + } + return nil +} + +func (n *fakeIPTablesRunner) DelLoopbackRule(addr netip.Addr) error { + curIPT := n.ipt4 + if addr.Is6() { + curIPT = n.ipt6 + } + + delRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String()) + + return deleteRule(n, curIPT, "filter/ts-input", delRule) +} + +func (n *fakeIPTablesRunner) AddHooks() error { + newRules := []struct{ chain, rule string }{ + {"filter/INPUT", "-j ts-input"}, + {"filter/FORWARD", "-j ts-forward"}, + {"nat/POSTROUTING", "-j ts-postrouting"}, + } + for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { + for _, r := range newRules { + if err := insertRule(n, ipt, r.chain, r.rule); err != nil { + return err } } - return false, nil - } else { - n.t.Errorf("unknown table/chain %s", k) - return false, errExec } + return nil } -func (n *fakeNetfilter) Delete(table, chain string, args ...string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - for i, rule := range rules { - if rule == strings.Join(args, " ") { - rules = append(rules[:i], rules[i+1:]...) - n.n[k] = rules - return nil +func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error { + delRules := []struct{ chain, rule string }{ + {"filter/INPUT", "-j ts-input"}, + {"filter/FORWARD", "-j ts-forward"}, + {"nat/POSTROUTING", "-j ts-postrouting"}, + } + for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { + for _, r := range delRules { + if err := deleteRule(n, ipt, r.chain, r.rule); err != nil { + return err } } - n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k) - return errExec - } else { - n.t.Errorf("unknown table/chain %s", k) - return errExec } + return nil } -func (n *fakeNetfilter) ClearChain(table, chain string) error { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - n.n[k] = nil - return nil - } else { - n.t.Logf("note: ClearChain: unknown table/chain %s", k) - return errors.New("exitcode:1") +func (n *fakeIPTablesRunner) AddChains() error { + for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { + for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} { + ipt[chain] = nil + } } + return nil } -func (n *fakeNetfilter) NewChain(table, chain string) error { - k := table + "/" + chain - if _, ok := n.n[k]; ok { - n.t.Errorf("table/chain %s already exists", k) - return errExec +func (n *fakeIPTablesRunner) DelChains() error { + for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { + for chain := range ipt { + if strings.HasPrefix(chain, "filter/ts-") || strings.HasPrefix(chain, "nat/ts-") { + delete(ipt, chain) + } + } } - n.n[k] = nil return nil } -func (n *fakeNetfilter) DeleteChain(table, chain string) error { - k := table + "/" + chain - if rules, ok := n.n[k]; ok { - if len(rules) != 0 { - n.t.Errorf("%s is not empty", k) - return errExec +func (n *fakeIPTablesRunner) DelBase() error { + for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { + for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} { + ipt[chain] = nil } - delete(n.n, k) - return nil - } else { - n.t.Errorf("%s does not exist", k) - return errExec } + return nil } +func (n *fakeIPTablesRunner) AddSNATRule() error { + newRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask) + for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { + if err := appendRule(n, ipt, "nat/ts-postrouting", newRule); err != nil { + return err + } + } + return nil +} + +func (n *fakeIPTablesRunner) DelSNATRule() error { + delRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask) + for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { + if err := deleteRule(n, ipt, "nat/ts-postrouting", delRule); err != nil { + return err + } + } + return nil +} + +func (n *fakeIPTablesRunner) HasIPV6() bool { return true } +func (n *fakeIPTablesRunner) HasIPV6NAT() bool { return true } + // fakeOS implements commandRunner and provides v4 and v6 // netfilterRunners, but captures changes without touching the OS. type fakeOS struct { - t *testing.T - up bool - ips []string - routes []string - rules []string - netfilter4 *fakeNetfilter - netfilter6 *fakeNetfilter + t *testing.T + up bool + ips []string + routes []string + rules []string + //This test tests on the router level, so we will not bother + //with using iptables or nftables, chose the simpler one. + nfr netfilterRunner } func NewFakeOS(t *testing.T) *fakeOS { return &fakeOS{ - t: t, - netfilter4: newNetfilter(t), - netfilter6: newNetfilter(t), + t: t, + nfr: newIPTablesRunner(t), } } @@ -516,23 +635,23 @@ func (o *fakeOS) String() string { } var chains []string - for chain := range o.netfilter4.n { + for chain := range o.nfr.(*fakeIPTablesRunner).ipt4 { chains = append(chains, chain) } sort.Strings(chains) for _, chain := range chains { - for _, rule := range o.netfilter4.n[chain] { + for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt4[chain] { fmt.Fprintf(&b, "v4/%s %s\n", chain, rule) } } chains = nil - for chain := range o.netfilter6.n { + for chain := range o.nfr.(*fakeIPTablesRunner).ipt6 { chains = append(chains, chain) } sort.Strings(chains) for _, chain := range chains { - for _, rule := range o.netfilter6.n[chain] { + for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt6[chain] { fmt.Fprintf(&b, "v6/%s %s\n", chain, rule) } } @@ -806,7 +925,7 @@ func TestDebugListRules(t *testing.T) { } func TestCheckIPRuleSupportsV6(t *testing.T) { - err := checkIPRuleSupportsV6(t.Logf) + err := linuxfw.CheckIPRuleSupportsV6(t.Logf) if err != nil && os.Getuid() != 0 { t.Skipf("skipping, error when not root: %v", err) }