From bdb69d1b1fc4ee08cfb13b5d0b7bab79e162bd4e Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 29 Sep 2025 14:03:32 -0700 Subject: [PATCH] net/dns/resolver: fix data race in test Fixes #17339 Change-Id: I486d2a0e0931d701923c1e0f8efbda99510ab19b Signed-off-by: Brad Fitzpatrick --- net/dns/resolver/forwarder.go | 20 ++++++++++-------- net/dns/resolver/forwarder_test.go | 34 +++++++++--------------------- 2 files changed, 21 insertions(+), 33 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index c87fbd504..105229fb8 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -217,11 +217,12 @@ type resolverAndDelay struct { // forwarder forwards DNS packets to a number of upstream nameservers. type forwarder struct { - logf logger.Logf - netMon *netmon.Monitor // always non-nil - linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it - dialer *tsdial.Dialer - health *health.Tracker // always non-nil + logf logger.Logf + netMon *netmon.Monitor // always non-nil + linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it + dialer *tsdial.Dialer + health *health.Tracker // always non-nil + verboseFwd bool // if true, log all DNS forwarding controlKnobs *controlknobs.Knobs // or nil @@ -258,6 +259,7 @@ func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkS dialer: dialer, health: health, controlKnobs: knobs, + verboseFwd: verboseDNSForward(), } f.ctx, f.ctxCancel = context.WithCancel(context.Background()) return f @@ -515,7 +517,7 @@ var ( // // send expects the reply to have the same txid as txidOut. func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { - if verboseDNSForward() { + if f.verboseFwd { id := forwarderCount.Add(1) domain, typ, _ := nameFromQuery(fq.packet) f.logf("forwarder.send(%q, %d, %v, %d) [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), id) @@ -978,7 +980,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } defer fq.closeOnCtxDone.Close() - if verboseDNSForward() { + if f.verboseFwd { domainSha256 := sha256.Sum256([]byte(domain)) domainSig := base64.RawStdEncoding.EncodeToString(domainSha256[:3]) f.logf("request(%d, %v, %d, %s) %d...", fq.txid, typ, len(domain), domainSig, len(fq.packet)) @@ -1023,7 +1025,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo metricDNSFwdErrorContext.Add(1) return fmt.Errorf("waiting to send response: %w", ctx.Err()) case responseChan <- packet{v, query.family, query.addr}: - if verboseDNSForward() { + if f.verboseFwd { f.logf("response(%d, %v, %d) = %d, nil", fq.txid, typ, len(domain), len(v)) } metricDNSFwdSuccess.Add(1) @@ -1053,7 +1055,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) case responseChan <- res: - if verboseDNSForward() { + if f.verboseFwd { f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) } return nil diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index f77388ca7..b5cc7d018 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -12,7 +12,6 @@ import ( "io" "net" "net/netip" - "os" "reflect" "slices" "strings" @@ -23,7 +22,6 @@ import ( dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/control/controlknobs" - "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" @@ -400,13 +398,6 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on return } -func enableDebug(tb testing.TB) { - const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND" - oldVal := os.Getenv(debugKnob) - envknob.Setenv(debugKnob, "true") - tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) }) -} - func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) { name := dns.MustNewName(domain) @@ -554,9 +545,11 @@ func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), po return resp } -func TestForwarderTCPFallback(t *testing.T) { - enableDebug(t) +func beVerbose(f *forwarder) { + f.verboseFwd = true +} +func TestForwarderTCPFallback(t *testing.T) { const domain = "large-dns-response.tailscale.com." // Make a response that's very large, containing a bunch of localhost addresses. @@ -576,7 +569,7 @@ func TestForwarderTCPFallback(t *testing.T) { } }) - resp := mustRunTestQuery(t, request, nil, port) + resp := mustRunTestQuery(t, request, beVerbose, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -592,8 +585,6 @@ func TestForwarderTCPFallback(t *testing.T) { // Test to ensure that if the UDP listener is unresponsive, we always make a // TCP request even if we never get a response. func TestForwarderTCPFallbackTimeout(t *testing.T) { - enableDebug(t) - const domain = "large-dns-response.tailscale.com." // Make a response that's very large, containing a bunch of localhost addresses. @@ -614,7 +605,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } }) - resp := mustRunTestQuery(t, request, nil, port) + resp := mustRunTestQuery(t, request, beVerbose, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -624,8 +615,6 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } func TestForwarderTCPFallbackDisabled(t *testing.T) { - enableDebug(t) - const domain = "large-dns-response.tailscale.com." // Make a response that's very large, containing a bunch of localhost addresses. @@ -646,6 +635,7 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { }) resp := mustRunTestQuery(t, request, func(fwd *forwarder) { + fwd.verboseFwd = true // Disable retries for this test. fwd.controlKnobs = &controlknobs.Knobs{} fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) @@ -668,8 +658,6 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { // Test to ensure that we propagate DNS errors func TestForwarderTCPFallbackError(t *testing.T) { - enableDebug(t) - const domain = "error-response.tailscale.com." // Our response is a SERVFAIL @@ -686,7 +674,7 @@ func TestForwarderTCPFallbackError(t *testing.T) { } }) - resp, err := runTestQuery(t, request, nil, port) + resp, err := runTestQuery(t, request, beVerbose, port) if !sawRequest.Load() { t.Error("did not see DNS request") } @@ -706,8 +694,6 @@ func TestForwarderTCPFallbackError(t *testing.T) { // Test to ensure that if we have more than one resolver, and at least one of them // returns a successful response, we propagate it. func TestForwarderWithManyResolvers(t *testing.T) { - enableDebug(t) - const domain = "example.com." request := makeTestRequest(t, domain) @@ -810,7 +796,7 @@ func TestForwarderWithManyResolvers(t *testing.T) { for i := range tt.responses { ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {}) } - gotResponse, err := runTestQuery(t, request, nil, ports...) + gotResponse, err := runTestQuery(t, request, beVerbose, ports...) if err != nil { t.Fatalf("wanted nil, got %v", err) } @@ -869,7 +855,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { }) - res, err := runTestQuery(t, request, nil, port) + res, err := runTestQuery(t, request, beVerbose, port) if err != nil { t.Fatal(err) }