net/dns/resolver: Truncate UDP DNS responses

Perform RFC-compliant truncation of DNS responses
which are to be delivered over UDP when they
exceed the maximum allowed size.

The maximum size is determined by the EDNS0
options in the DNS request, or defaults to
512 bytes if no EDNS0 options are present (or
the ENDS0 value in the request is less than 512).
Per DNS Flag Day 2010 recommendations, we cap
the maximum UDP DNS response size at 1232 bytes.

Fixes #13601 and #18107

Signed-off-by: Ryan Koski <ryan@koski.org>
pull/18259/head
Ryan Koski 4 weeks ago
parent 951d711054
commit ef2e3c64aa

@ -1158,7 +1158,8 @@ func servfailResponse(req packet) (res packet, err error) {
h := p.Header
h.Response = true
h.Authoritative = true
// Correct behavior for SERVFAIL is to set the Authoritative flag to 0.
h.Authoritative = false
h.RCode = dns.RCodeServerFailure
b := dns.NewBuilder(nil, h)
b.StartQuestions()

@ -506,9 +506,19 @@ func makeTestRequest(tb testing.TB, domain string) []byte {
func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
tb.Helper()
name := dns.MustNewName(domain)
// The correct value for the Authoritative bit is complicated.
// However, in all cases where a SERVFAIL is returned, it should be false.
// Since the servfailResponse() function correctly sets this bit to false,
// this test needs to also return false for RCodeServerFailure.
authoritative := true
if code == dns.RCodeServerFailure {
authoritative = false
}
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
Authoritative: true,
Authoritative: authoritative,
RCode: code,
})
builder.StartQuestions()

@ -315,17 +315,40 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net
default:
}
reqPacket := packet{bs: bs, family: family, addr: from}
out, err := r.respond(bs)
if err == errNotOurName {
responses := make(chan packet, 1)
ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout)
defer close(responses)
defer cancel()
err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses)
err = r.forwarder.forwardWithDestChan(ctx, reqPacket, responses)
if err != nil {
return nil, err
}
return (<-responses).bs, nil
out = (<-responses).bs
err = nil
}
// Only perform truncation/EDNS0 processing for UDP queries.
if err == nil && family == "udp" && out != nil {
// Determine client's advertised UDP size via EDNS0, default to 512
maxResponseSize := uint16(512)
if edns := extractEDNS0UDPSize(bs); edns > 0 {
maxResponseSize = edns
}
if len(out) > int(maxResponseSize) {
tr, terr := truncateDNSResponse(out, maxResponseSize)
if terr != nil {
// Can't safely truncate; return SERVFAIL
serv, berr := servfailResponse(reqPacket)
if berr != nil {
return nil, terr
}
return serv.bs, nil
}
out = tr
}
}
return out, err

@ -302,10 +302,15 @@ func dnsHandler(answers ...any) dns.HandlerFunc {
}
}
func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
func serveDNS(tb testing.TB, addr string, family string, records ...any) *dns.Server {
if len(records)%2 != 0 {
panic("must have an even number of record values")
}
switch family {
case "udp", "tcp":
default:
panic("family must be udp or tcp")
}
mux := dns.NewServeMux()
for i := 0; i < len(records); i += 2 {
name := records[i].(string)
@ -315,7 +320,7 @@ func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
waitch := make(chan struct{})
server := &dns.Server{
Addr: addr,
Net: "udp",
Net: family,
Handler: mux,
NotifyStartedFunc: func() { close(waitch) },
ReusePort: true,

@ -92,15 +92,16 @@ func dnspacket(domain dnsname.FQDN, tp dns.Type, ednsSize uint16) []byte {
}
type dnsResponse struct {
ip netip.Addr
txt []string
name dnsname.FQDN
rcode dns.RCode
truncated bool
requestEdns bool
requestEdnsSize uint16
responseEdns bool
responseEdnsSize uint16
ip netip.Addr
txt []string
name dnsname.FQDN
rcode dns.RCode
truncated bool
retryTCPonTruncation bool
requestEdns bool
requestEdnsSize uint16
responseEdns bool
responseEdnsSize uint16
}
func unpackResponse(payload []byte) (dnsResponse, error) {
@ -233,8 +234,13 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
return response, nil
}
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
return r.Query(context.Background(), query, "udp", netip.AddrPort{})
func syncRespond(r *Resolver, family string, query []byte) ([]byte, error) {
switch family {
case "udp", "tcp":
default:
return nil, fmt.Errorf("Invalid family %q", family)
}
return r.Query(context.Background(), query, family, netip.AddrPort{})
}
func mustIP(str string) netip.Addr {
@ -538,22 +544,37 @@ func TestDelegate(t *testing.T) {
"xlarge.txt.", resolveToTXT(xlargeTXT, 8000),
"huge.txt.", resolveToTXT(hugeTXT, 65527),
}
v4server := serveDNS(t, "127.0.0.1:0", records...)
defer v4server.Shutdown()
v6server := serveDNS(t, "[::1]:0", records...)
defer v6server.Shutdown()
r := newResolver(t)
defer r.Close()
v4UDPServer := serveDNS(t, "127.0.0.1:0", "udp", records...)
defer v4UDPServer.Shutdown()
v6UDPServer := serveDNS(t, "[::1]:0", "udp", records...)
defer v6UDPServer.Shutdown()
v4TCPServer := serveDNS(t, "127.0.0.1:0", "udp", records...)
defer v4TCPServer.Shutdown()
v6TCPServer := serveDNS(t, "[::1]:0", "udp", records...)
defer v6TCPServer.Shutdown()
udpResolver := newResolver(t)
defer udpResolver.Close()
tcpResolver := newResolver(t)
defer tcpResolver.Close()
udpcfg := dnsCfg
udpcfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{
".": {
&dnstype.Resolver{Addr: v4UDPServer.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v6UDPServer.PacketConn.LocalAddr().String()},
},
}
udpResolver.SetConfig(udpcfg)
cfg := dnsCfg
cfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{
tcpcfg := dnsCfg
tcpcfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{
".": {
&dnstype.Resolver{Addr: v4server.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v6server.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v4TCPServer.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v6TCPServer.PacketConn.LocalAddr().String()},
},
}
r.SetConfig(cfg)
tcpResolver.SetConfig(tcpcfg)
tests := []struct {
title string
@ -616,44 +637,44 @@ func TestDelegate(t *testing.T) {
"medtxt",
dnspacket("med.txt.", dns.TypeTXT, 2000),
dnsResponse{
txt: medTXT,
rcode: dns.RCodeSuccess,
requestEdns: true,
requestEdnsSize: 2000,
responseEdns: true,
responseEdnsSize: 1500,
txt: medTXT,
rcode: dns.RCodeSuccess,
retryTCPonTruncation: true,
requestEdns: true,
requestEdnsSize: 2000,
responseEdns: true,
responseEdnsSize: 1500,
},
},
{
"largetxt",
dnspacket("large.txt.", dns.TypeTXT, maxResponseBytes),
dnsResponse{
txt: largeTXT,
rcode: dns.RCodeSuccess,
requestEdns: true,
requestEdnsSize: maxResponseBytes,
responseEdns: true,
responseEdnsSize: maxResponseBytes,
txt: largeTXT,
rcode: dns.RCodeSuccess,
retryTCPonTruncation: true,
requestEdns: true,
requestEdnsSize: maxResponseBytes,
responseEdns: true,
responseEdnsSize: maxResponseBytes,
},
},
{
"xlargetxt",
dnspacket("xlarge.txt.", dns.TypeTXT, 8000),
dnsResponse{
rcode: dns.RCodeSuccess,
truncated: true,
// request/response EDNS fields will be unset because of
// they were truncated away
rcode: dns.RCodeSuccess,
truncated: true,
retryTCPonTruncation: true,
},
},
{
"hugetxt",
dnspacket("huge.txt.", dns.TypeTXT, 8000),
dnsResponse{
rcode: dns.RCodeSuccess,
truncated: true,
// request/response EDNS fields will be unset because of
// they were truncated away
rcode: dns.RCodeSuccess,
truncated: true,
retryTCPonTruncation: true,
},
},
}
@ -663,7 +684,9 @@ func TestDelegate(t *testing.T) {
if tt.title == "hugetxt" && runtime.GOOS == "darwin" {
t.Skip("known to not work on macOS: https://github.com/tailscale/tailscale/issues/2229")
}
payload, err := syncRespond(r, tt.query)
runEDNSSizeChecks := true
payload, err := syncRespond(udpResolver, "udp", tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
@ -673,6 +696,27 @@ func TestDelegate(t *testing.T) {
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return
}
// If truncated and the test is configured to do so, retry over TCP.
// Additionally, some of the tests may result in a SERVFAIL response
// when queried over UDP because the total response is larger than
// the maximum supported buffer size. This results in a byte sequence
// that fails to parse correctly. In that case, we also retry over TCP.
if (response.truncated || response.rcode == dns.RCodeServerFailure) && tt.response.retryTCPonTruncation {
// Retry over TCP.
t.Logf("Retrying over TCP for %q", tt.title)
payload, err = syncRespond(tcpResolver, "tcp", tt.query)
if err != nil {
t.Errorf("TCP retry: err = %v; want nil", err)
return
}
response, err = unpackResponse(payload)
if err != nil {
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return
}
// On TCP, EDNS size is not applicable.
runEDNSSizeChecks = false
}
if response.rcode != tt.response.rcode {
t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
}
@ -691,17 +735,19 @@ func TestDelegate(t *testing.T) {
}
}
}
if response.requestEdns != tt.response.requestEdns {
t.Errorf("requestEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.requestEdnsSize != tt.response.requestEdnsSize {
t.Errorf("requestEdnsSize = %v; want %v", response.requestEdnsSize, tt.response.requestEdnsSize)
}
if response.responseEdns != tt.response.responseEdns {
t.Errorf("responseEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.responseEdnsSize != tt.response.responseEdnsSize {
t.Errorf("responseEdnsSize = %v; want %v", response.responseEdnsSize, tt.response.responseEdnsSize)
if runEDNSSizeChecks {
if response.requestEdns != tt.response.requestEdns {
t.Errorf("requestEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.requestEdnsSize != tt.response.requestEdnsSize {
t.Errorf("requestEdnsSize = %v; want %v", response.requestEdnsSize, tt.response.requestEdnsSize)
}
if response.responseEdns != tt.response.responseEdns {
t.Errorf("responseEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.responseEdnsSize != tt.response.responseEdnsSize {
t.Errorf("responseEdnsSize = %v; want %v", response.responseEdnsSize, tt.response.responseEdnsSize)
}
}
})
}
@ -711,10 +757,10 @@ func TestDelegateSplitRoute(t *testing.T) {
test4 := netip.MustParseAddr("2.3.4.5")
test6 := netip.MustParseAddr("ff::1")
server1 := serveDNS(t, "127.0.0.1:0",
server1 := serveDNS(t, "127.0.0.1:0", "udp",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server1.Shutdown()
server2 := serveDNS(t, "127.0.0.1:0",
server2 := serveDNS(t, "127.0.0.1:0", "udp",
"test.other.", resolveToIP(test4, test6, "dns.other."))
defer server2.Shutdown()
@ -747,7 +793,7 @@ func TestDelegateSplitRoute(t *testing.T) {
for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
payload, err := syncRespond(r, tt.query)
payload, err := syncRespond(r, "udp", tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
@ -942,7 +988,7 @@ func TestFull(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
response, err := syncRespond(r, tt.request)
response, err := syncRespond(r, "udp", tt.request)
if err != nil {
t.Errorf("err = %v; want nil", err)
}
@ -974,7 +1020,7 @@ func TestAllocs(t *testing.T) {
for _, tt := range tests {
err := tstest.MinAllocsPerRun(t, tt.want, func() {
syncRespond(r, tt.query)
syncRespond(r, "udp", tt.query)
})
if err != nil {
t.Errorf("%s: %v", tt.name, err)
@ -1006,7 +1052,7 @@ func TestTrimRDNSBonjourPrefix(t *testing.T) {
}
func BenchmarkFull(b *testing.B) {
server := serveDNS(b, "127.0.0.1:0",
server := serveDNS(b, "127.0.0.1:0", "udp",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown()
@ -1031,7 +1077,7 @@ func BenchmarkFull(b *testing.B) {
b.Run(tt.name, func(b *testing.B) {
b.ReportAllocs()
for range b.N {
syncRespond(r, tt.request)
syncRespond(r, "udp", tt.request)
}
})
}
@ -1159,7 +1205,7 @@ func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) {
"ns.test.",
dnsHandler(miekdns.NS{Ns: "ns1.foo."}, miekdns.NS{Ns: "ns2.bar."}),
}
v4server := serveDNS(t, "127.0.0.1:0", records...)
v4server := serveDNS(t, "127.0.0.1:0", "udp", records...)
defer v4server.Shutdown()
// backendResolver is the resolver between
@ -1485,7 +1531,7 @@ func TestUnARPA(t *testing.T) {
//
// See: https://github.com/tailscale/tailscale/issues/4722
func TestServfail(t *testing.T) {
server := serveDNS(t, "127.0.0.1:0", "test.site.", miekdns.HandlerFunc(func(w miekdns.ResponseWriter, req *miekdns.Msg) {
server := serveDNS(t, "127.0.0.1:0", "udp", "test.site.", miekdns.HandlerFunc(func(w miekdns.ResponseWriter, req *miekdns.Msg) {
m := new(miekdns.Msg)
m.Rcode = miekdns.RcodeServerFailure
w.WriteMsg(m)
@ -1501,14 +1547,14 @@ func TestServfail(t *testing.T) {
}
r.SetConfig(cfg)
pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns))
pkt, err := syncRespond(r, "udp", dnspacket("test.site.", dns.TypeA, noEdns))
if err != nil {
t.Fatalf("err = %v, want nil", err)
}
wantPkt := []byte{
0x00, 0x00, // transaction id: 0
0x84, 0x02, // flags: response, authoritative, error: servfail
0x80, 0x02, // flags: response, error: servfail
0x00, 0x01, // one question
0x00, 0x00, // no answers
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs

@ -0,0 +1,294 @@
package resolver
import (
"errors"
"golang.org/x/net/dns/dnsmessage"
)
// extractOPTResource parses a DNS message and returns the OPT resource if present.
func extractOPTResource(msg []byte) *dnsmessage.Resource {
var p dnsmessage.Parser
if _, err := p.Start(msg); err != nil {
return nil
}
var optRes *dnsmessage.Resource
optRes = nil
// Fast-forward to find OPT
if err := p.SkipAllQuestions(); err == nil {
if err := p.SkipAllAnswers(); err == nil {
if err := p.SkipAllAuthorities(); err == nil {
for {
r, err := p.Additional()
if err != nil {
break
}
if r.Header.Type == dnsmessage.TypeOPT {
optRes = &r
break
}
}
}
}
}
return optRes
}
const minEDNS0Size = 512 // per RFC 6891 Section 6.2.5
const maxEDNS0Size = 1232 // per DNS Flag Day 2020 recommendation
// extractEDNS0UDPSize extracts the advertised UDP buffer size from an EDNS0 OPT record
// in a DNS query packet. If no EDNS0 record is present or the packet is malformed,
// it returns 0, indicating the default 512-byte limit should be used.
func extractEDNS0UDPSize(query []byte) uint16 {
size := uint16(0)
optRes := extractOPTResource(query)
if optRes != nil {
// UDP payload size is encoded in the CLASS field of the OPT header.
// Per RFC 6891 §6.2.5, treat any advertised UDP size smaller than 512
// as 512. Per DNS Flag Day 2020 (https://www.dnsflagday.net/2020/),
// the cap should be 1232 bytes, and newer versions of resolvers
// have set 1232 as their default limit.
size = uint16(optRes.Header.Class)
if size < minEDNS0Size {
size = minEDNS0Size
}
if size > maxEDNS0Size {
size = maxEDNS0Size
}
}
return size
}
// truncateDNSResponse performs RFC-compliant truncation of a DNS
// response message. It preserves the question section and as many
// resource records as possible in the answer, authority, and
// additional sections, setting the TC (truncated) bit if truncation
// occurs. It enforces RFC 6891 Section 7 (preserving the OPT record
// in truncated responses).
func truncateDNSResponse(resp []byte, maxSize uint16) ([]byte, error) {
// Sanity check on maxSize. It must be at least large enough
// to hold a minimal DNS header (12 bytes) and at least one
// question (5 bytes).
if maxSize < 12+5 {
return nil, errors.New("maxSize too small to hold minimal DNS message")
}
var p dnsmessage.Parser
header, err := p.Start(resp)
if err != nil {
return nil, err
}
// 1. Extract all records into slices so we can manage them.
questions, err := p.AllQuestions()
if err != nil {
return nil, err
}
var answers, authorities, additionals []dnsmessage.Resource
var optRes *dnsmessage.Resource
// Helper to extract resources from a section
extractSection := func(sectionName string) ([]dnsmessage.Resource, error) {
var extracted []dnsmessage.Resource
for {
var r dnsmessage.Resource
var err error
switch sectionName {
case "Ans":
r, err = p.Answer()
case "Auth":
r, err = p.Authority()
case "Add":
r, err = p.Additional()
}
if err == dnsmessage.ErrSectionDone {
return extracted, nil
}
if err != nil {
return nil, err
}
// Identify and isolate the OPT record
if r.Header.Type == dnsmessage.TypeOPT {
// We found the OPT record. Save it separately.
// (RFC 6891: Only one OPT record is allowed)
optRes = &r
} else {
extracted = append(extracted, r)
}
}
}
// We must parse sections in order: Skip Questions (already got them), then Ans, Auth, Add.
// Note: p.AllQuestions() already advanced the parser past questions.
if answers, err = extractSection("Ans"); err != nil {
return nil, err
}
if authorities, err = extractSection("Auth"); err != nil {
return nil, err
}
if additionals, err = extractSection("Add"); err != nil {
return nil, err
}
// 2. Try to build the FULL packet first (Happy Path).
// If it fits, we avoid the expensive iterative logic.
fullPacket, err := buildResponse(header, questions, answers, authorities, additionals, optRes)
if err == nil && uint16(len(fullPacket)) <= maxSize {
return fullPacket, nil
}
// 3. Truncation Path.
// The packet is too big. We must rebuild it record-by-record until full.
// We MUST set the TC bit.
header.Truncated = true
// We start with empty sections.
var finalAns, finalAuth, finalAdd []dnsmessage.Resource
// Define the order of candidates we want to try adding.
// (Answers first, then Authorities, then Additionals)
// We use a list of *slices* to iterate section by section.
sections := []struct {
candidates []dnsmessage.Resource
target *[]dnsmessage.Resource // Pointer to the slice we are building
}{
{answers, &finalAns},
{authorities, &finalAuth},
{additionals, &finalAdd},
}
for _, section := range sections {
for _, candidate := range section.candidates {
// Speculatively add this candidate to the target list
*section.target = append(*section.target, candidate)
// Build the packet with the current set of records + Mandatory OPT
testPacket, err := buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes)
if err != nil {
return nil, err // Should not happen with valid resources
}
// Check size
if uint16(len(testPacket)) > maxSize {
// Stop! This record broke the limit.
// Remove the last added record (backtrack).
*section.target = (*section.target)[:len(*section.target)-1]
// We are full. Return the last valid build.
// Note: We need to rebuild one last time or save the previous successful 'testPacket'.
// To be safe/clean, let's just rebuild the "safe" state.
return buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes)
}
// If it fits, continue loop to add next candidate.
}
}
// If we somehow finish the loop (unlikely given we failed the "Full" check), return what we have.
return buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes)
}
// buildResponse constructs a binary DNS message from the provided slices.
// It handles the complex state machine of dnsmessage.Builder.
func buildResponse(
h dnsmessage.Header,
qs []dnsmessage.Question,
ans, auths, adds []dnsmessage.Resource,
opt *dnsmessage.Resource,
) ([]byte, error) {
// Start with a nil buffer; Builder will allocate.
b := dnsmessage.NewBuilder(nil, h)
b.EnableCompression()
// 1. Questions
if err := b.StartQuestions(); err != nil {
return nil, err
}
for _, q := range qs {
if err := b.Question(q); err != nil {
return nil, err
}
}
// 2. Answers
if err := b.StartAnswers(); err != nil {
return nil, err
}
for _, r := range ans {
if err := addResource(&b, r); err != nil {
return nil, err
}
}
// 3. Authorities
if err := b.StartAuthorities(); err != nil {
return nil, err
}
for _, r := range auths {
if err := addResource(&b, r); err != nil {
return nil, err
}
}
// 4. Additionals
if err := b.StartAdditionals(); err != nil {
return nil, err
}
for _, r := range adds {
if err := addResource(&b, r); err != nil {
return nil, err
}
}
// Always append the OPT record if it exists (RFC 6891)
if opt != nil {
if err := addResource(&b, *opt); err != nil {
return nil, err
}
}
// Finish and return the bytes
return b.Finish()
}
// addResource is a helper to handle the various resource types
// when adding individual resources to the Builder.
func addResource(b *dnsmessage.Builder, r dnsmessage.Resource) error {
switch body := r.Body.(type) {
case *dnsmessage.AResource:
return b.AResource(r.Header, *body)
case *dnsmessage.AAAAResource:
return b.AAAAResource(r.Header, *body)
case *dnsmessage.CNAMEResource:
return b.CNAMEResource(r.Header, *body)
case *dnsmessage.HTTPSResource:
return b.HTTPSResource(r.Header, *body)
case *dnsmessage.NSResource:
return b.NSResource(r.Header, *body)
case *dnsmessage.PTRResource:
return b.PTRResource(r.Header, *body)
case *dnsmessage.SOAResource:
return b.SOAResource(r.Header, *body)
case *dnsmessage.MXResource:
return b.MXResource(r.Header, *body)
case *dnsmessage.TXTResource:
return b.TXTResource(r.Header, *body)
case *dnsmessage.SRVResource:
return b.SRVResource(r.Header, *body)
case *dnsmessage.OPTResource:
return b.OPTResource(r.Header, *body)
case *dnsmessage.UnknownResource:
// Handles unsupported/generic types
return b.UnknownResource(r.Header, *body)
default:
return errors.New("unsupported resource body type")
}
}

@ -0,0 +1,276 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package resolver
import (
"context"
"fmt"
"net/netip"
"testing"
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/types/dnstype"
"tailscale.com/util/dnsname"
)
// Note: This test file uses helper builders already present in other resolver
// tests (e.g., makeTestRequest/makeTestResponse/dnspacket) since they are in
// the same package test space.
func TestExtractValidEDNS0UDPSize(t *testing.T) {
q := dnspacket("example.com.", dns.TypeA, 917)
got := extractEDNS0UDPSize(q)
if got != 917 {
t.Fatalf("expected 917, got %v", got)
}
}
func TestExtractSmallEDNS0UDPSize(t *testing.T) {
q := dnspacket("example.com.", dns.TypeA, 100)
got := extractEDNS0UDPSize(q)
// extractEDNS0UDPSize enforces minimum of 512 per RFC 6891 §6.2.5
if got != minEDNS0Size {
t.Fatalf("expected %v, got %v", minEDNS0Size, got)
}
}
func TestExtractLargeEDNS0UDPSize(t *testing.T) {
q := dnspacket("example.com.", dns.TypeA, 5000)
got := extractEDNS0UDPSize(q)
// extractEDNS0UDPSize caps at maxEDNS0Size
if got != maxEDNS0Size {
t.Fatalf("expected %v, got %v", maxEDNS0Size, got)
}
}
func TestTruncateNonEDNS(t *testing.T) {
// Build a very large response (many A records) without EDNS
// Create response with many answers
name := dns.MustNewName("example.com.")
b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess})
if err := b.StartQuestions(); err != nil {
t.Fatal(err)
}
if err := b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET}); err != nil {
t.Fatal(err)
}
if err := b.StartAnswers(); err != nil {
t.Fatal(err)
}
// add enough A records to exceed 512 bytes
for i := 0; i < 200; i++ {
b.AResource(dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60}, dns.AResource{A: [4]byte{192, 0, 2, byte(i % 255)}})
}
resp, err := b.Finish()
if err != nil {
t.Fatal(err)
}
if len(resp) <= 512 {
t.Fatalf("response not large enough for test: %d", len(resp))
}
tr, err := truncateDNSResponse(resp, 512)
if err != nil {
t.Fatalf("truncate failed: %v", err)
}
if len(tr) > 512 {
t.Fatalf("truncated response too large: %d", len(tr))
}
// Check TC bit set
var p dns.Parser
h, err := p.Start(tr)
if err != nil {
t.Fatalf("parse truncated: %v", err)
}
if !h.Truncated {
t.Fatalf("expected Truncated bit set")
}
}
func TestEDNSAllowsLarger(t *testing.T) {
// Build request that advertises EDNS size 1232
ednsSize := uint16(1232)
q := dnspacket("example.com.", dns.TypeA, ednsSize)
if got := extractEDNS0UDPSize(q); got != ednsSize {
t.Fatalf("expected 1232, got %v", got)
}
// Build response of size >512 but <1232
name := dns.MustNewName("example.com.")
b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess})
b.EnableCompression()
b.StartQuestions()
b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET})
b.StartAnswers()
for i := 0; i < 50; i++ {
b.AResource(dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60}, dns.AResource{A: [4]byte{10, 0, 0, byte(i)}})
}
resp, err := b.Finish()
if err != nil {
t.Fatal(err)
}
if len(resp) <= 512 || len(resp) >= int(ednsSize) {
t.Fatalf("invalid response size %d", len(resp))
}
tr, err := truncateDNSResponse(resp, ednsSize)
if err != nil {
t.Fatalf("truncate failed: %v", err)
}
if len(tr) != len(resp) {
t.Fatalf("unexpected truncation when EDNS allows large: %d vs %d", len(tr), len(resp))
}
}
// TestTruncateDNSResponseImpossible verifies that truncateDNSResponse
// returns an error when the provided maxSize is too small to even encode
// the header+question portion of the message.
func TestTruncateDNSResponseImpossible(t *testing.T) {
// Build a normal query packet and attempt to truncate it to a very small
// size that cannot contain the header+question.
req := makeTestRequest(t, "example.com.")
if len(req) < 20 {
t.Fatalf("test request unexpectedly small: %d", len(req))
}
// Choose a maxSize smaller than the request's header+question length.
// Using 10 bytes is guaranteed to be too small.
if _, err := truncateDNSResponse(req, 10); err == nil {
t.Fatalf("expected error truncating to impossibly small size, got nil")
}
}
// TestTruncateDNSResponseDirectCall tests truncateDNSResponse with a large
// well-formed DNS response. This directly verifies that
// truncateDNSResponse produces a syntactically valid truncated response
// with the TC bit set.
func TestTruncateDNSResponseDirectCall(t *testing.T) {
const domain = "example.com."
// Build a very large DNS response (many A records)
name := dns.MustNewName(domain)
b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess})
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
t.Fatal(err)
}
if err := b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET}); err != nil {
t.Fatal(err)
}
if err := b.StartAnswers(); err != nil {
t.Fatal(err)
}
// Add enough A records to exceed 512 bytes significantly.
// Each A record is roughly 20 bytes, so 150 records will be ~3000 bytes.
for i := 0; i < 150; i++ {
err := b.AResource(
dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60},
dns.AResource{A: [4]byte{10, 0, 0, byte(i % 256)}},
)
if err != nil {
t.Fatalf("failed to add A record: %v", err)
}
}
largeResp, err := b.Finish()
if err != nil {
t.Fatalf("failed to build large response: %v", err)
}
// Verify the response is large enough for truncation.
if len(largeResp) <= 512 {
t.Fatalf("test response not large enough for truncation: %d bytes", len(largeResp))
}
tr, err := truncateDNSResponse(largeResp, 512)
if err != nil {
t.Fatalf("truncateDNSResponse failed: %v", err)
}
// Verify the truncated response:
// 1. Fits within 512 bytes
if len(tr) > 512 {
t.Fatalf("truncated response exceeds 512 bytes: got %d", len(tr))
}
// 2. Is syntactically valid
var p dns.Parser
h, err := p.Start(tr)
if err != nil {
t.Fatalf("failed to parse truncated response: %v", err)
}
// 3. Has TC (Truncated) bit set
if !h.Truncated {
t.Fatalf("expected TC (Truncated) bit to be set in truncated response")
}
}
// TestResolverSERVFAILOnImpossibleTruncation ensures that when a client
// advertises a tiny EDNS buffer size such that the resolver cannot safely
// encode even the header+question within that size, the resolver returns a
// SERVFAIL response rather than an invalid/truncated packet.
func TestResolverSERVFAILOnImpossibleTruncation(t *testing.T) {
const domain = "srvfail.example.com."
// Build a request that advertises a very small EDNS size (50 bytes).
// This is small enough to require truncation but large enough for header+question.
request := dnspacket(domain, dns.TypeA, 50)
// Verify EDNS extraction enforces the RFC 6891 minimum of 512.
ednsSize := extractEDNS0UDPSize(request)
if ednsSize != 512 {
t.Fatalf("EDNS extraction failed: expected 512, got %d", ednsSize)
}
// Build a very large upstream response for the same domain so that the
// resolver will attempt truncation and fail.
_, largeResponse := makeLargeResponse(t, domain)
// Run a test DNS server returning the large response.
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
// DNS server received a request; just ensure the server is reachable
})
// Configure resolver to forward queries to our server.
r := newResolver(t)
defer r.Close()
cfg := Config{
Routes: map[dnsname.FQDN][]*dnstype.Resolver{
dnsname.FQDN("."): {{Addr: fmt.Sprintf("127.0.0.1:%d", port)}},
},
}
if err := r.SetConfig(cfg); err != nil {
t.Fatalf("SetConfig: %v", err)
}
// Query the resolver over UDP with the tiny EDNS size.
ctx := context.Background()
out, err := r.Query(ctx, request, "udp", netip.MustParseAddrPort("127.0.0.1:12345"))
if err != nil {
t.Fatalf("Query failed: %v", err)
}
// The response should be either:
// 1. A SERVFAIL (if truncation was impossible), or
// 2. A response that fits within the effective EDNS size (512 bytes) with TC bit set.
var p dns.Parser
h, err := p.Start(out)
if err != nil {
t.Fatalf("parse response: %v", err)
}
if h.RCode == dns.RCodeServerFailure {
// Good - impossible truncation was handled correctly
return
}
// Otherwise the response must fit within 512 bytes and have TC set.
if len(out) > 512 {
t.Fatalf("expected SERVFAIL or <=512 byte response, got %d bytes with RCode=%v",
len(out), h.RCode)
}
if !h.Truncated {
t.Fatalf("expected TC bit set for truncated response")
}
}
Loading…
Cancel
Save