diff --git a/types/logger/logger.go b/types/logger/logger.go index 646becdcd..dfc9eb07b 100644 --- a/types/logger/logger.go +++ b/types/logger/logger.go @@ -192,3 +192,27 @@ func Filtered(logf Logf, allow func(s string) bool) Logf { logf(format, args...) } } + +// LogfCloser wraps logf to create a logger that can be closed. +// Calling close makes all future calls to newLogf into no-ops. +func LogfCloser(logf Logf) (newLogf Logf, close func()) { + var ( + mu sync.Mutex + closed bool + ) + close = func() { + mu.Lock() + defer mu.Unlock() + closed = true + } + newLogf = func(msg string, args ...interface{}) { + mu.Lock() + if closed { + mu.Unlock() + return + } + mu.Unlock() + logf(msg, args...) + } + return newLogf, close +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 5e7460588..05e4715ba 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -567,9 +567,12 @@ func TestConnClosing(t *testing.T) { t.Fatalf("generating private key: %v", err) } + logf, closeLogf := logger.LogfCloser(t.Logf) + defer closeLogf() + epCh := make(chan []string, 100) conn, err := NewConn(Options{ - Logf: t.Logf, + Logf: logf, PacketListener: nettype.Std{}, EndpointsFunc: func(eps []string) { epCh <- eps @@ -580,7 +583,7 @@ func TestConnClosing(t *testing.T) { t.Fatalf("constructing magicsock: %v", err) } - derpMap, cleanup := runDERPAndStun(t, t.Logf, nettype.Std{}, netaddr.IPv4(127, 0, 3, 1)) + derpMap, cleanup := runDERPAndStun(t, logf, nettype.Std{}, netaddr.IPv4(127, 0, 3, 1)) defer cleanup() // The point of this test case is to exercise handling in derpWriteChanOfAddr() which @@ -597,11 +600,11 @@ func TestConnClosing(t *testing.T) { } tun := tuntest.NewChannelTUN() - tsTun := tstun.WrapTUN(t.Logf, tun.TUN()) - tsTun.SetFilter(filter.NewAllowAllForTest(t.Logf)) + tsTun := tstun.WrapTUN(logf, tun.TUN()) + tsTun.SetFilter(filter.NewAllowAllForTest(logf)) dev := device.NewDevice(tsTun, &device.DeviceOptions{ - Logger: wireguardGoLogger(t.Logf), + Logger: wireguardGoLogger(logf), CreateEndpoint: conn.CreateEndpoint, CreateBind: conn.CreateBind, SkipBindUpdate: true, @@ -634,12 +637,15 @@ func TestConnClosed(t *testing.T) { stunIP: sif.V4(), } - derpMap, cleanup := runDERPAndStun(t, t.Logf, d.stun, d.stunIP) + logf, closeLogf := logger.LogfCloser(t.Logf) + defer closeLogf() + + derpMap, cleanup := runDERPAndStun(t, logf, d.stun, d.stunIP) defer cleanup() - ms1 := newMagicStack(t, logger.WithPrefix(t.Logf, "conn1: "), d.m1, derpMap) + ms1 := newMagicStack(t, logger.WithPrefix(logf, "conn1: "), d.m1, derpMap) defer ms1.Close() - ms2 := newMagicStack(t, logger.WithPrefix(t.Logf, "conn2: "), d.m2, derpMap) + ms2 := newMagicStack(t, logger.WithPrefix(logf, "conn2: "), d.m2, derpMap) defer ms2.Close() cleanup = meshStacks(t.Logf, []*magicStack{ms1, ms2})