diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index cf21d4885..48e0349de 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -1075,16 +1075,12 @@ func TestUpdatePrefs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.sshOverTailscale { - old := getSSHClientEnvVar - getSSHClientEnvVar = func() string { return "100.100.100.100 1 1" } - t.Cleanup(func() { getSSHClientEnvVar = old }) + tstest.Replace(t, &getSSHClientEnvVar, func() string { return "100.100.100.100 1 1" }) } else if isSSHOverTailscale() { // The test is being executed over a "real" tailscale SSH // session, but sshOverTailscale is unset. Make the test appear // as if it's not over tailscale SSH. - old := getSSHClientEnvVar - getSSHClientEnvVar = func() string { return "" } - t.Cleanup(func() { getSSHClientEnvVar = old }) + tstest.Replace(t, &getSSHClientEnvVar, func() string { return "" }) } if tt.env.goos == "" { tt.env.goos = "linux" diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 3f6360a10..ef4249b74 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -13,6 +13,8 @@ import ( "reflect" "testing" "time" + + "tailscale.com/tstest" ) var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") @@ -142,9 +144,7 @@ func TestResolverAllHostStaticResult(t *testing.T) { } func TestShouldTryBootstrap(t *testing.T) { - oldDebug := debug - t.Cleanup(func() { debug = oldDebug }) - debug = func() bool { return true } + tstest.Replace(t, &debug, func() bool { return true }) type step struct { ip netip.Addr // IP we pretended to dial diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 31f50ab6b..797889926 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -22,6 +22,7 @@ import ( "tailscale.com/net/stun" "tailscale.com/net/stun/stuntest" "tailscale.com/tailcfg" + "tailscale.com/tstest" ) func TestHairpinSTUN(t *testing.T) { @@ -679,9 +680,7 @@ func TestNoCaptivePortalWhenUDP(t *testing.T) { } }) - oldTransport := noRedirectClient.Transport - t.Cleanup(func() { noRedirectClient.Transport = oldTransport }) - noRedirectClient.Transport = tr + tstest.Replace(t, &noRedirectClient.Transport, http.RoundTripper(tr)) stunAddr, cleanup := stuntest.Serve(t) defer cleanup() diff --git a/tstest/tstest.go b/tstest/tstest.go index 6cfcc6b36..52820ad4d 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -6,12 +6,29 @@ package tstest import ( "context" + "testing" "time" "tailscale.com/logtail/backoff" "tailscale.com/types/logger" ) +// Replace replaces the value of target with val. +// The old value is restored when the test ends. +func Replace[T any](t *testing.T, target *T, val T) { + t.Helper() + if target == nil { + t.Fatalf("Replace: nil pointer") + } + old := *target + t.Cleanup(func() { + *target = old + }) + + *target = val + return +} + // WaitFor retries try for up to maxWait. // It returns nil once try returns nil the first time. // If maxWait passes without success, it returns try's last error. diff --git a/tstest/tstest_test.go b/tstest/tstest_test.go new file mode 100644 index 000000000..e988d5d56 --- /dev/null +++ b/tstest/tstest_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import "testing" + +func TestReplace(t *testing.T) { + before := "before" + done := false + t.Run("replace", func(t *testing.T) { + Replace(t, &before, "after") + if before != "after" { + t.Errorf("before = %q; want %q", before, "after") + } + done = true + }) + if !done { + t.Fatal("subtest didn't run") + } + if before != "before" { + t.Errorf("before = %q; want %q", before, "before") + } +}