diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 5adc43efc..26770b19b 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -1158,7 +1158,8 @@ func servfailResponse(req packet) (res packet, err error) { h := p.Header h.Response = true - h.Authoritative = true + // Correct behavior for SERVFAIL is to set the Authoritative flag to 0. + h.Authoritative = false h.RCode = dns.RCodeServerFailure b := dns.NewBuilder(nil, h) b.StartQuestions() diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index ec491c581..f2bdfc90c 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -506,9 +506,19 @@ func makeTestRequest(tb testing.TB, domain string) []byte { func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte { tb.Helper() name := dns.MustNewName(domain) + + // The correct value for the Authoritative bit is complicated. + // However, in all cases where a SERVFAIL is returned, it should be false. + // Since the servfailResponse() function correctly sets this bit to false, + // this test needs to also return false for RCodeServerFailure. + authoritative := true + if code == dns.RCodeServerFailure { + authoritative = false + } + builder := dns.NewBuilder(nil, dns.Header{ Response: true, - Authoritative: true, + Authoritative: authoritative, RCode: code, }) builder.StartQuestions() diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 3185cbe2b..771ca4f70 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -315,17 +315,40 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net default: } + reqPacket := packet{bs: bs, family: family, addr: from} out, err := r.respond(bs) if err == errNotOurName { responses := make(chan packet, 1) ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout) defer close(responses) defer cancel() - err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses) + err = r.forwarder.forwardWithDestChan(ctx, reqPacket, responses) if err != nil { return nil, err } - return (<-responses).bs, nil + out = (<-responses).bs + err = nil + } + + // Only perform truncation/EDNS0 processing for UDP queries. + if err == nil && family == "udp" && out != nil { + // Determine client's advertised UDP size via EDNS0, default to 512 + maxResponseSize := uint16(512) + if edns := extractEDNS0UDPSize(bs); edns > 0 { + maxResponseSize = edns + } + if len(out) > int(maxResponseSize) { + tr, terr := truncateDNSResponse(out, maxResponseSize) + if terr != nil { + // Can't safely truncate; return SERVFAIL + serv, berr := servfailResponse(reqPacket) + if berr != nil { + return nil, terr + } + return serv.bs, nil + } + out = tr + } } return out, err diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index 82fd3bebf..31edd3b2f 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -302,10 +302,15 @@ func dnsHandler(answers ...any) dns.HandlerFunc { } } -func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { +func serveDNS(tb testing.TB, addr string, family string, records ...any) *dns.Server { if len(records)%2 != 0 { panic("must have an even number of record values") } + switch family { + case "udp", "tcp": + default: + panic("family must be udp or tcp") + } mux := dns.NewServeMux() for i := 0; i < len(records); i += 2 { name := records[i].(string) @@ -315,7 +320,7 @@ func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { waitch := make(chan struct{}) server := &dns.Server{ Addr: addr, - Net: "udp", + Net: family, Handler: mux, NotifyStartedFunc: func() { close(waitch) }, ReusePort: true, diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index f0dbb48b3..aa1ab03bf 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -92,15 +92,16 @@ func dnspacket(domain dnsname.FQDN, tp dns.Type, ednsSize uint16) []byte { } type dnsResponse struct { - ip netip.Addr - txt []string - name dnsname.FQDN - rcode dns.RCode - truncated bool - requestEdns bool - requestEdnsSize uint16 - responseEdns bool - responseEdnsSize uint16 + ip netip.Addr + txt []string + name dnsname.FQDN + rcode dns.RCode + truncated bool + retryTCPonTruncation bool + requestEdns bool + requestEdnsSize uint16 + responseEdns bool + responseEdnsSize uint16 } func unpackResponse(payload []byte) (dnsResponse, error) { @@ -233,8 +234,13 @@ func unpackResponse(payload []byte) (dnsResponse, error) { return response, nil } -func syncRespond(r *Resolver, query []byte) ([]byte, error) { - return r.Query(context.Background(), query, "udp", netip.AddrPort{}) +func syncRespond(r *Resolver, family string, query []byte) ([]byte, error) { + switch family { + case "udp", "tcp": + default: + return nil, fmt.Errorf("Invalid family %q", family) + } + return r.Query(context.Background(), query, family, netip.AddrPort{}) } func mustIP(str string) netip.Addr { @@ -538,22 +544,37 @@ func TestDelegate(t *testing.T) { "xlarge.txt.", resolveToTXT(xlargeTXT, 8000), "huge.txt.", resolveToTXT(hugeTXT, 65527), } - v4server := serveDNS(t, "127.0.0.1:0", records...) - defer v4server.Shutdown() - v6server := serveDNS(t, "[::1]:0", records...) - defer v6server.Shutdown() - - r := newResolver(t) - defer r.Close() + v4UDPServer := serveDNS(t, "127.0.0.1:0", "udp", records...) + defer v4UDPServer.Shutdown() + v6UDPServer := serveDNS(t, "[::1]:0", "udp", records...) + defer v6UDPServer.Shutdown() + v4TCPServer := serveDNS(t, "127.0.0.1:0", "udp", records...) + defer v4TCPServer.Shutdown() + v6TCPServer := serveDNS(t, "[::1]:0", "udp", records...) + defer v6TCPServer.Shutdown() + + udpResolver := newResolver(t) + defer udpResolver.Close() + tcpResolver := newResolver(t) + defer tcpResolver.Close() + + udpcfg := dnsCfg + udpcfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{ + ".": { + &dnstype.Resolver{Addr: v4UDPServer.PacketConn.LocalAddr().String()}, + &dnstype.Resolver{Addr: v6UDPServer.PacketConn.LocalAddr().String()}, + }, + } + udpResolver.SetConfig(udpcfg) - cfg := dnsCfg - cfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{ + tcpcfg := dnsCfg + tcpcfg.Routes = map[dnsname.FQDN][]*dnstype.Resolver{ ".": { - &dnstype.Resolver{Addr: v4server.PacketConn.LocalAddr().String()}, - &dnstype.Resolver{Addr: v6server.PacketConn.LocalAddr().String()}, + &dnstype.Resolver{Addr: v4TCPServer.PacketConn.LocalAddr().String()}, + &dnstype.Resolver{Addr: v6TCPServer.PacketConn.LocalAddr().String()}, }, } - r.SetConfig(cfg) + tcpResolver.SetConfig(tcpcfg) tests := []struct { title string @@ -616,44 +637,44 @@ func TestDelegate(t *testing.T) { "medtxt", dnspacket("med.txt.", dns.TypeTXT, 2000), dnsResponse{ - txt: medTXT, - rcode: dns.RCodeSuccess, - requestEdns: true, - requestEdnsSize: 2000, - responseEdns: true, - responseEdnsSize: 1500, + txt: medTXT, + rcode: dns.RCodeSuccess, + retryTCPonTruncation: true, + requestEdns: true, + requestEdnsSize: 2000, + responseEdns: true, + responseEdnsSize: 1500, }, }, { "largetxt", dnspacket("large.txt.", dns.TypeTXT, maxResponseBytes), dnsResponse{ - txt: largeTXT, - rcode: dns.RCodeSuccess, - requestEdns: true, - requestEdnsSize: maxResponseBytes, - responseEdns: true, - responseEdnsSize: maxResponseBytes, + txt: largeTXT, + rcode: dns.RCodeSuccess, + retryTCPonTruncation: true, + requestEdns: true, + requestEdnsSize: maxResponseBytes, + responseEdns: true, + responseEdnsSize: maxResponseBytes, }, }, { "xlargetxt", dnspacket("xlarge.txt.", dns.TypeTXT, 8000), dnsResponse{ - rcode: dns.RCodeSuccess, - truncated: true, - // request/response EDNS fields will be unset because of - // they were truncated away + rcode: dns.RCodeSuccess, + truncated: true, + retryTCPonTruncation: true, }, }, { "hugetxt", dnspacket("huge.txt.", dns.TypeTXT, 8000), dnsResponse{ - rcode: dns.RCodeSuccess, - truncated: true, - // request/response EDNS fields will be unset because of - // they were truncated away + rcode: dns.RCodeSuccess, + truncated: true, + retryTCPonTruncation: true, }, }, } @@ -663,7 +684,9 @@ func TestDelegate(t *testing.T) { if tt.title == "hugetxt" && runtime.GOOS == "darwin" { t.Skip("known to not work on macOS: https://github.com/tailscale/tailscale/issues/2229") } - payload, err := syncRespond(r, tt.query) + + runEDNSSizeChecks := true + payload, err := syncRespond(udpResolver, "udp", tt.query) if err != nil { t.Errorf("err = %v; want nil", err) return @@ -673,6 +696,27 @@ func TestDelegate(t *testing.T) { t.Errorf("extract: err = %v; want nil (in %x)", err, payload) return } + // If truncated and the test is configured to do so, retry over TCP. + // Additionally, some of the tests may result in a SERVFAIL response + // when queried over UDP because the total response is larger than + // the maximum supported buffer size. This results in a byte sequence + // that fails to parse correctly. In that case, we also retry over TCP. + if (response.truncated || response.rcode == dns.RCodeServerFailure) && tt.response.retryTCPonTruncation { + // Retry over TCP. + t.Logf("Retrying over TCP for %q", tt.title) + payload, err = syncRespond(tcpResolver, "tcp", tt.query) + if err != nil { + t.Errorf("TCP retry: err = %v; want nil", err) + return + } + response, err = unpackResponse(payload) + if err != nil { + t.Errorf("extract: err = %v; want nil (in %x)", err, payload) + return + } + // On TCP, EDNS size is not applicable. + runEDNSSizeChecks = false + } if response.rcode != tt.response.rcode { t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode) } @@ -691,17 +735,19 @@ func TestDelegate(t *testing.T) { } } } - if response.requestEdns != tt.response.requestEdns { - t.Errorf("requestEdns = %v; want %v", response.requestEdns, tt.response.requestEdns) - } - if response.requestEdnsSize != tt.response.requestEdnsSize { - t.Errorf("requestEdnsSize = %v; want %v", response.requestEdnsSize, tt.response.requestEdnsSize) - } - if response.responseEdns != tt.response.responseEdns { - t.Errorf("responseEdns = %v; want %v", response.requestEdns, tt.response.requestEdns) - } - if response.responseEdnsSize != tt.response.responseEdnsSize { - t.Errorf("responseEdnsSize = %v; want %v", response.responseEdnsSize, tt.response.responseEdnsSize) + if runEDNSSizeChecks { + if response.requestEdns != tt.response.requestEdns { + t.Errorf("requestEdns = %v; want %v", response.requestEdns, tt.response.requestEdns) + } + if response.requestEdnsSize != tt.response.requestEdnsSize { + t.Errorf("requestEdnsSize = %v; want %v", response.requestEdnsSize, tt.response.requestEdnsSize) + } + if response.responseEdns != tt.response.responseEdns { + t.Errorf("responseEdns = %v; want %v", response.requestEdns, tt.response.requestEdns) + } + if response.responseEdnsSize != tt.response.responseEdnsSize { + t.Errorf("responseEdnsSize = %v; want %v", response.responseEdnsSize, tt.response.responseEdnsSize) + } } }) } @@ -711,10 +757,10 @@ func TestDelegateSplitRoute(t *testing.T) { test4 := netip.MustParseAddr("2.3.4.5") test6 := netip.MustParseAddr("ff::1") - server1 := serveDNS(t, "127.0.0.1:0", + server1 := serveDNS(t, "127.0.0.1:0", "udp", "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) defer server1.Shutdown() - server2 := serveDNS(t, "127.0.0.1:0", + server2 := serveDNS(t, "127.0.0.1:0", "udp", "test.other.", resolveToIP(test4, test6, "dns.other.")) defer server2.Shutdown() @@ -747,7 +793,7 @@ func TestDelegateSplitRoute(t *testing.T) { for _, tt := range tests { t.Run(tt.title, func(t *testing.T) { - payload, err := syncRespond(r, tt.query) + payload, err := syncRespond(r, "udp", tt.query) if err != nil { t.Errorf("err = %v; want nil", err) return @@ -942,7 +988,7 @@ func TestFull(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response, err := syncRespond(r, tt.request) + response, err := syncRespond(r, "udp", tt.request) if err != nil { t.Errorf("err = %v; want nil", err) } @@ -974,7 +1020,7 @@ func TestAllocs(t *testing.T) { for _, tt := range tests { err := tstest.MinAllocsPerRun(t, tt.want, func() { - syncRespond(r, tt.query) + syncRespond(r, "udp", tt.query) }) if err != nil { t.Errorf("%s: %v", tt.name, err) @@ -1006,7 +1052,7 @@ func TestTrimRDNSBonjourPrefix(t *testing.T) { } func BenchmarkFull(b *testing.B) { - server := serveDNS(b, "127.0.0.1:0", + server := serveDNS(b, "127.0.0.1:0", "udp", "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) defer server.Shutdown() @@ -1031,7 +1077,7 @@ func BenchmarkFull(b *testing.B) { b.Run(tt.name, func(b *testing.B) { b.ReportAllocs() for range b.N { - syncRespond(r, tt.request) + syncRespond(r, "udp", tt.request) } }) } @@ -1159,7 +1205,7 @@ func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) { "ns.test.", dnsHandler(miekdns.NS{Ns: "ns1.foo."}, miekdns.NS{Ns: "ns2.bar."}), } - v4server := serveDNS(t, "127.0.0.1:0", records...) + v4server := serveDNS(t, "127.0.0.1:0", "udp", records...) defer v4server.Shutdown() // backendResolver is the resolver between @@ -1485,7 +1531,7 @@ func TestUnARPA(t *testing.T) { // // See: https://github.com/tailscale/tailscale/issues/4722 func TestServfail(t *testing.T) { - server := serveDNS(t, "127.0.0.1:0", "test.site.", miekdns.HandlerFunc(func(w miekdns.ResponseWriter, req *miekdns.Msg) { + server := serveDNS(t, "127.0.0.1:0", "udp", "test.site.", miekdns.HandlerFunc(func(w miekdns.ResponseWriter, req *miekdns.Msg) { m := new(miekdns.Msg) m.Rcode = miekdns.RcodeServerFailure w.WriteMsg(m) @@ -1501,14 +1547,14 @@ func TestServfail(t *testing.T) { } r.SetConfig(cfg) - pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns)) + pkt, err := syncRespond(r, "udp", dnspacket("test.site.", dns.TypeA, noEdns)) if err != nil { t.Fatalf("err = %v, want nil", err) } wantPkt := []byte{ 0x00, 0x00, // transaction id: 0 - 0x84, 0x02, // flags: response, authoritative, error: servfail + 0x80, 0x02, // flags: response, error: servfail 0x00, 0x01, // one question 0x00, 0x00, // no answers 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs diff --git a/net/dns/resolver/udp_trunc.go b/net/dns/resolver/udp_trunc.go new file mode 100644 index 000000000..9ca6b98e1 --- /dev/null +++ b/net/dns/resolver/udp_trunc.go @@ -0,0 +1,294 @@ +package resolver + +import ( + "errors" + + "golang.org/x/net/dns/dnsmessage" +) + +// extractOPTResource parses a DNS message and returns the OPT resource if present. +func extractOPTResource(msg []byte) *dnsmessage.Resource { + var p dnsmessage.Parser + if _, err := p.Start(msg); err != nil { + return nil + } + + var optRes *dnsmessage.Resource + optRes = nil + + // Fast-forward to find OPT + if err := p.SkipAllQuestions(); err == nil { + if err := p.SkipAllAnswers(); err == nil { + if err := p.SkipAllAuthorities(); err == nil { + for { + r, err := p.Additional() + if err != nil { + break + } + if r.Header.Type == dnsmessage.TypeOPT { + optRes = &r + break + } + } + } + } + } + return optRes +} + +const minEDNS0Size = 512 // per RFC 6891 Section 6.2.5 +const maxEDNS0Size = 1232 // per DNS Flag Day 2020 recommendation + +// extractEDNS0UDPSize extracts the advertised UDP buffer size from an EDNS0 OPT record +// in a DNS query packet. If no EDNS0 record is present or the packet is malformed, +// it returns 0, indicating the default 512-byte limit should be used. +func extractEDNS0UDPSize(query []byte) uint16 { + size := uint16(0) + optRes := extractOPTResource(query) + + if optRes != nil { + // UDP payload size is encoded in the CLASS field of the OPT header. + // Per RFC 6891 §6.2.5, treat any advertised UDP size smaller than 512 + // as 512. Per DNS Flag Day 2020 (https://www.dnsflagday.net/2020/), + // the cap should be 1232 bytes, and newer versions of resolvers + // have set 1232 as their default limit. + size = uint16(optRes.Header.Class) + if size < minEDNS0Size { + size = minEDNS0Size + } + if size > maxEDNS0Size { + size = maxEDNS0Size + } + } + return size +} + +// truncateDNSResponse performs RFC-compliant truncation of a DNS +// response message. It preserves the question section and as many +// resource records as possible in the answer, authority, and +// additional sections, setting the TC (truncated) bit if truncation +// occurs. It enforces RFC 6891 Section 7 (preserving the OPT record +// in truncated responses). +func truncateDNSResponse(resp []byte, maxSize uint16) ([]byte, error) { + // Sanity check on maxSize. It must be at least large enough + // to hold a minimal DNS header (12 bytes) and at least one + // question (5 bytes). + if maxSize < 12+5 { + return nil, errors.New("maxSize too small to hold minimal DNS message") + } + + var p dnsmessage.Parser + + header, err := p.Start(resp) + if err != nil { + return nil, err + } + + // 1. Extract all records into slices so we can manage them. + questions, err := p.AllQuestions() + if err != nil { + return nil, err + } + + var answers, authorities, additionals []dnsmessage.Resource + var optRes *dnsmessage.Resource + + // Helper to extract resources from a section + extractSection := func(sectionName string) ([]dnsmessage.Resource, error) { + var extracted []dnsmessage.Resource + for { + var r dnsmessage.Resource + var err error + switch sectionName { + case "Ans": + r, err = p.Answer() + case "Auth": + r, err = p.Authority() + case "Add": + r, err = p.Additional() + } + if err == dnsmessage.ErrSectionDone { + return extracted, nil + } + if err != nil { + return nil, err + } + + // Identify and isolate the OPT record + if r.Header.Type == dnsmessage.TypeOPT { + // We found the OPT record. Save it separately. + // (RFC 6891: Only one OPT record is allowed) + optRes = &r + } else { + extracted = append(extracted, r) + } + } + } + + // We must parse sections in order: Skip Questions (already got them), then Ans, Auth, Add. + // Note: p.AllQuestions() already advanced the parser past questions. + + if answers, err = extractSection("Ans"); err != nil { + return nil, err + } + if authorities, err = extractSection("Auth"); err != nil { + return nil, err + } + if additionals, err = extractSection("Add"); err != nil { + return nil, err + } + + // 2. Try to build the FULL packet first (Happy Path). + // If it fits, we avoid the expensive iterative logic. + fullPacket, err := buildResponse(header, questions, answers, authorities, additionals, optRes) + if err == nil && uint16(len(fullPacket)) <= maxSize { + return fullPacket, nil + } + + // 3. Truncation Path. + // The packet is too big. We must rebuild it record-by-record until full. + // We MUST set the TC bit. + header.Truncated = true + + // We start with empty sections. + var finalAns, finalAuth, finalAdd []dnsmessage.Resource + + // Define the order of candidates we want to try adding. + // (Answers first, then Authorities, then Additionals) + // We use a list of *slices* to iterate section by section. + sections := []struct { + candidates []dnsmessage.Resource + target *[]dnsmessage.Resource // Pointer to the slice we are building + }{ + {answers, &finalAns}, + {authorities, &finalAuth}, + {additionals, &finalAdd}, + } + + for _, section := range sections { + for _, candidate := range section.candidates { + // Speculatively add this candidate to the target list + *section.target = append(*section.target, candidate) + + // Build the packet with the current set of records + Mandatory OPT + testPacket, err := buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes) + if err != nil { + return nil, err // Should not happen with valid resources + } + + // Check size + if uint16(len(testPacket)) > maxSize { + // Stop! This record broke the limit. + // Remove the last added record (backtrack). + *section.target = (*section.target)[:len(*section.target)-1] + + // We are full. Return the last valid build. + // Note: We need to rebuild one last time or save the previous successful 'testPacket'. + // To be safe/clean, let's just rebuild the "safe" state. + return buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes) + } + + // If it fits, continue loop to add next candidate. + } + } + + // If we somehow finish the loop (unlikely given we failed the "Full" check), return what we have. + return buildResponse(header, questions, finalAns, finalAuth, finalAdd, optRes) +} + +// buildResponse constructs a binary DNS message from the provided slices. +// It handles the complex state machine of dnsmessage.Builder. +func buildResponse( + h dnsmessage.Header, + qs []dnsmessage.Question, + ans, auths, adds []dnsmessage.Resource, + opt *dnsmessage.Resource, +) ([]byte, error) { + // Start with a nil buffer; Builder will allocate. + b := dnsmessage.NewBuilder(nil, h) + b.EnableCompression() + + // 1. Questions + if err := b.StartQuestions(); err != nil { + return nil, err + } + for _, q := range qs { + if err := b.Question(q); err != nil { + return nil, err + } + } + + // 2. Answers + if err := b.StartAnswers(); err != nil { + return nil, err + } + for _, r := range ans { + if err := addResource(&b, r); err != nil { + return nil, err + } + } + + // 3. Authorities + if err := b.StartAuthorities(); err != nil { + return nil, err + } + for _, r := range auths { + if err := addResource(&b, r); err != nil { + return nil, err + } + } + + // 4. Additionals + if err := b.StartAdditionals(); err != nil { + return nil, err + } + for _, r := range adds { + if err := addResource(&b, r); err != nil { + return nil, err + } + } + + // Always append the OPT record if it exists (RFC 6891) + if opt != nil { + if err := addResource(&b, *opt); err != nil { + return nil, err + } + } + + // Finish and return the bytes + return b.Finish() +} + +// addResource is a helper to handle the various resource types +// when adding individual resources to the Builder. +func addResource(b *dnsmessage.Builder, r dnsmessage.Resource) error { + switch body := r.Body.(type) { + case *dnsmessage.AResource: + return b.AResource(r.Header, *body) + case *dnsmessage.AAAAResource: + return b.AAAAResource(r.Header, *body) + case *dnsmessage.CNAMEResource: + return b.CNAMEResource(r.Header, *body) + case *dnsmessage.HTTPSResource: + return b.HTTPSResource(r.Header, *body) + case *dnsmessage.NSResource: + return b.NSResource(r.Header, *body) + case *dnsmessage.PTRResource: + return b.PTRResource(r.Header, *body) + case *dnsmessage.SOAResource: + return b.SOAResource(r.Header, *body) + case *dnsmessage.MXResource: + return b.MXResource(r.Header, *body) + case *dnsmessage.TXTResource: + return b.TXTResource(r.Header, *body) + case *dnsmessage.SRVResource: + return b.SRVResource(r.Header, *body) + case *dnsmessage.OPTResource: + return b.OPTResource(r.Header, *body) + case *dnsmessage.UnknownResource: + // Handles unsupported/generic types + return b.UnknownResource(r.Header, *body) + default: + return errors.New("unsupported resource body type") + } +} diff --git a/net/dns/resolver/udp_trunc_test.go b/net/dns/resolver/udp_trunc_test.go new file mode 100644 index 000000000..1fe680275 --- /dev/null +++ b/net/dns/resolver/udp_trunc_test.go @@ -0,0 +1,276 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package resolver + +import ( + "context" + "fmt" + "net/netip" + "testing" + + dns "golang.org/x/net/dns/dnsmessage" + "tailscale.com/types/dnstype" + "tailscale.com/util/dnsname" +) + +// Note: This test file uses helper builders already present in other resolver +// tests (e.g., makeTestRequest/makeTestResponse/dnspacket) since they are in +// the same package test space. + +func TestExtractValidEDNS0UDPSize(t *testing.T) { + q := dnspacket("example.com.", dns.TypeA, 917) + got := extractEDNS0UDPSize(q) + if got != 917 { + t.Fatalf("expected 917, got %v", got) + } +} + +func TestExtractSmallEDNS0UDPSize(t *testing.T) { + q := dnspacket("example.com.", dns.TypeA, 100) + got := extractEDNS0UDPSize(q) + // extractEDNS0UDPSize enforces minimum of 512 per RFC 6891 §6.2.5 + if got != minEDNS0Size { + t.Fatalf("expected %v, got %v", minEDNS0Size, got) + } +} + +func TestExtractLargeEDNS0UDPSize(t *testing.T) { + q := dnspacket("example.com.", dns.TypeA, 5000) + got := extractEDNS0UDPSize(q) + // extractEDNS0UDPSize caps at maxEDNS0Size + if got != maxEDNS0Size { + t.Fatalf("expected %v, got %v", maxEDNS0Size, got) + } +} + +func TestTruncateNonEDNS(t *testing.T) { + // Build a very large response (many A records) without EDNS + // Create response with many answers + name := dns.MustNewName("example.com.") + b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess}) + if err := b.StartQuestions(); err != nil { + t.Fatal(err) + } + if err := b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET}); err != nil { + t.Fatal(err) + } + if err := b.StartAnswers(); err != nil { + t.Fatal(err) + } + // add enough A records to exceed 512 bytes + for i := 0; i < 200; i++ { + b.AResource(dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60}, dns.AResource{A: [4]byte{192, 0, 2, byte(i % 255)}}) + } + resp, err := b.Finish() + if err != nil { + t.Fatal(err) + } + if len(resp) <= 512 { + t.Fatalf("response not large enough for test: %d", len(resp)) + } + + tr, err := truncateDNSResponse(resp, 512) + if err != nil { + t.Fatalf("truncate failed: %v", err) + } + if len(tr) > 512 { + t.Fatalf("truncated response too large: %d", len(tr)) + } + // Check TC bit set + var p dns.Parser + h, err := p.Start(tr) + if err != nil { + t.Fatalf("parse truncated: %v", err) + } + if !h.Truncated { + t.Fatalf("expected Truncated bit set") + } +} + +func TestEDNSAllowsLarger(t *testing.T) { + // Build request that advertises EDNS size 1232 + ednsSize := uint16(1232) + q := dnspacket("example.com.", dns.TypeA, ednsSize) + if got := extractEDNS0UDPSize(q); got != ednsSize { + t.Fatalf("expected 1232, got %v", got) + } + + // Build response of size >512 but <1232 + name := dns.MustNewName("example.com.") + b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess}) + b.EnableCompression() + b.StartQuestions() + b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET}) + b.StartAnswers() + for i := 0; i < 50; i++ { + b.AResource(dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60}, dns.AResource{A: [4]byte{10, 0, 0, byte(i)}}) + } + resp, err := b.Finish() + if err != nil { + t.Fatal(err) + } + if len(resp) <= 512 || len(resp) >= int(ednsSize) { + t.Fatalf("invalid response size %d", len(resp)) + } + + tr, err := truncateDNSResponse(resp, ednsSize) + if err != nil { + t.Fatalf("truncate failed: %v", err) + } + if len(tr) != len(resp) { + t.Fatalf("unexpected truncation when EDNS allows large: %d vs %d", len(tr), len(resp)) + } +} + +// TestTruncateDNSResponseImpossible verifies that truncateDNSResponse +// returns an error when the provided maxSize is too small to even encode +// the header+question portion of the message. +func TestTruncateDNSResponseImpossible(t *testing.T) { + // Build a normal query packet and attempt to truncate it to a very small + // size that cannot contain the header+question. + req := makeTestRequest(t, "example.com.") + if len(req) < 20 { + t.Fatalf("test request unexpectedly small: %d", len(req)) + } + + // Choose a maxSize smaller than the request's header+question length. + // Using 10 bytes is guaranteed to be too small. + if _, err := truncateDNSResponse(req, 10); err == nil { + t.Fatalf("expected error truncating to impossibly small size, got nil") + } +} + +// TestTruncateDNSResponseDirectCall tests truncateDNSResponse with a large +// well-formed DNS response. This directly verifies that +// truncateDNSResponse produces a syntactically valid truncated response +// with the TC bit set. +func TestTruncateDNSResponseDirectCall(t *testing.T) { + const domain = "example.com." + + // Build a very large DNS response (many A records) + name := dns.MustNewName(domain) + b := dns.NewBuilder(nil, dns.Header{Response: true, Authoritative: true, RCode: dns.RCodeSuccess}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + t.Fatal(err) + } + if err := b.Question(dns.Question{Name: name, Type: dns.TypeA, Class: dns.ClassINET}); err != nil { + t.Fatal(err) + } + if err := b.StartAnswers(); err != nil { + t.Fatal(err) + } + // Add enough A records to exceed 512 bytes significantly. + // Each A record is roughly 20 bytes, so 150 records will be ~3000 bytes. + for i := 0; i < 150; i++ { + err := b.AResource( + dns.ResourceHeader{Name: name, Class: dns.ClassINET, TTL: 60}, + dns.AResource{A: [4]byte{10, 0, 0, byte(i % 256)}}, + ) + if err != nil { + t.Fatalf("failed to add A record: %v", err) + } + } + largeResp, err := b.Finish() + if err != nil { + t.Fatalf("failed to build large response: %v", err) + } + + // Verify the response is large enough for truncation. + if len(largeResp) <= 512 { + t.Fatalf("test response not large enough for truncation: %d bytes", len(largeResp)) + } + + tr, err := truncateDNSResponse(largeResp, 512) + if err != nil { + t.Fatalf("truncateDNSResponse failed: %v", err) + } + + // Verify the truncated response: + // 1. Fits within 512 bytes + if len(tr) > 512 { + t.Fatalf("truncated response exceeds 512 bytes: got %d", len(tr)) + } + + // 2. Is syntactically valid + var p dns.Parser + h, err := p.Start(tr) + if err != nil { + t.Fatalf("failed to parse truncated response: %v", err) + } + + // 3. Has TC (Truncated) bit set + if !h.Truncated { + t.Fatalf("expected TC (Truncated) bit to be set in truncated response") + } +} + +// TestResolverSERVFAILOnImpossibleTruncation ensures that when a client +// advertises a tiny EDNS buffer size such that the resolver cannot safely +// encode even the header+question within that size, the resolver returns a +// SERVFAIL response rather than an invalid/truncated packet. +func TestResolverSERVFAILOnImpossibleTruncation(t *testing.T) { + const domain = "srvfail.example.com." + + // Build a request that advertises a very small EDNS size (50 bytes). + // This is small enough to require truncation but large enough for header+question. + request := dnspacket(domain, dns.TypeA, 50) + + // Verify EDNS extraction enforces the RFC 6891 minimum of 512. + ednsSize := extractEDNS0UDPSize(request) + if ednsSize != 512 { + t.Fatalf("EDNS extraction failed: expected 512, got %d", ednsSize) + } + + // Build a very large upstream response for the same domain so that the + // resolver will attempt truncation and fail. + _, largeResponse := makeLargeResponse(t, domain) + + // Run a test DNS server returning the large response. + port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) { + // DNS server received a request; just ensure the server is reachable + }) + + // Configure resolver to forward queries to our server. + r := newResolver(t) + defer r.Close() + cfg := Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{ + dnsname.FQDN("."): {{Addr: fmt.Sprintf("127.0.0.1:%d", port)}}, + }, + } + if err := r.SetConfig(cfg); err != nil { + t.Fatalf("SetConfig: %v", err) + } + + // Query the resolver over UDP with the tiny EDNS size. + ctx := context.Background() + out, err := r.Query(ctx, request, "udp", netip.MustParseAddrPort("127.0.0.1:12345")) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + + // The response should be either: + // 1. A SERVFAIL (if truncation was impossible), or + // 2. A response that fits within the effective EDNS size (512 bytes) with TC bit set. + var p dns.Parser + h, err := p.Start(out) + if err != nil { + t.Fatalf("parse response: %v", err) + } + + if h.RCode == dns.RCodeServerFailure { + // Good - impossible truncation was handled correctly + return + } + + // Otherwise the response must fit within 512 bytes and have TC set. + if len(out) > 512 { + t.Fatalf("expected SERVFAIL or <=512 byte response, got %d bytes with RCode=%v", + len(out), h.RCode) + } + if !h.Truncated { + t.Fatalf("expected TC bit set for truncated response") + } +}