From a45c9f982ae2b10f790f37abd8c8e6cfe495cbe5 Mon Sep 17 00:00:00 2001 From: Claire Wang Date: Fri, 23 Dec 2022 13:22:39 -0500 Subject: [PATCH] wgengine/netstack: change netstack API to require LocalBackend The macOS client was forgetting to call netstack.Impl.SetLocalBackend. Change the API so that it can't be started without one, eliminating this class of bug. Then update all the callers. Updates #6764 Change-Id: I2b3a4f31fdfd9fdbbbbfe25a42db0c505373562f Signed-off-by: Claire Wang Co-authored-by: Brad Fitzpatrick Signed-off-by: Brad Fitzpatrick --- cmd/tailscaled/tailscaled.go | 3 +- cmd/tsconnect/wasm/wasm_js.go | 9 ++-- tsnet/tsnet.go | 6 +-- wgengine/netstack/netstack.go | 12 ++--- wgengine/netstack/netstack_test.go | 87 ++++++++++++++---------------- 5 files changed, 54 insertions(+), 63 deletions(-) diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 99fbf9d96..b851807e4 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -533,8 +533,7 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logid string) (_ *ip return smallzstd.NewDecoder(nil) }) configureTaildrop(logf, lb) - ns.SetLocalBackend(lb) - if err := ns.Start(); err != nil { + if err := ns.Start(lb); err != nil { log.Fatalf("failed to start netstack: %v", err) } return lb, nil diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index b928ae7f9..aaaeef1ce 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -115,9 +115,7 @@ func newIPN(jsConfig js.Value) map[string]any { } ns.ProcessLocalIPs = true ns.ProcessSubnets = true - if err := ns.Start(); err != nil { - log.Fatalf("failed to start netstack: %v", err) - } + dialer.UseNetstackForIP = func(ip netip.Addr) bool { return true } @@ -127,16 +125,17 @@ func newIPN(jsConfig js.Value) map[string]any { logid := lpc.PublicID.String() srv := ipnserver.New(logf, logid) - lb, err := ipnlocal.NewLocalBackend(logf, logid, store, "wasm", dialer, eng, controlclient.LoginEphemeral) if err != nil { log.Fatalf("ipnlocal.NewLocalBackend: %v", err) } + if err := ns.Start(lb); err != nil { + log.Fatalf("failed to start netstack: %v", err) + } lb.SetDecompressor(func() (controlclient.Decompressor, error) { return smallzstd.NewDecoder(nil) }) srv.SetLocalBackend(lb) - ns.SetLocalBackend(lb) jsIPN := &jsIPN{ dialer: dialer, diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index ad7a28a1a..774340b23 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -317,9 +317,6 @@ func (s *Server) start() (reterr error) { } ns.ProcessLocalIPs = true ns.ForwardTCPIn = s.forwardTCP - if err := ns.Start(); err != nil { - return fmt.Errorf("failed to start netstack: %w", err) - } s.netstack = ns s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { _, ok := eng.PeerForIP(ip) @@ -349,6 +346,9 @@ func (s *Server) start() (reterr error) { lb.SetVarRoot(s.rootPath) logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath) s.lb = lb + if err := ns.Start(lb); err != nil { + return fmt.Errorf("failed to start netstack: %w", err) + } closePool.addFunc(func() { s.lb.Shutdown() }) lb.SetDecompressor(func() (controlclient.Decompressor, error) { return smallzstd.NewDecoder(nil) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 4f4ba6e63..992e99b00 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -204,12 +204,6 @@ func (ns *Impl) Close() error { return nil } -// SetLocalBackend sets the LocalBackend; it should only be run before -// the Start method is called. -func (ns *Impl) SetLocalBackend(lb *ipnlocal.LocalBackend) { - ns.lb = lb -} - // wrapProtoHandler returns protocol handler h wrapped in a version // that dynamically reconfigures ns's subnet addresses as needed for // outbound traffic. @@ -231,7 +225,11 @@ func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketB // Start sets up all the handlers so netstack can start working. Implements // wgengine.FakeImpl. -func (ns *Impl) Start() error { +func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error { + if lb == nil { + panic("nil LocalBackend") + } + ns.lb = lb ns.e.AddNetworkMapCallback(ns.updateIPs) // size = 0 means use default buffer size const tcpReceiveBufferSize = 0 diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index c92f1edb6..61c5933eb 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -17,7 +17,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/net/tstun" - "tailscale.com/tstest" "tailscale.com/types/ipproto" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" @@ -50,13 +49,18 @@ func TestInjectInboundLeak(t *testing.T) { t.Fatal("failed to get internals") } + lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", dialer, eng, 0) + if err != nil { + t.Fatal(err) + } + ns, err := Create(logf, tunWrap, eng, magicSock, dialer, dns) if err != nil { t.Fatal(err) } defer ns.Close() ns.ProcessLocalIPs = true - if err := ns.Start(); err != nil { + if err := ns.Start(lb); err != nil { t.Fatalf("Start: %v", err) } ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) @@ -114,10 +118,16 @@ func makeNetstack(t *testing.T, config func(*Impl)) *Impl { } t.Cleanup(func() { ns.Close() }) - ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) - config(ns) + lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", dialer, eng, 0) + if err != nil { + t.Fatalf("NewLocalBackend: %v", err) + } - if err := ns.Start(); err != nil { + ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) + if config != nil { + config(ns) + } + if err := ns.Start(lb); err != nil { t.Fatalf("Start: %v", err) } return ns @@ -257,11 +267,12 @@ func looksLikeATailscaleSelfAddress(addr netip.Addr) bool { func TestShouldProcessInbound(t *testing.T) { testCases := []struct { - name string - pkt *packet.Parsed - setup func(*Impl) - want bool - runOnGOOS string + name string + pkt *packet.Parsed + afterStart func(*Impl) // optional; after Impl.Start is called + beforeStart func(*Impl) // optional; before Impl.Start is called + want bool + runOnGOOS string }{ { name: "ipv6-via", @@ -275,7 +286,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.AdvertiseRoutes = []netip.Prefix{ // $ tailscale debug via 7 10.1.1.0/24 @@ -286,7 +297,8 @@ func TestShouldProcessInbound(t *testing.T) { LegacyMigrationPrefs: prefs, }) i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) - + }, + beforeStart: func(i *Impl) { // This should be handled even if we're // otherwise not processing local IPs or // subnets. @@ -307,7 +319,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.AdvertiseRoutes = []netip.Prefix{ // tailscale debug via 7 10.1.2.0/24 @@ -329,7 +341,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("100.101.102.104:22"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.RunSSH = true i.lb.Start(ipn.Options{ @@ -351,7 +363,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("100.101.102.104:22"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.RunSSH = false // default, but to be explicit i.lb.Start(ipn.Options{ @@ -372,7 +384,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("100.101.102.104:4567"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { i.ProcessLocalIPs = true i.atomicIsLocalIPFunc.Store(func(addr netip.Addr) bool { return addr.String() == "100.101.102.104" // Dst, above @@ -389,9 +401,10 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("10.1.2.3:4567"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + beforeStart: func(i *Impl) { i.ProcessSubnets = true - + }, + afterStart: func(i *Impl) { // For testing purposes, assume all Tailscale // IPs are local; the Dst above is something // not in that range. @@ -408,7 +421,12 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("10.0.0.23:5555"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + beforeStart: func(i *Impl) { + // As if we were running on Linux where netstack isn't used. + i.ProcessSubnets = false + i.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return false }) + }, + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.AdvertiseRoutes = []netip.Prefix{ netip.MustParsePrefix("10.0.0.1/24"), @@ -417,10 +435,6 @@ func TestShouldProcessInbound(t *testing.T) { LegacyMigrationPrefs: prefs, }) - // As if we were running on Linux where netstack isn't used. - i.ProcessSubnets = false - i.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return false }) - // Set the PeerAPI port to the Dst port above. i.peerapiPort4Atomic.Store(5555) i.peerapiPort6Atomic.Store(5555) @@ -437,30 +451,11 @@ func TestShouldProcessInbound(t *testing.T) { if tc.runOnGOOS != "" && runtime.GOOS != tc.runOnGOOS { t.Skipf("skipping on GOOS=%v", runtime.GOOS) } - impl := makeNetstack(t, func(i *Impl) { - defer t.Logf("netstack setup finished") - - logf := tstest.WhileTestRunningLogger(t) - e, err := wgengine.NewFakeUserspaceEngine(logf, 0) - if err != nil { - t.Fatalf("NewFakeUserspaceEngine: %v", err) - } - t.Cleanup(e.Close) - - lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", new(tsdial.Dialer), e, 0) - if err != nil { - t.Fatalf("NewLocalBackend: %v", err) - } - t.Cleanup(lb.Shutdown) - dir := t.TempDir() - lb.SetVarRoot(dir) - - i.SetLocalBackend(lb) + impl := makeNetstack(t, tc.beforeStart) + if tc.afterStart != nil { + tc.afterStart(impl) + } - if tc.setup != nil { - tc.setup(i) - } - }) got := impl.shouldProcessInbound(tc.pkt, nil) if got != tc.want { t.Errorf("got shouldProcessInbound()=%v; want %v", got, tc.want)