ryankoski 6 days ago committed by GitHub
commit d735dcbc3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1160,7 +1160,8 @@ func servfailResponse(req packet) (res packet, err error) {
h := p.Header
h.Response = true
h.Authoritative = true
// Correct behavior for SERVFAIL is to set the Authoritative flag to 0.
h.Authoritative = false
h.RCode = dns.RCodeServerFailure
b := dns.NewBuilder(nil, h)
b.StartQuestions()

@ -506,9 +506,19 @@ func makeTestRequest(tb testing.TB, domain string) []byte {
func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
tb.Helper()
name := dns.MustNewName(domain)
// The correct value for the Authoritative bit is complicated.
// However, in all cases where a SERVFAIL is returned, it should be false.
// Since the servfailResponse() function correctly sets this bit to false,
// this test needs to also return false for RCodeServerFailure.
authoritative := true
if code == dns.RCodeServerFailure {
authoritative = false
}
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
Authoritative: true,
Authoritative: authoritative,
RCode: code,
})
builder.StartQuestions()

@ -315,17 +315,40 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net
default:
}
reqPacket := packet{bs: bs, family: family, addr: from}
out, err := r.respond(bs)
if err == errNotOurName {
responses := make(chan packet, 1)
ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout)
defer close(responses)
defer cancel()
err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses)
err = r.forwarder.forwardWithDestChan(ctx, reqPacket, responses)
if err != nil {
return nil, err
}
return (<-responses).bs, nil
out = (<-responses).bs
err = nil
}
// Only perform truncation/EDNS0 processing for UDP queries.
if err == nil && family == "udp" && out != nil {
// Determine client's advertised UDP size via EDNS0, default to 512
maxResponseSize := uint16(512)
if edns := extractEDNS0UDPSize(bs); edns > 0 {
maxResponseSize = edns
}
if len(out) > int(maxResponseSize) {
tr, terr := truncateDNSResponse(out, maxResponseSize)
if terr != nil {
// Can't safely truncate; return SERVFAIL
serv, berr := servfailResponse(reqPacket)
if berr != nil {
return nil, terr
}
return serv.bs, nil
}
out = tr
}
}
return out, err

@ -302,10 +302,15 @@ func dnsHandler(answers ...any) dns.HandlerFunc {
}
}
func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
func serveDNS(tb testing.TB, addr string, family string, records ...any) *dns.Server {
if len(records)%2 != 0 {
panic("must have an even number of record values")
}
switch family {
case "udp", "tcp":
default:
panic("family must be udp or tcp")
}
mux := dns.NewServeMux()
for i := 0; i < len(records); i += 2 {
name := records[i].(string)
@ -315,7 +320,7 @@ func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
waitch := make(chan struct{})
server := &dns.Server{
Addr: addr,
Net: "udp",
Net: family,
Handler: mux,
NotifyStartedFunc: func() { close(waitch) },
ReusePort: true,

@ -92,15 +92,16 @@ func dnspacket(domain dnsname.FQDN, tp dns.Type, ednsSize uint16) []byte {
}
type dnsResponse struct {
ip netip.Addr
txt []string
name dnsname.FQDN
rcode dns.RCode
truncated bool
requestEdns bool
requestEdnsSize uint16
responseEdns bool
responseEdnsSize uint16
ip netip.Addr
txt []string
name dnsname.FQDN
rcode dns.RCode
truncated bool
retryTCPonTruncation bool
requestEdns bool
requestEdnsSize uint16
responseEdns bool
responseEdnsSize uint16
}
func unpackResponse(payload []byte) (dnsResponse, error) {
@ -233,8 +234,13 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
return response, nil
}
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
return r.Query(context.Background(), query, "udp", netip.AddrPort{})
func syncRespond(r *Resolver, family string, query []byte) ([]byte, error) {
switch family {
case "udp", "tcp":
default:
return nil, fmt.Errorf("Invalid family %q", family)
}
return r.Query(context.Background(), query, family, netip.AddrPort{})
}
func mustIP(str string) netip.Addr {
@ -538,22 +544,37 @@ func TestDelegate(t *testing.T) {
"xlarge.txt.", resolveToTXT(xlargeTXT, 8000),
"huge.txt.", resolveToTXT(hugeTXT, 65527),
}
v4server := serveDNS(t, "127.0.0.1:0", records...)
defer v4server.Shutdown()
v6server := serveDNS(t, "[::1]:0", records...)
defer v6server.Shutdown()
r := newResolver(t)
defer r.Close()
v4UDPServer := serveDNS(t, "127.0.0.1:0", "udp", records...)
defer v4UDPServer.Shutdown()
v6UDPServer := serveDNS(t, "[::1]:0", "udp", records...)
defer v6UDPServer.Shutdown()
v4TCPServer := serveDNS(t, "127.0.0.1:0", "udp", records...)
defer v4TCPServer.Shutdown()
v6TCPServer := serveDNS(t, "[::1]:0", "udp", records...)
defer v6TCPServer.Shutdown()
udpResolver := newResolver(t)
defer udpResolver.Close()
tcpResolver := newResolver(t)
defer tcpResolver.Close()
udpcfg := dnsCfg
udpcfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{
".": {
&dnstype.Resolver{Addr: v4UDPServer.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v6UDPServer.PacketConn.LocalAddr().String()},
},
}
udpResolver.SetConfig(udpcfg)
cfg := dnsCfg
cfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{
tcpcfg := dnsCfg
tcpcfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{
".": {
&dnstype.Resolver{Addr: v4server.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v6server.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v4TCPServer.PacketConn.LocalAddr().String()},
&dnstype.Resolver{Addr: v6TCPServer.PacketConn.LocalAddr().String()},
},
}
r.SetConfig(cfg)
tcpResolver.SetConfig(tcpcfg)
tests := []struct {
title string
@ -616,44 +637,44 @@ func TestDelegate(t *testing.T) {
"medtxt",
dnspacket("med.txt.", dns.TypeTXT, 2000),
dnsResponse{
txt: medTXT,
rcode: dns.RCodeSuccess,
requestEdns: true,
requestEdnsSize: 2000,
responseEdns: true,
responseEdnsSize: 1500,
txt: medTXT,
rcode: dns.RCodeSuccess,
retryTCPonTruncation: true,
requestEdns: true,
requestEdnsSize: 2000,
responseEdns: true,
responseEdnsSize: 1500,
},
},
{
"largetxt",
dnspacket("large.txt.", dns.TypeTXT, maxResponseBytes),
dnsResponse{
txt: largeTXT,
rcode: dns.RCodeSuccess,
requestEdns: true,
requestEdnsSize: maxResponseBytes,
responseEdns: true,
responseEdnsSize: maxResponseBytes,
txt: largeTXT,
rcode: dns.RCodeSuccess,
retryTCPonTruncation: true,
requestEdns: true,
requestEdnsSize: maxResponseBytes,
responseEdns: true,
responseEdnsSize: maxResponseBytes,
},
},
{
"xlargetxt",
dnspacket("xlarge.txt.", dns.TypeTXT, 8000),
dnsResponse{
rcode: dns.RCodeSuccess,
truncated: true,
// request/response EDNS fields will be unset because of
// they were truncated away
rcode: dns.RCodeSuccess,
truncated: true,
retryTCPonTruncation: true,
},
},
{
"hugetxt",
dnspacket("huge.txt.", dns.TypeTXT, 8000),
dnsResponse{
rcode: dns.RCodeSuccess,
truncated: true,
// request/response EDNS fields will be unset because of
// they were truncated away
rcode: dns.RCodeSuccess,
truncated: true,
retryTCPonTruncation: true,
},
},
}
@ -663,7 +684,9 @@ func TestDelegate(t *testing.T) {
if tt.title == "hugetxt" && runtime.GOOS == "darwin" {
t.Skip("known to not work on macOS: https://github.com/tailscale/tailscale/issues/2229")
}
payload, err := syncRespond(r, tt.query)
runEDNSSizeChecks := true
payload, err := syncRespond(udpResolver, "udp", tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
@ -673,6 +696,27 @@ func TestDelegate(t *testing.T) {
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return
}
// If truncated and the test is configured to do so, retry over TCP.
// Additionally, some of the tests may result in a SERVFAIL response
// when queried over UDP because the total response is larger than
// the maximum supported buffer size. This results in a byte sequence
// that fails to parse correctly. In that case, we also retry over TCP.
if (response.truncated || response.rcode == dns.RCodeServerFailure) && tt.response.retryTCPonTruncation {
// Retry over TCP.
t.Logf("Retrying over TCP for %q", tt.title)
payload, err = syncRespond(tcpResolver, "tcp", tt.query)
if err != nil {
t.Errorf("TCP retry: err = %v; want nil", err)
return
}
response, err = unpackResponse(payload)
if err != nil {
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return
}
// On TCP, EDNS size is not applicable.
runEDNSSizeChecks = false
}
if response.rcode != tt.response.rcode {
t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
}
@ -691,17 +735,19 @@ func TestDelegate(t *testing.T) {
}
}
}
if response.requestEdns != tt.response.requestEdns {
t.Errorf("requestEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.requestEdnsSize != tt.response.requestEdnsSize {
t.Errorf("requestEdnsSize = %v; want %v", response.requestEdnsSize, tt.response.requestEdnsSize)
}
if response.responseEdns != tt.response.responseEdns {
t.Errorf("responseEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.responseEdnsSize != tt.response.responseEdnsSize {
t.Errorf("responseEdnsSize = %v; want %v", response.responseEdnsSize, tt.response.responseEdnsSize)
if runEDNSSizeChecks {
if response.requestEdns != tt.response.requestEdns {
t.Errorf("requestEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.requestEdnsSize != tt.response.requestEdnsSize {
t.Errorf("requestEdnsSize = %v; want %v", response.requestEdnsSize, tt.response.requestEdnsSize)
}
if response.responseEdns != tt.response.responseEdns {
t.Errorf("responseEdns = %v; want %v", response.requestEdns, tt.response.requestEdns)
}
if response.responseEdnsSize != tt.response.responseEdnsSize {
t.Errorf("responseEdnsSize = %v; want %v", response.responseEdnsSize, tt.response.responseEdnsSize)
}
}
})
}
@ -711,10 +757,10 @@ func TestDelegateSplitRoute(t *testing.T) {
test4 := netip.MustParseAddr("2.3.4.5")
test6 := netip.MustParseAddr("ff::1")
server1 := serveDNS(t, "127.0.0.1:0",
server1 := serveDNS(t, "127.0.0.1:0", "udp",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server1.Shutdown()
server2 := serveDNS(t, "127.0.0.1:0",
server2 := serveDNS(t, "127.0.0.1:0", "udp",
"test.other.", resolveToIP(test4, test6, "dns.other."))
defer server2.Shutdown()
@ -747,7 +793,7 @@ func TestDelegateSplitRoute(t *testing.T) {
for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
payload, err := syncRespond(r, tt.query)
payload, err := syncRespond(r, "udp", tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
@ -942,7 +988,7 @@ func TestFull(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
response, err := syncRespond(r, tt.request)
response, err := syncRespond(r, "udp", tt.request)
if err != nil {
t.Errorf("err = %v; want nil", err)
}
@ -974,7 +1020,7 @@ func TestAllocs(t *testing.T) {
for _, tt := range tests {
err := tstest.MinAllocsPerRun(t, tt.want, func() {
syncRespond(r, tt.query)
syncRespond(r, "udp", tt.query)
})
if err != nil {
t.Errorf("%s: %v", tt.name, err)
@ -1006,7 +1052,7 @@ func TestTrimRDNSBonjourPrefix(t *testing.T) {
}
func BenchmarkFull(b *testing.B) {
server := serveDNS(b, "127.0.0.1:0",
server := serveDNS(b, "127.0.0.1:0", "udp",
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown()
@ -1031,7 +1077,7 @@ func BenchmarkFull(b *testing.B) {
b.Run(tt.name, func(b *testing.B) {
b.ReportAllocs()
for range b.N {
syncRespond(r, tt.request)
syncRespond(r, "udp", tt.request)
}
})
}
@ -1159,7 +1205,7 @@ func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) {
"ns.test.",
dnsHandler(miekdns.NS{Ns: "ns1.foo."}, miekdns.NS{Ns: "ns2.bar."}),
}
v4server := serveDNS(t, "127.0.0.1:0", records...)
v4server := serveDNS(t, "127.0.0.1:0", "udp", records...)
defer v4server.Shutdown()
// backendResolver is the resolver between
@ -1485,7 +1531,7 @@ func TestUnARPA(t *testing.T) {
//
// See: https://github.com/tailscale/tailscale/issues/4722
func TestServfail(t *testing.T) {
server := serveDNS(t, "127.0.0.1:0", "test.site.", miekdns.HandlerFunc(func(w miekdns.ResponseWriter, req *miekdns.Msg) {
server := serveDNS(t, "127.0.0.1:0", "udp", "test.site.", miekdns.HandlerFunc(func(w miekdns.ResponseWriter, req *miekdns.Msg) {
m := new(miekdns.Msg)
m.Rcode = miekdns.RcodeServerFailure
w.WriteMsg(m)
@ -1501,14 +1547,14 @@ func TestServfail(t *testing.T) {
}
r.SetConfig(cfg)
pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns))
pkt, err := syncRespond(r, "udp", dnspacket("test.site.", dns.TypeA, noEdns))
if err != nil {
t.Fatalf("err = %v, want nil", err)
}
wantPkt := []byte{
0x00, 0x00, // transaction id: 0
0x84, 0x02, // flags: response, authoritative, error: servfail
0x80, 0x02, // flags: response, error: servfail
0x00, 0x01, // one question
0x00, 0x00, // no answers
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs

@ -0,0 +1,294 @@
package resolver
import (
"errors"
"golang.org/x/net/dns/dnsmessage"
)
// extractOPTResource parses a DNS message and returns the OPT resource if present.
func extractOPTResource(msg []byte) *dnsmessage.Resource {
var p dnsmessage.Parser
if _, err := p.Start(msg); err != nil {
return nil
}
var optRes *dnsmessage.Resource
optRes = nil
// Fast-forward to find OPT
if err := p.SkipAllQuestions(); err == nil {
if err := p.SkipAllAnswers(); err == nil {
if err := p.SkipAllAuthorities(); err == nil {
for {
r, err := p.Additional()
if err != nil {
break
}
if r.Header.Type == dnsmessage.TypeOPT {
optRes = &r
break
}
}
}
}
}
return optRes
}
const minEDNS0Size = 512 // per RFC 6891 Section 6.2.5
const maxEDNS0Size = 1232 // per DNS Flag Day 2020 recommendation
// extractEDNS0UDPSize extracts the advertised UDP buffer size from an EDNS0 OPT record
// in a DNS query packet. If no EDNS0 record is present or the packet is malformed,
// it returns 0, indicating the default 512-byte limit should be used.
func extractEDNS0UDPSize(query []byte) uint16 {
size := uint16(0)
optRes := extractOPTResource(query)
if optRes != nil {
// UDP payload size is encoded in the CLASS field of the OPT header.
// Per RFC 6891 §6.2.5, treat any advertised UDP size smaller than 512
// as 512. Per DNS Flag Day 2020 (https://www.dnsflagday.net/2020/),
// the cap should be 1232 bytes, and newer versions of resolvers
// have set 1232 as their default limit.
size = uint16(optRes.Header.Class)
if size < minEDNS0Size {
size = minEDNS0Size
}
if size > maxEDNS0Size {
size = maxEDNS0Size
}
}
return size
}
// truncateDNSResponse performs RFC-compliant truncation of a DNS
// response message. It preserves the question section and as many
// resource records as possible in the answer, authority, and
// additional sections, setting the TC (truncated) bit if truncation
// occurs. It enforces RFC 6891 Section 7 (preserving the OPT record
// in truncated responses).
func truncateDNSResponse(resp []byte, maxSize uint16) ([]byte, error) {
// Sanity check on maxSize. It must be at least large enough
// to hold a minimal DNS header (12 bytes) and at least one
// question (5 bytes).
if maxSize < 12+5 {
return nil, errors.New("maxSize too small to hold minimal DNS message")
}
var p dnsmessage.Parser
header, err := p.Start(resp)
if err != nil {
return nil, err
}
// 1. Extract all records into slices so we can manage them.
questions, err := p.AllQuestions()
if err != nil {
return nil, err
}
var answers, authorities, additionals []dnsmessage.Resource
var optRes *dnsmessage.Resource
// Helper to extract resources from a section
extractSection := func(sectionName string) ([]dnsmessage.Resource, error) {
var extracted []dnsmessage.Resource
for {
var r dnsmessage.Resource
var err error
switch sectionName {
case "Ans":
r, err = p.Answer()
case "Auth":
r, err = p.Authority()
case "Add":
r, err = p.Additional()
}
if err == dnsmessage.ErrSectionDone {
return extracted, nil
}
if err != nil {
return nil, err
}
// Identify and isolate the OPT record
if r.Header.Type == dnsmessage.TypeOPT {
// We found the OPT record. Save it separately.
// (RFC 6891: Only one OPT record is allowed)
optRes = &r
} else {
extracted = append(extracted, r)
}
}
}
// We must parse sections in order: Skip Questions (already got them), then Ans, Auth, Add.
// Note: p.AllQuestions() already advanced the parser past questions.
if answers, err = extractSection("Ans"); err != nil {
return nil, err
}
if authorities, err = extractSection("Auth"); err != nil {
return nil, err
}
if additionals, err = extractSection("Add"); err != nil {
return nil, err
}
// 2. Try to build the FULL packet first (Happy Path).
// If it fits, we avoid the expensive iterative logic.
fullPacket, err := buildResponse(header, questions, answers, authorities, additionals, optRes)
if err == nil && uint16(len(fullPacket)) <= maxSize {
return fullPacket, nil
}
// 3. Truncation Path.
// The packet is too big. We must rebuild it record-by-record until full.
// We MUST set the TC bit.
header.Truncated = true
// We start with empty sections.
var finalAns, finalAuth, finalAdd []dnsmessage.Resource
// Define the order of candidates we want to try adding.
// (Answers first, then Authorities, then Additionals)
// We use a list of *slices* to iterate section by section.
sections := []struct {
candidates []dnsmessage.Resource
target *[]dnsmessage.Resource // Pointer to the slice we are building
}{
{answers, &finalAns},
{authorities, &finalAuth},
{additionals, &finalAdd},
}
for _, section := range sections {
for _, candidate := range section.candidates {
// Speculatively add this candidate to the target list
*section.target = append(*section.target, candidate)
// Build the packet with the current set of records + Mandatory OPT
testPacket, err := buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes)
if err != nil {
return nil, err // Should not happen with valid resources
}
// Check size
if uint16(len(testPacket)) > maxSize {
// Stop! This record broke the limit.
// Remove the last added record (backtrack).
*section.target = (*section.target)[:len(*section.target)-1]
// We are full. Return the last valid build.
// Note: We need to rebuild one last time or save the previous successful 'testPacket'.
// To be safe/clean, let's just rebuild the "safe" state.
return buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes)
}
// If it fits, continue loop to add next candidate.
}
}
// If we somehow finish the loop (unlikely given we failed the "Full" check), return what we have.
return buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes)
}
// buildResponse constructs a binary DNS message from the provided slices.
// It handles the complex state machine of dnsmessage.Builder.
func buildResponse(
h dnsmessage.Header,
qs []dnsmessage.Question,
ans, auths, adds []dnsmessage.Resource,
opt *dnsmessage.Resource,
) ([]byte, error) {
// Start with a nil buffer; Builder will allocate.
b := dnsmessage.NewBuilder(nil, h)
b.EnableCompression()
// 1. Questions
if err := b.StartQuestions(); err != nil {
return nil, err
}
for _, q := range qs {
if err := b.Question(q); err != nil {
return nil, err
}
}
// 2. Answers
if err := b.StartAnswers(); err != nil {
return nil, err
}
for _, r := range ans {
if err := addResource(&b, r); err != nil {
return nil, err
}
}
// 3. Authorities
if err := b.StartAuthorities(); err != nil {
return nil, err
}
for _, r := range auths {
if err := addResource(&b, r); err != nil {
return nil, err
}
}
// 4. Additionals
if err := b.StartAdditionals(); err != nil {
return nil, err
}
for _, r := range adds {
if err := addResource(&b, r); err != nil {
return nil, err
}
}
// Always append the OPT record if it exists (RFC 6891)
if opt != nil {
if err := addResource(&b, *opt); err != nil {
return nil, err
}
}
// Finish and return the bytes
return b.Finish()
}
// addResource is a helper to handle the various resource types
// when adding individual resources to the Builder.
func addResource(b *dnsmessage.Builder, r dnsmessage.Resource) error {
switch body := r.Body.(type) {
case *dnsmessage.AResource:
return b.AResource(r.Header, *body)
case *dnsmessage.AAAAResource:
return b.AAAAResource(r.Header, *body)
case *dnsmessage.CNAMEResource:
return b.CNAMEResource(r.Header, *body)
case *dnsmessage.HTTPSResource:
return b.HTTPSResource(r.Header, *body)
case *dnsmessage.NSResource:
return b.NSResource(r.Header, *body)
case *dnsmessage.PTRResource:
return b.PTRResource(r.Header, *body)
case *dnsmessage.SOAResource:
return b.SOAResource(r.Header, *body)
case *dnsmessage.MXResource:
return b.MXResource(r.Header, *body)
case *dnsmessage.TXTResource:
return b.TXTResource(r.Header, *body)
case *dnsmessage.SRVResource:
return b.SRVResource(r.Header, *body)
case *dnsmessage.OPTResource:
return b.OPTResource(r.Header, *body)
case *dnsmessage.UnknownResource:
// Handles unsupported/generic types
return b.UnknownResource(r.Header, *body)
default:
return errors.New("unsupported resource body type")
}
}

@ -0,0 +1,276 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package resolver
import (
"context"
"fmt"
"net/netip"
"testing"
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/types/dnstype"
"tailscale.com/util/dnsname"
)
// Note: This test file uses helper builders already present in other resolver
// tests (e.g., makeTestRequest/makeTestResponse/dnspacket) since they are in
// the same package test space.
func TestExtractValidEDNS0UDPSize(t *testing.T) {
q := dnspacket("example.com.", dns.TypeA, 917)
got := extractEDNS0UDPSize(q)
if got != 917 {
t.Fatalf("expected 917, got %v", got)
}
}
func TestExtractSmallEDNS0UDPSize(t *testing.T) {
q := dnspacket("example.com.", dns.TypeA, 100)
got := extractEDNS0UDPSize(q)
// extractEDNS0UDPSize enforces minimum of 512 per RFC 6891 §6.2.5
if got != minEDNS0Size {
t.Fatalf("expected %v, got %v", minEDNS0Size, got)
}
}
func TestExtractLargeEDNS0UDPSize(t *testing.T) {
q := dnspacket("example.com.", dns.TypeA, 5000)
got := extractEDNS0UDPSize(q)
// extractEDNS0UDPSize caps at maxEDNS0Size
if got != maxEDNS0Size {
t.Fatalf("expected %v, got %v", maxEDNS0Size, got)
}
}
func TestTruncateNonEDNS(t *testing.T) {
// Build a very large response (many A records) without EDNS
// Create response with many answers
name := dns.MustNewName("example.com.")
b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess})
if err := b.StartQuestions(); err != nil {
t.Fatal(err)
}
if err := b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET}); err != nil {
t.Fatal(err)
}
if err := b.StartAnswers(); err != nil {
t.Fatal(err)
}
// add enough A records to exceed 512 bytes
for i := 0; i < 200; i++ {
b.AResource(dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60}, dns.AResource{A: [4]byte{192, 0, 2, byte(i % 255)}})
}
resp, err := b.Finish()
if err != nil {
t.Fatal(err)
}
if len(resp) <= 512 {
t.Fatalf("response not large enough for test: %d", len(resp))
}
tr, err := truncateDNSResponse(resp, 512)
if err != nil {
t.Fatalf("truncate failed: %v", err)
}
if len(tr) > 512 {
t.Fatalf("truncated response too large: %d", len(tr))
}
// Check TC bit set
var p dns.Parser
h, err := p.Start(tr)
if err != nil {
t.Fatalf("parse truncated: %v", err)
}
if !h.Truncated {
t.Fatalf("expected Truncated bit set")
}
}
func TestEDNSAllowsLarger(t *testing.T) {
// Build request that advertises EDNS size 1232
ednsSize := uint16(1232)
q := dnspacket("example.com.", dns.TypeA, ednsSize)
if got := extractEDNS0UDPSize(q); got != ednsSize {
t.Fatalf("expected 1232, got %v", got)
}
// Build response of size >512 but <1232
name := dns.MustNewName("example.com.")
b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess})
b.EnableCompression()
b.StartQuestions()
b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET})
b.StartAnswers()
for i := 0; i < 50; i++ {
b.AResource(dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60}, dns.AResource{A: [4]byte{10, 0, 0, byte(i)}})
}
resp, err := b.Finish()
if err != nil {
t.Fatal(err)
}
if len(resp) <= 512 || len(resp) >= int(ednsSize) {
t.Fatalf("invalid response size %d", len(resp))
}
tr, err := truncateDNSResponse(resp, ednsSize)
if err != nil {
t.Fatalf("truncate failed: %v", err)
}
if len(tr) != len(resp) {
t.Fatalf("unexpected truncation when EDNS allows large: %d vs %d", len(tr), len(resp))
}
}
// TestTruncateDNSResponseImpossible verifies that truncateDNSResponse
// returns an error when the provided maxSize is too small to even encode
// the header+question portion of the message.
func TestTruncateDNSResponseImpossible(t *testing.T) {
// Build a normal query packet and attempt to truncate it to a very small
// size that cannot contain the header+question.
req := makeTestRequest(t, "example.com.")
if len(req) < 20 {
t.Fatalf("test request unexpectedly small: %d", len(req))
}
// Choose a maxSize smaller than the request's header+question length.
// Using 10 bytes is guaranteed to be too small.
if _, err := truncateDNSResponse(req, 10); err == nil {
t.Fatalf("expected error truncating to impossibly small size, got nil")
}
}
// TestTruncateDNSResponseDirectCall tests truncateDNSResponse with a large
// well-formed DNS response. This directly verifies that
// truncateDNSResponse produces a syntactically valid truncated response
// with the TC bit set.
func TestTruncateDNSResponseDirectCall(t *testing.T) {
const domain = "example.com."
// Build a very large DNS response (many A records)
name := dns.MustNewName(domain)
b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess})
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
t.Fatal(err)
}
if err := b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET}); err != nil {
t.Fatal(err)
}
if err := b.StartAnswers(); err != nil {
t.Fatal(err)
}
// Add enough A records to exceed 512 bytes significantly.
// Each A record is roughly 20 bytes, so 150 records will be ~3000 bytes.
for i := 0; i < 150; i++ {
err := b.AResource(
dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60},
dns.AResource{A: [4]byte{10, 0, 0, byte(i % 256)}},
)
if err != nil {
t.Fatalf("failed to add A record: %v", err)
}
}
largeResp, err := b.Finish()
if err != nil {
t.Fatalf("failed to build large response: %v", err)
}
// Verify the response is large enough for truncation.
if len(largeResp) <= 512 {
t.Fatalf("test response not large enough for truncation: %d bytes", len(largeResp))
}
tr, err := truncateDNSResponse(largeResp, 512)
if err != nil {
t.Fatalf("truncateDNSResponse failed: %v", err)
}
// Verify the truncated response:
// 1. Fits within 512 bytes
if len(tr) > 512 {
t.Fatalf("truncated response exceeds 512 bytes: got %d", len(tr))
}
// 2. Is syntactically valid
var p dns.Parser
h, err := p.Start(tr)
if err != nil {
t.Fatalf("failed to parse truncated response: %v", err)
}
// 3. Has TC (Truncated) bit set
if !h.Truncated {
t.Fatalf("expected TC (Truncated) bit to be set in truncated response")
}
}
// TestResolverSERVFAILOnImpossibleTruncation ensures that when a client
// advertises a tiny EDNS buffer size such that the resolver cannot safely
// encode even the header+question within that size, the resolver returns a
// SERVFAIL response rather than an invalid/truncated packet.
func TestResolverSERVFAILOnImpossibleTruncation(t *testing.T) {
const domain = "srvfail.example.com."
// Build a request that advertises a very small EDNS size (50 bytes).
// This is small enough to require truncation but large enough for header+question.
request := dnspacket(domain, dns.TypeA, 50)
// Verify EDNS extraction enforces the RFC 6891 minimum of 512.
ednsSize := extractEDNS0UDPSize(request)
if ednsSize != 512 {
t.Fatalf("EDNS extraction failed: expected 512, got %d", ednsSize)
}
// Build a very large upstream response for the same domain so that the
// resolver will attempt truncation and fail.
_, largeResponse := makeLargeResponse(t, domain)
// Run a test DNS server returning the large response.
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
// DNS server received a request; just ensure the server is reachable
})
// Configure resolver to forward queries to our server.
r := newResolver(t)
defer r.Close()
cfg := Config{
Routes: map[dnsname.FQDN][]*dnstype.Resolver{
dnsname.FQDN("."): {{Addr: fmt.Sprintf("127.0.0.1:%d", port)}},
},
}
if err := r.SetConfig(cfg); err != nil {
t.Fatalf("SetConfig: %v", err)
}
// Query the resolver over UDP with the tiny EDNS size.
ctx := context.Background()
out, err := r.Query(ctx, request, "udp", netip.MustParseAddrPort("127.0.0.1:12345"))
if err != nil {
t.Fatalf("Query failed: %v", err)
}
// The response should be either:
// 1. A SERVFAIL (if truncation was impossible), or
// 2. A response that fits within the effective EDNS size (512 bytes) with TC bit set.
var p dns.Parser
h, err := p.Start(out)
if err != nil {
t.Fatalf("parse response: %v", err)
}
if h.RCode == dns.RCodeServerFailure {
// Good - impossible truncation was handled correctly
return
}
// Otherwise the response must fit within 512 bytes and have TC set.
if len(out) > 512 {
t.Fatalf("expected SERVFAIL or <=512 byte response, got %d bytes with RCode=%v",
len(out), h.RCode)
}
if !h.Truncated {
t.Fatalf("expected TC bit set for truncated response")
}
}
Loading…
Cancel
Save