diff --git a/cmd/derper/bootstrap_dns_test.go b/cmd/derper/bootstrap_dns_test.go index 70d3c8678..cbede8587 100644 --- a/cmd/derper/bootstrap_dns_test.go +++ b/cmd/derper/bootstrap_dns_test.go @@ -11,14 +11,12 @@ import ( "net/url" "reflect" "testing" + + "tailscale.com/tstest" ) func BenchmarkHandleBootstrapDNS(b *testing.B) { - prev := *bootstrapDNS - *bootstrapDNS = "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com" - defer func() { - *bootstrapDNS = prev - }() + tstest.Replace(b, bootstrapDNS, "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com") refreshBootstrapDNS() w := new(bitbucketResponseWriter) req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil) diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index b2f605600..fa130c91b 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -13,6 +13,7 @@ import ( "go4.org/mem" "tailscale.com/tailcfg" + "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/opt" @@ -21,12 +22,11 @@ import ( ) func TestUndeltaPeers(t *testing.T) { - defer func(old func() time.Time) { clockNow = old }(clockNow) - var curTime time.Time - clockNow = func() time.Time { + tstest.Replace(t, &clockNow, func() time.Time { return curTime - } + }) + online := func(v bool) func(*tailcfg.Node) { return func(n *tailcfg.Node) { n.Online = &v diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 556a7f17e..cc105f0a8 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -497,8 +497,7 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) { // Issue 1573: don't generate a machine key if we don't want to be running. func TestLazyMachineKeyGeneration(t *testing.T) { - defer func(old func() bool) { panicOnMachineKeyGeneration = old }(panicOnMachineKeyGeneration) - panicOnMachineKeyGeneration = func() bool { return true } + tstest.Replace(t, &panicOnMachineKeyGeneration, func() bool { return true }) var logf logger.Logf = logger.Discard store := new(mem.Store) diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 8d7d317b0..057da9039 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -14,6 +14,7 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/hostinfo" "tailscale.com/ipn/ipnlocal" + "tailscale.com/tstest" ) func TestValidHost(t *testing.T) { @@ -42,10 +43,7 @@ func TestValidHost(t *testing.T) { } func TestSetPushDeviceToken(t *testing.T) { - validLocalHostForTesting = true - defer func() { - validLocalHostForTesting = false - }() + tstest.Replace(t, &validLocalHostForTesting, true) h := &Handler{ PermitWrite: true, diff --git a/logtail/filch/filch_test.go b/logtail/filch/filch_test.go index ed02e8058..5c33efbee 100644 --- a/logtail/filch/filch_test.go +++ b/logtail/filch/filch_test.go @@ -12,6 +12,8 @@ import ( "testing" "unicode" "unsafe" + + "tailscale.com/tstest" ) type filchTest struct { @@ -177,10 +179,7 @@ func TestFilchStderr(t *testing.T) { defer pipeR.Close() defer pipeW.Close() - stderrFD = int(pipeW.Fd()) - defer func() { - stderrFD = 2 - }() + tstest.Replace(t, &stderrFD, int(pipeW.Fd())) filePrefix := t.TempDir() f := newFilchTest(t, filePrefix, Options{ReplaceStderr: true}) diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index ac6b09b4a..3980cfa33 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -997,11 +997,8 @@ func TestMarshalResponseFormatError(t *testing.T) { } func TestForwardLinkSelection(t *testing.T) { - old := initListenConfig - defer func() { initListenConfig = old }() - configCall := make(chan string, 1) - initListenConfig = func(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error { + tstest.Replace(t, &initListenConfig, func(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error { select { case configCall <- tunName: return nil @@ -1009,7 +1006,7 @@ func TestForwardLinkSelection(t *testing.T) { t.Error("buffer full") return errors.New("buffer full") } - } + }) // specialIP is some IP we pretend that our link selector // routes differently. diff --git a/net/interfaces/interfaces_linux_test.go b/net/interfaces/interfaces_linux_test.go index 87be5610f..59249c67d 100644 --- a/net/interfaces/interfaces_linux_test.go +++ b/net/interfaces/interfaces_linux_test.go @@ -10,14 +10,14 @@ import ( "os" "path/filepath" "testing" + + "tailscale.com/tstest" ) // test the specific /proc/net/route path as found on Google Cloud Run instances func TestGoogleCloudRunDefaultRouteInterface(t *testing.T) { dir := t.TempDir() - savedProcNetRoutePath := procNetRoutePath - defer func() { procNetRoutePath = savedProcNetRoutePath }() - procNetRoutePath = filepath.Join(dir, "CloudRun") + tstest.Replace(t, &procNetRoutePath, filepath.Join(dir, "CloudRun")) buf := []byte("Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" + "eth0\t8008FEA9\t00000000\t0001\t0\t0\t0\t01FFFFFF\t0\t0\t0\n" + "eth1\t00000000\t00000000\t0001\t0\t0\t0\t00000000\t0\t0\t0\n") @@ -39,9 +39,7 @@ func TestGoogleCloudRunDefaultRouteInterface(t *testing.T) { // size can be handled. func TestExtremelyLongProcNetRoute(t *testing.T) { dir := t.TempDir() - savedProcNetRoutePath := procNetRoutePath - defer func() { procNetRoutePath = savedProcNetRoutePath }() - procNetRoutePath = filepath.Join(dir, "VeryLong") + tstest.Replace(t, &procNetRoutePath, filepath.Join(dir, "VeryLong")) f, err := os.Create(procNetRoutePath) if err != nil { t.Fatal(err) @@ -76,9 +74,7 @@ func TestExtremelyLongProcNetRoute(t *testing.T) { // test the specific /proc/net/route path as found on AWS App Runner instances func TestAwsAppRunnerDefaultRouteInterface(t *testing.T) { dir := t.TempDir() - savedProcNetRoutePath := procNetRoutePath - defer func() { procNetRoutePath = savedProcNetRoutePath }() - procNetRoutePath = filepath.Join(dir, "CloudRun") + tstest.Replace(t, &procNetRoutePath, filepath.Join(dir, "CloudRun")) buf := []byte("Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" + "eth0\t00000000\tF9AFFEA9\t0003\t0\t0\t0\t00000000\t0\t0\t0\n" + "*\tFEA9FEA9\t00000000\t0005\t0\t0\t0\tFFFFFFFF\t0\t0\t0\n" + diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index 2eae348ae..3061740f3 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -16,6 +16,8 @@ import ( "strings" "testing" "time" + + "tailscale.com/tstest" ) func TestSynologyProxyFromConfigCached(t *testing.T) { @@ -24,9 +26,7 @@ func TestSynologyProxyFromConfigCached(t *testing.T) { t.Fatal(err) } - var orig string - orig, synologyProxyConfigPath = synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf") - defer func() { synologyProxyConfigPath = orig }() + tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) t.Run("no config file", func(t *testing.T) { if _, err := os.Stat(synologyProxyConfigPath); err == nil { @@ -160,11 +160,9 @@ func TestSynologyProxiesFromConfig(t *testing.T) { openReader io.ReadCloser openErr error ) - var origOpen func() (io.ReadCloser, error) - origOpen, openSynologyProxyConf = openSynologyProxyConf, func() (io.ReadCloser, error) { + tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { return openReader, openErr - } - defer func() { openSynologyProxyConf = origOpen }() + }) t.Run("with config", func(t *testing.T) { mc := &mustCloser{Reader: strings.NewReader(` diff --git a/tstest/tstest.go b/tstest/tstest.go index 52820ad4d..7ccba8004 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -15,10 +15,11 @@ import ( // Replace replaces the value of target with val. // The old value is restored when the test ends. -func Replace[T any](t *testing.T, target *T, val T) { +func Replace[T any](t testing.TB, target *T, val T) { t.Helper() if target == nil { t.Fatalf("Replace: nil pointer") + panic("unreachable") // pacify staticcheck } old := *target t.Cleanup(func() { diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 2c8f48b52..37522029a 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -600,10 +600,9 @@ foo_foo_b 1 } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer func() { expvarDo = expvar.Do }() - expvarDo = func(f func(expvar.KeyValue)) { + tstest.Replace(t, &expvarDo, func(f func(expvar.KeyValue)) { f(expvar.KeyValue{Key: tt.k, Value: tt.v}) - } + }) rec := httptest.NewRecorder() VarzHandler(rec, httptest.NewRequest("GET", "/", nil)) if got := rec.Body.Bytes(); string(got) != tt.want { @@ -792,11 +791,10 @@ func TestSortedStructAllocs(t *testing.T) { } func TestVarzHandlerSorting(t *testing.T) { - defer func() { expvarDo = expvar.Do }() - expvarDo = func(f func(expvar.KeyValue)) { + tstest.Replace(t, &expvarDo, func(f func(expvar.KeyValue)) { f(expvar.KeyValue{Key: "counter_zz", Value: new(expvar.Int)}) f(expvar.KeyValue{Key: "gauge_aa", Value: new(expvar.Int)}) - } + }) rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", nil) VarzHandler(rec, req) diff --git a/util/dirwalk/dirwalk_test.go b/util/dirwalk/dirwalk_test.go index 2bec13ba2..15ebc13dd 100644 --- a/util/dirwalk/dirwalk_test.go +++ b/util/dirwalk/dirwalk_test.go @@ -13,6 +13,7 @@ import ( "testing" "go4.org/mem" + "tailscale.com/tstest" ) func TestWalkShallowOSSpecific(t *testing.T) { @@ -28,9 +29,7 @@ func TestWalkShallowPortable(t *testing.T) { func testWalkShallow(t *testing.T, portable bool) { if portable { - old := osWalkShallow - defer func() { osWalkShallow = old }() - osWalkShallow = nil + tstest.Replace(t, &osWalkShallow, nil) } d := t.TempDir() diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index d7f408135..5edd505f4 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -419,14 +419,8 @@ func TestOmitDropLogging(t *testing.T) { } func TestLoggingPrivacy(t *testing.T) { - oldDrop := dropBucket - oldAccept := acceptBucket - dropBucket = rate.NewLimiter(2^32, 2^32) - acceptBucket = dropBucket - defer func() { - dropBucket = oldDrop - acceptBucket = oldAccept - }() + tstest.Replace(t, &dropBucket, rate.NewLimiter(2^32, 2^32)) + tstest.Replace(t, &acceptBucket, dropBucket) var ( logged bool