pull/18157/merge
Brendan Creane 23 hours ago committed by GitHub
commit 94a2a9ee43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -63,6 +63,17 @@ func truncatedFlagSet(pkt []byte) bool {
return (binary.BigEndian.Uint16(pkt[2:4]) & dnsFlagTruncated) != 0
}
// setTCFlag sets the TC (truncated) flag in the DNS packet header.
// The packet must be at least headerBytes in length.
func setTCFlag(packet []byte) {
if len(packet) < headerBytes {
return
}
flags := binary.BigEndian.Uint16(packet[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(packet[2:4], flags)
}
const (
// dohIdleConnTimeout is how long to keep idle HTTP connections
// open to DNS-over-HTTPS servers. 10 seconds is a sensible
@ -131,47 +142,59 @@ func getRCode(packet []byte) dns.RCode {
return dns.RCode(packet[3] & 0x0F)
}
// clampEDNSSize attempts to limit the maximum EDNS response size. This is not
// an exhaustive solution, instead only easy cases are currently handled in the
// interest of speed and reduced complexity. Only OPT records at the very end of
// the message with no option codes are addressed.
// TODO: handle more situations if we discover that they happen often
func clampEDNSSize(packet []byte, maxSize uint16) {
// optFixedBytes is the size of an OPT record with no option codes.
const optFixedBytes = 11
const edns0Version = 0
// findOPTRecord finds and validates the EDNS OPT record at the end of a DNS packet.
// Returns the requested buffer size and a pointer to the OPT record bytes if valid,
// or (0, nil) if no valid OPT record is found.
// The OPT record must be at the very end of the packet with no option codes.
func findOPTRecord(packet []byte) (requestedSize uint16, opt []byte) {
const optFixedBytes = 11 // size of an OPT record with no option codes
const edns0Version = 0 // EDNS version number (currently only version 0 is defined)
if len(packet) < headerBytes+optFixedBytes {
return
return 0, nil
}
arCount := binary.BigEndian.Uint16(packet[10:12])
if arCount == 0 {
// OPT shows up in an AR, so there must be no OPT
return
return 0, nil
}
// https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.2
opt := packet[len(packet)-optFixedBytes:]
opt = packet[len(packet)-optFixedBytes:]
if opt[0] != 0 {
// OPT NAME must be 0 (root domain)
return
return 0, nil
}
if dns.Type(binary.BigEndian.Uint16(opt[1:3])) != dns.TypeOPT {
// Not an OPT record
return
return 0, nil
}
requestedSize := binary.BigEndian.Uint16(opt[3:5])
requestedSize = binary.BigEndian.Uint16(opt[3:5])
// Ignore extended RCODE in opt[5]
if opt[6] != edns0Version {
// Be conservative and don't touch unknown versions.
return
return 0, nil
}
// Ignore flags in opt[6:9]
if binary.BigEndian.Uint16(opt[9:11]) != 0 {
// RDLEN must be 0 (no variable length data). We're at the end of the
// packet so this should be 0 anyway)..
// packet so this should be 0 anyway.
return 0, nil
}
return requestedSize, opt
}
// clampEDNSSize attempts to limit the maximum EDNS response size. This is not
// an exhaustive solution, instead only easy cases are currently handled in the
// interest of speed and reduced complexity. Only OPT records at the very end of
// the message with no option codes are addressed.
// TODO: handle more situations if we discover that they happen often
func clampEDNSSize(packet []byte, maxSize uint16) {
requestedSize, opt := findOPTRecord(packet)
if opt == nil {
return
}
@ -183,6 +206,57 @@ func clampEDNSSize(packet []byte, maxSize uint16) {
binary.BigEndian.PutUint16(opt[3:5], maxSize)
}
// getEDNSBufferSize extracts the EDNS buffer size from a DNS request packet.
// Returns (bufferSize, true) if a valid EDNS OPT record is found,
// or (0, false) if no EDNS OPT record is found or if there's an error.
func getEDNSBufferSize(packet []byte) (uint16, bool) {
requestedSize, opt := findOPTRecord(packet)
return requestedSize, opt != nil
}
// checkResponseSizeAndSetTC sets the TC (truncated) flag in the DNS header when
// the response exceeds the maximum UDP size. If no EDNS OPT record is present
// in the request, it sets the TC flag when the response is bigger than 512 bytes
// per RFC 1035. If an EDNS OPT record is present, it sets the TC flag when the
// response is bigger than the EDNS buffer size. The response buffer is not
// truncated; only the TC flag is set. Returns the response unchanged except for
// the TC flag being set if needed.
func checkResponseSizeAndSetTC(response []byte, request []byte, family string, logf logger.Logf) []byte {
const defaultUDPSize = 512 // default maximum UDP DNS packet size per RFC 1035
// Only check for UDP queries; TCP can handle larger responses
if family != "udp" {
return response
}
// Check if TC flag is already set
if len(response) < headerBytes {
return response
}
if truncatedFlagSet(response) {
// TC flag already set, nothing to do
return response
}
ednsSize, hasEDNS := getEDNSBufferSize(request)
// Determine maximum allowed size
var maxSize int
if hasEDNS {
maxSize = int(ednsSize)
} else {
// No EDNS: enforce default UDP size limit per RFC 1035
maxSize = defaultUDPSize
}
// Check if response exceeds maximum size
if len(response) > maxSize {
setTCFlag(response)
}
return response
}
// dnsForwarderFailing should be raised when the forwarder is unable to reach the
// upstream resolvers. This is a high severity warning as it results in "no internet".
// This warning must be cleared when the forwarder is working again.
@ -535,7 +609,13 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe
if !buildfeatures.HasPeerAPIClient {
return nil, feature.ErrUnavailable
}
return f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet)
res, err := f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet)
if err != nil {
return nil, err
}
// Check response size and set TC flag if needed (only for UDP queries)
res = checkResponseSizeAndSetTC(res, fq.packet, fq.family, f.logf)
return res, nil
}
if strings.HasPrefix(rr.name.Addr, "https://") {
// Only known DoH providers are supported currently. Specifically, we
@ -546,7 +626,13 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe
// them.
urlBase := rr.name.Addr
if hc, ok := f.getKnownDoHClientForProvider(urlBase); ok {
return f.sendDoH(ctx, urlBase, hc, fq.packet)
res, err := f.sendDoH(ctx, urlBase, hc, fq.packet)
if err != nil {
return nil, err
}
// Check response size and set TC flag if needed (only for UDP queries)
res = checkResponseSizeAndSetTC(res, fq.packet, fq.family, f.logf)
return res, nil
}
metricDNSFwdErrorType.Add(1)
return nil, fmt.Errorf("arbitrary https:// resolvers not supported yet")
@ -710,12 +796,15 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
f.logf("recv: packet too small (%d bytes)", n)
}
out = out[:n]
tcFlagAlreadySet := truncatedFlagSet(out)
txid := getTxID(out)
if txid != fq.txid {
metricDNSFwdUDPErrorTxID.Add(1)
return nil, errTxIDMismatch
}
rcode := getRCode(out)
// don't forward transient errors back to the client when the server fails
if rcode == dns.RCodeServerFailure {
f.logf("recv: response code indicating server failure: %d", rcode)
@ -723,11 +812,9 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
return nil, errServerFailure
}
if truncated {
// Set the truncated bit if it wasn't already.
flags := binary.BigEndian.Uint16(out[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(out[2:4], flags)
// Set the truncated bit if buffer was truncated during read and the flag isn't already set
if truncated && !tcFlagAlreadySet {
setTCFlag(out)
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
// states that truncation should head drop so that the authority
@ -736,6 +823,8 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
// best we can do.
}
out = checkResponseSizeAndSetTC(out, fq.packet, fq.family, f.logf)
if truncatedFlagSet(out) {
metricDNSFwdTruncated.Add(1)
}

@ -34,6 +34,46 @@ func (rr resolverAndDelay) String() string {
return fmt.Sprintf("%v+%v", rr.name, rr.startDelay)
}
// setTCFlagInPacket sets the TC flag in a DNS packet (for testing).
func setTCFlagInPacket(packet []byte) {
if len(packet) < headerBytes {
return
}
flags := binary.BigEndian.Uint16(packet[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(packet[2:4], flags)
}
// clearTCFlagInPacket clears the TC flag in a DNS packet (for testing).
func clearTCFlagInPacket(packet []byte) {
if len(packet) < headerBytes {
return
}
flags := binary.BigEndian.Uint16(packet[2:4])
flags &^= dnsFlagTruncated
binary.BigEndian.PutUint16(packet[2:4], flags)
}
// verifyEDNSBufferSize verifies a request has the expected EDNS buffer size.
func verifyEDNSBufferSize(t *testing.T, request []byte, expectedSize uint16) {
t.Helper()
ednsSize, hasEDNS := getEDNSBufferSize(request)
if !hasEDNS {
t.Fatalf("request should have EDNS OPT record")
}
if ednsSize != expectedSize {
t.Fatalf("request EDNS size = %d, want %d", ednsSize, expectedSize)
}
}
// setupForwarderWithTCPRetriesDisabled returns a forwarder modifier that disables TCP retries.
func setupForwarderWithTCPRetriesDisabled() func(*forwarder) {
return func(fwd *forwarder) {
fwd.controlKnobs = &controlknobs.Knobs{}
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
}
}
func TestResolversWithDelays(t *testing.T) {
// query
q := func(ss ...string) (ipps []*dnstype.Resolver) {
@ -428,22 +468,16 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
}
// Our request is a single A query for the domain in the answer, above.
builder = dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err = builder.Finish()
if err != nil {
tb.Fatal(err)
}
request = makeTestRequest(tb, domain, dns.TypeA, 0)
return
}
func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
return runTestQueryWithFamily(tb, request, "udp", modify, ports...)
}
func runTestQueryWithFamily(tb testing.TB, request []byte, family string, modify func(*forwarder), ports ...uint16) ([]byte, error) {
logf := tstest.WhileTestRunningLogger(tb)
bus := eventbustest.NewBus(tb)
netMon, err := netmon.New(bus, logf)
@ -467,7 +501,7 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports
rpkt := packet{
bs: request,
family: "tcp",
family: family,
addr: netip.MustParseAddrPort("127.0.0.1:12345"),
}
@ -483,17 +517,29 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports
}
}
// makeTestRequest returns a new TypeA request for the given domain.
func makeTestRequest(tb testing.TB, domain string) []byte {
// makeTestRequest returns a new DNS request for the given domain.
// If queryType is 0, it defaults to TypeA. If ednsSize > 0, it adds an EDNS OPT record.
func makeTestRequest(tb testing.TB, domain string, queryType dns.Type, ednsSize uint16) []byte {
tb.Helper()
if queryType == 0 {
queryType = dns.TypeA
}
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Type: queryType,
Class: dns.ClassINET,
})
if ednsSize > 0 {
builder.StartAdditionals()
builder.OPTResource(dns.ResourceHeader{
Name: dns.MustNewName("."),
Type: dns.TypeOPT,
Class: dns.Class(ednsSize),
}, dns.OPTResource{})
}
request, err := builder.Finish()
if err != nil {
tb.Fatal(err)
@ -549,6 +595,371 @@ func beVerbose(f *forwarder) {
f.verboseFwd = true
}
// makeTestRequestWithEDNS returns a new TypeTXT request for the given domain with EDNS buffer size.
// Deprecated: Use makeTestRequest with queryType and ednsSize parameters instead.
func makeTestRequestWithEDNS(tb testing.TB, domain string, ednsSize uint16) []byte {
return makeTestRequest(tb, domain, dns.TypeTXT, ednsSize)
}
// makeEDNSResponse creates a DNS response of approximately the specified size
// with TXT records and an OPT record. The response will NOT have the TC flag set
// (simulating a non-compliant server that doesn't set TC when response exceeds EDNS buffer).
// The actual size may vary significantly due to DNS packet structure constraints.
func makeEDNSResponse(tb testing.TB, domain string, targetSize int) []byte {
tb.Helper()
// Use makeResponseOfSize with includeOPT=true
// Allow significant variance since DNS packet sizes are hard to predict exactly
// Use a combination of fixed tolerance (200 bytes) and percentage (25%) for larger targets
response := makeResponseOfSize(tb, domain, targetSize, true)
actualSize := len(response)
maxVariance := 200
if targetSize > 400 {
// For larger targets, allow 25% variance
maxVariance = targetSize * 25 / 100
}
if actualSize < targetSize-maxVariance || actualSize > targetSize+maxVariance {
tb.Fatalf("response size = %d, want approximately %d (variance: %d, allowed: ±%d)",
actualSize, targetSize, actualSize-targetSize, maxVariance)
}
return response
}
func TestEDNSBufferSizeTruncation(t *testing.T) {
const domain = "edns-test.example.com."
const ednsBufferSize = 500 // Small EDNS buffer
const responseSize = 800 // Response exceeds EDNS but < maxResponseBytes
// Create a response that exceeds EDNS buffer size but doesn't have TC flag set
response := makeEDNSResponse(t, domain, responseSize)
// Create a request with EDNS buffer size
request := makeTestRequest(t, domain, dns.TypeTXT, ednsBufferSize)
verifyEDNSBufferSize(t, request, ednsBufferSize)
// Verify response doesn't have TC flag set initially
if truncatedFlagSet(response) {
t.Fatal("test response should not have TC flag set initially")
}
// Set up test DNS server
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
verifyEDNSBufferSize(t, gotRequest, ednsBufferSize)
})
// Disable TCP retries to ensure we test UDP path
resp := mustRunTestQuery(t, request, setupForwarderWithTCPRetriesDisabled(), port)
// Verify the response has TC flag set by forwarder
if !truncatedFlagSet(resp) {
t.Errorf("TC flag not set in response (response size=%d, EDNS=%d)", len(resp), ednsBufferSize)
}
// Verify response size is preserved (not truncated by buffer)
if len(resp) != len(response) {
t.Errorf("response size = %d, want %d (response should not be truncated by buffer)", len(resp), len(response))
}
// Verify response size exceeds EDNS buffer
if len(resp) <= int(ednsBufferSize) {
t.Errorf("response size = %d, should exceed EDNS buffer size %d", len(resp), ednsBufferSize)
}
}
// makeResponseOfSize creates a DNS response of approximately the specified size
// with TXT records. The response will NOT have the TC flag set initially.
// If includeOPT is true, an OPT record is added to the response.
func makeResponseOfSize(tb testing.TB, domain string, targetSize int, includeOPT bool) []byte {
tb.Helper()
name := dns.MustNewName(domain)
// Estimate how many TXT records we need
// Each TXT record with ~200 bytes of data adds roughly 220-230 bytes to the packet
// (including DNS headers, name compression, etc.)
bytesPerRecord := 220
baseSize := 50 // Approximate base packet size (header + question)
if includeOPT {
baseSize += 11 // OPT record adds ~11 bytes
}
estimatedRecords := (targetSize - baseSize) / bytesPerRecord
if estimatedRecords < 1 {
estimatedRecords = 1
}
// Start with estimated records and adjust
txtLen := 200
var response []byte
var err error
for attempt := 0; attempt < 10; attempt++ {
testBuilder := dns.NewBuilder(nil, dns.Header{
Response: true,
Authoritative: true,
RCode: dns.RCodeSuccess,
})
testBuilder.StartQuestions()
testBuilder.Question(dns.Question{
Name: name,
Type: dns.TypeTXT,
Class: dns.ClassINET,
})
testBuilder.StartAnswers()
for i := 0; i < estimatedRecords; i++ {
txtValue := strings.Repeat("x", txtLen)
testBuilder.TXTResource(dns.ResourceHeader{
Name: name,
Type: dns.TypeTXT,
Class: dns.ClassINET,
TTL: 300,
}, dns.TXTResource{
TXT: []string{txtValue},
})
}
// Optionally add OPT record
if includeOPT {
testBuilder.StartAdditionals()
testBuilder.OPTResource(dns.ResourceHeader{
Name: dns.MustNewName("."),
Type: dns.TypeOPT,
Class: dns.Class(4096),
}, dns.OPTResource{})
}
response, err = testBuilder.Finish()
if err != nil {
tb.Fatal(err)
}
actualSize := len(response)
// Stop if we've reached or slightly exceeded the target
// Allow up to 20% overshoot to avoid excessive iterations
if actualSize >= targetSize && actualSize <= targetSize*120/100 {
break
}
// If we've overshot significantly, we're done (better than undershooting)
if actualSize > targetSize*120/100 {
break
}
// Adjust for next attempt
needed := targetSize - actualSize
additionalRecords := (needed / bytesPerRecord) + 1
estimatedRecords += additionalRecords
if estimatedRecords > 200 {
// If we need too many records, increase TXT length instead
txtLen = 255 // Max single TXT string length
bytesPerRecord = 280 // Adjusted estimate
estimatedRecords = (targetSize - baseSize) / bytesPerRecord
if estimatedRecords < 1 {
estimatedRecords = 1
}
}
}
// Ensure TC flag is NOT set initially
clearTCFlagInPacket(response)
return response
}
func TestCheckResponseSizeAndSetTC(t *testing.T) {
const domain = "test.example.com."
logf := func(format string, args ...any) {
// Silent logger for tests
}
tests := []struct {
name string
responseSize int
requestHasEDNS bool
ednsSize uint16
family string
responseTCSet bool // Whether response has TC flag set initially
wantTCSet bool // Whether TC flag should be set after function call
skipIfNotExact bool // Skip test if we can't hit exact size (for edge cases)
}{
// Default UDP size (512 bytes) without EDNS
{
name: "UDP_noEDNS_small_should_not_set_TC",
responseSize: 400,
requestHasEDNS: false,
family: "udp",
wantTCSet: false,
},
{
name: "UDP_noEDNS_512bytes_should_not_set_TC",
responseSize: 512,
requestHasEDNS: false,
family: "udp",
wantTCSet: false,
skipIfNotExact: true,
},
{
name: "UDP_noEDNS_513bytes_should_set_TC",
responseSize: 513,
requestHasEDNS: false,
family: "udp",
wantTCSet: true,
skipIfNotExact: true,
},
{
name: "UDP_noEDNS_large_should_set_TC",
responseSize: 600,
requestHasEDNS: false,
family: "udp",
wantTCSet: true,
},
// EDNS edge cases
{
name: "UDP_EDNS_small_under_limit_should_not_set_TC",
responseSize: 450,
requestHasEDNS: true,
ednsSize: 500,
family: "udp",
wantTCSet: false,
},
{
name: "UDP_EDNS_at_limit_should_not_set_TC",
responseSize: 500,
requestHasEDNS: true,
ednsSize: 500,
family: "udp",
wantTCSet: false,
},
{
name: "UDP_EDNS_over_limit_should_set_TC",
responseSize: 550,
requestHasEDNS: true,
ednsSize: 500,
family: "udp",
wantTCSet: true,
},
{
name: "UDP_EDNS_large_over_limit_should_set_TC",
responseSize: 1500,
requestHasEDNS: true,
ednsSize: 1200,
family: "udp",
wantTCSet: true,
},
// Early return paths
{
name: "TCP_query_should_skip",
responseSize: 1000,
family: "tcp",
wantTCSet: false,
},
{
name: "response_too_small_should_skip",
responseSize: headerBytes - 1,
family: "udp",
wantTCSet: false,
},
{
name: "response_exactly_headerBytes_should_not_set_TC",
responseSize: headerBytes,
family: "udp",
wantTCSet: false,
},
{
name: "response_TC_already_set_should_skip",
responseSize: 600,
family: "udp",
responseTCSet: true,
wantTCSet: true, // Should remain set
},
{
name: "UDP_noEDNS_large_TC_already_set_should_skip",
responseSize: 600,
requestHasEDNS: false,
family: "udp",
responseTCSet: true,
wantTCSet: true, // Should remain set
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var response []byte
// Create response of specified size
if tt.responseSize < headerBytes {
// For too-small test, create minimal invalid packet
response = make([]byte, tt.responseSize)
// Don't set any flags, just make it too small
} else {
response = makeResponseOfSize(t, domain, tt.responseSize, false)
actualSize := len(response)
// Only adjust expectations for UDP queries that go through size checking
// TCP queries and other early-return cases should keep their original expectations
if tt.family == "udp" && !tt.responseTCSet && actualSize >= headerBytes {
// Determine the maximum allowed size based on request
var maxSize int
if tt.requestHasEDNS {
maxSize = int(tt.ednsSize)
} else {
maxSize = 512 // default UDP size per RFC 1035
}
// For edge cases where exact size matters, verify we're close enough
if tt.skipIfNotExact {
// For 512/513 byte tests, we need to be very close
if actualSize < tt.responseSize-10 || actualSize > tt.responseSize+10 {
t.Skipf("skipping: could not create response close to target size %d (got %d)", tt.responseSize, actualSize)
}
// Function sets TC if response > maxSize, so adjust expectation based on actual size
tt.wantTCSet = actualSize > maxSize
} else {
// For non-exact tests, adjust expectation based on actual response size
// The function sets TC if actualSize > maxSize
tt.wantTCSet = actualSize > maxSize
}
}
}
// Set TC flag initially if requested
if tt.responseTCSet {
setTCFlagInPacket(response)
}
// Create request with or without EDNS
var ednsSize uint16
if tt.requestHasEDNS {
ednsSize = tt.ednsSize
}
request := makeTestRequest(t, domain, dns.TypeTXT, ednsSize)
// Call the function
result := checkResponseSizeAndSetTC(response, request, tt.family, logf)
// Verify response size is preserved (function should not truncate, only set flag)
if len(result) != len(response) {
t.Errorf("response size changed: got %d, want %d", len(result), len(response))
}
// Verify TC flag state
if len(result) >= headerBytes {
hasTC := truncatedFlagSet(result)
if hasTC != tt.wantTCSet {
t.Errorf("TC flag: got %v, want %v (response size=%d)", hasTC, tt.wantTCSet, len(result))
}
} else if tt.responseSize >= headerBytes {
// If we expected a valid response but got too small, that's unexpected
t.Errorf("response too small (%d bytes) but expected valid response", len(result))
}
// Verify response pointer is same (should be in-place modification)
if &result[0] != &response[0] {
t.Errorf("function should modify response in place, but got new slice")
}
})
}
}
func TestForwarderTCPFallback(t *testing.T) {
const domain = "large-dns-response.tailscale.com."
@ -569,7 +980,10 @@ func TestForwarderTCPFallback(t *testing.T) {
}
})
resp := mustRunTestQuery(t, request, beVerbose, port)
resp, err := runTestQueryWithFamily(t, request, "tcp", beVerbose, port)
if err != nil {
t.Fatalf("error making request: %v", err)
}
if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
}
@ -636,17 +1050,13 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
fwd.verboseFwd = true
// Disable retries for this test.
fwd.controlKnobs = &controlknobs.Knobs{}
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
setupForwarderWithTCPRetriesDisabled()(fwd)
}, port)
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
// Set the truncated flag on the expected response, since that's what we expect.
flags := binary.BigEndian.Uint16(wantResp[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(wantResp[2:4], flags)
setTCFlagInPacket(wantResp)
if !bytes.Equal(resp, wantResp) {
t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp)
@ -664,7 +1074,7 @@ func TestForwarderTCPFallbackError(t *testing.T) {
response := makeTestResponse(t, domain, dns.RCodeServerFailure)
// Our request is a single A query for the domain in the answer, above.
request := makeTestRequest(t, domain)
request := makeTestRequest(t, domain, dns.TypeA, 0)
var sawRequest atomic.Bool
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
@ -695,7 +1105,7 @@ func TestForwarderTCPFallbackError(t *testing.T) {
// returns a successful response, we propagate it.
func TestForwarderWithManyResolvers(t *testing.T) {
const domain = "example.com."
request := makeTestRequest(t, domain)
request := makeTestRequest(t, domain, dns.TypeA, 0)
tests := []struct {
name string
@ -837,20 +1247,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
}()
// Our request is a single PTR query for the domain in the answer, above.
request := func() []byte {
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypePTR,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return request
}()
request := makeTestRequest(t, domain, dns.TypePTR, 0)
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
})
@ -868,7 +1265,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
func TestForwarderVerboseLogs(t *testing.T) {
const domain = "test.tailscale.com."
response := makeTestResponse(t, domain, dns.RCodeServerFailure)
request := makeTestRequest(t, domain)
request := makeTestRequest(t, domain, dns.TypeA, 0)
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
if !bytes.Equal(request, gotRequest) {

@ -328,7 +328,12 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net
return (<-responses).bs, nil
}
return out, err
if err != nil {
return out, err
}
out = checkResponseSizeAndSetTC(out, bs, family, r.logf)
return out, nil
}
// GetUpstreamResolvers returns the resolvers that would be used to resolve

@ -1521,3 +1521,102 @@ func TestServfail(t *testing.T) {
t.Errorf("response was %X, want %X", pkt, wantPkt)
}
}
// TestLocalResponseTCFlagIntegration tests that checkResponseSizeAndSetTC is
// correctly applied to local DNS responses through the Resolver.Query integration path.
// This complements the unit test in forwarder_test.go by verifying the end-to-end behavior.
func TestLocalResponseTCFlagIntegration(t *testing.T) {
r := newResolver(t)
defer r.Close()
r.SetConfig(dnsCfg)
tests := []struct {
name string
query []byte
family string
wantTCSet bool
desc string
}{
{
name: "UDP_small_local_response_no_TC",
query: dnspacket("test1.ipn.dev.", dns.TypeA, noEdns),
family: "udp",
wantTCSet: false,
desc: "Small local response (< 512 bytes) should not have TC flag set",
},
{
name: "TCP_local_response_no_TC",
query: dnspacket("test1.ipn.dev.", dns.TypeA, noEdns),
family: "tcp",
wantTCSet: false,
desc: "TCP queries should skip TC flag setting (even for large responses)",
},
{
name: "UDP_EDNS_request_small_response",
query: dnspacket("test1.ipn.dev.", dns.TypeA, 1500),
family: "udp",
wantTCSet: false,
desc: "Small response with EDNS request should not have TC flag set",
},
{
name: "UDP_IPv6_response_no_TC",
query: dnspacket("test2.ipn.dev.", dns.TypeAAAA, noEdns),
family: "udp",
wantTCSet: false,
desc: "Small IPv6 local response should not have TC flag set",
},
{
name: "UDP_reverse_lookup_no_TC",
query: dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR, noEdns),
family: "udp",
wantTCSet: false,
desc: "Small reverse lookup response should not have TC flag set",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
response, err := r.Query(context.Background(), tt.query, tt.family, netip.AddrPort{})
if err != nil {
t.Fatalf("Query failed: %v", err)
}
if len(response) < headerBytes {
t.Fatalf("Response too small: %d bytes", len(response))
}
hasTC := truncatedFlagSet(response)
if hasTC != tt.wantTCSet {
t.Errorf("%s: TC flag = %v, want %v (response size=%d bytes)", tt.desc, hasTC, tt.wantTCSet, len(response))
}
// Verify response is valid by parsing it (if possible)
// Note: unpackResponse may not support all record types (e.g., PTR)
parsed, err := unpackResponse(response)
if err == nil {
// Verify the truncated field in parsed response matches the flag
if parsed.truncated != hasTC {
t.Errorf("Parsed truncated field (%v) doesn't match TC flag (%v)", parsed.truncated, hasTC)
}
} else {
// For unsupported types, just verify we can parse the header
var parser dns.Parser
h, err := parser.Start(response)
if err != nil {
t.Errorf("Failed to parse DNS header: %v", err)
} else {
// Verify header truncated flag matches
if h.Truncated != hasTC {
t.Errorf("Header truncated field (%v) doesn't match TC flag (%v)", h.Truncated, hasTC)
}
}
}
// Verify response size is reasonable (local responses are typically small)
if len(response) > 1000 {
t.Logf("Warning: Local response is unusually large: %d bytes", len(response))
}
})
}
}

Loading…
Cancel
Save