From b9ebf7cf14659180e71c939d1e121afe6d6595af Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Fri, 3 Mar 2023 16:18:59 -0800 Subject: [PATCH] tstest: add method to Replace values for tests We have many function pointers that we replace for the duration of test and restore it on test completion, add method to do that. Signed-off-by: Maisem Ali --- cmd/tailscale/cli/cli_test.go | 8 ++------ net/dnscache/dnscache_test.go | 6 +++--- net/netcheck/netcheck_test.go | 5 ++--- tstest/tstest.go | 17 +++++++++++++++++ tstest/tstest_test.go | 24 ++++++++++++++++++++++++ 5 files changed, 48 insertions(+), 12 deletions(-) create mode 100644 tstest/tstest_test.go diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index cf21d4885..48e0349de 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -1075,16 +1075,12 @@ func TestUpdatePrefs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.sshOverTailscale { - old := getSSHClientEnvVar - getSSHClientEnvVar = func() string { return "100.100.100.100 1 1" } - t.Cleanup(func() { getSSHClientEnvVar = old }) + tstest.Replace(t, &getSSHClientEnvVar, func() string { return "100.100.100.100 1 1" }) } else if isSSHOverTailscale() { // The test is being executed over a "real" tailscale SSH // session, but sshOverTailscale is unset. Make the test appear // as if it's not over tailscale SSH. - old := getSSHClientEnvVar - getSSHClientEnvVar = func() string { return "" } - t.Cleanup(func() { getSSHClientEnvVar = old }) + tstest.Replace(t, &getSSHClientEnvVar, func() string { return "" }) } if tt.env.goos == "" { tt.env.goos = "linux" diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 3f6360a10..ef4249b74 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -13,6 +13,8 @@ import ( "reflect" "testing" "time" + + "tailscale.com/tstest" ) var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") @@ -142,9 +144,7 @@ func TestResolverAllHostStaticResult(t *testing.T) { } func TestShouldTryBootstrap(t *testing.T) { - oldDebug := debug - t.Cleanup(func() { debug = oldDebug }) - debug = func() bool { return true } + tstest.Replace(t, &debug, func() bool { return true }) type step struct { ip netip.Addr // IP we pretended to dial diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 31f50ab6b..797889926 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -22,6 +22,7 @@ import ( "tailscale.com/net/stun" "tailscale.com/net/stun/stuntest" "tailscale.com/tailcfg" + "tailscale.com/tstest" ) func TestHairpinSTUN(t *testing.T) { @@ -679,9 +680,7 @@ func TestNoCaptivePortalWhenUDP(t *testing.T) { } }) - oldTransport := noRedirectClient.Transport - t.Cleanup(func() { noRedirectClient.Transport = oldTransport }) - noRedirectClient.Transport = tr + tstest.Replace(t, &noRedirectClient.Transport, http.RoundTripper(tr)) stunAddr, cleanup := stuntest.Serve(t) defer cleanup() diff --git a/tstest/tstest.go b/tstest/tstest.go index 6cfcc6b36..52820ad4d 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -6,12 +6,29 @@ package tstest import ( "context" + "testing" "time" "tailscale.com/logtail/backoff" "tailscale.com/types/logger" ) +// Replace replaces the value of target with val. +// The old value is restored when the test ends. +func Replace[T any](t *testing.T, target *T, val T) { + t.Helper() + if target == nil { + t.Fatalf("Replace: nil pointer") + } + old := *target + t.Cleanup(func() { + *target = old + }) + + *target = val + return +} + // WaitFor retries try for up to maxWait. // It returns nil once try returns nil the first time. // If maxWait passes without success, it returns try's last error. diff --git a/tstest/tstest_test.go b/tstest/tstest_test.go new file mode 100644 index 000000000..e988d5d56 --- /dev/null +++ b/tstest/tstest_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import "testing" + +func TestReplace(t *testing.T) { + before := "before" + done := false + t.Run("replace", func(t *testing.T) { + Replace(t, &before, "after") + if before != "after" { + t.Errorf("before = %q; want %q", before, "after") + } + done = true + }) + if !done { + t.Fatal("subtest didn't run") + } + if before != "before" { + t.Errorf("before = %q; want %q", before, "before") + } +}