diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 99fbf9d96..b851807e4 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -533,8 +533,7 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logid string) (_ *ip return smallzstd.NewDecoder(nil) }) configureTaildrop(logf, lb) - ns.SetLocalBackend(lb) - if err := ns.Start(); err != nil { + if err := ns.Start(lb); err != nil { log.Fatalf("failed to start netstack: %v", err) } return lb, nil diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index b928ae7f9..aaaeef1ce 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -115,9 +115,7 @@ func newIPN(jsConfig js.Value) map[string]any { } ns.ProcessLocalIPs = true ns.ProcessSubnets = true - if err := ns.Start(); err != nil { - log.Fatalf("failed to start netstack: %v", err) - } + dialer.UseNetstackForIP = func(ip netip.Addr) bool { return true } @@ -127,16 +125,17 @@ func newIPN(jsConfig js.Value) map[string]any { logid := lpc.PublicID.String() srv := ipnserver.New(logf, logid) - lb, err := ipnlocal.NewLocalBackend(logf, logid, store, "wasm", dialer, eng, controlclient.LoginEphemeral) if err != nil { log.Fatalf("ipnlocal.NewLocalBackend: %v", err) } + if err := ns.Start(lb); err != nil { + log.Fatalf("failed to start netstack: %v", err) + } lb.SetDecompressor(func() (controlclient.Decompressor, error) { return smallzstd.NewDecoder(nil) }) srv.SetLocalBackend(lb) - ns.SetLocalBackend(lb) jsIPN := &jsIPN{ dialer: dialer, diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index ad7a28a1a..774340b23 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -317,9 +317,6 @@ func (s *Server) start() (reterr error) { } ns.ProcessLocalIPs = true ns.ForwardTCPIn = s.forwardTCP - if err := ns.Start(); err != nil { - return fmt.Errorf("failed to start netstack: %w", err) - } s.netstack = ns s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { _, ok := eng.PeerForIP(ip) @@ -349,6 +346,9 @@ func (s *Server) start() (reterr error) { lb.SetVarRoot(s.rootPath) logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath) s.lb = lb + if err := ns.Start(lb); err != nil { + return fmt.Errorf("failed to start netstack: %w", err) + } closePool.addFunc(func() { s.lb.Shutdown() }) lb.SetDecompressor(func() (controlclient.Decompressor, error) { return smallzstd.NewDecoder(nil) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 4f4ba6e63..992e99b00 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -204,12 +204,6 @@ func (ns *Impl) Close() error { return nil } -// SetLocalBackend sets the LocalBackend; it should only be run before -// the Start method is called. -func (ns *Impl) SetLocalBackend(lb *ipnlocal.LocalBackend) { - ns.lb = lb -} - // wrapProtoHandler returns protocol handler h wrapped in a version // that dynamically reconfigures ns's subnet addresses as needed for // outbound traffic. @@ -231,7 +225,11 @@ func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketB // Start sets up all the handlers so netstack can start working. Implements // wgengine.FakeImpl. -func (ns *Impl) Start() error { +func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error { + if lb == nil { + panic("nil LocalBackend") + } + ns.lb = lb ns.e.AddNetworkMapCallback(ns.updateIPs) // size = 0 means use default buffer size const tcpReceiveBufferSize = 0 diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index c92f1edb6..61c5933eb 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -17,7 +17,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/net/tstun" - "tailscale.com/tstest" "tailscale.com/types/ipproto" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" @@ -50,13 +49,18 @@ func TestInjectInboundLeak(t *testing.T) { t.Fatal("failed to get internals") } + lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", dialer, eng, 0) + if err != nil { + t.Fatal(err) + } + ns, err := Create(logf, tunWrap, eng, magicSock, dialer, dns) if err != nil { t.Fatal(err) } defer ns.Close() ns.ProcessLocalIPs = true - if err := ns.Start(); err != nil { + if err := ns.Start(lb); err != nil { t.Fatalf("Start: %v", err) } ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) @@ -114,10 +118,16 @@ func makeNetstack(t *testing.T, config func(*Impl)) *Impl { } t.Cleanup(func() { ns.Close() }) - ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) - config(ns) + lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", dialer, eng, 0) + if err != nil { + t.Fatalf("NewLocalBackend: %v", err) + } - if err := ns.Start(); err != nil { + ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) + if config != nil { + config(ns) + } + if err := ns.Start(lb); err != nil { t.Fatalf("Start: %v", err) } return ns @@ -257,11 +267,12 @@ func looksLikeATailscaleSelfAddress(addr netip.Addr) bool { func TestShouldProcessInbound(t *testing.T) { testCases := []struct { - name string - pkt *packet.Parsed - setup func(*Impl) - want bool - runOnGOOS string + name string + pkt *packet.Parsed + afterStart func(*Impl) // optional; after Impl.Start is called + beforeStart func(*Impl) // optional; before Impl.Start is called + want bool + runOnGOOS string }{ { name: "ipv6-via", @@ -275,7 +286,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.AdvertiseRoutes = []netip.Prefix{ // $ tailscale debug via 7 10.1.1.0/24 @@ -286,7 +297,8 @@ func TestShouldProcessInbound(t *testing.T) { LegacyMigrationPrefs: prefs, }) i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) - + }, + beforeStart: func(i *Impl) { // This should be handled even if we're // otherwise not processing local IPs or // subnets. @@ -307,7 +319,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.AdvertiseRoutes = []netip.Prefix{ // tailscale debug via 7 10.1.2.0/24 @@ -329,7 +341,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("100.101.102.104:22"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.RunSSH = true i.lb.Start(ipn.Options{ @@ -351,7 +363,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("100.101.102.104:22"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.RunSSH = false // default, but to be explicit i.lb.Start(ipn.Options{ @@ -372,7 +384,7 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("100.101.102.104:4567"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + afterStart: func(i *Impl) { i.ProcessLocalIPs = true i.atomicIsLocalIPFunc.Store(func(addr netip.Addr) bool { return addr.String() == "100.101.102.104" // Dst, above @@ -389,9 +401,10 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("10.1.2.3:4567"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + beforeStart: func(i *Impl) { i.ProcessSubnets = true - + }, + afterStart: func(i *Impl) { // For testing purposes, assume all Tailscale // IPs are local; the Dst above is something // not in that range. @@ -408,7 +421,12 @@ func TestShouldProcessInbound(t *testing.T) { Dst: netip.MustParseAddrPort("10.0.0.23:5555"), TCPFlags: packet.TCPSyn, }, - setup: func(i *Impl) { + beforeStart: func(i *Impl) { + // As if we were running on Linux where netstack isn't used. + i.ProcessSubnets = false + i.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return false }) + }, + afterStart: func(i *Impl) { prefs := ipn.NewPrefs() prefs.AdvertiseRoutes = []netip.Prefix{ netip.MustParsePrefix("10.0.0.1/24"), @@ -417,10 +435,6 @@ func TestShouldProcessInbound(t *testing.T) { LegacyMigrationPrefs: prefs, }) - // As if we were running on Linux where netstack isn't used. - i.ProcessSubnets = false - i.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return false }) - // Set the PeerAPI port to the Dst port above. i.peerapiPort4Atomic.Store(5555) i.peerapiPort6Atomic.Store(5555) @@ -437,30 +451,11 @@ func TestShouldProcessInbound(t *testing.T) { if tc.runOnGOOS != "" && runtime.GOOS != tc.runOnGOOS { t.Skipf("skipping on GOOS=%v", runtime.GOOS) } - impl := makeNetstack(t, func(i *Impl) { - defer t.Logf("netstack setup finished") - - logf := tstest.WhileTestRunningLogger(t) - e, err := wgengine.NewFakeUserspaceEngine(logf, 0) - if err != nil { - t.Fatalf("NewFakeUserspaceEngine: %v", err) - } - t.Cleanup(e.Close) - - lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", new(tsdial.Dialer), e, 0) - if err != nil { - t.Fatalf("NewLocalBackend: %v", err) - } - t.Cleanup(lb.Shutdown) - dir := t.TempDir() - lb.SetVarRoot(dir) - - i.SetLocalBackend(lb) + impl := makeNetstack(t, tc.beforeStart) + if tc.afterStart != nil { + tc.afterStart(impl) + } - if tc.setup != nil { - tc.setup(i) - } - }) got := impl.shouldProcessInbound(tc.pkt, nil) if got != tc.want { t.Errorf("got shouldProcessInbound()=%v; want %v", got, tc.want)