From e590877667fff296719cb0a886d899cd29015935 Mon Sep 17 00:00:00 2001 From: Brendan Creane Date: Mon, 8 Dec 2025 11:11:05 -0800 Subject: [PATCH] net/dns/resolver: set TC flag when UDP responses exceed size limits The forwarder was not setting the Truncated (TC) flag when UDP DNS responses exceeded either the EDNS buffer size (if present) or the RFC 1035 default 512-byte limit. This affected DoH, TCP fallback, and UDP response paths. The fix ensures checkResponseSizeAndSetTC is called in all code paths that return UDP responses, enforcing both EDNS and default UDP size limits. Added comprehensive unit tests and consolidated duplicate test helpers. Updates #18107 Signed-off-by: Brendan Creane --- net/dns/resolver/forwarder.go | 137 +++++++-- net/dns/resolver/forwarder_test.go | 475 ++++++++++++++++++++++++++--- net/dns/resolver/tsdns.go | 7 +- net/dns/resolver/tsdns_test.go | 99 ++++++ 4 files changed, 654 insertions(+), 64 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 797c5272a..43b151130 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -63,6 +63,17 @@ func truncatedFlagSet(pkt []byte) bool { return (binary.BigEndian.Uint16(pkt[2:4]) & dnsFlagTruncated) != 0 } +// setTCFlag sets the TC (truncated) flag in the DNS packet header. +// The packet must be at least headerBytes in length. +func setTCFlag(packet []byte) { + if len(packet) < headerBytes { + return + } + flags := binary.BigEndian.Uint16(packet[2:4]) + flags |= dnsFlagTruncated + binary.BigEndian.PutUint16(packet[2:4], flags) +} + const ( // dohIdleConnTimeout is how long to keep idle HTTP connections // open to DNS-over-HTTPS servers. 10 seconds is a sensible @@ -131,47 +142,59 @@ func getRCode(packet []byte) dns.RCode { return dns.RCode(packet[3] & 0x0F) } -// clampEDNSSize attempts to limit the maximum EDNS response size. This is not -// an exhaustive solution, instead only easy cases are currently handled in the -// interest of speed and reduced complexity. Only OPT records at the very end of -// the message with no option codes are addressed. -// TODO: handle more situations if we discover that they happen often -func clampEDNSSize(packet []byte, maxSize uint16) { - // optFixedBytes is the size of an OPT record with no option codes. - const optFixedBytes = 11 - const edns0Version = 0 +// findOPTRecord finds and validates the EDNS OPT record at the end of a DNS packet. +// Returns the requested buffer size and a pointer to the OPT record bytes if valid, +// or (0, nil) if no valid OPT record is found. +// The OPT record must be at the very end of the packet with no option codes. +func findOPTRecord(packet []byte) (requestedSize uint16, opt []byte) { + const optFixedBytes = 11 // size of an OPT record with no option codes + const edns0Version = 0 // EDNS version number (currently only version 0 is defined) if len(packet) < headerBytes+optFixedBytes { - return + return 0, nil } arCount := binary.BigEndian.Uint16(packet[10:12]) if arCount == 0 { // OPT shows up in an AR, so there must be no OPT - return + return 0, nil } // https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.2 - opt := packet[len(packet)-optFixedBytes:] + opt = packet[len(packet)-optFixedBytes:] if opt[0] != 0 { // OPT NAME must be 0 (root domain) - return + return 0, nil } if dns.Type(binary.BigEndian.Uint16(opt[1:3])) != dns.TypeOPT { // Not an OPT record - return + return 0, nil } - requestedSize := binary.BigEndian.Uint16(opt[3:5]) + requestedSize = binary.BigEndian.Uint16(opt[3:5]) // Ignore extended RCODE in opt[5] if opt[6] != edns0Version { // Be conservative and don't touch unknown versions. - return + return 0, nil } // Ignore flags in opt[6:9] if binary.BigEndian.Uint16(opt[9:11]) != 0 { // RDLEN must be 0 (no variable length data). We're at the end of the - // packet so this should be 0 anyway).. + // packet so this should be 0 anyway. + return 0, nil + } + + return requestedSize, opt +} + +// clampEDNSSize attempts to limit the maximum EDNS response size. This is not +// an exhaustive solution, instead only easy cases are currently handled in the +// interest of speed and reduced complexity. Only OPT records at the very end of +// the message with no option codes are addressed. +// TODO: handle more situations if we discover that they happen often +func clampEDNSSize(packet []byte, maxSize uint16) { + requestedSize, opt := findOPTRecord(packet) + if opt == nil { return } @@ -183,6 +206,57 @@ func clampEDNSSize(packet []byte, maxSize uint16) { binary.BigEndian.PutUint16(opt[3:5], maxSize) } +// getEDNSBufferSize extracts the EDNS buffer size from a DNS request packet. +// Returns (bufferSize, true) if a valid EDNS OPT record is found, +// or (0, false) if no EDNS OPT record is found or if there's an error. +func getEDNSBufferSize(packet []byte) (uint16, bool) { + requestedSize, opt := findOPTRecord(packet) + return requestedSize, opt != nil +} + +// checkResponseSizeAndSetTC sets the TC (truncated) flag in the DNS header when +// the response exceeds the maximum UDP size. If no EDNS OPT record is present +// in the request, it sets the TC flag when the response is bigger than 512 bytes +// per RFC 1035. If an EDNS OPT record is present, it sets the TC flag when the +// response is bigger than the EDNS buffer size. The response buffer is not +// truncated; only the TC flag is set. Returns the response unchanged except for +// the TC flag being set if needed. +func checkResponseSizeAndSetTC(response []byte, request []byte, family string, logf logger.Logf) []byte { + const defaultUDPSize = 512 // default maximum UDP DNS packet size per RFC 1035 + + // Only check for UDP queries; TCP can handle larger responses + if family != "udp" { + return response + } + + // Check if TC flag is already set + if len(response) < headerBytes { + return response + } + if truncatedFlagSet(response) { + // TC flag already set, nothing to do + return response + } + + ednsSize, hasEDNS := getEDNSBufferSize(request) + + // Determine maximum allowed size + var maxSize int + if hasEDNS { + maxSize = int(ednsSize) + } else { + // No EDNS: enforce default UDP size limit per RFC 1035 + maxSize = defaultUDPSize + } + + // Check if response exceeds maximum size + if len(response) > maxSize { + setTCFlag(response) + } + + return response +} + // dnsForwarderFailing should be raised when the forwarder is unable to reach the // upstream resolvers. This is a high severity warning as it results in "no internet". // This warning must be cleared when the forwarder is working again. @@ -535,7 +609,13 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe if !buildfeatures.HasPeerAPIClient { return nil, feature.ErrUnavailable } - return f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet) + res, err := f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet) + if err != nil { + return nil, err + } + // Check response size and set TC flag if needed (only for UDP queries) + res = checkResponseSizeAndSetTC(res, fq.packet, fq.family, f.logf) + return res, nil } if strings.HasPrefix(rr.name.Addr, "https://") { // Only known DoH providers are supported currently. Specifically, we @@ -546,7 +626,13 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe // them. urlBase := rr.name.Addr if hc, ok := f.getKnownDoHClientForProvider(urlBase); ok { - return f.sendDoH(ctx, urlBase, hc, fq.packet) + res, err := f.sendDoH(ctx, urlBase, hc, fq.packet) + if err != nil { + return nil, err + } + // Check response size and set TC flag if needed (only for UDP queries) + res = checkResponseSizeAndSetTC(res, fq.packet, fq.family, f.logf) + return res, nil } metricDNSFwdErrorType.Add(1) return nil, fmt.Errorf("arbitrary https:// resolvers not supported yet") @@ -710,12 +796,15 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn f.logf("recv: packet too small (%d bytes)", n) } out = out[:n] + tcFlagAlreadySet := truncatedFlagSet(out) + txid := getTxID(out) if txid != fq.txid { metricDNSFwdUDPErrorTxID.Add(1) return nil, errTxIDMismatch } rcode := getRCode(out) + // don't forward transient errors back to the client when the server fails if rcode == dns.RCodeServerFailure { f.logf("recv: response code indicating server failure: %d", rcode) @@ -723,11 +812,9 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn return nil, errServerFailure } - if truncated { - // Set the truncated bit if it wasn't already. - flags := binary.BigEndian.Uint16(out[2:4]) - flags |= dnsFlagTruncated - binary.BigEndian.PutUint16(out[2:4], flags) + // Set the truncated bit if buffer was truncated during read and the flag isn't already set + if truncated && !tcFlagAlreadySet { + setTCFlag(out) // TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2 // states that truncation should head drop so that the authority @@ -736,6 +823,8 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn // best we can do. } + out = checkResponseSizeAndSetTC(out, fq.packet, fq.family, f.logf) + if truncatedFlagSet(out) { metricDNSFwdTruncated.Add(1) } diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 0b38008c8..9510b4294 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -34,6 +34,46 @@ func (rr resolverAndDelay) String() string { return fmt.Sprintf("%v+%v", rr.name, rr.startDelay) } +// setTCFlagInPacket sets the TC flag in a DNS packet (for testing). +func setTCFlagInPacket(packet []byte) { + if len(packet) < headerBytes { + return + } + flags := binary.BigEndian.Uint16(packet[2:4]) + flags |= dnsFlagTruncated + binary.BigEndian.PutUint16(packet[2:4], flags) +} + +// clearTCFlagInPacket clears the TC flag in a DNS packet (for testing). +func clearTCFlagInPacket(packet []byte) { + if len(packet) < headerBytes { + return + } + flags := binary.BigEndian.Uint16(packet[2:4]) + flags &^= dnsFlagTruncated + binary.BigEndian.PutUint16(packet[2:4], flags) +} + +// verifyEDNSBufferSize verifies a request has the expected EDNS buffer size. +func verifyEDNSBufferSize(t *testing.T, request []byte, expectedSize uint16) { + t.Helper() + ednsSize, hasEDNS := getEDNSBufferSize(request) + if !hasEDNS { + t.Fatalf("request should have EDNS OPT record") + } + if ednsSize != expectedSize { + t.Fatalf("request EDNS size = %d, want %d", ednsSize, expectedSize) + } +} + +// setupForwarderWithTCPRetriesDisabled returns a forwarder modifier that disables TCP retries. +func setupForwarderWithTCPRetriesDisabled() func(*forwarder) { + return func(fwd *forwarder) { + fwd.controlKnobs = &controlknobs.Knobs{} + fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) + } +} + func TestResolversWithDelays(t *testing.T) { // query q := func(ss ...string) (ipps []*dnstype.Resolver) { @@ -428,22 +468,16 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) } // Our request is a single A query for the domain in the answer, above. - builder = dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypeA, - Class: dns.ClassINET, - }) - request, err = builder.Finish() - if err != nil { - tb.Fatal(err) - } + request = makeTestRequest(tb, domain, dns.TypeA, 0) return } func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { + return runTestQueryWithFamily(tb, request, "udp", modify, ports...) +} + +func runTestQueryWithFamily(tb testing.TB, request []byte, family string, modify func(*forwarder), ports ...uint16) ([]byte, error) { logf := tstest.WhileTestRunningLogger(tb) bus := eventbustest.NewBus(tb) netMon, err := netmon.New(bus, logf) @@ -467,7 +501,7 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports rpkt := packet{ bs: request, - family: "tcp", + family: family, addr: netip.MustParseAddrPort("127.0.0.1:12345"), } @@ -483,17 +517,29 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports } } -// makeTestRequest returns a new TypeA request for the given domain. -func makeTestRequest(tb testing.TB, domain string) []byte { +// makeTestRequest returns a new DNS request for the given domain. +// If queryType is 0, it defaults to TypeA. If ednsSize > 0, it adds an EDNS OPT record. +func makeTestRequest(tb testing.TB, domain string, queryType dns.Type, ednsSize uint16) []byte { tb.Helper() + if queryType == 0 { + queryType = dns.TypeA + } name := dns.MustNewName(domain) builder := dns.NewBuilder(nil, dns.Header{}) builder.StartQuestions() builder.Question(dns.Question{ Name: name, - Type: dns.TypeA, + Type: queryType, Class: dns.ClassINET, }) + if ednsSize > 0 { + builder.StartAdditionals() + builder.OPTResource(dns.ResourceHeader{ + Name: dns.MustNewName("."), + Type: dns.TypeOPT, + Class: dns.Class(ednsSize), + }, dns.OPTResource{}) + } request, err := builder.Finish() if err != nil { tb.Fatal(err) @@ -549,6 +595,371 @@ func beVerbose(f *forwarder) { f.verboseFwd = true } +// makeTestRequestWithEDNS returns a new TypeTXT request for the given domain with EDNS buffer size. +// Deprecated: Use makeTestRequest with queryType and ednsSize parameters instead. +func makeTestRequestWithEDNS(tb testing.TB, domain string, ednsSize uint16) []byte { + return makeTestRequest(tb, domain, dns.TypeTXT, ednsSize) +} + +// makeEDNSResponse creates a DNS response of approximately the specified size +// with TXT records and an OPT record. The response will NOT have the TC flag set +// (simulating a non-compliant server that doesn't set TC when response exceeds EDNS buffer). +// The actual size may vary significantly due to DNS packet structure constraints. +func makeEDNSResponse(tb testing.TB, domain string, targetSize int) []byte { + tb.Helper() + // Use makeResponseOfSize with includeOPT=true + // Allow significant variance since DNS packet sizes are hard to predict exactly + // Use a combination of fixed tolerance (200 bytes) and percentage (25%) for larger targets + response := makeResponseOfSize(tb, domain, targetSize, true) + actualSize := len(response) + maxVariance := 200 + if targetSize > 400 { + // For larger targets, allow 25% variance + maxVariance = targetSize * 25 / 100 + } + if actualSize < targetSize-maxVariance || actualSize > targetSize+maxVariance { + tb.Fatalf("response size = %d, want approximately %d (variance: %d, allowed: ±%d)", + actualSize, targetSize, actualSize-targetSize, maxVariance) + } + return response +} + +func TestEDNSBufferSizeTruncation(t *testing.T) { + const domain = "edns-test.example.com." + const ednsBufferSize = 500 // Small EDNS buffer + const responseSize = 800 // Response exceeds EDNS but < maxResponseBytes + + // Create a response that exceeds EDNS buffer size but doesn't have TC flag set + response := makeEDNSResponse(t, domain, responseSize) + + // Create a request with EDNS buffer size + request := makeTestRequest(t, domain, dns.TypeTXT, ednsBufferSize) + verifyEDNSBufferSize(t, request, ednsBufferSize) + + // Verify response doesn't have TC flag set initially + if truncatedFlagSet(response) { + t.Fatal("test response should not have TC flag set initially") + } + + // Set up test DNS server + port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { + verifyEDNSBufferSize(t, gotRequest, ednsBufferSize) + }) + + // Disable TCP retries to ensure we test UDP path + resp := mustRunTestQuery(t, request, setupForwarderWithTCPRetriesDisabled(), port) + + // Verify the response has TC flag set by forwarder + if !truncatedFlagSet(resp) { + t.Errorf("TC flag not set in response (response size=%d, EDNS=%d)", len(resp), ednsBufferSize) + } + + // Verify response size is preserved (not truncated by buffer) + if len(resp) != len(response) { + t.Errorf("response size = %d, want %d (response should not be truncated by buffer)", len(resp), len(response)) + } + + // Verify response size exceeds EDNS buffer + if len(resp) <= int(ednsBufferSize) { + t.Errorf("response size = %d, should exceed EDNS buffer size %d", len(resp), ednsBufferSize) + } +} + +// makeResponseOfSize creates a DNS response of approximately the specified size +// with TXT records. The response will NOT have the TC flag set initially. +// If includeOPT is true, an OPT record is added to the response. +func makeResponseOfSize(tb testing.TB, domain string, targetSize int, includeOPT bool) []byte { + tb.Helper() + name := dns.MustNewName(domain) + + // Estimate how many TXT records we need + // Each TXT record with ~200 bytes of data adds roughly 220-230 bytes to the packet + // (including DNS headers, name compression, etc.) + bytesPerRecord := 220 + baseSize := 50 // Approximate base packet size (header + question) + if includeOPT { + baseSize += 11 // OPT record adds ~11 bytes + } + estimatedRecords := (targetSize - baseSize) / bytesPerRecord + if estimatedRecords < 1 { + estimatedRecords = 1 + } + + // Start with estimated records and adjust + txtLen := 200 + var response []byte + var err error + + for attempt := 0; attempt < 10; attempt++ { + testBuilder := dns.NewBuilder(nil, dns.Header{ + Response: true, + Authoritative: true, + RCode: dns.RCodeSuccess, + }) + testBuilder.StartQuestions() + testBuilder.Question(dns.Question{ + Name: name, + Type: dns.TypeTXT, + Class: dns.ClassINET, + }) + testBuilder.StartAnswers() + + for i := 0; i < estimatedRecords; i++ { + txtValue := strings.Repeat("x", txtLen) + testBuilder.TXTResource(dns.ResourceHeader{ + Name: name, + Type: dns.TypeTXT, + Class: dns.ClassINET, + TTL: 300, + }, dns.TXTResource{ + TXT: []string{txtValue}, + }) + } + + // Optionally add OPT record + if includeOPT { + testBuilder.StartAdditionals() + testBuilder.OPTResource(dns.ResourceHeader{ + Name: dns.MustNewName("."), + Type: dns.TypeOPT, + Class: dns.Class(4096), + }, dns.OPTResource{}) + } + + response, err = testBuilder.Finish() + if err != nil { + tb.Fatal(err) + } + + actualSize := len(response) + // Stop if we've reached or slightly exceeded the target + // Allow up to 20% overshoot to avoid excessive iterations + if actualSize >= targetSize && actualSize <= targetSize*120/100 { + break + } + // If we've overshot significantly, we're done (better than undershooting) + if actualSize > targetSize*120/100 { + break + } + + // Adjust for next attempt + needed := targetSize - actualSize + additionalRecords := (needed / bytesPerRecord) + 1 + estimatedRecords += additionalRecords + if estimatedRecords > 200 { + // If we need too many records, increase TXT length instead + txtLen = 255 // Max single TXT string length + bytesPerRecord = 280 // Adjusted estimate + estimatedRecords = (targetSize - baseSize) / bytesPerRecord + if estimatedRecords < 1 { + estimatedRecords = 1 + } + } + } + + // Ensure TC flag is NOT set initially + clearTCFlagInPacket(response) + + return response +} + +func TestCheckResponseSizeAndSetTC(t *testing.T) { + const domain = "test.example.com." + logf := func(format string, args ...any) { + // Silent logger for tests + } + + tests := []struct { + name string + responseSize int + requestHasEDNS bool + ednsSize uint16 + family string + responseTCSet bool // Whether response has TC flag set initially + wantTCSet bool // Whether TC flag should be set after function call + skipIfNotExact bool // Skip test if we can't hit exact size (for edge cases) + }{ + // Default UDP size (512 bytes) without EDNS + { + name: "UDP_noEDNS_small_should_not_set_TC", + responseSize: 400, + requestHasEDNS: false, + family: "udp", + wantTCSet: false, + }, + { + name: "UDP_noEDNS_512bytes_should_not_set_TC", + responseSize: 512, + requestHasEDNS: false, + family: "udp", + wantTCSet: false, + skipIfNotExact: true, + }, + { + name: "UDP_noEDNS_513bytes_should_set_TC", + responseSize: 513, + requestHasEDNS: false, + family: "udp", + wantTCSet: true, + skipIfNotExact: true, + }, + { + name: "UDP_noEDNS_large_should_set_TC", + responseSize: 600, + requestHasEDNS: false, + family: "udp", + wantTCSet: true, + }, + + // EDNS edge cases + { + name: "UDP_EDNS_small_under_limit_should_not_set_TC", + responseSize: 450, + requestHasEDNS: true, + ednsSize: 500, + family: "udp", + wantTCSet: false, + }, + { + name: "UDP_EDNS_at_limit_should_not_set_TC", + responseSize: 500, + requestHasEDNS: true, + ednsSize: 500, + family: "udp", + wantTCSet: false, + }, + { + name: "UDP_EDNS_over_limit_should_set_TC", + responseSize: 550, + requestHasEDNS: true, + ednsSize: 500, + family: "udp", + wantTCSet: true, + }, + { + name: "UDP_EDNS_large_over_limit_should_set_TC", + responseSize: 1500, + requestHasEDNS: true, + ednsSize: 1200, + family: "udp", + wantTCSet: true, + }, + + // Early return paths + { + name: "TCP_query_should_skip", + responseSize: 1000, + family: "tcp", + wantTCSet: false, + }, + { + name: "response_too_small_should_skip", + responseSize: headerBytes - 1, + family: "udp", + wantTCSet: false, + }, + { + name: "response_exactly_headerBytes_should_not_set_TC", + responseSize: headerBytes, + family: "udp", + wantTCSet: false, + }, + { + name: "response_TC_already_set_should_skip", + responseSize: 600, + family: "udp", + responseTCSet: true, + wantTCSet: true, // Should remain set + }, + { + name: "UDP_noEDNS_large_TC_already_set_should_skip", + responseSize: 600, + requestHasEDNS: false, + family: "udp", + responseTCSet: true, + wantTCSet: true, // Should remain set + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var response []byte + + // Create response of specified size + if tt.responseSize < headerBytes { + // For too-small test, create minimal invalid packet + response = make([]byte, tt.responseSize) + // Don't set any flags, just make it too small + } else { + response = makeResponseOfSize(t, domain, tt.responseSize, false) + actualSize := len(response) + + // Only adjust expectations for UDP queries that go through size checking + // TCP queries and other early-return cases should keep their original expectations + if tt.family == "udp" && !tt.responseTCSet && actualSize >= headerBytes { + // Determine the maximum allowed size based on request + var maxSize int + if tt.requestHasEDNS { + maxSize = int(tt.ednsSize) + } else { + maxSize = 512 // default UDP size per RFC 1035 + } + + // For edge cases where exact size matters, verify we're close enough + if tt.skipIfNotExact { + // For 512/513 byte tests, we need to be very close + if actualSize < tt.responseSize-10 || actualSize > tt.responseSize+10 { + t.Skipf("skipping: could not create response close to target size %d (got %d)", tt.responseSize, actualSize) + } + // Function sets TC if response > maxSize, so adjust expectation based on actual size + tt.wantTCSet = actualSize > maxSize + } else { + // For non-exact tests, adjust expectation based on actual response size + // The function sets TC if actualSize > maxSize + tt.wantTCSet = actualSize > maxSize + } + } + } + + // Set TC flag initially if requested + if tt.responseTCSet { + setTCFlagInPacket(response) + } + + // Create request with or without EDNS + var ednsSize uint16 + if tt.requestHasEDNS { + ednsSize = tt.ednsSize + } + request := makeTestRequest(t, domain, dns.TypeTXT, ednsSize) + + // Call the function + result := checkResponseSizeAndSetTC(response, request, tt.family, logf) + + // Verify response size is preserved (function should not truncate, only set flag) + if len(result) != len(response) { + t.Errorf("response size changed: got %d, want %d", len(result), len(response)) + } + + // Verify TC flag state + if len(result) >= headerBytes { + hasTC := truncatedFlagSet(result) + if hasTC != tt.wantTCSet { + t.Errorf("TC flag: got %v, want %v (response size=%d)", hasTC, tt.wantTCSet, len(result)) + } + } else if tt.responseSize >= headerBytes { + // If we expected a valid response but got too small, that's unexpected + t.Errorf("response too small (%d bytes) but expected valid response", len(result)) + } + + // Verify response pointer is same (should be in-place modification) + if &result[0] != &response[0] { + t.Errorf("function should modify response in place, but got new slice") + } + }) + } +} + func TestForwarderTCPFallback(t *testing.T) { const domain = "large-dns-response.tailscale.com." @@ -569,7 +980,10 @@ func TestForwarderTCPFallback(t *testing.T) { } }) - resp := mustRunTestQuery(t, request, beVerbose, port) + resp, err := runTestQueryWithFamily(t, request, "tcp", beVerbose, port) + if err != nil { + t.Fatalf("error making request: %v", err) + } if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -636,17 +1050,13 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { resp := mustRunTestQuery(t, request, func(fwd *forwarder) { fwd.verboseFwd = true - // Disable retries for this test. - fwd.controlKnobs = &controlknobs.Knobs{} - fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) + setupForwarderWithTCPRetriesDisabled()(fwd) }, port) wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) // Set the truncated flag on the expected response, since that's what we expect. - flags := binary.BigEndian.Uint16(wantResp[2:4]) - flags |= dnsFlagTruncated - binary.BigEndian.PutUint16(wantResp[2:4], flags) + setTCFlagInPacket(wantResp) if !bytes.Equal(resp, wantResp) { t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp) @@ -664,7 +1074,7 @@ func TestForwarderTCPFallbackError(t *testing.T) { response := makeTestResponse(t, domain, dns.RCodeServerFailure) // Our request is a single A query for the domain in the answer, above. - request := makeTestRequest(t, domain) + request := makeTestRequest(t, domain, dns.TypeA, 0) var sawRequest atomic.Bool port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { @@ -695,7 +1105,7 @@ func TestForwarderTCPFallbackError(t *testing.T) { // returns a successful response, we propagate it. func TestForwarderWithManyResolvers(t *testing.T) { const domain = "example.com." - request := makeTestRequest(t, domain) + request := makeTestRequest(t, domain, dns.TypeA, 0) tests := []struct { name string @@ -837,20 +1247,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { }() // Our request is a single PTR query for the domain in the answer, above. - request := func() []byte { - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypePTR, - Class: dns.ClassINET, - }) - request, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return request - }() + request := makeTestRequest(t, domain, dns.TypePTR, 0) port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { }) @@ -868,7 +1265,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { func TestForwarderVerboseLogs(t *testing.T) { const domain = "test.tailscale.com." response := makeTestResponse(t, domain, dns.RCodeServerFailure) - request := makeTestRequest(t, domain) + request := makeTestRequest(t, domain, dns.TypeA, 0) port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { if !bytes.Equal(request, gotRequest) { diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 3185cbe2b..38da40b91 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -328,7 +328,12 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net return (<-responses).bs, nil } - return out, err + if err != nil { + return out, err + } + + out = checkResponseSizeAndSetTC(out, bs, family, r.logf) + return out, nil } // GetUpstreamResolvers returns the resolvers that would be used to resolve diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index f0dbb48b3..508153544 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1521,3 +1521,102 @@ func TestServfail(t *testing.T) { t.Errorf("response was %X, want %X", pkt, wantPkt) } } + +// TestLocalResponseTCFlagIntegration tests that checkResponseSizeAndSetTC is +// correctly applied to local DNS responses through the Resolver.Query integration path. +// This complements the unit test in forwarder_test.go by verifying the end-to-end behavior. +func TestLocalResponseTCFlagIntegration(t *testing.T) { + r := newResolver(t) + defer r.Close() + + r.SetConfig(dnsCfg) + + tests := []struct { + name string + query []byte + family string + wantTCSet bool + desc string + }{ + { + name: "UDP_small_local_response_no_TC", + query: dnspacket("test1.ipn.dev.", dns.TypeA, noEdns), + family: "udp", + wantTCSet: false, + desc: "Small local response (< 512 bytes) should not have TC flag set", + }, + { + name: "TCP_local_response_no_TC", + query: dnspacket("test1.ipn.dev.", dns.TypeA, noEdns), + family: "tcp", + wantTCSet: false, + desc: "TCP queries should skip TC flag setting (even for large responses)", + }, + { + name: "UDP_EDNS_request_small_response", + query: dnspacket("test1.ipn.dev.", dns.TypeA, 1500), + family: "udp", + wantTCSet: false, + desc: "Small response with EDNS request should not have TC flag set", + }, + { + name: "UDP_IPv6_response_no_TC", + query: dnspacket("test2.ipn.dev.", dns.TypeAAAA, noEdns), + family: "udp", + wantTCSet: false, + desc: "Small IPv6 local response should not have TC flag set", + }, + { + name: "UDP_reverse_lookup_no_TC", + query: dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR, noEdns), + family: "udp", + wantTCSet: false, + desc: "Small reverse lookup response should not have TC flag set", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + response, err := r.Query(context.Background(), tt.query, tt.family, netip.AddrPort{}) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + + if len(response) < headerBytes { + t.Fatalf("Response too small: %d bytes", len(response)) + } + + hasTC := truncatedFlagSet(response) + if hasTC != tt.wantTCSet { + t.Errorf("%s: TC flag = %v, want %v (response size=%d bytes)", tt.desc, hasTC, tt.wantTCSet, len(response)) + } + + // Verify response is valid by parsing it (if possible) + // Note: unpackResponse may not support all record types (e.g., PTR) + parsed, err := unpackResponse(response) + if err == nil { + // Verify the truncated field in parsed response matches the flag + if parsed.truncated != hasTC { + t.Errorf("Parsed truncated field (%v) doesn't match TC flag (%v)", parsed.truncated, hasTC) + } + } else { + // For unsupported types, just verify we can parse the header + var parser dns.Parser + h, err := parser.Start(response) + if err != nil { + t.Errorf("Failed to parse DNS header: %v", err) + } else { + // Verify header truncated flag matches + if h.Truncated != hasTC { + t.Errorf("Header truncated field (%v) doesn't match TC flag (%v)", h.Truncated, hasTC) + } + } + } + + // Verify response size is reasonable (local responses are typically small) + if len(response) > 1000 { + t.Logf("Warning: Local response is unusually large: %d bytes", len(response)) + } + }) + } +}