wgengine/netstack: change netstack API to require LocalBackend

The macOS client was forgetting to call netstack.Impl.SetLocalBackend.
Change the API so that it can't be started without one, eliminating this
class of bug. Then update all the callers.

Updates #6764

Change-Id: I2b3a4f31fdfd9fdbbbbfe25a42db0c505373562f
Signed-off-by: Claire Wang <claire@tailscale.com>
Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/6838/head
Claire Wang 2 years ago committed by Brad Fitzpatrick
parent 84eaef0bbb
commit a45c9f982a

@ -533,8 +533,7 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logid string) (_ *ip
return smallzstd.NewDecoder(nil) return smallzstd.NewDecoder(nil)
}) })
configureTaildrop(logf, lb) configureTaildrop(logf, lb)
ns.SetLocalBackend(lb) if err := ns.Start(lb); err != nil {
if err := ns.Start(); err != nil {
log.Fatalf("failed to start netstack: %v", err) log.Fatalf("failed to start netstack: %v", err)
} }
return lb, nil return lb, nil

@ -115,9 +115,7 @@ func newIPN(jsConfig js.Value) map[string]any {
} }
ns.ProcessLocalIPs = true ns.ProcessLocalIPs = true
ns.ProcessSubnets = true ns.ProcessSubnets = true
if err := ns.Start(); err != nil {
log.Fatalf("failed to start netstack: %v", err)
}
dialer.UseNetstackForIP = func(ip netip.Addr) bool { dialer.UseNetstackForIP = func(ip netip.Addr) bool {
return true return true
} }
@ -127,16 +125,17 @@ func newIPN(jsConfig js.Value) map[string]any {
logid := lpc.PublicID.String() logid := lpc.PublicID.String()
srv := ipnserver.New(logf, logid) srv := ipnserver.New(logf, logid)
lb, err := ipnlocal.NewLocalBackend(logf, logid, store, "wasm", dialer, eng, controlclient.LoginEphemeral) lb, err := ipnlocal.NewLocalBackend(logf, logid, store, "wasm", dialer, eng, controlclient.LoginEphemeral)
if err != nil { if err != nil {
log.Fatalf("ipnlocal.NewLocalBackend: %v", err) log.Fatalf("ipnlocal.NewLocalBackend: %v", err)
} }
if err := ns.Start(lb); err != nil {
log.Fatalf("failed to start netstack: %v", err)
}
lb.SetDecompressor(func() (controlclient.Decompressor, error) { lb.SetDecompressor(func() (controlclient.Decompressor, error) {
return smallzstd.NewDecoder(nil) return smallzstd.NewDecoder(nil)
}) })
srv.SetLocalBackend(lb) srv.SetLocalBackend(lb)
ns.SetLocalBackend(lb)
jsIPN := &jsIPN{ jsIPN := &jsIPN{
dialer: dialer, dialer: dialer,

@ -317,9 +317,6 @@ func (s *Server) start() (reterr error) {
} }
ns.ProcessLocalIPs = true ns.ProcessLocalIPs = true
ns.ForwardTCPIn = s.forwardTCP ns.ForwardTCPIn = s.forwardTCP
if err := ns.Start(); err != nil {
return fmt.Errorf("failed to start netstack: %w", err)
}
s.netstack = ns s.netstack = ns
s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { s.dialer.UseNetstackForIP = func(ip netip.Addr) bool {
_, ok := eng.PeerForIP(ip) _, ok := eng.PeerForIP(ip)
@ -349,6 +346,9 @@ func (s *Server) start() (reterr error) {
lb.SetVarRoot(s.rootPath) lb.SetVarRoot(s.rootPath)
logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath) logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath)
s.lb = lb s.lb = lb
if err := ns.Start(lb); err != nil {
return fmt.Errorf("failed to start netstack: %w", err)
}
closePool.addFunc(func() { s.lb.Shutdown() }) closePool.addFunc(func() { s.lb.Shutdown() })
lb.SetDecompressor(func() (controlclient.Decompressor, error) { lb.SetDecompressor(func() (controlclient.Decompressor, error) {
return smallzstd.NewDecoder(nil) return smallzstd.NewDecoder(nil)

@ -204,12 +204,6 @@ func (ns *Impl) Close() error {
return nil return nil
} }
// SetLocalBackend sets the LocalBackend; it should only be run before
// the Start method is called.
func (ns *Impl) SetLocalBackend(lb *ipnlocal.LocalBackend) {
ns.lb = lb
}
// wrapProtoHandler returns protocol handler h wrapped in a version // wrapProtoHandler returns protocol handler h wrapped in a version
// that dynamically reconfigures ns's subnet addresses as needed for // that dynamically reconfigures ns's subnet addresses as needed for
// outbound traffic. // outbound traffic.
@ -231,7 +225,11 @@ func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, stack.PacketB
// Start sets up all the handlers so netstack can start working. Implements // Start sets up all the handlers so netstack can start working. Implements
// wgengine.FakeImpl. // wgengine.FakeImpl.
func (ns *Impl) Start() error { func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error {
if lb == nil {
panic("nil LocalBackend")
}
ns.lb = lb
ns.e.AddNetworkMapCallback(ns.updateIPs) ns.e.AddNetworkMapCallback(ns.updateIPs)
// size = 0 means use default buffer size // size = 0 means use default buffer size
const tcpReceiveBufferSize = 0 const tcpReceiveBufferSize = 0

@ -17,7 +17,6 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/net/tstun" "tailscale.com/net/tstun"
"tailscale.com/tstest"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/wgengine" "tailscale.com/wgengine"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
@ -50,13 +49,18 @@ func TestInjectInboundLeak(t *testing.T) {
t.Fatal("failed to get internals") t.Fatal("failed to get internals")
} }
lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", dialer, eng, 0)
if err != nil {
t.Fatal(err)
}
ns, err := Create(logf, tunWrap, eng, magicSock, dialer, dns) ns, err := Create(logf, tunWrap, eng, magicSock, dialer, dns)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer ns.Close() defer ns.Close()
ns.ProcessLocalIPs = true ns.ProcessLocalIPs = true
if err := ns.Start(); err != nil { if err := ns.Start(lb); err != nil {
t.Fatalf("Start: %v", err) t.Fatalf("Start: %v", err)
} }
ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true })
@ -114,10 +118,16 @@ func makeNetstack(t *testing.T, config func(*Impl)) *Impl {
} }
t.Cleanup(func() { ns.Close() }) t.Cleanup(func() { ns.Close() })
ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", dialer, eng, 0)
config(ns) if err != nil {
t.Fatalf("NewLocalBackend: %v", err)
}
if err := ns.Start(); err != nil { ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true })
if config != nil {
config(ns)
}
if err := ns.Start(lb); err != nil {
t.Fatalf("Start: %v", err) t.Fatalf("Start: %v", err)
} }
return ns return ns
@ -257,11 +267,12 @@ func looksLikeATailscaleSelfAddress(addr netip.Addr) bool {
func TestShouldProcessInbound(t *testing.T) { func TestShouldProcessInbound(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
pkt *packet.Parsed pkt *packet.Parsed
setup func(*Impl) afterStart func(*Impl) // optional; after Impl.Start is called
want bool beforeStart func(*Impl) // optional; before Impl.Start is called
runOnGOOS string want bool
runOnGOOS string
}{ }{
{ {
name: "ipv6-via", name: "ipv6-via",
@ -275,7 +286,7 @@ func TestShouldProcessInbound(t *testing.T) {
Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"), Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"),
TCPFlags: packet.TCPSyn, TCPFlags: packet.TCPSyn,
}, },
setup: func(i *Impl) { afterStart: func(i *Impl) {
prefs := ipn.NewPrefs() prefs := ipn.NewPrefs()
prefs.AdvertiseRoutes = []netip.Prefix{ prefs.AdvertiseRoutes = []netip.Prefix{
// $ tailscale debug via 7 10.1.1.0/24 // $ tailscale debug via 7 10.1.1.0/24
@ -286,7 +297,8 @@ func TestShouldProcessInbound(t *testing.T) {
LegacyMigrationPrefs: prefs, LegacyMigrationPrefs: prefs,
}) })
i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress)
},
beforeStart: func(i *Impl) {
// This should be handled even if we're // This should be handled even if we're
// otherwise not processing local IPs or // otherwise not processing local IPs or
// subnets. // subnets.
@ -307,7 +319,7 @@ func TestShouldProcessInbound(t *testing.T) {
Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"), Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0:b1a:0:7:a01:109]:5678"),
TCPFlags: packet.TCPSyn, TCPFlags: packet.TCPSyn,
}, },
setup: func(i *Impl) { afterStart: func(i *Impl) {
prefs := ipn.NewPrefs() prefs := ipn.NewPrefs()
prefs.AdvertiseRoutes = []netip.Prefix{ prefs.AdvertiseRoutes = []netip.Prefix{
// tailscale debug via 7 10.1.2.0/24 // tailscale debug via 7 10.1.2.0/24
@ -329,7 +341,7 @@ func TestShouldProcessInbound(t *testing.T) {
Dst: netip.MustParseAddrPort("100.101.102.104:22"), Dst: netip.MustParseAddrPort("100.101.102.104:22"),
TCPFlags: packet.TCPSyn, TCPFlags: packet.TCPSyn,
}, },
setup: func(i *Impl) { afterStart: func(i *Impl) {
prefs := ipn.NewPrefs() prefs := ipn.NewPrefs()
prefs.RunSSH = true prefs.RunSSH = true
i.lb.Start(ipn.Options{ i.lb.Start(ipn.Options{
@ -351,7 +363,7 @@ func TestShouldProcessInbound(t *testing.T) {
Dst: netip.MustParseAddrPort("100.101.102.104:22"), Dst: netip.MustParseAddrPort("100.101.102.104:22"),
TCPFlags: packet.TCPSyn, TCPFlags: packet.TCPSyn,
}, },
setup: func(i *Impl) { afterStart: func(i *Impl) {
prefs := ipn.NewPrefs() prefs := ipn.NewPrefs()
prefs.RunSSH = false // default, but to be explicit prefs.RunSSH = false // default, but to be explicit
i.lb.Start(ipn.Options{ i.lb.Start(ipn.Options{
@ -372,7 +384,7 @@ func TestShouldProcessInbound(t *testing.T) {
Dst: netip.MustParseAddrPort("100.101.102.104:4567"), Dst: netip.MustParseAddrPort("100.101.102.104:4567"),
TCPFlags: packet.TCPSyn, TCPFlags: packet.TCPSyn,
}, },
setup: func(i *Impl) { afterStart: func(i *Impl) {
i.ProcessLocalIPs = true i.ProcessLocalIPs = true
i.atomicIsLocalIPFunc.Store(func(addr netip.Addr) bool { i.atomicIsLocalIPFunc.Store(func(addr netip.Addr) bool {
return addr.String() == "100.101.102.104" // Dst, above return addr.String() == "100.101.102.104" // Dst, above
@ -389,9 +401,10 @@ func TestShouldProcessInbound(t *testing.T) {
Dst: netip.MustParseAddrPort("10.1.2.3:4567"), Dst: netip.MustParseAddrPort("10.1.2.3:4567"),
TCPFlags: packet.TCPSyn, TCPFlags: packet.TCPSyn,
}, },
setup: func(i *Impl) { beforeStart: func(i *Impl) {
i.ProcessSubnets = true i.ProcessSubnets = true
},
afterStart: func(i *Impl) {
// For testing purposes, assume all Tailscale // For testing purposes, assume all Tailscale
// IPs are local; the Dst above is something // IPs are local; the Dst above is something
// not in that range. // not in that range.
@ -408,7 +421,12 @@ func TestShouldProcessInbound(t *testing.T) {
Dst: netip.MustParseAddrPort("10.0.0.23:5555"), Dst: netip.MustParseAddrPort("10.0.0.23:5555"),
TCPFlags: packet.TCPSyn, TCPFlags: packet.TCPSyn,
}, },
setup: func(i *Impl) { beforeStart: func(i *Impl) {
// As if we were running on Linux where netstack isn't used.
i.ProcessSubnets = false
i.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return false })
},
afterStart: func(i *Impl) {
prefs := ipn.NewPrefs() prefs := ipn.NewPrefs()
prefs.AdvertiseRoutes = []netip.Prefix{ prefs.AdvertiseRoutes = []netip.Prefix{
netip.MustParsePrefix("10.0.0.1/24"), netip.MustParsePrefix("10.0.0.1/24"),
@ -417,10 +435,6 @@ func TestShouldProcessInbound(t *testing.T) {
LegacyMigrationPrefs: prefs, LegacyMigrationPrefs: prefs,
}) })
// As if we were running on Linux where netstack isn't used.
i.ProcessSubnets = false
i.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return false })
// Set the PeerAPI port to the Dst port above. // Set the PeerAPI port to the Dst port above.
i.peerapiPort4Atomic.Store(5555) i.peerapiPort4Atomic.Store(5555)
i.peerapiPort6Atomic.Store(5555) i.peerapiPort6Atomic.Store(5555)
@ -437,30 +451,11 @@ func TestShouldProcessInbound(t *testing.T) {
if tc.runOnGOOS != "" && runtime.GOOS != tc.runOnGOOS { if tc.runOnGOOS != "" && runtime.GOOS != tc.runOnGOOS {
t.Skipf("skipping on GOOS=%v", runtime.GOOS) t.Skipf("skipping on GOOS=%v", runtime.GOOS)
} }
impl := makeNetstack(t, func(i *Impl) { impl := makeNetstack(t, tc.beforeStart)
defer t.Logf("netstack setup finished") if tc.afterStart != nil {
tc.afterStart(impl)
logf := tstest.WhileTestRunningLogger(t) }
e, err := wgengine.NewFakeUserspaceEngine(logf, 0)
if err != nil {
t.Fatalf("NewFakeUserspaceEngine: %v", err)
}
t.Cleanup(e.Close)
lb, err := ipnlocal.NewLocalBackend(logf, "logid", new(mem.Store), "", new(tsdial.Dialer), e, 0)
if err != nil {
t.Fatalf("NewLocalBackend: %v", err)
}
t.Cleanup(lb.Shutdown)
dir := t.TempDir()
lb.SetVarRoot(dir)
i.SetLocalBackend(lb)
if tc.setup != nil {
tc.setup(i)
}
})
got := impl.shouldProcessInbound(tc.pkt, nil) got := impl.shouldProcessInbound(tc.pkt, nil)
if got != tc.want { if got != tc.want {
t.Errorf("got shouldProcessInbound()=%v; want %v", got, tc.want) t.Errorf("got shouldProcessInbound()=%v; want %v", got, tc.want)

Loading…
Cancel
Save