diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index a9cf1de8a..dbd3e4694 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -520,6 +520,7 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons Data: vv, }) ns.linkEP.InjectInbound(pn, packetBuf) + packetBuf.DecRef() // We've now delivered this to netstack, so we're done. // Instead of returning a filter.Accept here (which would also diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go new file mode 100644 index 000000000..9cea5b842 --- /dev/null +++ b/wgengine/netstack/netstack_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netstack + +import ( + "runtime" + "testing" + + "inet.af/netaddr" + "tailscale.com/net/packet" + "tailscale.com/net/tsdial" + "tailscale.com/net/tstun" + "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" +) + +// TestInjectInboundLeak tests that injectInbound doesn't leak memory. +// See https://github.com/tailscale/tailscale/issues/3762 +func TestInjectInboundLeak(t *testing.T) { + tunDev := tstun.NewFake() + dialer := new(tsdial.Dialer) + logf := func(format string, args ...interface{}) { + if !t.Failed() { + t.Logf(format, args...) + } + } + eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ + Tun: tunDev, + Dialer: dialer, + }) + if err != nil { + t.Fatal(err) + } + defer eng.Close() + ig, ok := eng.(wgengine.InternalsGetter) + if !ok { + t.Fatal("not an InternalsGetter") + } + tunWrap, magicSock, ok := ig.GetInternals() + if !ok { + t.Fatal("failed to get internals") + } + + ns, err := Create(logf, tunWrap, eng, magicSock, dialer) + if err != nil { + t.Fatal(err) + } + defer ns.Close() + ns.ProcessLocalIPs = true + if err := ns.Start(); err != nil { + t.Fatalf("Start: %v", err) + } + ns.atomicIsLocalIPFunc.Store(func(netaddr.IP) bool { return true }) + + pkt := &packet.Parsed{} + const N = 10_000 + ms0 := getMemStats() + for i := 0; i < N; i++ { + outcome := ns.injectInbound(pkt, tunWrap) + if outcome != filter.DropSilently { + t.Fatalf("got outcome %v; want DropSilently", outcome) + } + } + ms1 := getMemStats() + if grew := int64(ms1.HeapObjects) - int64(ms0.HeapObjects); grew >= N { + t.Fatalf("grew by %v (which is too much and >= the %v packets we sent)", grew, N) + } +} + +func getMemStats() (ms runtime.MemStats) { + runtime.GC() + runtime.ReadMemStats(&ms) + return +}