all: use tstest.Replace more

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/7457/head
Maisem Ali 2 years ago committed by Maisem Ali
parent 57a44846ae
commit 1a30b2d73f

@ -11,14 +11,12 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"testing" "testing"
"tailscale.com/tstest"
) )
func BenchmarkHandleBootstrapDNS(b *testing.B) { func BenchmarkHandleBootstrapDNS(b *testing.B) {
prev := *bootstrapDNS tstest.Replace(b, bootstrapDNS, "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com")
*bootstrapDNS = "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com"
defer func() {
*bootstrapDNS = prev
}()
refreshBootstrapDNS() refreshBootstrapDNS()
w := new(bitbucketResponseWriter) w := new(bitbucketResponseWriter)
req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil) req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil)

@ -13,6 +13,7 @@ import (
"go4.org/mem" "go4.org/mem"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/types/opt" "tailscale.com/types/opt"
@ -21,12 +22,11 @@ import (
) )
func TestUndeltaPeers(t *testing.T) { func TestUndeltaPeers(t *testing.T) {
defer func(old func() time.Time) { clockNow = old }(clockNow)
var curTime time.Time var curTime time.Time
clockNow = func() time.Time { tstest.Replace(t, &clockNow, func() time.Time {
return curTime return curTime
} })
online := func(v bool) func(*tailcfg.Node) { online := func(v bool) func(*tailcfg.Node) {
return func(n *tailcfg.Node) { return func(n *tailcfg.Node) {
n.Online = &v n.Online = &v

@ -497,8 +497,7 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) {
// Issue 1573: don't generate a machine key if we don't want to be running. // Issue 1573: don't generate a machine key if we don't want to be running.
func TestLazyMachineKeyGeneration(t *testing.T) { func TestLazyMachineKeyGeneration(t *testing.T) {
defer func(old func() bool) { panicOnMachineKeyGeneration = old }(panicOnMachineKeyGeneration) tstest.Replace(t, &panicOnMachineKeyGeneration, func() bool { return true })
panicOnMachineKeyGeneration = func() bool { return true }
var logf logger.Logf = logger.Discard var logf logger.Logf = logger.Discard
store := new(mem.Store) store := new(mem.Store)

@ -14,6 +14,7 @@ import (
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/hostinfo" "tailscale.com/hostinfo"
"tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnlocal"
"tailscale.com/tstest"
) )
func TestValidHost(t *testing.T) { func TestValidHost(t *testing.T) {
@ -42,10 +43,7 @@ func TestValidHost(t *testing.T) {
} }
func TestSetPushDeviceToken(t *testing.T) { func TestSetPushDeviceToken(t *testing.T) {
validLocalHostForTesting = true tstest.Replace(t, &validLocalHostForTesting, true)
defer func() {
validLocalHostForTesting = false
}()
h := &Handler{ h := &Handler{
PermitWrite: true, PermitWrite: true,

@ -12,6 +12,8 @@ import (
"testing" "testing"
"unicode" "unicode"
"unsafe" "unsafe"
"tailscale.com/tstest"
) )
type filchTest struct { type filchTest struct {
@ -177,10 +179,7 @@ func TestFilchStderr(t *testing.T) {
defer pipeR.Close() defer pipeR.Close()
defer pipeW.Close() defer pipeW.Close()
stderrFD = int(pipeW.Fd()) tstest.Replace(t, &stderrFD, int(pipeW.Fd()))
defer func() {
stderrFD = 2
}()
filePrefix := t.TempDir() filePrefix := t.TempDir()
f := newFilchTest(t, filePrefix, Options{ReplaceStderr: true}) f := newFilchTest(t, filePrefix, Options{ReplaceStderr: true})

@ -997,11 +997,8 @@ func TestMarshalResponseFormatError(t *testing.T) {
} }
func TestForwardLinkSelection(t *testing.T) { func TestForwardLinkSelection(t *testing.T) {
old := initListenConfig
defer func() { initListenConfig = old }()
configCall := make(chan string, 1) configCall := make(chan string, 1)
initListenConfig = func(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error { tstest.Replace(t, &initListenConfig, func(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error {
select { select {
case configCall <- tunName: case configCall <- tunName:
return nil return nil
@ -1009,7 +1006,7 @@ func TestForwardLinkSelection(t *testing.T) {
t.Error("buffer full") t.Error("buffer full")
return errors.New("buffer full") return errors.New("buffer full")
} }
} })
// specialIP is some IP we pretend that our link selector // specialIP is some IP we pretend that our link selector
// routes differently. // routes differently.

@ -10,14 +10,14 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"tailscale.com/tstest"
) )
// test the specific /proc/net/route path as found on Google Cloud Run instances // test the specific /proc/net/route path as found on Google Cloud Run instances
func TestGoogleCloudRunDefaultRouteInterface(t *testing.T) { func TestGoogleCloudRunDefaultRouteInterface(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
savedProcNetRoutePath := procNetRoutePath tstest.Replace(t, &procNetRoutePath, filepath.Join(dir, "CloudRun"))
defer func() { procNetRoutePath = savedProcNetRoutePath }()
procNetRoutePath = filepath.Join(dir, "CloudRun")
buf := []byte("Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" + buf := []byte("Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" +
"eth0\t8008FEA9\t00000000\t0001\t0\t0\t0\t01FFFFFF\t0\t0\t0\n" + "eth0\t8008FEA9\t00000000\t0001\t0\t0\t0\t01FFFFFF\t0\t0\t0\n" +
"eth1\t00000000\t00000000\t0001\t0\t0\t0\t00000000\t0\t0\t0\n") "eth1\t00000000\t00000000\t0001\t0\t0\t0\t00000000\t0\t0\t0\n")
@ -39,9 +39,7 @@ func TestGoogleCloudRunDefaultRouteInterface(t *testing.T) {
// size can be handled. // size can be handled.
func TestExtremelyLongProcNetRoute(t *testing.T) { func TestExtremelyLongProcNetRoute(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
savedProcNetRoutePath := procNetRoutePath tstest.Replace(t, &procNetRoutePath, filepath.Join(dir, "VeryLong"))
defer func() { procNetRoutePath = savedProcNetRoutePath }()
procNetRoutePath = filepath.Join(dir, "VeryLong")
f, err := os.Create(procNetRoutePath) f, err := os.Create(procNetRoutePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -76,9 +74,7 @@ func TestExtremelyLongProcNetRoute(t *testing.T) {
// test the specific /proc/net/route path as found on AWS App Runner instances // test the specific /proc/net/route path as found on AWS App Runner instances
func TestAwsAppRunnerDefaultRouteInterface(t *testing.T) { func TestAwsAppRunnerDefaultRouteInterface(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
savedProcNetRoutePath := procNetRoutePath tstest.Replace(t, &procNetRoutePath, filepath.Join(dir, "CloudRun"))
defer func() { procNetRoutePath = savedProcNetRoutePath }()
procNetRoutePath = filepath.Join(dir, "CloudRun")
buf := []byte("Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" + buf := []byte("Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT\n" +
"eth0\t00000000\tF9AFFEA9\t0003\t0\t0\t0\t00000000\t0\t0\t0\n" + "eth0\t00000000\tF9AFFEA9\t0003\t0\t0\t0\t00000000\t0\t0\t0\n" +
"*\tFEA9FEA9\t00000000\t0005\t0\t0\t0\tFFFFFFFF\t0\t0\t0\n" + "*\tFEA9FEA9\t00000000\t0005\t0\t0\t0\tFFFFFFFF\t0\t0\t0\n" +

@ -16,6 +16,8 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"tailscale.com/tstest"
) )
func TestSynologyProxyFromConfigCached(t *testing.T) { func TestSynologyProxyFromConfigCached(t *testing.T) {
@ -24,9 +26,7 @@ func TestSynologyProxyFromConfigCached(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var orig string tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf"))
orig, synologyProxyConfigPath = synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")
defer func() { synologyProxyConfigPath = orig }()
t.Run("no config file", func(t *testing.T) { t.Run("no config file", func(t *testing.T) {
if _, err := os.Stat(synologyProxyConfigPath); err == nil { if _, err := os.Stat(synologyProxyConfigPath); err == nil {
@ -160,11 +160,9 @@ func TestSynologyProxiesFromConfig(t *testing.T) {
openReader io.ReadCloser openReader io.ReadCloser
openErr error openErr error
) )
var origOpen func() (io.ReadCloser, error) tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) {
origOpen, openSynologyProxyConf = openSynologyProxyConf, func() (io.ReadCloser, error) {
return openReader, openErr return openReader, openErr
} })
defer func() { openSynologyProxyConf = origOpen }()
t.Run("with config", func(t *testing.T) { t.Run("with config", func(t *testing.T) {
mc := &mustCloser{Reader: strings.NewReader(` mc := &mustCloser{Reader: strings.NewReader(`

@ -15,10 +15,11 @@ import (
// Replace replaces the value of target with val. // Replace replaces the value of target with val.
// The old value is restored when the test ends. // The old value is restored when the test ends.
func Replace[T any](t *testing.T, target *T, val T) { func Replace[T any](t testing.TB, target *T, val T) {
t.Helper() t.Helper()
if target == nil { if target == nil {
t.Fatalf("Replace: nil pointer") t.Fatalf("Replace: nil pointer")
panic("unreachable") // pacify staticcheck
} }
old := *target old := *target
t.Cleanup(func() { t.Cleanup(func() {

@ -600,10 +600,9 @@ foo_foo_b 1
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
defer func() { expvarDo = expvar.Do }() tstest.Replace(t, &expvarDo, func(f func(expvar.KeyValue)) {
expvarDo = func(f func(expvar.KeyValue)) {
f(expvar.KeyValue{Key: tt.k, Value: tt.v}) f(expvar.KeyValue{Key: tt.k, Value: tt.v})
} })
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
VarzHandler(rec, httptest.NewRequest("GET", "/", nil)) VarzHandler(rec, httptest.NewRequest("GET", "/", nil))
if got := rec.Body.Bytes(); string(got) != tt.want { if got := rec.Body.Bytes(); string(got) != tt.want {
@ -792,11 +791,10 @@ func TestSortedStructAllocs(t *testing.T) {
} }
func TestVarzHandlerSorting(t *testing.T) { func TestVarzHandlerSorting(t *testing.T) {
defer func() { expvarDo = expvar.Do }() tstest.Replace(t, &expvarDo, func(f func(expvar.KeyValue)) {
expvarDo = func(f func(expvar.KeyValue)) {
f(expvar.KeyValue{Key: "counter_zz", Value: new(expvar.Int)}) f(expvar.KeyValue{Key: "counter_zz", Value: new(expvar.Int)})
f(expvar.KeyValue{Key: "gauge_aa", Value: new(expvar.Int)}) f(expvar.KeyValue{Key: "gauge_aa", Value: new(expvar.Int)})
} })
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
VarzHandler(rec, req) VarzHandler(rec, req)

@ -13,6 +13,7 @@ import (
"testing" "testing"
"go4.org/mem" "go4.org/mem"
"tailscale.com/tstest"
) )
func TestWalkShallowOSSpecific(t *testing.T) { func TestWalkShallowOSSpecific(t *testing.T) {
@ -28,9 +29,7 @@ func TestWalkShallowPortable(t *testing.T) {
func testWalkShallow(t *testing.T, portable bool) { func testWalkShallow(t *testing.T, portable bool) {
if portable { if portable {
old := osWalkShallow tstest.Replace(t, &osWalkShallow, nil)
defer func() { osWalkShallow = old }()
osWalkShallow = nil
} }
d := t.TempDir() d := t.TempDir()

@ -419,14 +419,8 @@ func TestOmitDropLogging(t *testing.T) {
} }
func TestLoggingPrivacy(t *testing.T) { func TestLoggingPrivacy(t *testing.T) {
oldDrop := dropBucket tstest.Replace(t, &dropBucket, rate.NewLimiter(2^32, 2^32))
oldAccept := acceptBucket tstest.Replace(t, &acceptBucket, dropBucket)
dropBucket = rate.NewLimiter(2^32, 2^32)
acceptBucket = dropBucket
defer func() {
dropBucket = oldDrop
acceptBucket = oldAccept
}()
var ( var (
logged bool logged bool

Loading…
Cancel
Save